1use std::collections::HashMap;
4use std::sync::Arc;
5
6use serde_json::Value;
7
8use crate::sources::SourceRegistry;
9
10pub use super::unified_tools::{
11 DeduplicatePapersHandler, DownloadPaperHandler, GetCitationsHandler, GetPaperHandler,
12 GetReferencesHandler, LookupByDoiHandler, ReadPaperHandler, SearchByAuthorHandler,
13 SearchPapersHandler,
14};
15
16#[derive(Clone)]
18pub struct Tool {
19 pub name: String,
21
22 pub description: String,
24
25 pub input_schema: serde_json::Value,
27
28 pub handler: Arc<dyn ToolHandler>,
30}
31
32impl std::fmt::Debug for Tool {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("Tool")
35 .field("name", &self.name)
36 .field("description", &self.description)
37 .field("input_schema", &self.input_schema)
38 .finish()
39 }
40}
41
42#[async_trait::async_trait]
44pub trait ToolHandler: Send + Sync + std::fmt::Debug {
45 async fn execute(&self, args: Value) -> Result<Value, String>;
47}
48
49#[derive(Debug, Clone)]
51pub struct ToolRegistry {
52 tools: HashMap<String, Tool>,
53}
54
55impl ToolRegistry {
56 pub fn from_sources(sources: &SourceRegistry) -> Self {
58 let mut registry = Self {
59 tools: HashMap::new(),
60 };
61
62 let sources_vec: Vec<Arc<dyn crate::sources::Source>> = sources.all().cloned().collect();
64 let sources_arc = Arc::new(sources_vec);
65
66 registry.register_unified_tools(&sources_arc);
68
69 registry
70 }
71
72 fn register_unified_tools(&mut self, sources: &Arc<Vec<Arc<dyn crate::sources::Source>>>) {
74 let sources_count = sources.len();
75
76 self.register(Tool {
78 name: "search_papers".to_string(),
79 description: format!(
80 "Search for papers across {} available research sources",
81 sources_count
82 ),
83 input_schema: serde_json::json!({
84 "type": "object",
85 "properties": {
86 "query": {
87 "type": "string",
88 "description": "Search query string"
89 },
90 "source": {
91 "type": "string",
92 "description": "Specific source to search (e.g., 'arxiv', 'semantic', 'pubmed'). If not specified, searches all sources."
93 },
94 "max_results": {
95 "type": "integer",
96 "description": "Maximum number of results per source",
97 "default": 10
98 },
99 "year": {
100 "type": "string",
101 "description": "Year filter (e.g., '2020', '2018-2022', '2010-', '-2015')"
102 },
103 "category": {
104 "type": "string",
105 "description": "Category/subject filter"
106 }
107 },
108 "required": ["query"]
109 }),
110 handler: Arc::new(SearchPapersHandler {
111 sources: sources.clone(),
112 }),
113 });
114
115 self.register(Tool {
117 name: "search_by_author".to_string(),
118 description: format!(
119 "Search for papers by author across {} research sources",
120 sources_count
121 ),
122 input_schema: serde_json::json!({
123 "type": "object",
124 "properties": {
125 "author": {
126 "type": "string",
127 "description": "Author name"
128 },
129 "source": {
130 "type": "string",
131 "description": "Specific source to search. If not specified, searches all sources with author search capability."
132 },
133 "max_results": {
134 "type": "integer",
135 "description": "Maximum results per source",
136 "default": 10
137 }
138 },
139 "required": ["author"]
140 }),
141 handler: Arc::new(SearchByAuthorHandler {
142 sources: sources.clone(),
143 }),
144 });
145
146 self.register(Tool {
148 name: "get_paper".to_string(),
149 description: "Get detailed metadata for a specific paper. Source is auto-detected from paper ID format.".to_string(),
150 input_schema: serde_json::json!({
151 "type": "object",
152 "properties": {
153 "paper_id": {
154 "type": "string",
155 "description": "Paper identifier (e.g., '2301.12345', 'arXiv:2301.12345', 'PMC12345678')"
156 },
157 "source": {
158 "type": "string",
159 "description": "Override auto-detection and use specific source"
160 }
161 },
162 "required": ["paper_id"]
163 }),
164 handler: Arc::new(GetPaperHandler {
165 sources: sources.clone(),
166 }),
167 });
168
169 self.register(Tool {
171 name: "download_paper".to_string(),
172 description: "Download a paper PDF to your local filesystem. Source is auto-detected from paper ID format.".to_string(),
173 input_schema: serde_json::json!({
174 "type": "object",
175 "properties": {
176 "paper_id": {
177 "type": "string",
178 "description": "Paper identifier"
179 },
180 "source": {
181 "type": "string",
182 "description": "Override auto-detection and use specific source"
183 },
184 "output_path": {
185 "type": "string",
186 "description": "Save path for the PDF",
187 "default": "./downloads"
188 },
189 "auto_filename": {
190 "type": "boolean",
191 "description": "Auto-generate filename from paper title",
192 "default": true
193 }
194 },
195 "required": ["paper_id"]
196 }),
197 handler: Arc::new(DownloadPaperHandler {
198 sources: sources.clone(),
199 }),
200 });
201
202 self.register(Tool {
204 name: "read_paper".to_string(),
205 description: "Extract and return the full text content from a paper PDF. Source is auto-detected from paper ID format. Requires poppler to be installed.".to_string(),
206 input_schema: serde_json::json!({
207 "type": "object",
208 "properties": {
209 "paper_id": {
210 "type": "string",
211 "description": "Paper identifier"
212 },
213 "source": {
214 "type": "string",
215 "description": "Override auto-detection and use specific source"
216 }
217 },
218 "required": ["paper_id"]
219 }),
220 handler: Arc::new(ReadPaperHandler {
221 sources: sources.clone(),
222 }),
223 });
224
225 self.register(Tool {
227 name: "get_citations".to_string(),
228 description:
229 "Get papers that cite a specific paper. Prefers Semantic Scholar for best results."
230 .to_string(),
231 input_schema: serde_json::json!({
232 "type": "object",
233 "properties": {
234 "paper_id": {
235 "type": "string",
236 "description": "Paper identifier"
237 },
238 "source": {
239 "type": "string",
240 "description": "Specific source (default: 'semantic')",
241 "default": "semantic"
242 },
243 "max_results": {
244 "type": "integer",
245 "description": "Maximum results",
246 "default": 20
247 }
248 },
249 "required": ["paper_id"]
250 }),
251 handler: Arc::new(GetCitationsHandler {
252 sources: sources.clone(),
253 }),
254 });
255
256 self.register(Tool {
258 name: "get_references".to_string(),
259 description: "Get papers referenced by a specific paper. Prefers Semantic Scholar for best results.".to_string(),
260 input_schema: serde_json::json!({
261 "type": "object",
262 "properties": {
263 "paper_id": {
264 "type": "string",
265 "description": "Paper identifier"
266 },
267 "source": {
268 "type": "string",
269 "description": "Specific source (default: 'semantic')",
270 "default": "semantic"
271 },
272 "max_results": {
273 "type": "integer",
274 "description": "Maximum results",
275 "default": 20
276 }
277 },
278 "required": ["paper_id"]
279 }),
280 handler: Arc::new(GetReferencesHandler {
281 sources: sources.clone(),
282 }),
283 });
284
285 self.register(Tool {
287 name: "lookup_by_doi".to_string(),
288 description: "Look up a paper by its DOI across all sources that support DOI lookup.".to_string(),
289 input_schema: serde_json::json!({
290 "type": "object",
291 "properties": {
292 "doi": {
293 "type": "string",
294 "description": "Digital Object Identifier (e.g., '10.48550/arXiv.2301.12345')"
295 },
296 "source": {
297 "type": "string",
298 "description": "Specific source to query. If not specified, queries all sources with DOI lookup capability."
299 }
300 },
301 "required": ["doi"]
302 }),
303 handler: Arc::new(LookupByDoiHandler {
304 sources: sources.clone(),
305 }),
306 });
307
308 self.register(Tool {
310 name: "deduplicate_papers".to_string(),
311 description: "Remove duplicate papers from a list using DOI matching and title similarity.".to_string(),
312 input_schema: serde_json::json!({
313 "type": "object",
314 "properties": {
315 "papers": {
316 "type": "array",
317 "description": "Array of paper objects",
318 "items": {
319 "type": "object"
320 }
321 },
322 "strategy": {
323 "type": "string",
324 "description": "Deduplication strategy: 'first' (keep first), 'last' (keep last), or 'mark' (add is_duplicate flag)",
325 "enum": ["first", "last", "mark"],
326 "default": "first"
327 }
328 },
329 "required": ["papers"]
330 }),
331 handler: Arc::new(DeduplicatePapersHandler),
332 });
333 }
334
335 pub fn register(&mut self, tool: Tool) {
337 self.tools.insert(tool.name.clone(), tool);
338 }
339
340 pub fn all(&self) -> Vec<&Tool> {
342 self.tools.values().collect()
343 }
344
345 pub fn get(&self, name: &str) -> Option<&Tool> {
347 self.tools.get(name)
348 }
349
350 pub async fn execute(&self, name: &str, args: Value) -> Result<Value, String> {
352 let tool = self
353 .get(name)
354 .ok_or_else(|| format!("Tool '{}' not found", name))?;
355
356 tool.handler.execute(args).await
357 }
358}