Skip to main content

research_master/mcp/
tools.rs

1//! Tool registry for MCP tools.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use serde_json::Value;
7
8use crate::sources::SourceRegistry;
9
10pub use super::unified_tools::{
11    DeduplicatePapersHandler, DownloadPaperHandler, GetCitationsHandler, GetPaperHandler,
12    GetReferencesHandler, LookupByDoiHandler, ReadPaperHandler, SearchByAuthorHandler,
13    SearchPapersHandler,
14};
15
16/// An MCP tool that can be called by the client
17#[derive(Clone)]
18pub struct Tool {
19    /// Tool name (e.g., "search_papers")
20    pub name: String,
21
22    /// Human-readable description
23    pub description: String,
24
25    /// JSON Schema for input parameters
26    pub input_schema: serde_json::Value,
27
28    /// Handler function to execute the tool
29    pub handler: Arc<dyn ToolHandler>,
30}
31
32impl std::fmt::Debug for Tool {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("Tool")
35            .field("name", &self.name)
36            .field("description", &self.description)
37            .field("input_schema", &self.input_schema)
38            .finish()
39    }
40}
41
42/// Handler for executing a tool
43#[async_trait::async_trait]
44pub trait ToolHandler: Send + Sync + std::fmt::Debug {
45    /// Execute the tool with the given arguments
46    async fn execute(&self, args: Value) -> Result<Value, String>;
47}
48
49/// Registry for all MCP tools
50#[derive(Debug, Clone)]
51pub struct ToolRegistry {
52    tools: HashMap<String, Tool>,
53}
54
55impl ToolRegistry {
56    /// Create a new tool registry and register all unified tools from the source registry
57    pub fn from_sources(sources: &SourceRegistry) -> Self {
58        let mut registry = Self {
59            tools: HashMap::new(),
60        };
61
62        // Convert sources to a shared Arc<Vec>
63        let sources_vec: Vec<Arc<dyn crate::sources::Source>> = sources.all().cloned().collect();
64        let sources_arc = Arc::new(sources_vec);
65
66        // Register unified tools
67        registry.register_unified_tools(&sources_arc);
68
69        registry
70    }
71
72    /// Register unified tools (9 tools total instead of per-source tools)
73    fn register_unified_tools(&mut self, sources: &Arc<Vec<Arc<dyn crate::sources::Source>>>) {
74        let sources_count = sources.len();
75
76        // 1. search_papers - Search across all or specific sources
77        self.register(Tool {
78            name: "search_papers".to_string(),
79            description: format!(
80                "Search for papers across {} available research sources",
81                sources_count
82            ),
83            input_schema: serde_json::json!({
84                "type": "object",
85                "properties": {
86                    "query": {
87                        "type": "string",
88                        "description": "Search query string"
89                    },
90                    "source": {
91                        "type": "string",
92                        "description": "Specific source to search (e.g., 'arxiv', 'semantic', 'pubmed'). If not specified, searches all sources."
93                    },
94                    "max_results": {
95                        "type": "integer",
96                        "description": "Maximum number of results per source",
97                        "default": 10
98                    },
99                    "year": {
100                        "type": "string",
101                        "description": "Year filter (e.g., '2020', '2018-2022', '2010-', '-2015')"
102                    },
103                    "category": {
104                        "type": "string",
105                        "description": "Category/subject filter"
106                    }
107                },
108                "required": ["query"]
109            }),
110            handler: Arc::new(SearchPapersHandler {
111                sources: sources.clone(),
112            }),
113        });
114
115        // 2. search_by_author - Author search across sources
116        self.register(Tool {
117            name: "search_by_author".to_string(),
118            description: format!(
119                "Search for papers by author across {} research sources",
120                sources_count
121            ),
122            input_schema: serde_json::json!({
123                "type": "object",
124                "properties": {
125                    "author": {
126                        "type": "string",
127                        "description": "Author name"
128                    },
129                    "source": {
130                        "type": "string",
131                        "description": "Specific source to search. If not specified, searches all sources with author search capability."
132                    },
133                    "max_results": {
134                        "type": "integer",
135                        "description": "Maximum results per source",
136                        "default": 10
137                    }
138                },
139                "required": ["author"]
140            }),
141            handler: Arc::new(SearchByAuthorHandler {
142                sources: sources.clone(),
143            }),
144        });
145
146        // 3. get_paper - Get paper metadata with auto-detection
147        self.register(Tool {
148            name: "get_paper".to_string(),
149            description: "Get detailed metadata for a specific paper. Source is auto-detected from paper ID format.".to_string(),
150            input_schema: serde_json::json!({
151                "type": "object",
152                "properties": {
153                    "paper_id": {
154                        "type": "string",
155                        "description": "Paper identifier (e.g., '2301.12345', 'arXiv:2301.12345', 'PMC12345678')"
156                    },
157                    "source": {
158                        "type": "string",
159                        "description": "Override auto-detection and use specific source"
160                    }
161                },
162                "required": ["paper_id"]
163            }),
164            handler: Arc::new(GetPaperHandler {
165                sources: sources.clone(),
166            }),
167        });
168
169        // 4. download_paper - Download with auto-detection
170        self.register(Tool {
171            name: "download_paper".to_string(),
172            description: "Download a paper PDF to your local filesystem. Source is auto-detected from paper ID format.".to_string(),
173            input_schema: serde_json::json!({
174                "type": "object",
175                "properties": {
176                    "paper_id": {
177                        "type": "string",
178                        "description": "Paper identifier"
179                    },
180                    "source": {
181                        "type": "string",
182                        "description": "Override auto-detection and use specific source"
183                    },
184                    "output_path": {
185                        "type": "string",
186                        "description": "Save path for the PDF",
187                        "default": "./downloads"
188                    },
189                    "auto_filename": {
190                        "type": "boolean",
191                        "description": "Auto-generate filename from paper title",
192                        "default": true
193                    }
194                },
195                "required": ["paper_id"]
196            }),
197            handler: Arc::new(DownloadPaperHandler {
198                sources: sources.clone(),
199            }),
200        });
201
202        // 5. read_paper - PDF text extraction with auto-detection
203        self.register(Tool {
204            name: "read_paper".to_string(),
205            description: "Extract and return the full text content from a paper PDF. Source is auto-detected from paper ID format. Requires poppler to be installed.".to_string(),
206            input_schema: serde_json::json!({
207                "type": "object",
208                "properties": {
209                    "paper_id": {
210                        "type": "string",
211                        "description": "Paper identifier"
212                    },
213                    "source": {
214                        "type": "string",
215                        "description": "Override auto-detection and use specific source"
216                    }
217                },
218                "required": ["paper_id"]
219            }),
220            handler: Arc::new(ReadPaperHandler {
221                sources: sources.clone(),
222            }),
223        });
224
225        // 6. get_citations - Get papers that cite a given paper
226        self.register(Tool {
227            name: "get_citations".to_string(),
228            description:
229                "Get papers that cite a specific paper. Prefers Semantic Scholar for best results."
230                    .to_string(),
231            input_schema: serde_json::json!({
232                "type": "object",
233                "properties": {
234                    "paper_id": {
235                        "type": "string",
236                        "description": "Paper identifier"
237                    },
238                    "source": {
239                        "type": "string",
240                        "description": "Specific source (default: 'semantic')",
241                        "default": "semantic"
242                    },
243                    "max_results": {
244                        "type": "integer",
245                        "description": "Maximum results",
246                        "default": 20
247                    }
248                },
249                "required": ["paper_id"]
250            }),
251            handler: Arc::new(GetCitationsHandler {
252                sources: sources.clone(),
253            }),
254        });
255
256        // 7. get_references - Get papers referenced by a given paper
257        self.register(Tool {
258            name: "get_references".to_string(),
259            description: "Get papers referenced by a specific paper. Prefers Semantic Scholar for best results.".to_string(),
260            input_schema: serde_json::json!({
261                "type": "object",
262                "properties": {
263                    "paper_id": {
264                        "type": "string",
265                        "description": "Paper identifier"
266                    },
267                    "source": {
268                        "type": "string",
269                        "description": "Specific source (default: 'semantic')",
270                        "default": "semantic"
271                    },
272                    "max_results": {
273                        "type": "integer",
274                        "description": "Maximum results",
275                        "default": 20
276                    }
277                },
278                "required": ["paper_id"]
279            }),
280            handler: Arc::new(GetReferencesHandler {
281                sources: sources.clone(),
282            }),
283        });
284
285        // 8. lookup_by_doi - DOI lookup across all sources
286        self.register(Tool {
287            name: "lookup_by_doi".to_string(),
288            description: "Look up a paper by its DOI across all sources that support DOI lookup.".to_string(),
289            input_schema: serde_json::json!({
290                "type": "object",
291                "properties": {
292                    "doi": {
293                        "type": "string",
294                        "description": "Digital Object Identifier (e.g., '10.48550/arXiv.2301.12345')"
295                    },
296                    "source": {
297                        "type": "string",
298                        "description": "Specific source to query. If not specified, queries all sources with DOI lookup capability."
299                    }
300                },
301                "required": ["doi"]
302            }),
303            handler: Arc::new(LookupByDoiHandler {
304                sources: sources.clone(),
305            }),
306        });
307
308        // 9. deduplicate_papers - Remove duplicates
309        self.register(Tool {
310            name: "deduplicate_papers".to_string(),
311            description: "Remove duplicate papers from a list using DOI matching and title similarity.".to_string(),
312            input_schema: serde_json::json!({
313                "type": "object",
314                "properties": {
315                    "papers": {
316                        "type": "array",
317                        "description": "Array of paper objects",
318                        "items": {
319                            "type": "object"
320                        }
321                    },
322                    "strategy": {
323                        "type": "string",
324                        "description": "Deduplication strategy: 'first' (keep first), 'last' (keep last), or 'mark' (add is_duplicate flag)",
325                        "enum": ["first", "last", "mark"],
326                        "default": "first"
327                    }
328                },
329                "required": ["papers"]
330            }),
331            handler: Arc::new(DeduplicatePapersHandler),
332        });
333    }
334
335    /// Register a tool
336    pub fn register(&mut self, tool: Tool) {
337        self.tools.insert(tool.name.clone(), tool);
338    }
339
340    /// Get all tools
341    pub fn all(&self) -> Vec<&Tool> {
342        self.tools.values().collect()
343    }
344
345    /// Get a tool by name
346    pub fn get(&self, name: &str) -> Option<&Tool> {
347        self.tools.get(name)
348    }
349
350    /// Execute a tool by name
351    pub async fn execute(&self, name: &str, args: Value) -> Result<Value, String> {
352        let tool = self
353            .get(name)
354            .ok_or_else(|| format!("Tool '{}' not found", name))?;
355
356        tool.handler.execute(args).await
357    }
358}