use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::tokensave::TokenSave;
use crate::context::format_context_as_markdown;
use crate::errors::{TokenSaveError, Result};
use crate::types::BuildContextOptions;
const MAX_RESPONSE_CHARS: usize = 15_000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
#[serde(rename = "inputSchema")]
pub input_schema: Value,
}
pub fn get_tool_definitions() -> Vec<ToolDefinition> {
vec![
ToolDefinition {
name: "tokensave_search".to_string(),
description: "Search for symbols (functions, structs, traits, etc.) in the code graph by name or keyword.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query string to match against symbol names"
},
"limit": {
"type": "number",
"description": "Maximum number of results to return (default: 10)"
}
},
"required": ["query"]
}),
},
ToolDefinition {
name: "tokensave_context".to_string(),
description: "Build an AI-ready context for a task description. Returns relevant symbols, relationships, and code snippets.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"task": {
"type": "string",
"description": "Natural language description of the task or question"
},
"max_nodes": {
"type": "number",
"description": "Maximum number of symbols to include (default: 20)"
}
},
"required": ["task"]
}),
},
ToolDefinition {
name: "tokensave_callers".to_string(),
description: "Find all callers of a given node (function, method, etc.) up to a specified depth.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"node_id": {
"type": "string",
"description": "The unique node ID to find callers for"
},
"max_depth": {
"type": "number",
"description": "Maximum traversal depth (default: 3)"
}
},
"required": ["node_id"]
}),
},
ToolDefinition {
name: "tokensave_callees".to_string(),
description: "Find all callees of a given node (function, method, etc.) up to a specified depth.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"node_id": {
"type": "string",
"description": "The unique node ID to find callees for"
},
"max_depth": {
"type": "number",
"description": "Maximum traversal depth (default: 3)"
}
},
"required": ["node_id"]
}),
},
ToolDefinition {
name: "tokensave_impact".to_string(),
description: "Compute the impact radius of a node: all symbols that directly or indirectly depend on it.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"node_id": {
"type": "string",
"description": "The unique node ID to compute impact for"
},
"max_depth": {
"type": "number",
"description": "Maximum traversal depth (default: 3)"
}
},
"required": ["node_id"]
}),
},
ToolDefinition {
name: "tokensave_node".to_string(),
description: "Retrieve detailed information about a single node by its ID.".to_string(),
input_schema: json!({
"type": "object",
"properties": {
"node_id": {
"type": "string",
"description": "The unique node ID to retrieve"
}
},
"required": ["node_id"]
}),
},
ToolDefinition {
name: "tokensave_status".to_string(),
description: "Return aggregate statistics about the code graph (node/edge/file counts, DB size, etc.).".to_string(),
input_schema: json!({
"type": "object",
"properties": {}
}),
},
]
}
pub struct ToolResult {
pub value: Value,
pub touched_files: Vec<String>,
}
pub async fn handle_tool_call(
cg: &TokenSave,
tool_name: &str,
args: Value,
server_stats: Option<Value>,
) -> Result<ToolResult> {
match tool_name {
"tokensave_search" => handle_search(cg, args).await,
"tokensave_context" => handle_context(cg, args).await,
"tokensave_callers" => handle_callers(cg, args).await,
"tokensave_callees" => handle_callees(cg, args).await,
"tokensave_impact" => handle_impact(cg, args).await,
"tokensave_node" => handle_node(cg, args).await,
"tokensave_status" => handle_status(cg, server_stats).await,
_ => Err(TokenSaveError::Config {
message: format!("unknown tool: {}", tool_name),
}),
}
}
fn unique_file_paths<'a>(paths: impl Iterator<Item = &'a str>) -> Vec<String> {
let mut seen = HashSet::new();
let mut result = Vec::new();
for p in paths {
if seen.insert(p) {
result.push(p.to_string());
}
}
result
}
fn truncate_response(s: &str) -> String {
if s.len() <= MAX_RESPONSE_CHARS {
s.to_string()
} else {
let mut end = MAX_RESPONSE_CHARS;
while !s.is_char_boundary(end) && end > 0 {
end -= 1;
}
format!("{}\n\n[... truncated at {} chars]", &s[..end], end)
}
}
async fn handle_search(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let query =
args.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: query".to_string(),
})?;
let limit = args
.get("limit")
.and_then(|v| v.as_u64())
.map(|v| v.min(500) as usize)
.unwrap_or(10);
let results = cg.search(query, limit).await?;
let touched_files = unique_file_paths(results.iter().map(|r| r.node.file_path.as_str()));
let items: Vec<Value> = results
.iter()
.map(|r| {
json!({
"id": r.node.id,
"name": r.node.name,
"kind": r.node.kind.as_str(),
"file": r.node.file_path,
"line": r.node.start_line,
"signature": r.node.signature,
"score": r.score,
})
})
.collect();
let output = serde_json::to_string_pretty(&items).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&output) }]
}),
touched_files,
})
}
async fn handle_context(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let task = args
.get("task")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: task".to_string(),
})?;
let max_nodes = args
.get("max_nodes")
.and_then(|v| v.as_u64())
.map(|v| v.min(100) as usize)
.unwrap_or(20);
let options = BuildContextOptions {
max_nodes,
..Default::default()
};
let context = cg.build_context(task, &options).await?;
let touched_files = unique_file_paths(
context
.subgraph
.nodes
.iter()
.map(|n| n.file_path.as_str())
.chain(context.related_files.iter().map(|s| s.as_str())),
);
let output = format_context_as_markdown(&context);
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&output) }]
}),
touched_files,
})
}
async fn handle_callers(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let node_id = args
.get("node_id")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: node_id".to_string(),
})?;
let max_depth = args
.get("max_depth")
.and_then(|v| v.as_u64())
.map(|v| v.min(10) as usize)
.unwrap_or(3);
let results = cg.get_callers(node_id, max_depth).await?;
let touched_files = unique_file_paths(results.iter().map(|(n, _)| n.file_path.as_str()));
let items: Vec<Value> = results
.iter()
.map(|(node, edge)| {
json!({
"node_id": node.id,
"name": node.name,
"kind": node.kind.as_str(),
"file": node.file_path,
"line": node.start_line,
"edge_kind": edge.kind.as_str(),
})
})
.collect();
let output = serde_json::to_string_pretty(&items).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&output) }]
}),
touched_files,
})
}
async fn handle_callees(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let node_id = args
.get("node_id")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: node_id".to_string(),
})?;
let max_depth = args
.get("max_depth")
.and_then(|v| v.as_u64())
.map(|v| v.min(10) as usize)
.unwrap_or(3);
let results = cg.get_callees(node_id, max_depth).await?;
let touched_files = unique_file_paths(results.iter().map(|(n, _)| n.file_path.as_str()));
let items: Vec<Value> = results
.iter()
.map(|(node, edge)| {
json!({
"node_id": node.id,
"name": node.name,
"kind": node.kind.as_str(),
"file": node.file_path,
"line": node.start_line,
"edge_kind": edge.kind.as_str(),
})
})
.collect();
let output = serde_json::to_string_pretty(&items).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&output) }]
}),
touched_files,
})
}
async fn handle_impact(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let node_id = args
.get("node_id")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: node_id".to_string(),
})?;
let max_depth = args
.get("max_depth")
.and_then(|v| v.as_u64())
.map(|v| v.min(10) as usize)
.unwrap_or(3);
let subgraph = cg.get_impact_radius(node_id, max_depth).await?;
let touched_files = unique_file_paths(subgraph.nodes.iter().map(|n| n.file_path.as_str()));
let nodes: Vec<Value> = subgraph
.nodes
.iter()
.map(|n| {
json!({
"id": n.id,
"name": n.name,
"kind": n.kind.as_str(),
"file": n.file_path,
"line": n.start_line,
})
})
.collect();
let output = json!({
"node_count": subgraph.nodes.len(),
"edge_count": subgraph.edges.len(),
"nodes": nodes,
});
let formatted = serde_json::to_string_pretty(&output).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&formatted) }]
}),
touched_files,
})
}
async fn handle_node(cg: &TokenSave, args: Value) -> Result<ToolResult> {
let node_id = args
.get("node_id")
.and_then(|v| v.as_str())
.ok_or_else(|| TokenSaveError::Config {
message: "missing required parameter: node_id".to_string(),
})?;
let node = cg.get_node(node_id).await?;
match node {
Some(n) => {
let touched_files = vec![n.file_path.clone()];
let output = json!({
"id": n.id,
"name": n.name,
"kind": n.kind.as_str(),
"qualified_name": n.qualified_name,
"file": n.file_path,
"start_line": n.start_line,
"end_line": n.end_line,
"signature": n.signature,
"docstring": n.docstring,
"visibility": n.visibility.as_str(),
"is_async": n.is_async,
});
let formatted = serde_json::to_string_pretty(&output).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&formatted) }]
}),
touched_files,
})
}
None => Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": format!("Node not found: {}", node_id) }]
}),
touched_files: vec![],
}),
}
}
async fn handle_status(cg: &TokenSave, server_stats: Option<Value>) -> Result<ToolResult> {
let stats = cg.get_stats().await?;
let mut output: Value = serde_json::to_value(&stats).unwrap_or(json!({}));
if let Some(ss) = server_stats {
output["server"] = ss;
}
let formatted = serde_json::to_string_pretty(&output).unwrap_or_default();
Ok(ToolResult {
value: json!({
"content": [{ "type": "text", "text": truncate_response(&formatted) }]
}),
touched_files: vec![],
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_definitions_complete() {
let tools = get_tool_definitions();
assert_eq!(tools.len(), 7);
let tool_names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
assert!(tool_names.contains(&"tokensave_search"));
assert!(tool_names.contains(&"tokensave_context"));
assert!(tool_names.contains(&"tokensave_callers"));
assert!(tool_names.contains(&"tokensave_callees"));
assert!(tool_names.contains(&"tokensave_impact"));
assert!(tool_names.contains(&"tokensave_node"));
assert!(tool_names.contains(&"tokensave_status"));
}
#[test]
fn test_tool_definitions_have_schemas() {
let tools = get_tool_definitions();
for tool in &tools {
assert!(!tool.name.is_empty());
assert!(!tool.description.is_empty());
assert!(tool.input_schema.is_object());
assert_eq!(tool.input_schema["type"], "object");
}
}
#[test]
fn test_truncate_short_response() {
let short = "hello world";
assert_eq!(truncate_response(short), short);
}
#[test]
fn test_truncate_long_response() {
let long = "x".repeat(20_000);
let result = truncate_response(&long);
assert!(result.len() < 20_000);
assert!(result.contains("[... truncated at 15000 chars]"));
}
#[test]
fn test_tool_definitions_serializable() {
let tools = get_tool_definitions();
let json = serde_json::to_string(&tools).unwrap();
assert!(json.contains("tokensave_search"));
assert!(json.contains("tokensave_status"));
}
}