use crate::error::{McpError, McpResult};
use aimdb_client::AimxClient;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::debug;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ListRecordsParams {
socket_path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GetRecordParams {
socket_path: String,
record_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SetRecordParams {
socket_path: String,
record_name: String,
value: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RecordInfo {
name: String,
type_id: String,
buffer_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
buffer_capacity: Option<usize>,
producer_count: usize,
consumer_count: usize,
writable: bool,
created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
last_update: Option<String>,
outbound_connector_count: usize,
}
pub async fn list_records(args: Option<Value>) -> McpResult<Value> {
debug!("📋 list_records called with args: {:?}", args.as_ref());
let params: ListRecordsParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
debug!("🔌 Connecting to {}", params.socket_path);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(¶ms.socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(¶ms.socket_path)
.await
.map_err(McpError::Client)?
};
let records = client.list_records().await.map_err(McpError::Client)?;
debug!("✅ Found {} record(s)", records.len());
let record_infos: Vec<RecordInfo> = records
.into_iter()
.map(|r| RecordInfo {
name: r.name,
type_id: r.type_id,
buffer_type: r.buffer_type,
buffer_capacity: r.buffer_capacity,
producer_count: r.producer_count,
consumer_count: r.consumer_count,
writable: r.writable,
created_at: r.created_at,
last_update: r.last_update,
outbound_connector_count: r.outbound_connector_count,
})
.collect();
serde_json::to_value(record_infos)
.map_err(|e| McpError::Internal(format!("JSON serialization failed: {}", e)))
}
pub async fn get_record(args: Option<Value>) -> McpResult<Value> {
debug!("🔍 get_record called with args: {:?}", args.as_ref());
let params: GetRecordParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
debug!(
"🔌 Connecting to {} to get record '{}'",
params.socket_path, params.record_name
);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(¶ms.socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(¶ms.socket_path)
.await
.map_err(McpError::Client)?
};
let value = client
.get_record(¶ms.record_name)
.await
.map_err(McpError::Client)?;
debug!("✅ Retrieved record '{}'", params.record_name);
Ok(value)
}
pub async fn set_record(args: Option<Value>) -> McpResult<Value> {
debug!("✏️ set_record called with args: {:?}", args.as_ref());
let params: SetRecordParams = serde_json::from_value(args.unwrap_or(Value::Null))
.map_err(|e| McpError::InvalidParams(format!("Invalid parameters: {}", e)))?;
debug!(
"🔌 Connecting to {} to set record '{}'",
params.socket_path, params.record_name
);
let mut client = if let Some(pool) = super::connection_pool() {
pool.get_connection(¶ms.socket_path)
.await
.map_err(McpError::Client)?
} else {
AimxClient::connect(¶ms.socket_path)
.await
.map_err(McpError::Client)?
};
let result = client
.set_record(¶ms.record_name, params.value)
.await
.map_err(McpError::Client)?;
debug!("✅ Updated record '{}'", params.record_name);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_list_records_missing_socket_path() {
let result = list_records(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_list_records_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock"
});
let result = list_records(Some(params)).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message().contains("Failed to connect") || err.message().contains("No such file")
);
}
#[tokio::test]
async fn test_get_record_missing_params() {
let result = get_record(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_get_record_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock",
"record_name": "TestRecord"
});
let result = get_record(Some(params)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_set_record_missing_params() {
let result = set_record(None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message().contains("Invalid parameters"));
}
#[tokio::test]
async fn test_set_record_invalid_socket() {
let params = json!({
"socket_path": "/tmp/nonexistent.sock",
"record_name": "TestRecord",
"value": {"test": "value"}
});
let result = set_record(Some(params)).await;
assert!(result.is_err());
}
}