use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value;
use crate::sources::SourceRegistry;
pub use super::unified_tools::{
DeduplicatePapersHandler, DownloadPaperHandler, GetCitationsHandler, GetPaperHandler,
GetReferencesHandler, LookupByDoiHandler, ReadPaperHandler, SearchByAuthorHandler,
SearchPapersHandler,
};
#[derive(Clone)]
pub struct Tool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
pub handler: Arc<dyn ToolHandler>,
}
impl std::fmt::Debug for Tool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tool")
.field("name", &self.name)
.field("description", &self.description)
.field("input_schema", &self.input_schema)
.finish()
}
}
#[async_trait::async_trait]
pub trait ToolHandler: Send + Sync + std::fmt::Debug {
async fn execute(&self, args: Value) -> Result<Value, String>;
}
#[derive(Debug, Clone)]
pub struct ToolRegistry {
tools: HashMap<String, Tool>,
}
impl ToolRegistry {
pub fn from_sources(sources: &SourceRegistry) -> Self {
let mut registry = Self {
tools: HashMap::new(),
};
let sources_vec: Vec<Arc<dyn crate::sources::Source>> = sources.all().cloned().collect();
let sources_arc = Arc::new(sources_vec);
registry.register_unified_tools(&sources_arc);
registry
}
fn register_unified_tools(&mut self, sources: &Arc<Vec<Arc<dyn crate::sources::Source>>>) {
let sources_count = sources.len();
self.register(Tool {
name: "search_papers".to_string(),
description: format!(
"Search for papers across {} available research sources",
sources_count
),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query string"
},
"source": {
"type": "string",
"description": "Specific source to search (e.g., 'arxiv', 'semantic', 'pubmed'). If not specified, searches all sources."
},
"max_results": {
"type": "integer",
"description": "Maximum number of results per source",
"default": 10
},
"year": {
"type": "string",
"description": "Year filter (e.g., '2020', '2018-2022', '2010-', '-2015')"
},
"category": {
"type": "string",
"description": "Category/subject filter"
}
},
"required": ["query"]
}),
handler: Arc::new(SearchPapersHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "search_by_author".to_string(),
description: format!(
"Search for papers by author across {} research sources",
sources_count
),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"author": {
"type": "string",
"description": "Author name"
},
"source": {
"type": "string",
"description": "Specific source to search. If not specified, searches all sources with author search capability."
},
"max_results": {
"type": "integer",
"description": "Maximum results per source",
"default": 10
}
},
"required": ["author"]
}),
handler: Arc::new(SearchByAuthorHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "get_paper".to_string(),
description: "Get detailed metadata for a specific paper. Source is auto-detected from paper ID format.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "Paper identifier (e.g., '2301.12345', 'arXiv:2301.12345', 'PMC12345678')"
},
"source": {
"type": "string",
"description": "Override auto-detection and use specific source"
}
},
"required": ["paper_id"]
}),
handler: Arc::new(GetPaperHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "download_paper".to_string(),
description: "Download a paper PDF to your local filesystem. Source is auto-detected from paper ID format.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "Paper identifier"
},
"source": {
"type": "string",
"description": "Override auto-detection and use specific source"
},
"output_path": {
"type": "string",
"description": "Save path for the PDF",
"default": "./downloads"
},
"auto_filename": {
"type": "boolean",
"description": "Auto-generate filename from paper title",
"default": true
}
},
"required": ["paper_id"]
}),
handler: Arc::new(DownloadPaperHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "read_paper".to_string(),
description: "Extract and return the full text content from a paper PDF. Source is auto-detected from paper ID format. Requires poppler to be installed.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "Paper identifier"
},
"source": {
"type": "string",
"description": "Override auto-detection and use specific source"
}
},
"required": ["paper_id"]
}),
handler: Arc::new(ReadPaperHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "get_citations".to_string(),
description:
"Get papers that cite a specific paper. Prefers Semantic Scholar for best results."
.to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "Paper identifier"
},
"source": {
"type": "string",
"description": "Specific source (default: 'semantic')",
"default": "semantic"
},
"max_results": {
"type": "integer",
"description": "Maximum results",
"default": 20
}
},
"required": ["paper_id"]
}),
handler: Arc::new(GetCitationsHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "get_references".to_string(),
description: "Get papers referenced by a specific paper. Prefers Semantic Scholar for best results.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"paper_id": {
"type": "string",
"description": "Paper identifier"
},
"source": {
"type": "string",
"description": "Specific source (default: 'semantic')",
"default": "semantic"
},
"max_results": {
"type": "integer",
"description": "Maximum results",
"default": 20
}
},
"required": ["paper_id"]
}),
handler: Arc::new(GetReferencesHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "lookup_by_doi".to_string(),
description: "Look up a paper by its DOI across all sources that support DOI lookup.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"doi": {
"type": "string",
"description": "Digital Object Identifier (e.g., '10.48550/arXiv.2301.12345')"
},
"source": {
"type": "string",
"description": "Specific source to query. If not specified, queries all sources with DOI lookup capability."
}
},
"required": ["doi"]
}),
handler: Arc::new(LookupByDoiHandler {
sources: sources.clone(),
}),
});
self.register(Tool {
name: "deduplicate_papers".to_string(),
description: "Remove duplicate papers from a list using DOI matching and title similarity.".to_string(),
input_schema: serde_json::json!({
"type": "object",
"properties": {
"papers": {
"type": "array",
"description": "Array of paper objects",
"items": {
"type": "object"
}
},
"strategy": {
"type": "string",
"description": "Deduplication strategy: 'first' (keep first), 'last' (keep last), or 'mark' (add is_duplicate flag)",
"enum": ["first", "last", "mark"],
"default": "first"
}
},
"required": ["papers"]
}),
handler: Arc::new(DeduplicatePapersHandler),
});
}
pub fn register(&mut self, tool: Tool) {
self.tools.insert(tool.name.clone(), tool);
}
pub fn all(&self) -> Vec<&Tool> {
self.tools.values().collect()
}
pub fn get(&self, name: &str) -> Option<&Tool> {
self.tools.get(name)
}
pub async fn execute(&self, name: &str, args: Value) -> Result<Value, String> {
let tool = self
.get(name)
.ok_or_else(|| format!("Tool '{}' not found", name))?;
tool.handler.execute(args).await
}
}