Skip to main content

batuta/oracle/
arxiv.rs

1//! arXiv Paper Enrichment for Oracle Mode
2//!
3//! Two-tier enrichment:
4//! - **Builtin** (`--arxiv`): Instant results from the curated `ArxivDatabase` (no network)
5//! - **Live** (`--arxiv-live`): Fresh papers from `export.arxiv.org/api/query` (Atom XML)
6
7use super::coursera::arxiv_db::ArxivDatabase;
8use super::coursera::types::ArxivCitation;
9use super::query_engine::ParsedQuery;
10use super::types::ProblemDomain;
11use serde::Serialize;
12
13// =============================================================================
14// Types
15// =============================================================================
16
17/// A paper from arXiv (either curated or fetched live).
18#[derive(Debug, Clone, Serialize)]
19pub struct ArxivPaper {
20    /// arXiv identifier (e.g., "2212.04356")
21    pub arxiv_id: String,
22    /// Paper title
23    pub title: String,
24    /// Author list (abbreviated, e.g., "Radford et al.")
25    pub authors: String,
26    /// Publication year
27    pub year: u16,
28    /// Brief summary (truncated to ~200 chars)
29    pub summary: String,
30    /// Canonical URL (e.g., `https://arxiv.org/abs/2212.04356`)
31    pub url: String,
32    /// PDF download URL
33    pub pdf_url: Option<String>,
34    /// ISO date string from arXiv (e.g., "2022-12-08T...")
35    pub published: Option<String>,
36}
37
38/// Indicates whether results came from the builtin DB or a live fetch.
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
40pub enum ArxivSource {
41    Builtin,
42    Live,
43}
44
45/// Enrichment result returned by the enricher.
46#[derive(Debug, Clone, Serialize)]
47pub struct ArxivEnrichment {
48    pub papers: Vec<ArxivPaper>,
49    pub source: ArxivSource,
50    pub query_terms: Vec<String>,
51}
52
53// =============================================================================
54// ArxivPaper conversions
55// =============================================================================
56
57impl ArxivPaper {
58    /// Convert from the existing curated `ArxivCitation` type.
59    pub fn from_citation(c: &ArxivCitation) -> Self {
60        Self {
61            arxiv_id: c.arxiv_id.clone(),
62            title: c.title.clone(),
63            authors: c.authors.clone(),
64            year: c.year,
65            summary: c.abstract_snippet.clone(),
66            url: c.url.clone(),
67            pdf_url: Some(format!("https://arxiv.org/pdf/{}.pdf", c.arxiv_id)),
68            published: None,
69        }
70    }
71}
72
73// =============================================================================
74// Search term derivation
75// =============================================================================
76
77/// Component-name to arXiv search-term mapping.
78const COMPONENT_TERMS: &[(&str, &[&str])] = &[
79    ("whisper-apr", &["whisper", "speech recognition"]),
80    ("realizar", &["model inference", "optimization"]),
81    ("trueno", &["SIMD", "GPU compute"]),
82    ("aprender", &["machine learning", "algorithms"]),
83    ("entrenar", &["training", "fine-tuning", "LoRA"]),
84    ("repartir", &["distributed computing", "parallelism"]),
85    ("trueno-db", &["database", "analytics"]),
86    ("trueno-graph", &["graph neural network", "graph analytics"]),
87    ("trueno-rag", &["retrieval augmented generation"]),
88    ("simular", &["simulation", "monte carlo"]),
89    ("jugar", &["game engine", "reinforcement learning"]),
90    ("alimentar", &["data loading", "parquet"]),
91    ("pacha", &["model registry", "model management"]),
92    ("depyler", &["python", "transpilation"]),
93    ("decy", &["C++", "transpilation"]),
94    ("bashrs", &["shell", "scripting"]),
95    ("pepita", &["kernel", "operating system"]),
96];
97
98/// Domain-to-arXiv search-term mapping.
99const DOMAIN_TERMS: &[(ProblemDomain, &[&str])] = &[
100    (ProblemDomain::SpeechRecognition, &["speech recognition", "ASR"]),
101    (ProblemDomain::Inference, &["model inference", "serving"]),
102    (ProblemDomain::DeepLearning, &["deep learning", "transformer"]),
103    (ProblemDomain::SupervisedLearning, &["supervised learning", "classification"]),
104    (ProblemDomain::UnsupervisedLearning, &["unsupervised learning", "clustering"]),
105    (ProblemDomain::LinearAlgebra, &["linear algebra", "SIMD"]),
106    (ProblemDomain::VectorSearch, &["vector search", "embedding"]),
107    (ProblemDomain::GraphAnalytics, &["graph neural network"]),
108    (ProblemDomain::DistributedCompute, &["distributed computing"]),
109    (ProblemDomain::PythonMigration, &["python", "machine learning"]),
110    (ProblemDomain::CMigration, &["systems programming"]),
111    (ProblemDomain::ShellMigration, &["automation"]),
112    (ProblemDomain::DataPipeline, &["data pipeline", "ETL"]),
113    (ProblemDomain::ModelServing, &["model serving", "edge deployment"]),
114    (ProblemDomain::Testing, &["mutation testing", "software testing"]),
115    (ProblemDomain::Profiling, &["profiling", "tracing"]),
116    (ProblemDomain::Validation, &["validation", "quality"]),
117];
118
119/// Derive search terms from a parsed oracle query.
120///
121/// Priority: mentioned components > detected domains > fallback to keywords.
122pub fn derive_search_terms(parsed: &ParsedQuery) -> Vec<String> {
123    let mut terms = Vec::new();
124
125    // 1. Component mentions → mapped terms
126    for comp in &parsed.mentioned_components {
127        if let Some(&(_, mapped)) = COMPONENT_TERMS.iter().find(|&&(name, _)| name == comp) {
128            terms.extend(mapped.iter().map(|s| (*s).to_string()));
129        }
130    }
131
132    // 2. Detected domains → mapped terms
133    for domain in &parsed.domains {
134        if let Some(&(_, mapped)) = DOMAIN_TERMS.iter().find(|&&(d, _)| d == *domain) {
135            for t in mapped {
136                if !terms.iter().any(|existing| existing == *t) {
137                    terms.push((*t).to_string());
138                }
139            }
140        }
141    }
142
143    // 3. Detected algorithms → direct use
144    for algo in &parsed.algorithms {
145        let readable = algo.replace('_', " ");
146        if !terms.iter().any(|existing| existing == &readable) {
147            terms.push(readable);
148        }
149    }
150
151    // 4. Fallback: use first 3 keywords if nothing was derived
152    if terms.is_empty() {
153        terms.extend(parsed.keywords.iter().take(3).cloned());
154    }
155
156    terms
157}
158
159// =============================================================================
160// ArxivEnricher
161// =============================================================================
162
163/// Enricher that can query the builtin curated database or fetch live from arXiv.
164#[derive(Default)]
165pub struct ArxivEnricher;
166
167impl ArxivEnricher {
168    pub fn new() -> Self {
169        Self
170    }
171
172    /// Enrich from the builtin curated database (instant, no network).
173    pub fn enrich_builtin(&self, parsed: &ParsedQuery, max: usize) -> ArxivEnrichment {
174        let terms = derive_search_terms(parsed);
175        let db = ArxivDatabase::builtin();
176
177        let term_refs: Vec<&str> = terms.iter().map(|s| s.as_str()).collect();
178        let citations = db.find_by_keywords(&term_refs, max);
179        let papers = citations.iter().map(ArxivPaper::from_citation).collect();
180
181        ArxivEnrichment { papers, source: ArxivSource::Builtin, query_terms: terms }
182    }
183
184    /// Enrich from the live arXiv API. Falls back to builtin on error.
185    #[cfg(feature = "native")]
186    pub async fn enrich_live(&self, parsed: &ParsedQuery, max: usize) -> ArxivEnrichment {
187        let terms = derive_search_terms(parsed);
188        let search_query = terms.join(" ");
189
190        match fetch_arxiv_api(&search_query, max).await {
191            Ok(papers) if !papers.is_empty() => {
192                ArxivEnrichment { papers, source: ArxivSource::Live, query_terms: terms }
193            }
194            _ => {
195                // Fallback to builtin on error or empty results
196                self.enrich_builtin(parsed, max)
197            }
198        }
199    }
200}
201
202// =============================================================================
203// Live arXiv API
204// =============================================================================
205
206/// Fetch papers from the arXiv Atom API.
207#[cfg(feature = "native")]
208pub async fn fetch_arxiv_api(query: &str, max: usize) -> anyhow::Result<Vec<ArxivPaper>> {
209    let encoded = query.replace(' ', "+");
210    let url = format!(
211        "http://export.arxiv.org/api/query?search_query=all:{}&start=0&max_results={}",
212        encoded, max
213    );
214
215    let client = reqwest::Client::builder().timeout(std::time::Duration::from_secs(10)).build()?;
216
217    let body = client.get(&url).send().await?.text().await?;
218    parse_arxiv_atom_xml(&body)
219}
220
221/// Per-entry accumulator for XML parsing state.
222#[cfg(feature = "native")]
223struct EntryAccum {
224    id: String,
225    title: String,
226    summary: String,
227    published: String,
228    authors: Vec<String>,
229    pdf_url: Option<String>,
230}
231
232#[cfg(feature = "native")]
233impl EntryAccum {
234    fn new() -> Self {
235        Self {
236            id: String::new(),
237            title: String::new(),
238            summary: String::new(),
239            published: String::new(),
240            authors: Vec::new(),
241            pdf_url: None,
242        }
243    }
244
245    fn clear(&mut self) {
246        self.id.clear();
247        self.title.clear();
248        self.summary.clear();
249        self.published.clear();
250        self.authors.clear();
251        self.pdf_url = None;
252    }
253
254    fn into_paper(self) -> Option<ArxivPaper> {
255        if self.id.is_empty() || self.title.is_empty() {
256            return None;
257        }
258        let arxiv_id = extract_arxiv_id(&self.id);
259        let year = extract_year(&self.published);
260        Some(ArxivPaper {
261            url: format!("https://arxiv.org/abs/{}", arxiv_id),
262            pdf_url: self.pdf_url,
263            arxiv_id,
264            title: normalize_whitespace(&self.title),
265            authors: format_authors(&self.authors),
266            year,
267            summary: truncate_summary(&self.summary),
268            published: if self.published.is_empty() { None } else { Some(self.published) },
269        })
270    }
271
272    fn push_text(&mut self, tag: &str, text: String, in_author: bool) {
273        match tag {
274            "id" => self.id.push_str(&text),
275            "title" => self.title.push_str(&text),
276            "summary" => self.summary.push_str(&text),
277            "published" => self.published.push_str(&text),
278            "name" if in_author => self.authors.push(text),
279            _ => {}
280        }
281    }
282}
283
284/// Extract PDF href from a `<link>` element's attributes, if present.
285#[cfg(feature = "native")]
286fn extract_pdf_href(attrs: quick_xml::events::attributes::Attributes) -> Option<String> {
287    let mut href = String::new();
288    let mut is_pdf = false;
289    for attr in attrs.flatten() {
290        let key = String::from_utf8_lossy(attr.key.as_ref());
291        let val = String::from_utf8_lossy(&attr.value);
292        if key == "title" && val == "pdf" {
293            is_pdf = true;
294        }
295        if key == "href" {
296            href = val.to_string();
297        }
298    }
299    (is_pdf && !href.is_empty()).then_some(href)
300}
301
302/// Atom XML state machine for parsing arXiv feeds.
303#[cfg(feature = "native")]
304struct AtomParser {
305    papers: Vec<ArxivPaper>,
306    accum: EntryAccum,
307    current_tag: String,
308    in_entry: bool,
309    in_author: bool,
310}
311
312#[cfg(feature = "native")]
313impl AtomParser {
314    fn new() -> Self {
315        Self {
316            papers: Vec::new(),
317            accum: EntryAccum::new(),
318            current_tag: String::new(),
319            in_entry: false,
320            in_author: false,
321        }
322    }
323
324    fn handle_start(&mut self, e: &quick_xml::events::BytesStart<'_>) {
325        let tag = String::from_utf8_lossy(e.name().as_ref()).to_string();
326        match tag.as_str() {
327            "entry" => {
328                self.in_entry = true;
329                self.accum.clear();
330            }
331            "author" if self.in_entry => {
332                self.in_author = true;
333            }
334            "link" if self.in_entry => {
335                if let Some(href) = extract_pdf_href(e.attributes()) {
336                    self.accum.pdf_url = Some(href);
337                }
338            }
339            _ if self.in_entry => {
340                self.current_tag = tag;
341            }
342            _ => {}
343        }
344    }
345
346    fn handle_empty(&mut self, e: &quick_xml::events::BytesStart<'_>) {
347        if !self.in_entry {
348            return;
349        }
350        let name = e.name();
351        let tag = String::from_utf8_lossy(name.as_ref());
352        if tag == "link" {
353            if let Some(href) = extract_pdf_href(e.attributes()) {
354                self.accum.pdf_url = Some(href);
355            }
356        }
357    }
358
359    fn handle_text(&mut self, e: &quick_xml::events::BytesText<'_>) {
360        if !self.in_entry {
361            return;
362        }
363        let text = e.unescape().unwrap_or_default().to_string();
364        self.accum.push_text(&self.current_tag, text, self.in_author);
365    }
366
367    fn handle_end(&mut self, e: &quick_xml::events::BytesEnd<'_>) {
368        let name = e.name();
369        let tag = String::from_utf8_lossy(name.as_ref());
370        match tag.as_ref() {
371            "entry" => {
372                let finished = std::mem::replace(&mut self.accum, EntryAccum::new());
373                if let Some(paper) = finished.into_paper() {
374                    self.papers.push(paper);
375                }
376                self.in_entry = false;
377                self.current_tag.clear();
378            }
379            "author" => {
380                self.in_author = false;
381            }
382            _ => {
383                self.current_tag.clear();
384            }
385        }
386    }
387}
388
389/// Parse arXiv Atom XML feed into `ArxivPaper` structs.
390#[cfg(feature = "native")]
391pub fn parse_arxiv_atom_xml(xml: &str) -> anyhow::Result<Vec<ArxivPaper>> {
392    use quick_xml::events::Event;
393    use quick_xml::Reader;
394
395    let mut reader = Reader::from_str(xml);
396    let mut parser = AtomParser::new();
397
398    loop {
399        match reader.read_event() {
400            Ok(Event::Start(ref e)) => parser.handle_start(e),
401            Ok(Event::Empty(ref e)) => parser.handle_empty(e),
402            Ok(Event::Text(ref e)) => parser.handle_text(e),
403            Ok(Event::End(ref e)) => parser.handle_end(e),
404            Ok(Event::Eof) | Err(_) => break,
405            _ => {}
406        }
407    }
408
409    Ok(parser.papers)
410}
411
412// =============================================================================
413// Helpers
414// =============================================================================
415
416/// Extract the arXiv ID from a full URL like `http://arxiv.org/abs/2212.04356v1`.
417#[cfg(feature = "native")]
418fn extract_arxiv_id(url: &str) -> String {
419    let base = url.rsplit('/').next().unwrap_or(url);
420    // Strip trailing version suffix (e.g., "v2")
421    if let Some(idx) = base.rfind('v') {
422        if base[idx + 1..].chars().all(|c| c.is_ascii_digit()) && !base[idx + 1..].is_empty() {
423            return base[..idx].to_string();
424        }
425    }
426    base.to_string()
427}
428
429/// Extract a 4-digit year from an ISO date string like "2022-12-08T...".
430#[cfg(feature = "native")]
431fn extract_year(published: &str) -> u16 {
432    published.get(..4).and_then(|s| s.parse::<u16>().ok()).unwrap_or(0)
433}
434
435/// Format author list: 1 author → "Name", 2 → "A & B", 3+ → "A et al."
436pub fn format_authors(authors: &[String]) -> String {
437    match authors.len() {
438        0 => "Unknown".to_string(),
439        1 => authors[0].clone(),
440        2 => format!("{} & {}", authors[0], authors[1]),
441        _ => format!("{} et al.", authors[0]),
442    }
443}
444
445/// Truncate summary to ~200 chars at a word boundary.
446#[cfg(feature = "native")]
447fn truncate_summary(s: &str) -> String {
448    let normalized = normalize_whitespace(s);
449    if normalized.len() <= 200 {
450        return normalized;
451    }
452    // Find a char boundary at or before byte 200
453    let boundary =
454        normalized.char_indices().take_while(|&(i, _)| i <= 200).last().map_or(0, |(i, _)| i);
455    let truncated = &normalized[..boundary];
456    match truncated.rfind(' ') {
457        Some(idx) => format!("{}...", &truncated[..idx]),
458        None => format!("{}...", truncated),
459    }
460}
461
462/// Collapse consecutive whitespace (newlines, tabs, spaces) into a single space.
463#[cfg(feature = "native")]
464fn normalize_whitespace(s: &str) -> String {
465    s.split_whitespace().collect::<Vec<_>>().join(" ")
466}
467
468// =============================================================================
469// Tests
470// =============================================================================
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use crate::oracle::QueryEngine;
476
477    #[test]
478    fn test_derive_terms_from_component() {
479        let engine = QueryEngine::new();
480        let parsed = engine.parse("improve whisper-apr performance");
481        let terms = derive_search_terms(&parsed);
482        assert!(
483            terms.iter().any(|t| t.contains("whisper")),
484            "Expected 'whisper' in terms: {:?}",
485            terms
486        );
487    }
488
489    #[test]
490    fn test_derive_terms_from_domain() {
491        let engine = QueryEngine::new();
492        let parsed = engine.parse("train a classifier with supervised learning");
493        let terms = derive_search_terms(&parsed);
494        assert!(!terms.is_empty(), "Expected non-empty terms for supervised learning query");
495    }
496
497    #[test]
498    fn test_derive_terms_fallback() {
499        let engine = QueryEngine::new();
500        let parsed = engine.parse("completely unknown xyzzy topic");
501        let terms = derive_search_terms(&parsed);
502        // Should fallback to keywords
503        assert!(!terms.is_empty(), "Expected keyword fallback for unknown query");
504    }
505
506    #[test]
507    fn test_enrich_builtin_returns_papers() {
508        let enricher = ArxivEnricher::new();
509        let engine = QueryEngine::new();
510        let parsed = engine.parse("whisper speech recognition");
511        let result = enricher.enrich_builtin(&parsed, 5);
512        assert!(!result.papers.is_empty(), "Expected papers for 'whisper speech recognition'");
513        assert_eq!(result.source, ArxivSource::Builtin);
514    }
515
516    #[test]
517    fn test_enrich_builtin_respects_max() {
518        let enricher = ArxivEnricher::new();
519        let engine = QueryEngine::new();
520        let parsed = engine.parse("deep learning transformer attention");
521        let result = enricher.enrich_builtin(&parsed, 1);
522        assert!(result.papers.len() <= 1, "Expected at most 1 paper, got {}", result.papers.len());
523    }
524
525    #[test]
526    fn test_enrich_builtin_no_results() {
527        let enricher = ArxivEnricher::new();
528        let engine = QueryEngine::new();
529        let parsed = engine.parse("xyzzy nonexistent gibberish");
530        let result = enricher.enrich_builtin(&parsed, 5);
531        // May or may not find results (keywords might match something), but shouldn't panic
532        assert_eq!(result.source, ArxivSource::Builtin);
533    }
534
535    #[test]
536    fn test_from_citation() {
537        let citation = ArxivCitation {
538            arxiv_id: "2212.04356".to_string(),
539            title: "Whisper: Robust Speech Recognition".to_string(),
540            authors: "Radford et al.".to_string(),
541            year: 2022,
542            url: "https://arxiv.org/abs/2212.04356".to_string(),
543            abstract_snippet: "Multitask speech model.".to_string(),
544            topics: vec!["speech".to_string()],
545        };
546        let paper = ArxivPaper::from_citation(&citation);
547        assert_eq!(paper.arxiv_id, "2212.04356");
548        assert_eq!(paper.year, 2022);
549        assert_eq!(paper.url, "https://arxiv.org/abs/2212.04356");
550        assert!(paper.pdf_url.is_some());
551        assert!(paper.pdf_url.expect("unexpected failure").contains("2212.04356"));
552    }
553
554    #[test]
555    fn test_format_authors_zero() {
556        assert_eq!(format_authors(&[]), "Unknown");
557    }
558
559    #[test]
560    fn test_format_authors_one() {
561        assert_eq!(format_authors(&["Alice".to_string()]), "Alice");
562    }
563
564    #[test]
565    fn test_format_authors_two() {
566        assert_eq!(format_authors(&["Alice".to_string(), "Bob".to_string()]), "Alice & Bob");
567    }
568
569    #[test]
570    fn test_format_authors_three() {
571        assert_eq!(
572            format_authors(&["Alice".to_string(), "Bob".to_string(), "Carol".to_string(),]),
573            "Alice et al."
574        );
575    }
576
577    #[cfg(feature = "native")]
578    #[test]
579    fn test_parse_arxiv_atom_xml() {
580        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
581<feed xmlns="http://www.w3.org/2005/Atom">
582  <entry>
583    <id>http://arxiv.org/abs/2212.04356v2</id>
584    <title>Robust Speech Recognition via Large-Scale Weak Supervision</title>
585    <summary>We study the capabilities of speech processing systems.</summary>
586    <published>2022-12-06T00:00:00Z</published>
587    <author><name>Alec Radford</name></author>
588    <author><name>Jong Wook Kim</name></author>
589    <author><name>Tao Xu</name></author>
590    <link title="pdf" href="http://arxiv.org/pdf/2212.04356v2" rel="related" type="application/pdf"/>
591  </entry>
592</feed>"#;
593
594        let papers = parse_arxiv_atom_xml(xml).expect("unexpected failure");
595        assert_eq!(papers.len(), 1);
596        let p = &papers[0];
597        assert_eq!(p.arxiv_id, "2212.04356");
598        assert!(p.title.contains("Robust Speech Recognition"));
599        assert_eq!(p.year, 2022);
600        assert_eq!(p.authors, "Alec Radford et al.");
601        assert!(p.pdf_url.is_some());
602        assert!(p.published.is_some());
603    }
604
605    #[cfg(feature = "native")]
606    #[test]
607    fn test_parse_arxiv_atom_xml_empty() {
608        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
609<feed xmlns="http://www.w3.org/2005/Atom">
610</feed>"#;
611        let papers = parse_arxiv_atom_xml(xml).expect("unexpected failure");
612        assert!(papers.is_empty());
613    }
614
615    #[cfg(feature = "native")]
616    #[test]
617    fn test_extract_arxiv_id_with_version() {
618        assert_eq!(extract_arxiv_id("http://arxiv.org/abs/2212.04356v2"), "2212.04356");
619    }
620
621    #[cfg(feature = "native")]
622    #[test]
623    fn test_extract_arxiv_id_without_version() {
624        assert_eq!(extract_arxiv_id("http://arxiv.org/abs/2212.04356"), "2212.04356");
625    }
626
627    #[cfg(feature = "native")]
628    #[test]
629    fn test_truncate_summary_short() {
630        assert_eq!(truncate_summary("Short text."), "Short text.");
631    }
632
633    #[cfg(feature = "native")]
634    #[test]
635    fn test_truncate_summary_long() {
636        let long = "a ".repeat(200);
637        let truncated = truncate_summary(&long);
638        assert!(truncated.len() <= 210);
639        assert!(truncated.ends_with("..."));
640    }
641
642    #[cfg(feature = "native")]
643    #[test]
644    fn test_normalize_whitespace() {
645        assert_eq!(normalize_whitespace("hello\n  world\t\tfoo"), "hello world foo");
646    }
647
648    #[ignore = "requires network access to arXiv API"]
649    #[cfg(feature = "native")]
650    #[tokio::test]
651    async fn test_live_arxiv_query() {
652        let papers =
653            fetch_arxiv_api("whisper speech recognition", 3).await.expect("unexpected failure");
654        assert!(!papers.is_empty());
655        for p in &papers {
656            assert!(!p.title.is_empty());
657            assert!(!p.arxiv_id.is_empty());
658        }
659    }
660}