use rmcp::handler::server::prompt::PromptContext;
use rmcp::handler::server::router::prompt::{PromptRoute, PromptRouter};
use rmcp::model::{
GetPromptResult, Prompt, PromptArgument, PromptMessage, PromptMessageContent, PromptMessageRole,
};
pub fn create_prompt_router<S: Send + Sync + Clone + 'static>() -> PromptRouter<S> {
PromptRouter::new()
.with_route(semantic_search_prompt())
.with_route(find_callers_prompt())
.with_route(find_callees_prompt())
.with_route(trace_path_prompt())
.with_route(explain_symbol_prompt())
.with_route(code_impact_prompt())
.with_route(ask_prompt())
}
fn semantic_search_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"semantic_search",
Some("Structural code search - find symbols by name/kind/visibility with 100% precision (not embedding similarity)"),
Some(vec![
PromptArgument {
name: "query".to_string(),
title: Some("Search Query".to_string()),
description: Some("What to search for (e.g., 'authentication functions', 'public classes', 'database handlers')".to_string()),
required: Some(true),
},
PromptArgument {
name: "path".to_string(),
title: Some("Path Filter".to_string()),
description: Some("Optional directory to limit search (e.g., 'src/auth')".to_string()),
required: Some(false),
},
]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_semantic_search(&context)) })
})
}
fn handle_semantic_search<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let query = context
.arguments
.as_ref()
.and_then(|args| args.get("query"))
.and_then(|v| v.as_str())
.unwrap_or("functions");
let path_filter = context
.arguments
.as_ref()
.and_then(|args| args.get("path"))
.and_then(|v| v.as_str())
.map(|p| format!(" in path:{p}"))
.unwrap_or_default();
let message = format!(
r#"Use the sqry semantic_search or hierarchical_search tool to find code matching: "{query}"{path_filter}
Note: sqry provides deterministic results via AST analysis (not probabilistic embedding similarity).
Same query → same results. You get the COMPLETE list - critical for refactoring, security audits, and impact analysis.
Translate the user's query into sqry predicates in the `query` parameter:
- For symbol names: use `name:` predicate (e.g., `name:login`, `name~=/.*Handler/`)
- For symbol types: use `kind:` predicate (e.g., `kind:function`, `kind:class`, `kind:method`)
- For visibility: use `visibility:` predicate (e.g., `visibility:public`, `visibility:private`)
- For language: use `lang:` predicate (e.g., `lang:rust`, `lang:typescript`)
Example queries:
- "authentication functions" → semantic_search with query="name~=/^auth/ AND kind:function"
- "public classes" → semantic_search with query="visibility:public AND kind:class"
- "all methods in User class" → semantic_search with query="name~=/^User::/ AND kind:method"
Alternatively, use the `filters` parameter for simple structured constraints:
filters={{"language":["rust"],"symbol_kind":["function"]}}
Use `query` for complex boolean expressions with AND/OR/NOT/regex.
Use `filters` for simple pre-filtering by language, kind, or visibility.
Both can be combined:
query="name~=/^auth/" filters={{"language":["typescript"],"visibility":"public"}}
Use hierarchical_search for RAG-optimized results with file/container grouping."#
);
GetPromptResult {
description: Some(format!("Search for code matching: {query}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn find_callers_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"find_callers",
Some("Find all code that calls a specific function or method"),
Some(vec![PromptArgument {
name: "symbol".to_string(),
title: Some("Symbol Name".to_string()),
description: Some(
"The function or method to find callers for (e.g., 'authenticate', 'User::save')"
.to_string(),
),
required: Some(true),
}]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_find_callers(&context)) })
})
}
fn handle_find_callers<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let symbol = context
.arguments
.as_ref()
.and_then(|args| args.get("symbol"))
.and_then(|v| v.as_str())
.unwrap_or("main");
let message = format!(
r#"Use the sqry relation_query tool to find all callers of "{symbol}".
Call relation_query with:
- symbol: "{symbol}"
- relation_type: "callers"
- max_depth: 2 (increase for transitive callers)
This will show all functions/methods that call {symbol}, helping understand:
- Who depends on this code
- Impact of changing this function
- Call patterns in the codebase"#
);
GetPromptResult {
description: Some(format!("Find all code that calls: {symbol}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn find_callees_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"find_callees",
Some("Find all functions/methods that a specific function calls"),
Some(vec![PromptArgument {
name: "symbol".to_string(),
title: Some("Symbol Name".to_string()),
description: Some(
"The function to analyze (e.g., 'process_request', 'main')".to_string(),
),
required: Some(true),
}]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_find_callees(&context)) })
})
}
fn handle_find_callees<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let symbol = context
.arguments
.as_ref()
.and_then(|args| args.get("symbol"))
.and_then(|v| v.as_str())
.unwrap_or("main");
let message = format!(
r#"Use the sqry relation_query tool to find all functions called by "{symbol}".
Call relation_query with:
- symbol: "{symbol}"
- relation_type: "callees"
- max_depth: 2 (increase for transitive callees)
This will show all functions/methods that {symbol} calls, helping understand:
- Dependencies of this function
- What subsystems it touches
- Complexity and coupling"#
);
GetPromptResult {
description: Some(format!("Find all functions called by: {symbol}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn trace_path_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"trace_path",
Some("Trace the call path between two functions - how does A eventually call B?"),
Some(vec![
PromptArgument {
name: "from".to_string(),
title: Some("Starting Function".to_string()),
description: Some(
"The function where the path starts (e.g., 'main', 'handle_request')"
.to_string(),
),
required: Some(true),
},
PromptArgument {
name: "to".to_string(),
title: Some("Target Function".to_string()),
description: Some(
"The function where the path ends (e.g., 'database_query', 'send_email')"
.to_string(),
),
required: Some(true),
},
]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_trace_path(&context)) })
})
}
fn handle_trace_path<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let from = context
.arguments
.as_ref()
.and_then(|args| args.get("from"))
.and_then(|v| v.as_str())
.unwrap_or("main");
let to = context
.arguments
.as_ref()
.and_then(|args| args.get("to"))
.and_then(|v| v.as_str())
.unwrap_or("target");
let message = format!(
r#"Use the sqry trace_path tool to find how "{from}" reaches "{to}".
Call trace_path with:
- from_symbol: "{from}"
- to_symbol: "{to}"
- max_hops: 5 (increase if path might be longer)
- max_paths: 3 (to see alternative routes)
This will show the call chain from {from} to {to}, helping understand:
- How control flows through the codebase
- Critical paths for debugging
- Dependencies between subsystems"#
);
GetPromptResult {
description: Some(format!("Trace call path from {from} to {to}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn explain_symbol_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"explain_symbol",
Some("Get detailed explanation of a code symbol including its context and relationships"),
Some(vec![
PromptArgument {
name: "file".to_string(),
title: Some("File Path".to_string()),
description: Some(
"Path to the file containing the symbol (e.g., 'src/auth/login.rs')"
.to_string(),
),
required: Some(true),
},
PromptArgument {
name: "symbol".to_string(),
title: Some("Symbol Name".to_string()),
description: Some(
"Name of the symbol to explain (e.g., 'authenticate', 'UserService')"
.to_string(),
),
required: Some(true),
},
]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_explain_symbol(&context)) })
})
}
fn handle_explain_symbol<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let file = context
.arguments
.as_ref()
.and_then(|args| args.get("file"))
.and_then(|v| v.as_str())
.unwrap_or("src/main.rs");
let symbol = context
.arguments
.as_ref()
.and_then(|args| args.get("symbol"))
.and_then(|v| v.as_str())
.unwrap_or("main");
let message = format!(
r#"Use the sqry explain_code tool to get detailed information about "{symbol}" in {file}.
Call explain_code with:
- file_path: "{file}"
- symbol_name: "{symbol}"
- include_context: true
- include_relations: true
This will provide:
- Symbol signature and documentation
- Surrounding context code
- Callers and callees relationships
- Import/export information"#
);
GetPromptResult {
description: Some(format!("Explain symbol {symbol} in {file}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn code_impact_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"code_impact",
Some("Analyze what code would be affected if a symbol is changed or removed"),
Some(vec![PromptArgument {
name: "symbol".to_string(),
title: Some("Symbol Name".to_string()),
description: Some(
"The symbol to analyze impact for (e.g., 'UserService', 'validate_input')"
.to_string(),
),
required: Some(true),
}]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_code_impact(&context)) })
})
}
fn handle_code_impact<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let symbol = context
.arguments
.as_ref()
.and_then(|args| args.get("symbol"))
.and_then(|v| v.as_str())
.unwrap_or("target");
let message = format!(
r#"Use the sqry dependency_impact tool to analyze what would be affected by changing "{symbol}".
Call dependency_impact with:
- symbol: "{symbol}"
- max_depth: 3
- include_indirect: true
- include_files: true
This will show:
- Direct dependents (code that directly uses this symbol)
- Indirect dependents (transitive impact)
- Affected files list
- Risk assessment for the change"#
);
GetPromptResult {
description: Some(format!("Analyze impact of changing: {symbol}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
fn ask_prompt<S: Send + Sync + 'static>() -> PromptRoute<S> {
let prompt = Prompt::new(
"ask",
Some("Ask questions about code in natural language - sqry will translate to the right query"),
Some(vec![PromptArgument {
name: "question".to_string(),
title: Some("Question".to_string()),
description: Some(
"Your question in plain English (e.g., 'who calls the login function?', 'find all database queries')"
.to_string(),
),
required: Some(true),
}]),
);
PromptRoute::new_dyn(prompt, |context: PromptContext<'_, S>| {
Box::pin(async move { Ok(handle_ask(&context)) })
})
}
fn handle_ask<S>(context: &PromptContext<'_, S>) -> GetPromptResult {
let question = context
.arguments
.as_ref()
.and_then(|args| args.get("question"))
.and_then(|v| v.as_str())
.unwrap_or("find all functions");
let message = format!(
r#"Use the sqry sqry_ask tool to answer: "{question}"
The sqry_ask tool uses natural language understanding to:
1. Parse your question
2. Identify the appropriate sqry command (search, relation query, trace, etc.)
3. Generate and execute the correct query
4. Return results in a readable format
This is the easiest way to query code - just ask your question naturally!"#
);
GetPromptResult {
description: Some(format!("Ask: {question}")),
messages: vec![PromptMessage {
role: PromptMessageRole::User,
content: PromptMessageContent::Text { text: message },
}],
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_router_creation() {
let router: PromptRouter<()> = create_prompt_router();
let prompts = router.list_all();
assert!(prompts.len() >= 6);
let names: Vec<&str> = prompts.iter().map(|p| p.name.as_str()).collect();
assert!(names.contains(&"semantic_search"));
assert!(names.contains(&"find_callers"));
assert!(names.contains(&"find_callees"));
assert!(names.contains(&"trace_path"));
assert!(names.contains(&"explain_symbol"));
assert!(names.contains(&"code_impact"));
assert!(names.contains(&"ask"));
}
#[test]
fn test_semantic_search_prompt_has_arguments() {
let router: PromptRouter<()> = create_prompt_router();
let prompts = router.list_all();
let search_prompt = prompts
.iter()
.find(|p| p.name == "semantic_search")
.unwrap();
assert!(search_prompt.arguments.is_some());
let args = search_prompt.arguments.as_ref().unwrap();
assert!(args.iter().any(|a| a.name == "query"));
}
}