use std::sync::Arc;
use serde_json::Value;
use super::tools::ToolHandler;
fn auto_detect_source(
sources: &Arc<Vec<Arc<dyn crate::sources::Source>>>,
paper_id: &str,
) -> Result<Arc<dyn crate::sources::Source>, String> {
let paper_id_lower = paper_id.to_lowercase();
if paper_id_lower.starts_with("arxiv:")
|| (paper_id.len() > 4 && paper_id.chars().take(9).all(|c| c.is_numeric() || c == '.'))
{
return sources
.iter()
.find(|s| s.id() == "arxiv")
.cloned()
.ok_or_else(|| "arXiv source not available".to_string());
}
if paper_id_upper_start(paper_id, "PMC") {
return sources
.iter()
.find(|s| s.id() == "pmc")
.cloned()
.ok_or_else(|| "PMC source not available".to_string());
}
if paper_id_lower.starts_with("hal-") {
return sources
.iter()
.find(|s| s.id() == "hal")
.cloned()
.ok_or_else(|| "HAL source not available".to_string());
}
if paper_id.chars().filter(|&c| c == '/').count() == 1 {
return sources
.iter()
.find(|s| s.id() == "iacr")
.cloned()
.ok_or_else(|| "IACR source not available".to_string());
}
if paper_id.starts_with("10.") {
if let Some(source) = sources
.iter()
.find(|s| s.id() == "semantic" && s.supports_doi_lookup())
{
return Ok(Arc::clone(source));
}
if let Some(source) = sources.iter().find(|s| s.supports_doi_lookup()) {
return Ok(Arc::clone(source));
}
}
if let Some(source) = sources.iter().find(|s| s.id() == "arxiv") {
return Ok(Arc::clone(source));
}
if let Some(source) = sources.iter().find(|s| s.id() == "semantic") {
return Ok(Arc::clone(source));
}
Err("Could not auto-detect source. Please specify source explicitly.".to_string())
}
fn paper_id_upper_start(paper_id: &str, prefix: &str) -> bool {
if paper_id.len() < prefix.len() {
return false;
}
paper_id[..prefix.len()].to_uppercase() == prefix
}
#[derive(Debug)]
pub struct SearchPapersHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for SearchPapersHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or("Missing 'query' parameter")?;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let year = args
.get("year")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let category = args
.get("category")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let source_filter = args.get("source").and_then(|v| v.as_str());
let mut all_results = Vec::new();
for source in self.sources.iter() {
if let Some(filter) = source_filter {
if source.id() != filter {
continue;
}
}
if !source.supports_search() {
continue;
}
let mut search_query = crate::models::SearchQuery::new(query).max_results(max_results);
if let Some(ref year) = year {
search_query = search_query.year(year);
}
if let Some(ref cat) = category {
search_query = search_query.category(cat);
}
match source.search(&search_query).await {
Ok(response) => {
all_results.extend(response.papers);
}
Err(e) => {
tracing::warn!("Search failed for {}: {}", source.id(), e);
}
}
}
serde_json::to_value(all_results).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct SearchByAuthorHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for SearchByAuthorHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let author = args
.get("author")
.and_then(|v| v.as_str())
.ok_or("Missing 'author' parameter")?;
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(10) as usize;
let year = args.get("year").and_then(|v| v.as_str());
let source_filter = args.get("source").and_then(|v| v.as_str());
let mut all_results = Vec::new();
for source in self.sources.iter() {
if let Some(filter) = source_filter {
if source.id() != filter {
continue;
}
}
if !source.supports_author_search() {
continue;
}
match source.search_by_author(author, max_results, year).await {
Ok(response) => {
all_results.extend(response.papers);
}
Err(e) => {
tracing::warn!("Author search failed for {}: {}", source.id(), e);
}
}
}
serde_json::to_value(all_results).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct GetPaperHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for GetPaperHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let paper_id = args
.get("paper_id")
.and_then(|v| v.as_str())
.ok_or("Missing 'paper_id' parameter")?;
let source_override = args.get("source").and_then(|v| v.as_str());
let source = self.find_source(paper_id, source_override)?;
let search_query = crate::models::SearchQuery::new(paper_id).max_results(1);
let response = source
.search(&search_query)
.await
.map_err(|e| e.to_string())?;
if response.papers.is_empty() {
return Err(format!("Paper '{}' not found in {}", paper_id, source.id()));
}
serde_json::to_value(&response.papers[0]).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct DownloadPaperHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for DownloadPaperHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let paper_id = args
.get("paper_id")
.and_then(|v| v.as_str())
.ok_or("Missing 'paper_id' parameter")?;
let source_override = args.get("source").and_then(|v| v.as_str());
let output_path = args
.get("output_path")
.and_then(|v| v.as_str())
.unwrap_or("./downloads");
let source = self.find_source(paper_id, source_override)?;
let request = crate::models::DownloadRequest::new(paper_id, output_path);
let result = source.download(&request).await.map_err(|e| e.to_string())?;
serde_json::to_value(result).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct ReadPaperHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for ReadPaperHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let paper_id = args
.get("paper_id")
.and_then(|v| v.as_str())
.ok_or("Missing 'paper_id' parameter")?;
let source_override = args.get("source").and_then(|v| v.as_str());
let source = self.find_source(paper_id, source_override)?;
let request = crate::models::ReadRequest::new(paper_id, "./downloads");
let result = source.read(&request).await.map_err(|e| e.to_string())?;
serde_json::to_value(result).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct GetCitationsHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for GetCitationsHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let paper_id = args
.get("paper_id")
.and_then(|v| v.as_str())
.ok_or("Missing 'paper_id' parameter")?;
let source_override = args.get("source").and_then(|v| v.as_str());
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(20) as usize;
let source_id = source_override.unwrap_or("semantic");
let source = self
.sources
.iter()
.find(|s| s.id() == source_id)
.ok_or_else(|| format!("Source '{}' not found", source_id))?;
if !source.supports_citations() {
return Err(format!("Source '{}' does not support citations", source_id));
}
let request = crate::models::CitationRequest::new(paper_id).max_results(max_results);
let response = source
.get_citations(&request)
.await
.map_err(|e| e.to_string())?;
serde_json::to_value(response).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct GetReferencesHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for GetReferencesHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let paper_id = args
.get("paper_id")
.and_then(|v| v.as_str())
.ok_or("Missing 'paper_id' parameter")?;
let source_override = args.get("source").and_then(|v| v.as_str());
let max_results = args
.get("max_results")
.and_then(|v| v.as_u64())
.unwrap_or(20) as usize;
let source_id = source_override.unwrap_or("semantic");
let source = self
.sources
.iter()
.find(|s| s.id() == source_id)
.ok_or_else(|| format!("Source '{}' not found", source_id))?;
if !source.supports_citations() {
return Err(format!(
"Source '{}' does not support references",
source_id
));
}
let request = crate::models::CitationRequest::new(paper_id).max_results(max_results);
let response = source
.get_references(&request)
.await
.map_err(|e| e.to_string())?;
serde_json::to_value(response).map_err(|e| e.to_string())
}
}
#[derive(Debug)]
pub struct LookupByDoiHandler {
pub sources: Arc<Vec<Arc<dyn crate::sources::Source>>>,
}
#[async_trait::async_trait]
impl ToolHandler for LookupByDoiHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let doi = args
.get("doi")
.and_then(|v| v.as_str())
.ok_or("Missing 'doi' parameter")?;
let source_filter = args.get("source").and_then(|v| v.as_str());
for source in self.sources.iter() {
if let Some(filter) = source_filter {
if source.id() != filter {
continue;
}
}
if !source.supports_doi_lookup() {
continue;
}
match source.get_by_doi(doi).await {
Ok(paper) => {
return serde_json::to_value(paper).map_err(|e| e.to_string());
}
Err(e) => {
tracing::debug!("DOI lookup failed for {}: {}", source.id(), e);
}
}
}
Err(format!("Paper with DOI '{}' not found", doi))
}
}
#[derive(Debug)]
pub struct DeduplicatePapersHandler;
#[async_trait::async_trait]
impl ToolHandler for DeduplicatePapersHandler {
async fn execute(&self, args: Value) -> Result<Value, String> {
let papers: Vec<crate::models::Paper> = serde_json::from_value(
args.get("papers")
.ok_or("Missing 'papers' parameter")?
.clone(),
)
.map_err(|e| format!("Invalid papers array: {}", e))?;
let strategy_str = args
.get("strategy")
.and_then(|v| v.as_str())
.unwrap_or("first");
let strategy = match strategy_str {
"last" => crate::utils::DuplicateStrategy::Last,
"mark" => crate::utils::DuplicateStrategy::Mark,
_ => crate::utils::DuplicateStrategy::First,
};
let deduped = crate::utils::deduplicate_papers(papers, strategy);
serde_json::to_value(deduped).map_err(|e| e.to_string())
}
}
impl GetPaperHandler {
fn find_source(
&self,
paper_id: &str,
source_override: Option<&str>,
) -> Result<Arc<dyn crate::sources::Source>, String> {
if let Some(source_id) = source_override {
return self
.sources
.iter()
.find(|s| s.id() == source_id)
.cloned()
.ok_or_else(|| format!("Source '{}' not found", source_id));
}
auto_detect_source(&self.sources, paper_id)
}
}
impl DownloadPaperHandler {
fn find_source(
&self,
paper_id: &str,
source_override: Option<&str>,
) -> Result<Arc<dyn crate::sources::Source>, String> {
if let Some(source_id) = source_override {
return self
.sources
.iter()
.find(|s| s.id() == source_id)
.cloned()
.ok_or_else(|| format!("Source '{}' not found", source_id));
}
auto_detect_source(&self.sources, paper_id)
}
}
impl ReadPaperHandler {
fn find_source(
&self,
paper_id: &str,
source_override: Option<&str>,
) -> Result<Arc<dyn crate::sources::Source>, String> {
if let Some(source_id) = source_override {
return self
.sources
.iter()
.find(|s| s.id() == source_id)
.cloned()
.ok_or_else(|| format!("Source '{}' not found", source_id));
}
auto_detect_source(&self.sources, paper_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{CitationRequest, DownloadRequest, ReadRequest};
use crate::sources::{Source, SourceCapabilities};
use std::sync::Arc;
#[derive(Debug)]
struct MockSource {
id: String,
capabilities: SourceCapabilities,
}
impl MockSource {
fn new(id: &str, capabilities: SourceCapabilities) -> Self {
Self {
id: id.to_string(),
capabilities,
}
}
}
#[async_trait::async_trait]
impl Source for MockSource {
fn id(&self) -> &str {
&self.id
}
fn name(&self) -> &str {
&self.id
}
fn capabilities(&self) -> SourceCapabilities {
self.capabilities
}
async fn search(
&self,
_query: &crate::models::SearchQuery,
) -> Result<crate::models::SearchResponse, crate::sources::SourceError> {
unimplemented!()
}
async fn download(
&self,
_request: &DownloadRequest,
) -> Result<crate::models::DownloadResult, crate::sources::SourceError> {
unimplemented!()
}
async fn read(
&self,
_request: &ReadRequest,
) -> Result<crate::models::ReadResult, crate::sources::SourceError> {
unimplemented!()
}
async fn get_citations(
&self,
_request: &CitationRequest,
) -> Result<crate::models::SearchResponse, crate::sources::SourceError> {
unimplemented!()
}
async fn get_references(
&self,
_request: &CitationRequest,
) -> Result<crate::models::SearchResponse, crate::sources::SourceError> {
unimplemented!()
}
fn supports_doi_lookup(&self) -> bool {
self.capabilities.contains(SourceCapabilities::DOI_LOOKUP)
}
async fn get_by_doi(
&self,
_doi: &str,
) -> Result<crate::models::Paper, crate::sources::SourceError> {
unimplemented!()
}
async fn get_related(
&self,
_request: &CitationRequest,
) -> Result<crate::models::SearchResponse, crate::sources::SourceError> {
unimplemented!()
}
fn validate_id(&self, _id: &str) -> Result<(), crate::sources::SourceError> {
Ok(())
}
}
fn make_test_sources() -> Vec<Arc<dyn Source>> {
vec![
Arc::new(MockSource::new("arxiv", SourceCapabilities::all())),
Arc::new(MockSource::new("semantic", SourceCapabilities::all())),
Arc::new(MockSource::new("pmc", SourceCapabilities::all())),
Arc::new(MockSource::new("hal", SourceCapabilities::all())),
Arc::new(MockSource::new("iacr", SourceCapabilities::all())),
]
}
#[test]
fn test_auto_detect_arxiv_numeric() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "2301.12345");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "arxiv");
}
#[test]
fn test_auto_detect_arxiv_prefix() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "arxiv:2301.12345");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "arxiv");
}
#[test]
fn test_auto_detect_pmc() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "PMC12345");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "pmc");
}
#[test]
fn test_auto_detect_pmc_lowercase() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "pmc12345");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "pmc");
}
#[test]
fn test_auto_detect_hal() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "hal-12345");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "hal");
}
#[test]
fn test_auto_detect_iacr() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "2023/1234");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "iacr");
}
#[test]
fn test_auto_detect_doi() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "10.12345/testpaper");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "iacr");
}
#[test]
fn test_auto_detect_doi_no_slash() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "10.12345.67890");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "arxiv");
}
#[test]
fn test_auto_detect_fallback() {
let sources = make_test_sources();
let result = auto_detect_source(&Arc::new(sources), "unknown-id-123");
assert!(result.is_ok());
assert_eq!(result.unwrap().id(), "arxiv");
}
#[test]
fn test_auto_detect_source_not_available() {
let sources: Vec<Arc<dyn Source>> =
vec![Arc::new(MockSource::new("pmc", SourceCapabilities::SEARCH))];
let result = auto_detect_source(&Arc::new(sources), "unknown-id");
assert!(result.is_err());
}
#[test]
fn test_paper_id_upper_start_basic() {
assert!(paper_id_upper_start("PMC12345", "PMC"));
assert!(paper_id_upper_start("pmc12345", "PMC"));
assert!(paper_id_upper_start("Pmc12345", "PMC"));
assert!(!paper_id_upper_start("ABC12345", "PMC"));
assert!(!paper_id_upper_start("PM", "PMC")); }
#[test]
fn test_paper_id_upper_start_edge_cases() {
assert!(!paper_id_upper_start("", "PMC"));
assert!(!paper_id_upper_start("PM", "PMC"));
}
}