use crate::error::{Error, Result};
#[cfg(feature = "memory")]
use crate::mcp::tools::{Tool, ToolHandler, ToolResult};
#[cfg(feature = "memory")]
use crate::retrieval::HybridRetriever;
#[cfg(feature = "memory")]
use async_trait::async_trait;
#[cfg(feature = "memory")]
use serde_json::json;
use serde_json::Value;
use std::collections::HashMap;
#[cfg(feature = "memory")]
use std::sync::Arc;
#[cfg(feature = "memory")]
use tracing::{debug, error, info, instrument};
#[cfg(feature = "memory")]
use uuid::Uuid;
#[cfg(feature = "memory")]
#[derive(Clone)]
pub struct DocsetHandler {
retriever: Arc<HybridRetriever>,
}
#[cfg(feature = "memory")]
impl DocsetHandler {
pub fn new(retriever: Arc<HybridRetriever>) -> Self {
Self { retriever }
}
}
#[cfg(feature = "memory")]
impl DocsetHandler {
pub fn tool_definitions() -> Vec<Tool> {
vec![
Self::get_chunk_tool(),
Self::get_neighbors_tool(),
Self::search_tool(),
Self::ingest_tool(),
Self::list_tool(),
]
}
fn get_chunk_tool() -> Tool {
Tool::with_schema(
"rkmem_docs_get_chunk",
"Retrieve a specific document chunk by ID. Use this when you need the full content \
of a specific chunk identified in search results or references.",
json!({
"type": "object",
"properties": {
"doc_id": {
"type": "string",
"description": "Document UUID to retrieve chunk from"
},
"chunk_id": {
"type": "string",
"description": "Specific chunk UUID within the document"
}
},
"required": ["doc_id", "chunk_id"],
"additionalProperties": false
}),
)
}
fn get_neighbors_tool() -> Tool {
Tool::with_schema(
"rkmem_docs_get_neighbors",
"Get neighboring chunks for context around a specific chunk. Use this to get \
surrounding context when analyzing a specific chunk in detail.",
json!({
"type": "object",
"properties": {
"doc_id": {
"type": "string",
"description": "Document UUID to retrieve neighbors from"
},
"chunk_id": {
"type": "string",
"description": "Reference chunk UUID to get neighbors for"
},
"count": {
"type": "integer",
"description": "Number of neighbors to retrieve (default: 3)",
"minimum": 1,
"maximum": 10,
"default": 3
}
},
"required": ["doc_id", "chunk_id"],
"additionalProperties": false
}),
)
}
fn search_tool() -> Tool {
Tool::with_schema(
"rkmem_docs_search",
"Search documents using hybrid retrieval (semantic + keyword). Use this to find \
relevant documents and chunks for a query across the entire docset.",
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query to find relevant documents"
},
"top_k": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)",
"minimum": 1,
"maximum": 50,
"default": 10
},
"min_score": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0.0)",
"minimum": 0.0,
"maximum": 1.0,
"default": 0.0
}
},
"required": ["query"],
"additionalProperties": false
}),
)
}
fn ingest_tool() -> Tool {
Tool::with_schema(
"rkmem_docs_ingest",
"Ingest new documents into the docset. Use this to add new content for retrieval \
and analysis. Supports various document formats and sources.",
json!({
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Document content to ingest"
},
"title": {
"type": "string",
"description": "Document title"
},
"source_url": {
"type": "string",
"description": "Source URL (optional)"
},
"tags": {
"type": "array",
"items": { "type": "string" },
"description": "Tags for categorization (optional)"
}
},
"required": ["content", "title"],
"additionalProperties": false
}),
)
}
fn list_tool() -> Tool {
Tool::with_schema(
"rkmem_docs_list",
"List available documents in the docset. Use this to see what content is available \
for search and retrieval.",
json!({
"type": "object",
"properties": {
"limit": {
"type": "integer",
"description": "Maximum number of documents to list (default: 20)",
"minimum": 1,
"maximum": 100,
"default": 20
},
"tag_filter": {
"type": "string",
"description": "Filter by tag (optional)"
}
},
"additionalProperties": false
}),
)
}
}
#[cfg(feature = "memory")]
impl DocsetHandler {
#[instrument(skip(self, arguments), fields(tool = %name))]
pub async fn call_tool(
&self,
name: &str,
arguments: HashMap<String, Value>,
) -> Result<ToolResult> {
info!(tool = %name, "Executing docset tool");
match name {
"rkmem_docs_get_chunk" => self.handle_get_chunk(arguments).await,
"rkmem_docs_get_neighbors" => self.handle_get_neighbors(arguments).await,
"rkmem_docs_search" => self.handle_search(arguments).await,
"rkmem_docs_ingest" => self.handle_ingest(arguments).await,
"rkmem_docs_list" => self.handle_list(arguments).await,
_ => {
error!(tool = %name, "Unknown tool requested");
Ok(ToolResult::error(format!("Unknown docset tool: {}", name)))
}
}
}
async fn handle_get_chunk(&self, args: HashMap<String, Value>) -> Result<ToolResult> {
let doc_id_str = match extract_required_string(&args, "doc_id") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let chunk_id_str = match extract_required_string(&args, "chunk_id") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
debug!(doc_id = %doc_id_str, chunk_id = %chunk_id_str, "Retrieving chunk");
let doc_id = match Uuid::parse_str(&doc_id_str) {
Ok(value) => value,
Err(_) => return Ok(ToolResult::error("Invalid document ID format")),
};
let chunk_id = match Uuid::parse_str(&chunk_id_str) {
Ok(value) => value,
Err(_) => return Ok(ToolResult::error("Invalid chunk ID format")),
};
let context = reasonkit_mem::storage::AccessContext::new(
"mcp".to_string(),
reasonkit_mem::storage::AccessLevel::Read,
"get_chunk".to_string(),
);
match self
.retriever
.storage()
.get_document(&doc_id, &context)
.await
{
Ok(Some(document)) => {
if let Some(chunk) = document.chunks.iter().find(|c| c.id == chunk_id) {
let result = json!({
"doc_id": doc_id.to_string(),
"chunk_id": chunk_id.to_string(),
"text": chunk.text,
"section": chunk.section,
"start_char": chunk.start_char,
"end_char": chunk.end_char,
"token_count": chunk.token_count,
"page": chunk.page,
"embedding_ids": chunk.embedding_ids,
"document_title": document.metadata.title,
"document_url": document.source.url,
});
Ok(ToolResult::text(serde_json::to_string_pretty(&result)?))
} else {
Ok(ToolResult::error(format!(
"Chunk {} not found in document {}",
chunk_id, doc_id
)))
}
}
Ok(None) => Ok(ToolResult::error(format!("Document {} not found", doc_id))),
Err(e) => Ok(ToolResult::error(format!(
"Failed to retrieve document {}: {}",
doc_id, e
))),
}
}
async fn handle_get_neighbors(&self, args: HashMap<String, Value>) -> Result<ToolResult> {
let doc_id_str = match extract_required_string(&args, "doc_id") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let chunk_id_str = match extract_required_string(&args, "chunk_id") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let count = args
.get("count")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(3);
debug!(doc_id = %doc_id_str, chunk_id = %chunk_id_str, count = count, "Retrieving neighbors");
let doc_id = match Uuid::parse_str(&doc_id_str) {
Ok(value) => value,
Err(_) => return Ok(ToolResult::error("Invalid document ID format")),
};
let chunk_id = match Uuid::parse_str(&chunk_id_str) {
Ok(value) => value,
Err(_) => return Ok(ToolResult::error("Invalid chunk ID format")),
};
let context = reasonkit_mem::storage::AccessContext::new(
"mcp".to_string(),
reasonkit_mem::storage::AccessLevel::Read,
"get_neighbors".to_string(),
);
match self
.retriever
.storage()
.get_document(&doc_id, &context)
.await
{
Ok(Some(document)) => {
if let Some(chunk_index) = document.chunks.iter().position(|c| c.id == chunk_id) {
let start = chunk_index.saturating_sub(count / 2);
let end = (chunk_index + count / 2 + 1).min(document.chunks.len());
let neighbors: Vec<_> = document.chunks[start..end]
.iter()
.map(|chunk| {
json!({
"chunk_id": chunk.id.to_string(),
"text": chunk.text,
"section": chunk.section,
"index": chunk.index,
"start_char": chunk.start_char,
"end_char": chunk.end_char,
})
})
.collect();
let result = json!({
"doc_id": doc_id.to_string(),
"reference_chunk_id": chunk_id.to_string(),
"neighbors": neighbors,
"total_neighbors": neighbors.len(),
"document_title": document.metadata.title,
});
Ok(ToolResult::text(serde_json::to_string_pretty(&result)?))
} else {
Ok(ToolResult::error(format!(
"Chunk {} not found in document {}",
chunk_id, doc_id
)))
}
}
Ok(None) => Ok(ToolResult::error(format!("Document {} not found", doc_id))),
Err(e) => Ok(ToolResult::error(format!(
"Failed to retrieve document {}: {}",
doc_id, e
))),
}
}
async fn handle_search(&self, args: HashMap<String, Value>) -> Result<ToolResult> {
let query = match extract_required_string(&args, "query") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let top_k = args
.get("top_k")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(10);
let min_score = args
.get("min_score")
.and_then(|v| v.as_f64())
.unwrap_or(0.0) as f32;
debug!(query = %query, top_k = top_k, min_score = min_score, "Performing search");
match self.retriever.search(&query, top_k).await {
Ok(results) => {
let formatted_results: Vec<Value> = results
.into_iter()
.map(|result| {
json!({
"score": result.score,
"doc_id": result.doc_id.to_string(),
"chunk_id": result.chunk_id.to_string(),
"text_preview": result.text.chars().take(200).collect::<String>(),
"section": "", "match_source": format!("{:?}", result.match_source),
})
})
.collect();
let result = json!({
"query": query,
"results": formatted_results,
"total_results": formatted_results.len(),
});
Ok(ToolResult::text(serde_json::to_string_pretty(&result)?))
}
Err(e) => Ok(ToolResult::error(format!("Search failed: {}", e))),
}
}
async fn handle_ingest(&self, args: HashMap<String, Value>) -> Result<ToolResult> {
let content = match extract_required_string(&args, "content") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let title = match extract_required_string(&args, "title") {
Ok(value) => value,
Err(e) => return Ok(ToolResult::error(e.to_string())),
};
let source_url = args
.get("source_url")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let tags: Vec<String> = args
.get("tags")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
debug!(title = %title, "Ingesting new document");
let source = reasonkit_mem::Source {
source_type: reasonkit_mem::SourceType::Api,
url: source_url,
path: None,
arxiv_id: None,
github_repo: None,
retrieved_at: chrono::Utc::now(),
version: None,
};
let mut document =
reasonkit_mem::Document::new(reasonkit_mem::DocumentType::Documentation, source);
document.metadata.title = Some(title);
document.metadata.tags = tags;
document.content = reasonkit_mem::DocumentContent {
raw: content.clone(),
format: reasonkit_mem::ContentFormat::Text,
language: "en".to_string(),
word_count: content.split_whitespace().count(),
char_count: content.len(),
};
document.chunks = vec![reasonkit_mem::Chunk {
id: uuid::Uuid::new_v4(),
text: content.clone(),
index: 0,
start_char: 0,
end_char: content.len(),
token_count: None,
section: None,
page: None,
embedding_ids: reasonkit_mem::EmbeddingIds::default(),
}];
match self.retriever.add_document(&document).await {
Ok(_) => {
let result = json!({
"status": "success",
"doc_id": document.id.to_string(),
"title": document.metadata.title,
"chunk_count": document.chunks.len(),
"word_count": document.content.word_count,
});
Ok(ToolResult::text(serde_json::to_string_pretty(&result)?))
}
Err(e) => Ok(ToolResult::error(format!("Ingestion failed: {}", e))),
}
}
async fn handle_list(&self, args: HashMap<String, Value>) -> Result<ToolResult> {
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.map(|v| v as usize)
.unwrap_or(20);
let tag_filter = args
.get("tag_filter")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
debug!(limit = limit, tag_filter = ?tag_filter, "Listing documents");
let result = json!({
"status": "not_implemented",
"message": "Document listing not yet implemented in this MCP tool",
"limit": limit,
"tag_filter": tag_filter,
});
Ok(ToolResult::text(serde_json::to_string_pretty(&result)?))
}
}
#[async_trait]
#[cfg(feature = "memory")]
impl ToolHandler for DocsetHandler {
async fn call(&self, arguments: HashMap<String, Value>) -> Result<ToolResult> {
let tool_name = arguments
.get("_tool")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| Error::Mcp("Missing _tool identifier in arguments".into()))?;
self.call_tool(&tool_name, arguments).await
}
}
#[allow(dead_code)]
fn extract_required_string(args: &HashMap<String, Value>, key: &str) -> Result<String> {
args.get(key)
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| Error::Mcp(format!("Missing required argument: {}", key)))
}
#[cfg(all(feature = "memory", feature = "mcp-server-pro"))]
pub async fn register_docset_tools<T: crate::mcp::McpServerTrait + ?Sized>(
server: &T,
retriever: std::sync::Arc<HybridRetriever>,
) {
let handler = Arc::new(DocsetHandler::new(retriever));
for tool in DocsetHandler::tool_definitions() {
server.register_tool(tool, handler.clone()).await;
}
tracing::info!("Registered 5 docset tools: get_chunk, get_neighbors, search, ingest, list");
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_handler_creation() {
assert!(true);
}
#[test]
#[cfg(feature = "memory")]
fn test_tool_definitions_count() {
let tools = DocsetHandler::tool_definitions();
assert_eq!(tools.len(), 5, "Should have 5 docset tool definitions");
}
#[test]
#[cfg(feature = "memory")]
fn test_tool_definitions_names() {
let tools = DocsetHandler::tool_definitions();
let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(names.contains(&"rkmem_docs_get_chunk"));
assert!(names.contains(&"rkmem_docs_get_neighbors"));
assert!(names.contains(&"rkmem_docs_search"));
assert!(names.contains(&"rkmem_docs_ingest"));
assert!(names.contains(&"rkmem_docs_list"));
}
#[test]
#[cfg(feature = "memory")]
fn test_tool_definitions_have_descriptions() {
let tools = DocsetHandler::tool_definitions();
for tool in tools {
assert!(
tool.description.is_some(),
"Tool {} should have description",
tool.name
);
}
}
#[test]
fn test_extract_required_string_success() {
let mut args = HashMap::new();
args.insert("doc_id".to_string(), json!("test-id"));
let result = extract_required_string(&args, "doc_id");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "test-id");
}
#[test]
fn test_extract_required_string_missing() {
let args = HashMap::new();
let result = extract_required_string(&args, "doc_id");
assert!(result.is_err());
}
}