atomr_agents_deep_research_shell/
shell.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8use atomr_agents_callable::Callable;
9use atomr_agents_core::{CallCtx, Result as CoreResult, Value};
10use atomr_agents_deep_research_core::{ResearchRequest, ResearchResult};
11use atomr_agents_deep_research_harness::DeepResearchHarnessRef;
12
13use crate::classifier::{IntentClassifier, ResearchTier};
14use crate::error::ShellError;
15use crate::shallow::ShallowResearcher;
16
17#[derive(Clone)]
23pub struct DeepResearchShell {
24 classifier: Arc<dyn IntentClassifier>,
25 shallow: Arc<dyn ShallowResearcher>,
26 deep: DeepResearchHarnessRef,
27 label: String,
28}
29
30impl DeepResearchShell {
31 pub fn new(
34 classifier: Arc<dyn IntentClassifier>,
35 shallow: Arc<dyn ShallowResearcher>,
36 deep: DeepResearchHarnessRef,
37 ) -> Self {
38 let label = format!("deep-research-shell:{}", deep.id.as_str());
39 Self {
40 classifier,
41 shallow,
42 deep,
43 label,
44 }
45 }
46
47 pub fn deep(&self) -> &DeepResearchHarnessRef {
49 &self.deep
50 }
51
52 pub async fn run(&self, req: ResearchRequest) -> CoreResult<ResearchResult> {
54 let tier = self.classifier.classify(&req).await.map_err(|e| {
55 match e {
58 ShellError::Classifier(_) => e,
59 other => ShellError::Classifier(other.to_string()),
60 }
61 })?;
62 match tier {
63 ResearchTier::Shallow => Ok(self.shallow.run(&req).await?),
64 ResearchTier::Deep => {
65 let v = self.deep.run(req).await?;
66 Ok(serde_json::from_value::<ResearchResult>(v).map_err(ShellError::Serde)?)
67 }
68 }
69 }
70}
71
72#[async_trait]
73impl Callable for DeepResearchShell {
74 async fn call(&self, input: Value, _ctx: CallCtx) -> CoreResult<Value> {
75 let req = parse_request(input)?;
76 let result = self.run(req).await?;
77 Ok(serde_json::to_value(&result).map_err(ShellError::Serde)?)
78 }
79
80 fn label(&self) -> &str {
81 &self.label
82 }
83}
84
85fn parse_request(input: Value) -> CoreResult<ResearchRequest> {
89 if let Some(s) = input.as_str() {
90 return Ok(ResearchRequest::new(s));
91 }
92 serde_json::from_value(input).map_err(|e| ShellError::Serde(e).into())
93}