Skip to main content

atomr_agents_deep_research_shell/
shallow.rs

1//! Shallow research path.
2//!
3//! When the [`IntentClassifier`](crate::IntentClassifier) routes a
4//! request to [`ResearchTier::Shallow`](crate::ResearchTier), the shell
5//! defers to a [`ShallowResearcher`] instead of the full deep harness.
6//! The default [`DirectSearchShallow`] issues one [`WebSearch`] call
7//! and synthesizes a [`ResearchResult`] directly — no clarifier, no
8//! planner, no critic, no verify loop.
9
10use std::sync::Arc;
11use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_agents_deep_research_core::{
15    Citation, CitationStatus, NodeKind, NodeStep, Plan, RawSearchHit, ResearchRequest, ResearchResult,
16    ResearchState, Telemetry,
17};
18use atomr_agents_web_search_core::{WebSearch, WebSearchRequest};
19use chrono::Utc;
20
21use crate::error::{Result, ShellError};
22
23/// Object-safe trait for the shallow research path.
24#[async_trait]
25pub trait ShallowResearcher: Send + Sync + 'static {
26    /// Produce a [`ResearchResult`] without engaging the deep harness.
27    ///
28    /// Implementations should set `result.strategy` to something
29    /// descriptive so callers can tell shallow runs from deep runs
30    /// downstream.
31    async fn run(&self, req: &ResearchRequest) -> Result<ResearchResult>;
32}
33
34/// Default shallow researcher: one web-search call, results rendered as
35/// a numbered markdown report.
36///
37/// This is intentionally non-LLM-driven so the shell can serve fast
38/// queries without provider credentials. It mirrors the
39/// `DeepResearchRoles::defaults()` philosophy: deterministic baseline,
40/// callers swap in something smarter when they need it.
41pub struct DirectSearchShallow {
42    search: Arc<dyn WebSearch>,
43    /// Floor on `max_results` per search. Defaults to `3`.
44    pub min_results: u32,
45    /// Provider label recorded against each `RawSearchHit` and used as
46    /// the shallow `strategy` source tag in the transcript.
47    pub source_label: String,
48}
49
50impl DirectSearchShallow {
51    /// Wire a shallow researcher around an existing `WebSearch`
52    /// implementation.
53    pub fn new(search: Arc<dyn WebSearch>) -> Self {
54        let label = search.provider_name().to_string();
55        Self {
56            search,
57            min_results: 3,
58            source_label: label,
59        }
60    }
61
62    /// Override the `min_results` floor (the actual request uses
63    /// `req.breadth.max(min_results)`).
64    pub fn with_min_results(mut self, n: u32) -> Self {
65        self.min_results = n;
66        self
67    }
68
69    /// Override the provider label stamped onto raw hits.
70    pub fn with_source_label(mut self, label: impl Into<String>) -> Self {
71        self.source_label = label.into();
72        self
73    }
74}
75
76#[async_trait]
77impl ShallowResearcher for DirectSearchShallow {
78    async fn run(&self, req: &ResearchRequest) -> Result<ResearchResult> {
79        let started = Instant::now();
80        let max_results = req.breadth.max(self.min_results);
81        let mut search_req = WebSearchRequest::new(req.query.clone()).with_max_results(max_results);
82        if !req.scope.allowed_domains.is_empty() {
83            search_req = search_req.with_allowed_domains(req.scope.allowed_domains.clone());
84        }
85        if !req.scope.blocked_domains.is_empty() {
86            search_req.blocked_domains = req.scope.blocked_domains.clone();
87        }
88
89        let hits = self
90            .search
91            .search(&search_req)
92            .await
93            .map_err(ShellError::WebSearch)?;
94
95        let now_ms = Utc::now().timestamp_millis();
96        let mut result = ResearchResult {
97            id: uuid::Uuid::new_v4().to_string(),
98            query: req.query.clone(),
99            strategy: "shallow-direct".to_string(),
100            state: ResearchState::Done,
101            final_report: None,
102            citations: Vec::new(),
103            plan: Some(Plan {
104                outline: vec!["Summary".to_string()],
105                sub_questions: Vec::new(),
106                rationale: None,
107            }),
108            transcript: Vec::new(),
109            coverage: Default::default(),
110            telemetry: Telemetry::default(),
111            artifacts: Default::default(),
112            model_id: None,
113            failure_reason: None,
114            created_at_ms: now_ms,
115            updated_at_ms: now_ms,
116        };
117
118        // Record raw hits in artifacts (one per returned hit).
119        for h in &hits {
120            result.artifacts.raw_search_hits.push(RawSearchHit {
121                provider: self.source_label.clone(),
122                url: h.url.clone(),
123                title: h.title.clone(),
124                snippet: h.snippet.clone(),
125                source: h.source.clone(),
126                captured_at: Utc::now(),
127                sub_question_id: None,
128                content: h.content.clone(),
129            });
130        }
131
132        // Build citations + final report.
133        if hits.is_empty() {
134            result.final_report = Some(format!("# {}\n\nNo results.\n", req.query));
135        } else {
136            let mut body = String::new();
137            body.push_str(&format!("# {}\n\n", req.query));
138            for (i, h) in hits.iter().enumerate() {
139                let n = (i as u32) + 1;
140                body.push_str(&format!("[{n}] **{}** — {}\n\n", h.title, h.snippet));
141                let mut citation = Citation::new(n, h.url.clone(), h.title.clone(), h.snippet.clone());
142                citation.source = h.source.clone();
143                citation.published = h.published;
144                citation.status = CitationStatus::Verified;
145                result.citations.push(citation);
146            }
147            body.push_str("## References\n\n");
148            for c in &result.citations {
149                body.push_str(&format!("[{}] {}\n", c.number, c.url));
150            }
151            result.final_report = Some(body);
152        }
153
154        // Add a single transcript entry summarizing the shallow run.
155        let summary = format!("Direct search returned {} hits", hits.len());
156        result.transcript.push(NodeStep {
157            role: NodeKind::Other,
158            label: "shallow-direct".to_string(),
159            ts: Utc::now(),
160            summary,
161            sub_question_id: None,
162        });
163
164        // Trivial telemetry: one tool call, measured wall time.
165        let elapsed_ms = started.elapsed().as_millis() as u64;
166        result.telemetry.tool_calls = 1;
167        result.telemetry.wall_ms = elapsed_ms;
168
169        let touch_ms = Utc::now().timestamp_millis();
170        result.updated_at_ms = touch_ms;
171
172        Ok(result)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use atomr_agents_web_search_core::{MockWebSearch, WebSearchHit};
180    use url::Url;
181
182    fn hit(url: &str, title: &str) -> WebSearchHit {
183        WebSearchHit::new(Url::parse(url).unwrap(), title, format!("snippet for {title}"))
184    }
185
186    #[tokio::test]
187    async fn empty_results_produce_no_results_report() {
188        let mock = Arc::new(MockWebSearch::new());
189        let shallow = DirectSearchShallow::new(mock);
190        let req = ResearchRequest::new("anything");
191        let result = shallow.run(&req).await.unwrap();
192        assert_eq!(result.strategy, "shallow-direct");
193        assert_eq!(result.state, ResearchState::Done);
194        assert!(result.citations.is_empty());
195        assert!(result.final_report.as_deref().unwrap().contains("No results"));
196    }
197
198    #[tokio::test]
199    async fn results_become_numbered_citations() {
200        let mock = MockWebSearch::new().with_fixture(
201            "rust",
202            vec![
203                hit("https://rust-lang.org/", "Rust"),
204                hit("https://blog.rust-lang.org/", "Blog"),
205            ],
206        );
207        let shallow = DirectSearchShallow::new(Arc::new(mock));
208        let req = ResearchRequest::new("rust language");
209        let result = shallow.run(&req).await.unwrap();
210        assert_eq!(result.citations.len(), 2);
211        assert_eq!(result.citations[0].number, 1);
212        assert_eq!(result.citations[1].number, 2);
213        let report = result.final_report.unwrap();
214        assert!(report.contains("[1]"));
215        assert!(report.contains("[2]"));
216        assert!(report.contains("## References"));
217        assert_eq!(result.telemetry.tool_calls, 1);
218        assert_eq!(result.artifacts.raw_search_hits.len(), 2);
219    }
220}