1use super::config::DeepAgentConfig;
7use crate::middleware::{
8 AgentMiddleware, AnthropicPromptCachingMiddleware, BaseSystemPromptMiddleware,
9 FilesystemMiddleware, HumanInLoopMiddleware, MiddlewareContext, ModelRequest,
10 PlanningMiddleware, SubAgentDescriptor, SubAgentMiddleware, SubAgentRegistration,
11 SummarizationMiddleware,
12};
13use crate::planner::LlmBackedPlanner;
14use agents_core::agent::{
15 AgentDescriptor, AgentHandle, PlannerAction, PlannerContext, PlannerHandle,
16};
17use agents_core::hitl::{AgentInterrupt, HitlAction, HitlInterrupt};
18use agents_core::messaging::{AgentMessage, MessageContent, MessageMetadata, MessageRole};
19use agents_core::persistence::{Checkpointer, ThreadId};
20use agents_core::state::AgentStateSnapshot;
21use agents_core::tools::{ToolBox, ToolContext, ToolResult};
22use async_trait::async_trait;
23use serde_json::Value;
24use std::collections::{HashMap, HashSet};
25use std::sync::{Arc, RwLock};
26
27const BUILTIN_TOOL_NAMES: &[&str] = &["write_todos", "ls", "read_file", "write_file", "edit_file"];
29
30pub struct DeepAgent {
37 descriptor: AgentDescriptor,
38 instructions: String,
39 planner: Arc<dyn PlannerHandle>,
40 middlewares: Vec<Arc<dyn AgentMiddleware>>,
41 base_tools: Vec<ToolBox>,
42 state: Arc<RwLock<AgentStateSnapshot>>,
43 history: Arc<RwLock<Vec<AgentMessage>>>,
44 _summarization: Option<Arc<SummarizationMiddleware>>,
45 hitl: Option<Arc<HumanInLoopMiddleware>>,
46 pending_hitl: Arc<RwLock<Option<HitlPending>>>,
47 builtin_tools: Option<HashSet<String>>,
48 checkpointer: Option<Arc<dyn Checkpointer>>,
49}
50
51struct HitlPending {
52 tool_name: String,
53 payload: Value,
54 tool: ToolBox,
55 message: AgentMessage,
56}
57
58impl DeepAgent {
59 fn collect_tools(&self) -> HashMap<String, ToolBox> {
60 let mut tools: HashMap<String, ToolBox> = HashMap::new();
61 for tool in &self.base_tools {
62 tools.insert(tool.schema().name.clone(), tool.clone());
63 }
64 for middleware in &self.middlewares {
65 for tool in middleware.tools() {
66 let tool_name = tool.schema().name.clone();
67 if self.should_include(&tool_name) {
68 tools.insert(tool_name, tool);
69 }
70 }
71 }
72 tools
73 }
74 fn should_include(&self, name: &str) -> bool {
77 let is_builtin = BUILTIN_TOOL_NAMES.contains(&name);
78 if !is_builtin {
79 return true;
80 }
81 match &self.builtin_tools {
82 None => true,
83 Some(selected) => selected.contains(name),
84 }
85 }
86
87 fn append_history(&self, message: AgentMessage) {
88 if let Ok(mut history) = self.history.write() {
89 history.push(message);
90 }
91 }
92
93 fn current_history(&self) -> Vec<AgentMessage> {
94 self.history.read().map(|h| h.clone()).unwrap_or_default()
95 }
96
97 pub async fn save_state(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
99 if let Some(ref checkpointer) = self.checkpointer {
100 let state = self
101 .state
102 .read()
103 .map_err(|_| anyhow::anyhow!("Failed to read agent state"))?
104 .clone();
105 checkpointer.save_state(thread_id, &state).await
106 } else {
107 tracing::warn!("Attempted to save state but no checkpointer is configured");
108 Ok(())
109 }
110 }
111
112 pub async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<bool> {
114 if let Some(ref checkpointer) = self.checkpointer {
115 if let Some(saved_state) = checkpointer.load_state(thread_id).await? {
116 *self
117 .state
118 .write()
119 .map_err(|_| anyhow::anyhow!("Failed to write agent state"))? = saved_state;
120 tracing::info!(thread_id = %thread_id, "Loaded agent state from checkpointer");
121 Ok(true)
122 } else {
123 tracing::debug!(thread_id = %thread_id, "No saved state found for thread");
124 Ok(false)
125 }
126 } else {
127 tracing::warn!("Attempted to load state but no checkpointer is configured");
128 Ok(false)
129 }
130 }
131
132 pub async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
134 if let Some(ref checkpointer) = self.checkpointer {
135 checkpointer.delete_thread(thread_id).await
136 } else {
137 tracing::warn!("Attempted to delete thread state but no checkpointer is configured");
138 Ok(())
139 }
140 }
141
142 pub async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
144 if let Some(ref checkpointer) = self.checkpointer {
145 checkpointer.list_threads().await
146 } else {
147 Ok(Vec::new())
148 }
149 }
150
151 async fn execute_tool(
152 &self,
153 tool: ToolBox,
154 _tool_name: String,
155 payload: Value,
156 ) -> anyhow::Result<AgentMessage> {
157 let state_snapshot = self.state.read().unwrap().clone();
158 let ctx = ToolContext::with_mutable_state(Arc::new(state_snapshot), self.state.clone());
159
160 let result = tool.execute(payload, ctx).await?;
161 Ok(self.apply_tool_result(result))
162 }
163
164 fn apply_tool_result(&self, result: ToolResult) -> AgentMessage {
165 match result {
166 ToolResult::Message(message) => {
167 self.append_history(message.clone());
168 message
169 }
170 ToolResult::WithStateUpdate {
171 message,
172 state_diff,
173 } => {
174 if let Ok(mut state) = self.state.write() {
175 let command = agents_core::command::Command::with_state(state_diff);
176 command.apply_to(&mut state);
177 }
178 self.append_history(message.clone());
179 message
180 }
181 }
182 }
183
184 pub fn current_interrupt(&self) -> Option<AgentInterrupt> {
185 self.pending_hitl.read().ok().and_then(|guard| {
186 guard.as_ref().map(|pending| {
187 AgentInterrupt::HumanInLoop(HitlInterrupt {
188 tool_name: pending.tool_name.clone(),
189 message: pending.message.clone(),
190 })
191 })
192 })
193 }
194
195 pub async fn resume_hitl(&self, action: HitlAction) -> anyhow::Result<AgentMessage> {
196 let pending = self
197 .pending_hitl
198 .write()
199 .ok()
200 .and_then(|mut guard| guard.take())
201 .ok_or_else(|| anyhow::anyhow!("No pending HITL action"))?;
202 match action {
203 HitlAction::Approve => {
204 let result = self
205 .execute_tool(
206 pending.tool.clone(),
207 pending.tool_name.clone(),
208 pending.payload.clone(),
209 )
210 .await?;
211 Ok(result)
212 }
213 HitlAction::Reject { reason } => {
214 let text =
215 reason.unwrap_or_else(|| "Tool execution rejected by human reviewer.".into());
216 let message = AgentMessage {
217 role: MessageRole::System,
218 content: MessageContent::Text(text),
219 metadata: None,
220 };
221 self.append_history(message.clone());
222 Ok(message)
223 }
224 HitlAction::Respond { message } => {
225 self.append_history(message.clone());
226 Ok(message)
227 }
228 HitlAction::Edit { action, args } => {
229 let tools = self.collect_tools();
231 if let Some(tool) = tools.get(&action).cloned() {
232 let result = self.execute_tool(tool, action, args).await?;
233 Ok(result)
234 } else {
235 Ok(AgentMessage {
236 role: MessageRole::System,
237 content: MessageContent::Text(format!(
238 "Edited tool '{}' not available",
239 action
240 )),
241 metadata: None,
242 })
243 }
244 }
245 }
246 }
247
248 pub async fn handle_message(
250 &self,
251 input: impl AsRef<str>,
252 state: Arc<AgentStateSnapshot>,
253 ) -> anyhow::Result<AgentMessage> {
254 self.handle_message_with_metadata(input, None, state).await
255 }
256
257 pub async fn handle_message_with_metadata(
259 &self,
260 input: impl AsRef<str>,
261 metadata: Option<MessageMetadata>,
262 state: Arc<AgentStateSnapshot>,
263 ) -> anyhow::Result<AgentMessage> {
264 let agent_message = AgentMessage {
265 role: MessageRole::User,
266 content: MessageContent::Text(input.as_ref().to_string()),
267 metadata,
268 };
269 self.handle_message_internal(agent_message, state).await
270 }
271
272 async fn handle_message_internal(
274 &self,
275 input: AgentMessage,
276 _state: Arc<AgentStateSnapshot>,
277 ) -> anyhow::Result<AgentMessage> {
278 self.append_history(input.clone());
279
280 let mut request = ModelRequest::new(&self.instructions, self.current_history());
281 let tools = self.collect_tools();
282 for middleware in &self.middlewares {
283 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
284 middleware.modify_model_request(&mut ctx).await?;
285 }
286
287 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
288 let context = PlannerContext {
289 history: request.messages.clone(),
290 system_prompt: request.system_prompt.clone(),
291 tools: tool_schemas,
292 };
293 let state_snapshot = Arc::new(self.state.read().map(|s| s.clone()).unwrap_or_default());
294
295 let decision = self.planner.plan(context, state_snapshot).await?;
296
297 match decision.next_action {
298 PlannerAction::Respond { message } => {
299 self.append_history(message.clone());
300 Ok(message)
301 }
302 PlannerAction::CallTool { tool_name, payload } => {
303 if let Some(tool) = tools.get(&tool_name).cloned() {
304 if let Some(hitl) = &self.hitl {
305 if let Some(policy) = hitl.requires_approval(&tool_name) {
306 let message_text = policy
307 .note
308 .clone()
309 .unwrap_or_else(|| "Awaiting human approval.".into());
310 let approval_message = AgentMessage {
311 role: MessageRole::System,
312 content: MessageContent::Text(format!(
313 "HITL_REQUIRED: Tool '{tool}' requires approval: {message}",
314 tool = tool_name,
315 message = message_text
316 )),
317 metadata: None,
318 };
319 let pending = HitlPending {
320 tool_name: tool_name.clone(),
321 payload: payload.clone(),
322 tool: tool.clone(),
323 message: approval_message.clone(),
324 };
325 if let Ok(mut guard) = self.pending_hitl.write() {
326 *guard = Some(pending);
327 }
328 self.append_history(approval_message.clone());
329 return Ok(approval_message);
330 }
331 }
332 self.execute_tool(tool.clone(), tool_name.clone(), payload.clone())
333 .await
334 } else {
335 Ok(AgentMessage {
336 role: MessageRole::Tool,
337 content: MessageContent::Text(format!(
338 "Tool '{tool}' not available",
339 tool = tool_name
340 )),
341 metadata: Some(MessageMetadata {
342 tool_call_id: None,
343 cache_control: None,
344 }),
345 })
346 }
347 }
348 PlannerAction::Terminate => Ok(AgentMessage {
349 role: MessageRole::Agent,
350 content: MessageContent::Text("Terminating conversation.".into()),
351 metadata: Some(MessageMetadata {
352 tool_call_id: None,
353 cache_control: None,
354 }),
355 }),
356 }
357 }
358}
359
360#[async_trait]
361impl AgentHandle for DeepAgent {
362 async fn describe(&self) -> AgentDescriptor {
363 self.descriptor.clone()
364 }
365
366 async fn handle_message(
367 &self,
368 input: AgentMessage,
369 _state: Arc<AgentStateSnapshot>,
370 ) -> anyhow::Result<AgentMessage> {
371 self.handle_message_internal(input, _state).await
372 }
373
374 async fn handle_message_stream(
375 &self,
376 input: AgentMessage,
377 _state: Arc<AgentStateSnapshot>,
378 ) -> anyhow::Result<agents_core::agent::AgentStream> {
379 use crate::planner::LlmBackedPlanner;
380 use agents_core::llm::{LlmRequest, StreamChunk};
381
382 self.append_history(input.clone());
384
385 let mut request = ModelRequest::new(&self.instructions, self.current_history());
387 let tools = self.collect_tools();
388
389 for middleware in &self.middlewares {
391 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
392 middleware.modify_model_request(&mut ctx).await?;
393 }
394
395 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
397 let llm_request = LlmRequest {
398 system_prompt: request.system_prompt.clone(),
399 messages: request.messages.clone(),
400 tools: tool_schemas,
401 };
402
403 let planner_any = self.planner.as_any();
405
406 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
407 let model = llm_planner.model().clone();
409 let stream = model.generate_stream(llm_request).await?;
410 Ok(stream)
411 } else {
412 let response = self.handle_message_internal(input, _state).await?;
414 Ok(Box::pin(futures::stream::once(async move {
415 Ok(StreamChunk::Done { message: response })
416 })))
417 }
418 }
419}
420
421pub fn create_deep_agent_from_config(config: DeepAgentConfig) -> DeepAgent {
426 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
427 let history = Arc::new(RwLock::new(Vec::<AgentMessage>::new()));
428
429 let planning = Arc::new(PlanningMiddleware::new(state.clone()));
430 let filesystem = Arc::new(FilesystemMiddleware::new(state.clone()));
431
432 let mut registrations: Vec<SubAgentRegistration> = Vec::new();
434
435 for subagent_config in &config.subagent_configs {
437 let sub_planner = if let Some(ref model) = subagent_config.model {
439 Arc::new(LlmBackedPlanner::new(model.clone())) as Arc<dyn PlannerHandle>
441 } else {
442 config.planner.clone()
444 };
445
446 let mut sub_cfg = DeepAgentConfig::new(subagent_config.instructions.clone(), sub_planner);
448
449 if let Some(ref tools) = subagent_config.tools {
451 for tool in tools {
452 sub_cfg = sub_cfg.with_tool(tool.clone());
453 }
454 }
455
456 if let Some(ref builtin) = subagent_config.builtin_tools {
458 sub_cfg = sub_cfg.with_builtin_tools(builtin.iter().cloned());
459 }
460
461 sub_cfg = sub_cfg.with_auto_general_purpose(false);
463
464 sub_cfg = sub_cfg.with_prompt_caching(subagent_config.enable_prompt_caching);
466
467 let sub_agent = create_deep_agent_from_config(sub_cfg);
469
470 registrations.push(SubAgentRegistration {
472 descriptor: SubAgentDescriptor {
473 name: subagent_config.name.clone(),
474 description: subagent_config.description.clone(),
475 },
476 agent: Arc::new(sub_agent),
477 });
478 }
479
480 if config.auto_general_purpose {
482 let has_gp = registrations
483 .iter()
484 .any(|r| r.descriptor.name == "general-purpose");
485 if !has_gp {
486 let mut sub_cfg =
488 DeepAgentConfig::new(config.instructions.clone(), config.planner.clone())
489 .with_auto_general_purpose(false)
490 .with_prompt_caching(config.enable_prompt_caching);
491 if let Some(ref selected) = config.builtin_tools {
492 sub_cfg = sub_cfg.with_builtin_tools(selected.iter().cloned());
493 }
494 if let Some(ref sum) = config.summarization {
495 sub_cfg = sub_cfg.with_summarization(sum.clone());
496 }
497 for t in &config.tools {
498 sub_cfg = sub_cfg.with_tool(t.clone());
499 }
500
501 let gp = create_deep_agent_from_config(sub_cfg);
502 registrations.push(SubAgentRegistration {
503 descriptor: SubAgentDescriptor {
504 name: "general-purpose".into(),
505 description: "Default reasoning agent".into(),
506 },
507 agent: Arc::new(gp),
508 });
509 }
510 }
511
512 let subagent = Arc::new(SubAgentMiddleware::new(registrations));
513 let base_prompt = Arc::new(BaseSystemPromptMiddleware);
514 let summarization = config.summarization.as_ref().map(|cfg| {
515 Arc::new(SummarizationMiddleware::new(
516 cfg.messages_to_keep,
517 cfg.summary_note.clone(),
518 ))
519 });
520 let hitl = if config.tool_interrupts.is_empty() {
521 None
522 } else {
523 Some(Arc::new(HumanInLoopMiddleware::new(
524 config.tool_interrupts.clone(),
525 )))
526 };
527
528 let mut middlewares: Vec<Arc<dyn AgentMiddleware>> =
530 vec![base_prompt, planning, filesystem, subagent];
531 if let Some(ref summary) = summarization {
532 middlewares.push(summary.clone());
533 }
534 if config.enable_prompt_caching {
535 middlewares.push(Arc::new(AnthropicPromptCachingMiddleware::with_defaults()));
536 }
537 if let Some(ref hitl_mw) = hitl {
538 middlewares.push(hitl_mw.clone());
539 }
540
541 DeepAgent {
542 descriptor: AgentDescriptor {
543 name: "deep-agent".into(),
544 version: "0.0.1".into(),
545 description: Some("Rust deep agent".into()),
546 },
547 instructions: config.instructions,
548 planner: config.planner,
549 middlewares,
550 base_tools: config.tools,
551 state,
552 history,
553 _summarization: summarization,
554 hitl,
555 pending_hitl: Arc::new(RwLock::new(None)),
556 builtin_tools: config.builtin_tools,
557 checkpointer: config.checkpointer,
558 }
559}