atomr_agents_deep_research_shell/
shallow.rs1use std::sync::Arc;
11use std::time::Instant;
12
13use async_trait::async_trait;
14use atomr_agents_deep_research_core::{
15 Citation, CitationStatus, NodeKind, NodeStep, Plan, RawSearchHit, ResearchRequest, ResearchResult,
16 ResearchState, Telemetry,
17};
18use atomr_agents_web_search_core::{WebSearch, WebSearchRequest};
19use chrono::Utc;
20
21use crate::error::{Result, ShellError};
22
23#[async_trait]
25pub trait ShallowResearcher: Send + Sync + 'static {
26 async fn run(&self, req: &ResearchRequest) -> Result<ResearchResult>;
32}
33
34pub struct DirectSearchShallow {
42 search: Arc<dyn WebSearch>,
43 pub min_results: u32,
45 pub source_label: String,
48}
49
50impl DirectSearchShallow {
51 pub fn new(search: Arc<dyn WebSearch>) -> Self {
54 let label = search.provider_name().to_string();
55 Self {
56 search,
57 min_results: 3,
58 source_label: label,
59 }
60 }
61
62 pub fn with_min_results(mut self, n: u32) -> Self {
65 self.min_results = n;
66 self
67 }
68
69 pub fn with_source_label(mut self, label: impl Into<String>) -> Self {
71 self.source_label = label.into();
72 self
73 }
74}
75
76#[async_trait]
77impl ShallowResearcher for DirectSearchShallow {
78 async fn run(&self, req: &ResearchRequest) -> Result<ResearchResult> {
79 let started = Instant::now();
80 let max_results = req.breadth.max(self.min_results);
81 let mut search_req = WebSearchRequest::new(req.query.clone()).with_max_results(max_results);
82 if !req.scope.allowed_domains.is_empty() {
83 search_req = search_req.with_allowed_domains(req.scope.allowed_domains.clone());
84 }
85 if !req.scope.blocked_domains.is_empty() {
86 search_req.blocked_domains = req.scope.blocked_domains.clone();
87 }
88
89 let hits = self
90 .search
91 .search(&search_req)
92 .await
93 .map_err(ShellError::WebSearch)?;
94
95 let now_ms = Utc::now().timestamp_millis();
96 let mut result = ResearchResult {
97 id: uuid::Uuid::new_v4().to_string(),
98 query: req.query.clone(),
99 strategy: "shallow-direct".to_string(),
100 state: ResearchState::Done,
101 final_report: None,
102 citations: Vec::new(),
103 plan: Some(Plan {
104 outline: vec!["Summary".to_string()],
105 sub_questions: Vec::new(),
106 rationale: None,
107 }),
108 transcript: Vec::new(),
109 coverage: Default::default(),
110 telemetry: Telemetry::default(),
111 artifacts: Default::default(),
112 model_id: None,
113 failure_reason: None,
114 created_at_ms: now_ms,
115 updated_at_ms: now_ms,
116 };
117
118 for h in &hits {
120 result.artifacts.raw_search_hits.push(RawSearchHit {
121 provider: self.source_label.clone(),
122 url: h.url.clone(),
123 title: h.title.clone(),
124 snippet: h.snippet.clone(),
125 source: h.source.clone(),
126 captured_at: Utc::now(),
127 sub_question_id: None,
128 content: h.content.clone(),
129 });
130 }
131
132 if hits.is_empty() {
134 result.final_report = Some(format!("# {}\n\nNo results.\n", req.query));
135 } else {
136 let mut body = String::new();
137 body.push_str(&format!("# {}\n\n", req.query));
138 for (i, h) in hits.iter().enumerate() {
139 let n = (i as u32) + 1;
140 body.push_str(&format!("[{n}] **{}** — {}\n\n", h.title, h.snippet));
141 let mut citation = Citation::new(n, h.url.clone(), h.title.clone(), h.snippet.clone());
142 citation.source = h.source.clone();
143 citation.published = h.published;
144 citation.status = CitationStatus::Verified;
145 result.citations.push(citation);
146 }
147 body.push_str("## References\n\n");
148 for c in &result.citations {
149 body.push_str(&format!("[{}] {}\n", c.number, c.url));
150 }
151 result.final_report = Some(body);
152 }
153
154 let summary = format!("Direct search returned {} hits", hits.len());
156 result.transcript.push(NodeStep {
157 role: NodeKind::Other,
158 label: "shallow-direct".to_string(),
159 ts: Utc::now(),
160 summary,
161 sub_question_id: None,
162 });
163
164 let elapsed_ms = started.elapsed().as_millis() as u64;
166 result.telemetry.tool_calls = 1;
167 result.telemetry.wall_ms = elapsed_ms;
168
169 let touch_ms = Utc::now().timestamp_millis();
170 result.updated_at_ms = touch_ms;
171
172 Ok(result)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use atomr_agents_web_search_core::{MockWebSearch, WebSearchHit};
180 use url::Url;
181
182 fn hit(url: &str, title: &str) -> WebSearchHit {
183 WebSearchHit::new(Url::parse(url).unwrap(), title, format!("snippet for {title}"))
184 }
185
186 #[tokio::test]
187 async fn empty_results_produce_no_results_report() {
188 let mock = Arc::new(MockWebSearch::new());
189 let shallow = DirectSearchShallow::new(mock);
190 let req = ResearchRequest::new("anything");
191 let result = shallow.run(&req).await.unwrap();
192 assert_eq!(result.strategy, "shallow-direct");
193 assert_eq!(result.state, ResearchState::Done);
194 assert!(result.citations.is_empty());
195 assert!(result.final_report.as_deref().unwrap().contains("No results"));
196 }
197
198 #[tokio::test]
199 async fn results_become_numbered_citations() {
200 let mock = MockWebSearch::new().with_fixture(
201 "rust",
202 vec![
203 hit("https://rust-lang.org/", "Rust"),
204 hit("https://blog.rust-lang.org/", "Blog"),
205 ],
206 );
207 let shallow = DirectSearchShallow::new(Arc::new(mock));
208 let req = ResearchRequest::new("rust language");
209 let result = shallow.run(&req).await.unwrap();
210 assert_eq!(result.citations.len(), 2);
211 assert_eq!(result.citations[0].number, 1);
212 assert_eq!(result.citations[1].number, 2);
213 let report = result.final_report.unwrap();
214 assert!(report.contains("[1]"));
215 assert!(report.contains("[2]"));
216 assert!(report.contains("## References"));
217 assert_eq!(result.telemetry.tool_calls, 1);
218 assert_eq!(result.artifacts.raw_search_hits.len(), 2);
219 }
220}