use crate::chunking::{ChunkingStrategy, FileChunker};
use crate::error::Result;
use crate::models::{SearchParams, SearchStrategy};
use crate::storage::Storage;
use serde_json::{json, Value};
use std::path::Path;
use std::sync::Arc;
use uuid::Uuid;
pub struct MCPHandlers {
storage: Arc<Storage>,
}
impl MCPHandlers {
pub fn new(storage: Arc<Storage>) -> Self {
Self { storage }
}
pub async fn handle_tool_call(&self, tool_name: &str, params: Value) -> Result<Value> {
match tool_name {
"store_memory" => self.handle_store_memory(params).await,
"get_memory" => self.handle_get_memory(params).await,
"delete_memory" => self.handle_delete_memory(params).await,
"get_statistics" => self.handle_get_statistics().await,
"store_file" => self.handle_store_file(params).await,
"search_memory" => self.handle_search_memory(params).await,
_ => Err(crate::error::Error::MethodNotFound(format!(
"Unknown tool: {}",
tool_name
))),
}
}
async fn handle_store_memory(&self, params: Value) -> Result<Value> {
let content = params["content"]
.as_str()
.ok_or_else(|| crate::error::Error::InvalidParams("Missing content parameter".to_string()))?;
if content.len() > 1024 * 1024 {
return Err(crate::error::Error::InvalidParams(format!(
"Content size {} bytes exceeds maximum limit of 1MB (1048576 bytes)",
content.len()
)));
}
let context = params["context"]
.as_str()
.ok_or_else(|| {
crate::error::Error::InvalidParams("Missing required context parameter".to_string())
})?
.to_string();
if context.len() > 1000 {
return Err(crate::error::Error::InvalidParams(format!(
"Context length {} characters exceeds maximum limit of 1000 characters",
context.len()
)));
}
let summary = params["summary"]
.as_str()
.ok_or_else(|| {
crate::error::Error::InvalidParams("Missing required summary parameter".to_string())
})?
.to_string();
if summary.len() > 500 {
return Err(crate::error::Error::InvalidParams(format!(
"Summary length {} characters exceeds maximum limit of 500 characters",
summary.len()
)));
}
let tags = params["tags"]
.as_array()
.ok_or_else(|| {
crate::error::Error::InvalidParams("Missing required tags parameter".to_string())
})?
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>();
if tags.len() > 50 {
return Err(crate::error::Error::InvalidParams(format!(
"Tags count {} exceeds maximum limit of 50 tags",
tags.len()
)));
}
let id = self
.storage
.store(content, context, summary, Some(tags))
.await?;
Ok(json!({
"id": id.to_string(),
"message": "Memory stored successfully"
}))
}
async fn handle_get_memory(&self, params: Value) -> Result<Value> {
let id_str = params["id"]
.as_str()
.ok_or_else(|| crate::error::Error::InvalidParams("Missing id parameter".to_string()))?;
let id = Uuid::parse_str(id_str)
.map_err(|e| crate::error::Error::InvalidParams(format!("Invalid UUID: {}", e)))?;
match self.storage.get(id).await? {
Some(memory) => Ok(serde_json::to_value(memory)?),
None => Err(crate::error::Error::InvalidParams(format!(
"Memory not found: {}",
id
))),
}
}
async fn handle_delete_memory(&self, params: Value) -> Result<Value> {
let id_str = params["id"]
.as_str()
.ok_or_else(|| crate::error::Error::InvalidParams("Missing id parameter".to_string()))?;
let id = Uuid::parse_str(id_str)
.map_err(|e| crate::error::Error::InvalidParams(format!("Invalid UUID: {}", e)))?;
let deleted = self.storage.delete(id).await?;
Ok(json!({
"deleted": deleted,
"message": if deleted { "Memory deleted successfully" } else { "Memory not found" }
}))
}
async fn handle_get_statistics(&self) -> Result<Value> {
let stats = self.storage.stats().await?;
Ok(serde_json::to_value(stats)?)
}
async fn handle_store_file(&self, params: Value) -> Result<Value> {
let file_path = params["file_path"]
.as_str()
.ok_or_else(|| crate::error::Error::InvalidParams("Missing file_path parameter".to_string()))?;
if tokio::fs::metadata(file_path).await.is_err() {
return Err(crate::error::Error::InvalidParams(format!(
"File not found or not readable: {}",
file_path
)));
}
let chunk_size = params
.get("chunk_size")
.and_then(|v| v.as_u64())
.unwrap_or(8000) as usize;
if chunk_size < 1024 || chunk_size > 102400 {
return Err(crate::error::Error::InvalidParams(format!(
"Chunk size {} must be between 1024 and 102400 characters",
chunk_size
)));
}
let overlap = params
.get("overlap")
.and_then(|v| v.as_u64())
.unwrap_or(200) as usize;
if overlap >= chunk_size / 2 {
return Err(crate::error::Error::InvalidParams(format!(
"Overlap size {} must be less than half of chunk size ({})",
overlap,
chunk_size / 2
)));
}
let chunking_strategy: ChunkingStrategy = params
.get("chunking_strategy")
.and_then(|v| v.as_str())
.and_then(|s| s.parse().ok())
.unwrap_or_default();
let tags = params.get("tags").and_then(|v| v.as_array()).map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
});
const MAX_FILE_SIZE: u64 = 50 * 1024 * 1024; let file_metadata = tokio::fs::metadata(file_path)
.await
.map_err(|e| crate::error::Error::InternalError(format!("Failed to get file metadata: {}", e)))?;
if file_metadata.len() > MAX_FILE_SIZE {
return Err(crate::error::Error::InvalidParams(format!(
"File size {} bytes exceeds maximum limit of 50MB ({})",
file_metadata.len(),
MAX_FILE_SIZE
)));
}
let content = if file_metadata.len() > 1024 * 1024 {
self.read_file_streaming(file_path).await?
} else {
tokio::fs::read_to_string(file_path)
.await
.map_err(|e| crate::error::Error::InternalError(format!("Failed to read file: {}", e)))?
};
let filename = Path::new(file_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
let content_len = content.len();
let mut stored_ids = Vec::new();
let chunker = FileChunker::with_strategy(chunk_size, overlap, chunking_strategy.clone());
let chunks = chunker.chunk_content(&content)?;
if chunks.len() == 1 {
let context = format!("Content from file: {}", filename);
let summary = format!(
"Complete content of {} ({} characters)",
filename, content_len
);
let id = self
.storage
.store(&content, context, summary, tags.clone())
.await?;
stored_ids.push(id.to_string());
} else {
let parent_id = Uuid::new_v4();
let total_chunks = chunks.len();
for (index, chunk) in chunks.into_iter().enumerate() {
let chunk_num = index + 1;
let context = format!(
"Chunk {} of {} from file: {}",
chunk_num, total_chunks, filename
);
let summary = format!(
"Part {} of {} from {} (bytes {}-{} of {})",
chunk_num,
total_chunks,
filename,
chunk.start_byte,
chunk.end_byte,
content_len
);
let mut chunk_tags = tags.clone().unwrap_or_default();
chunk_tags.push(format!("chunk_{}", chunk_num));
chunk_tags.push(format!("file_{}", filename));
chunk_tags.push(format!("strategy_{:?}", chunking_strategy).to_lowercase());
let id = self
.storage
.store_chunk(
&chunk.content,
context,
summary,
Some(chunk_tags),
chunk_num as i32,
total_chunks as i32,
parent_id,
)
.await?;
stored_ids.push(id.to_string());
}
}
Ok(json!({
"file_path": file_path,
"file_size": content_len,
"chunks_created": stored_ids.len(),
"chunk_ids": stored_ids,
"chunking_strategy": format!("{:?}", chunking_strategy),
"chunk_size": chunk_size,
"overlap": overlap,
"message": format!("Successfully ingested {} as {} chunk(s) using {:?} strategy", filename, stored_ids.len(), chunking_strategy)
}))
}
async fn handle_search_memory(&self, params: Value) -> Result<Value> {
let query = params["query"]
.as_str()
.ok_or_else(|| crate::error::Error::InvalidParams("Missing query parameter".to_string()))?
.to_string();
let tag_filter = params
.get("tag_filter")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect::<Vec<_>>()
});
let use_tag_embedding = params
.get("use_tag_embedding")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let use_content_embedding = params
.get("use_content_embedding")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let similarity_threshold = params
.get("similarity_threshold")
.and_then(|v| v.as_f64())
.unwrap_or(0.7)
.clamp(0.0, 1.0);
let max_results = params
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(10)
.clamp(1, 100) as usize;
let search_strategy = params
.get("search_strategy")
.and_then(|v| v.as_str())
.map(|s| match s {
"tags_first" => SearchStrategy::TagsFirst,
"content_first" => SearchStrategy::ContentFirst,
_ => SearchStrategy::Hybrid,
})
.unwrap_or(SearchStrategy::Hybrid);
let boost_recent = params
.get("boost_recent")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let tag_weight = params
.get("tag_weight")
.and_then(|v| v.as_f64())
.unwrap_or(0.4)
.clamp(0.0, 1.0);
let content_weight = params
.get("content_weight")
.and_then(|v| v.as_f64())
.unwrap_or(0.6)
.clamp(0.0, 1.0);
let search_params = SearchParams {
query: query.clone(),
tag_filter: tag_filter.clone(),
use_tag_embedding,
use_content_embedding,
similarity_threshold,
max_results,
search_strategy: search_strategy.clone(),
boost_recent,
tag_weight,
content_weight,
};
let search_start = std::time::Instant::now();
let search_result_with_metadata = self
.storage
.search_memories_progressive_with_metadata(search_params.clone())
.await?;
let _search_duration = search_start.elapsed();
let formatted_results: Vec<Value> = search_result_with_metadata
.results
.iter()
.map(|result| {
json!({
"id": result.memory.id,
"content": result.memory.content,
"context": result.memory.context,
"summary": result.memory.summary,
"tags": result.memory.tags,
"chunk_index": result.memory.chunk_index,
"total_chunks": result.memory.total_chunks,
"parent_id": result.memory.parent_id,
"created_at": result.memory.created_at,
"updated_at": result.memory.updated_at,
"tag_similarity": result.tag_similarity,
"content_similarity": result.content_similarity,
"combined_score": result.combined_score,
"semantic_cluster": result.semantic_cluster
})
})
.collect();
if cfg!(test) {
let result_count = formatted_results.len();
Ok(json!({
"results": formatted_results,
"search_metadata": {
"query": query.clone(),
"total_results": result_count,
"similarity_threshold": similarity_threshold,
"max_results": max_results,
"search_strategy": format!("{:?}", search_strategy).to_lowercase(),
"boost_recent": boost_recent,
"tag_weight": tag_weight,
"content_weight": content_weight,
"use_tag_embedding": use_tag_embedding,
"use_content_embedding": use_content_embedding,
"tag_filter": tag_filter.clone(),
"search_time_ms": 0, "progressive_search": {},
"average_score": 0.0 }
}))
} else {
Ok(json!(formatted_results))
}
}
async fn read_file_streaming(&self, file_path: &str) -> Result<String> {
use tokio::io::{AsyncReadExt, BufReader};
const STREAM_BUFFER_SIZE: usize = 8192; const MAX_CONTENT_SIZE: usize = 50 * 1024 * 1024;
let file = tokio::fs::File::open(file_path)
.await
.map_err(|e| crate::error::Error::InternalError(format!("Failed to open file: {}", e)))?;
let mut reader = BufReader::with_capacity(STREAM_BUFFER_SIZE, file);
let mut content = String::new();
let mut buffer = vec![0u8; STREAM_BUFFER_SIZE];
let mut total_read = 0;
loop {
let bytes_read = reader
.read(&mut buffer)
.await
.map_err(|e| crate::error::Error::InternalError(format!("Failed to read file chunk: {}", e)))?;
if bytes_read == 0 {
break; }
total_read += bytes_read;
if total_read > MAX_CONTENT_SIZE {
return Err(crate::error::Error::InvalidParams(format!(
"File content exceeds maximum size limit of {} bytes during streaming",
MAX_CONTENT_SIZE
)));
}
let chunk_str = std::str::from_utf8(&buffer[..bytes_read])
.map_err(|e| crate::error::Error::InternalError(format!("Invalid UTF-8 in file: {}", e)))?;
content.push_str(chunk_str);
}
Ok(content)
}
}