use crate::error::{McpError, McpResult};
use aimdb_client::AimxClient;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ListRecordsParams {
socket_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GetRecordParams {
socket_path: Option<String>,
record_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SetRecordParams {
socket_path: Option<String>,
record_name: String,
value: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct DrainRecordParams {
socket_path: Option<String>,
record_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
limit: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RecordInfo {
record_key: String,
name: String,
type_id: String,
buffer_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
buffer_capacity: Option<usize>,
producer_count: usize,
consumer_count: usize,
writable: bool,
created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
last_update: Option<String>,
outbound_connector_count: usize,
}
pub async fn list_records(args: Option<Value>) -> McpResult<Value> {
debug!("📋 list_records called with args: {:?}", args.as_ref());
let params: ListRecordsParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
let socket_path = super::resolve_socket_path(params.socket_path)?;
debug!("🔌 Connecting to {}", socket_path);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(&socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(&socket_path)
.await
.map_err(McpError::Client)?
};
let records = client.list_records().await.map_err(McpError::Client)?;
debug!("✅ Found {} record(s)", records.len());
let record_infos: Vec<RecordInfo> = records
.into_iter()
.map(|r| RecordInfo {
record_key: r.record_key,
name: r.name,
type_id: r.type_id,
buffer_type: r.buffer_type,
buffer_capacity: r.buffer_capacity,
producer_count: r.producer_count,
consumer_count: r.consumer_count,
writable: r.writable,
created_at: r.created_at,
last_update: r.last_update,
outbound_connector_count: r.outbound_connector_count,
})
.collect();
serde_json::to_value(record_infos)
.map_err(|e| McpError::Internal(format!("JSON serialization failed: {}", e)))
}
pub async fn get_record(args: Option<Value>) -> McpResult<Value> {
debug!("🔍 get_record called with args: {:?}", args.as_ref());
let params: GetRecordParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
let socket_path = super::resolve_socket_path(params.socket_path)?;
debug!(
"🔌 Connecting to {} to get record '{}'",
socket_path, params.record_name
);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(&socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(&socket_path)
.await
.map_err(McpError::Client)?
};
let value = client
.get_record(¶ms.record_name)
.await
.map_err(McpError::Client)?;
debug!("✅ Retrieved record '{}'", params.record_name);
Ok(value)
}
pub async fn set_record(args: Option<Value>) -> McpResult<Value> {
debug!("✏️ set_record called with args: {:?}", args.as_ref());
let params: SetRecordParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
let socket_path = super::resolve_socket_path(params.socket_path)?;
debug!(
"🔌 Connecting to {} to set record '{}'",
socket_path, params.record_name
);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(&socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(&socket_path)
.await
.map_err(McpError::Client)?
};
let result = client
.set_record(¶ms.record_name, params.value)
.await
.map_err(McpError::Client)?;
debug!("✅ Updated record '{}'", params.record_name);
Ok(result)
}
pub async fn drain_record(args: Option<Value>) -> McpResult<Value> {
debug!("🔄 drain_record called with args: {:?}", args.as_ref());
let params: DrainRecordParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
let socket_path = super::resolve_socket_path(params.socket_path)?;
debug!(
"🔌 Connecting to {} to drain record '{}'",
socket_path, params.record_name
);
let pool = super::connection_pool()
.ok_or_else(|| McpError::Internal("Connection pool not initialized".to_string()))?;
let client_arc = pool
.get_drain_client(&socket_path)
.await
.map_err(McpError::Client)?;
let mut client = client_arc.lock().await;
let response = match params.limit {
Some(limit) => client
.drain_record_with_limit(¶ms.record_name, limit)
.await
.map_err(|e| {
let socket = socket_path.clone();
let pool = pool.clone();
tokio::spawn(async move { pool.invalidate_drain_client(&socket).await });
McpError::Client(e)
})?,
None => client
.drain_record(¶ms.record_name)
.await
.map_err(|e| {
let socket = socket_path.clone();
let pool = pool.clone();
tokio::spawn(async move { pool.invalidate_drain_client(&socket).await });
McpError::Client(e)
})?,
};
debug!(
"✅ Drained {} values from record '{}'",
response.count, params.record_name
);
serde_json::to_value(response)
.map_err(|e| McpError::Internal(format!("JSON serialization failed: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_list_records_missing_socket_path() {
let result = list_records(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_list_records_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock"
});
let result = list_records(Some(params)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message().contains("Failed to connect") || err.message().contains("No such file")
);
}
#[tokio::test]
async fn test_get_record_missing_params() {
let result = get_record(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_get_record_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock",
"record_name": "TestRecord"
});
let result = get_record(Some(params)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_set_record_missing_params() {
let result = set_record(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_set_record_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock",
"record_name": "TestRecord",
"value": {"test": "value"}
});
let result = set_record(Some(params)).await;
assert!(result.is_err());
}
}