1use super::config::DeepAgentConfig;
7use crate::middleware::{
8 AgentMiddleware, AnthropicPromptCachingMiddleware, BaseSystemPromptMiddleware,
9 DeepAgentPromptMiddleware, FilesystemMiddleware, HumanInLoopMiddleware, MiddlewareContext,
10 ModelRequest, PlanningMiddleware, SubAgentDescriptor, SubAgentMiddleware, SubAgentRegistration,
11 SummarizationMiddleware,
12};
13use crate::planner::LlmBackedPlanner;
14use agents_core::agent::{
15 AgentDescriptor, AgentHandle, PlannerAction, PlannerContext, PlannerHandle,
16};
17use agents_core::hitl::{AgentInterrupt, HitlAction};
18use agents_core::messaging::{AgentMessage, MessageContent, MessageMetadata, MessageRole};
19use agents_core::persistence::{Checkpointer, ThreadId};
20use agents_core::state::AgentStateSnapshot;
21use agents_core::tools::{ToolBox, ToolContext, ToolResult};
22use async_trait::async_trait;
23use serde_json::Value;
24use std::collections::{HashMap, HashSet};
25use std::sync::{Arc, RwLock};
26
27const BUILTIN_TOOL_NAMES: &[&str] = &["write_todos", "ls", "read_file", "write_file", "edit_file"];
29
30pub struct DeepAgent {
37 descriptor: AgentDescriptor,
38 instructions: String,
39 planner: Arc<dyn PlannerHandle>,
40 middlewares: Vec<Arc<dyn AgentMiddleware>>,
41 base_tools: Vec<ToolBox>,
42 state: Arc<RwLock<AgentStateSnapshot>>,
43 history: Arc<RwLock<Vec<AgentMessage>>>,
44 _summarization: Option<Arc<SummarizationMiddleware>>,
45 _hitl: Option<Arc<HumanInLoopMiddleware>>,
46 builtin_tools: Option<HashSet<String>>,
47 checkpointer: Option<Arc<dyn Checkpointer>>,
48}
49
50impl DeepAgent {
51 fn collect_tools(&self) -> HashMap<String, ToolBox> {
52 let mut tools: HashMap<String, ToolBox> = HashMap::new();
53 for tool in &self.base_tools {
54 tools.insert(tool.schema().name.clone(), tool.clone());
55 }
56 for middleware in &self.middlewares {
57 for tool in middleware.tools() {
58 let tool_name = tool.schema().name.clone();
59 if self.should_include(&tool_name) {
60 tools.insert(tool_name, tool);
61 }
62 }
63 }
64 tools
65 }
66 fn should_include(&self, name: &str) -> bool {
69 let is_builtin = BUILTIN_TOOL_NAMES.contains(&name);
70 if !is_builtin {
71 return true;
72 }
73 match &self.builtin_tools {
74 None => true,
75 Some(selected) => selected.contains(name),
76 }
77 }
78
79 fn append_history(&self, message: AgentMessage) {
80 if let Ok(mut history) = self.history.write() {
81 history.push(message);
82 }
83 }
84
85 fn current_history(&self) -> Vec<AgentMessage> {
86 self.history.read().map(|h| h.clone()).unwrap_or_default()
87 }
88
89 pub async fn save_state(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
91 if let Some(ref checkpointer) = self.checkpointer {
92 let state = self
93 .state
94 .read()
95 .map_err(|_| anyhow::anyhow!("Failed to read agent state"))?
96 .clone();
97 checkpointer.save_state(thread_id, &state).await
98 } else {
99 tracing::warn!("Attempted to save state but no checkpointer is configured");
100 Ok(())
101 }
102 }
103
104 pub async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<bool> {
106 if let Some(ref checkpointer) = self.checkpointer {
107 if let Some(saved_state) = checkpointer.load_state(thread_id).await? {
108 *self
109 .state
110 .write()
111 .map_err(|_| anyhow::anyhow!("Failed to write agent state"))? = saved_state;
112 tracing::info!(thread_id = %thread_id, "Loaded agent state from checkpointer");
113 Ok(true)
114 } else {
115 tracing::debug!(thread_id = %thread_id, "No saved state found for thread");
116 Ok(false)
117 }
118 } else {
119 tracing::warn!("Attempted to load state but no checkpointer is configured");
120 Ok(false)
121 }
122 }
123
124 pub async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
126 if let Some(ref checkpointer) = self.checkpointer {
127 checkpointer.delete_thread(thread_id).await
128 } else {
129 tracing::warn!("Attempted to delete thread state but no checkpointer is configured");
130 Ok(())
131 }
132 }
133
134 pub async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
136 if let Some(ref checkpointer) = self.checkpointer {
137 checkpointer.list_threads().await
138 } else {
139 Ok(Vec::new())
140 }
141 }
142
143 async fn execute_tool(
144 &self,
145 tool: ToolBox,
146 _tool_name: String,
147 payload: Value,
148 ) -> anyhow::Result<AgentMessage> {
149 let state_snapshot = self.state.read().unwrap().clone();
150 let ctx = ToolContext::with_mutable_state(Arc::new(state_snapshot), self.state.clone());
151
152 let result = tool.execute(payload, ctx).await?;
153 Ok(self.apply_tool_result(result))
154 }
155
156 fn apply_tool_result(&self, result: ToolResult) -> AgentMessage {
157 match result {
158 ToolResult::Message(message) => {
159 message
162 }
163 ToolResult::WithStateUpdate {
164 message,
165 state_diff,
166 } => {
167 if let Ok(mut state) = self.state.write() {
168 let command = agents_core::command::Command::with_state(state_diff);
169 command.apply_to(&mut state);
170 }
171 message
174 }
175 }
176 }
177
178 pub fn current_interrupt(&self) -> Option<AgentInterrupt> {
180 self.state
181 .read()
182 .ok()
183 .and_then(|guard| guard.pending_interrupts.first().cloned())
184 }
185
186 pub async fn resume_with_approval(&self, action: HitlAction) -> anyhow::Result<AgentMessage> {
188 let interrupt = {
190 let state_guard = self
191 .state
192 .read()
193 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?;
194 state_guard
195 .pending_interrupts
196 .first()
197 .cloned()
198 .ok_or_else(|| anyhow::anyhow!("No pending interrupts"))?
199 };
200
201 let result_message = match action {
202 HitlAction::Accept => {
203 let AgentInterrupt::HumanInLoop(hitl) = interrupt;
205 tracing::info!(
206 tool_name = %hitl.tool_name,
207 call_id = %hitl.call_id,
208 "β
HITL: Tool approved, executing with original arguments"
209 );
210
211 let tools = self.collect_tools();
212 let tool = tools
213 .get(&hitl.tool_name)
214 .cloned()
215 .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", hitl.tool_name))?;
216
217 self.execute_tool(tool, hitl.tool_name, hitl.tool_args)
218 .await?
219 }
220
221 HitlAction::Edit {
222 tool_name,
223 tool_args,
224 } => {
225 tracing::info!(
227 tool_name = %tool_name,
228 "βοΈ HITL: Tool edited, executing with modified arguments"
229 );
230
231 let tools = self.collect_tools();
232 let tool = tools
233 .get(&tool_name)
234 .cloned()
235 .ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", tool_name))?;
236
237 self.execute_tool(tool, tool_name, tool_args).await?
238 }
239
240 HitlAction::Reject { reason } => {
241 tracing::info!("β HITL: Tool rejected");
243
244 let text = reason
245 .unwrap_or_else(|| "Tool execution rejected by human reviewer.".to_string());
246
247 let message = AgentMessage {
248 role: MessageRole::Tool,
249 content: MessageContent::Text(text),
250 metadata: None,
251 };
252
253 self.append_history(message.clone());
254 message
255 }
256
257 HitlAction::Respond { message } => {
258 tracing::info!("π¬ HITL: Custom response provided");
260
261 self.append_history(message.clone());
262 message
263 }
264 };
265
266 {
268 let mut state_guard = self
269 .state
270 .write()
271 .map_err(|_| anyhow::anyhow!("Failed to acquire write lock on state"))?;
272 state_guard.clear_interrupts();
273 }
274
275 if let Some(checkpointer) = &self.checkpointer {
277 let state_clone = self
278 .state
279 .read()
280 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?
281 .clone();
282 checkpointer
283 .save_state(&ThreadId::default(), &state_clone)
284 .await?;
285 }
286
287 Ok(result_message)
288 }
289
290 pub async fn handle_message(
292 &self,
293 input: impl AsRef<str>,
294 state: Arc<AgentStateSnapshot>,
295 ) -> anyhow::Result<AgentMessage> {
296 self.handle_message_with_metadata(input, None, state).await
297 }
298
299 pub async fn handle_message_with_metadata(
301 &self,
302 input: impl AsRef<str>,
303 metadata: Option<MessageMetadata>,
304 state: Arc<AgentStateSnapshot>,
305 ) -> anyhow::Result<AgentMessage> {
306 let agent_message = AgentMessage {
307 role: MessageRole::User,
308 content: MessageContent::Text(input.as_ref().to_string()),
309 metadata,
310 };
311 self.handle_message_internal(agent_message, state).await
312 }
313
314 async fn handle_message_internal(
316 &self,
317 input: AgentMessage,
318 _state: Arc<AgentStateSnapshot>,
319 ) -> anyhow::Result<AgentMessage> {
320 self.append_history(input.clone());
321
322 let max_iterations = 10;
324 let mut iteration = 0;
325
326 loop {
327 iteration += 1;
328 if iteration > max_iterations {
329 tracing::warn!(
330 "β οΈ Max iterations ({}) reached, stopping ReAct loop",
331 max_iterations
332 );
333 let response = AgentMessage {
334 role: MessageRole::Agent,
335 content: MessageContent::Text(
336 "I've reached the maximum number of steps. Let me summarize what I've done so far.".to_string()
337 ),
338 metadata: None,
339 };
340 self.append_history(response.clone());
341 return Ok(response);
342 }
343
344 tracing::debug!("π ReAct iteration {}/{}", iteration, max_iterations);
345
346 let mut request = ModelRequest::new(&self.instructions, self.current_history());
348 let tools = self.collect_tools();
349 for middleware in &self.middlewares {
350 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
351 middleware.modify_model_request(&mut ctx).await?;
352 }
353
354 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
355 let context = PlannerContext {
356 history: request.messages.clone(),
357 system_prompt: request.system_prompt.clone(),
358 tools: tool_schemas,
359 };
360 let state_snapshot = Arc::new(self.state.read().map(|s| s.clone()).unwrap_or_default());
361
362 let decision = self.planner.plan(context, state_snapshot).await?;
364
365 match decision.next_action {
366 PlannerAction::Respond { message } => {
367 self.append_history(message.clone());
369 return Ok(message);
370 }
371 PlannerAction::CallTool { tool_name, payload } => {
372 let tool_call_message = AgentMessage {
377 role: MessageRole::Agent,
378 content: MessageContent::Text(format!(
379 "Calling tool: {} with args: {}",
380 tool_name,
381 serde_json::to_string(&payload).unwrap_or_default()
382 )),
383 metadata: None,
384 };
385 self.append_history(tool_call_message);
386
387 if let Some(tool) = tools.get(&tool_name).cloned() {
388 let call_id = format!("call_{}", uuid::Uuid::new_v4());
390 for middleware in &self.middlewares {
391 if let Some(interrupt) = middleware
392 .before_tool_execution(&tool_name, &payload, &call_id)
393 .await?
394 {
395 {
397 let mut state_guard = self.state.write().map_err(|_| {
398 anyhow::anyhow!("Failed to acquire write lock on state")
399 })?;
400 state_guard.add_interrupt(interrupt.clone());
401 }
402
403 if let Some(checkpointer) = &self.checkpointer {
405 let state_clone = self
406 .state
407 .read()
408 .map_err(|_| {
409 anyhow::anyhow!("Failed to acquire read lock on state")
410 })?
411 .clone();
412 checkpointer
413 .save_state(&ThreadId::default(), &state_clone)
414 .await?;
415 }
416
417 let interrupt_message = AgentMessage {
419 role: MessageRole::System,
420 content: MessageContent::Text(format!(
421 "βΈοΈ Execution paused: Tool '{}' requires human approval",
422 tool_name
423 )),
424 metadata: None,
425 };
426 self.append_history(interrupt_message.clone());
427 return Ok(interrupt_message);
428 }
429 }
430
431 let start_time = std::time::Instant::now();
433 tracing::warn!(
434 "βοΈ EXECUTING TOOL: {} with payload: {}",
435 tool_name,
436 serde_json::to_string(&payload)
437 .unwrap_or_else(|_| "invalid json".to_string())
438 );
439
440 let result = self
441 .execute_tool(tool.clone(), tool_name.clone(), payload.clone())
442 .await;
443
444 let duration = start_time.elapsed();
445 match result {
446 Ok(tool_result_message) => {
447 let content_preview = match &tool_result_message.content {
448 MessageContent::Text(t) => {
449 if t.len() > 100 {
450 format!("{}... ({} chars)", &t[..100], t.len())
451 } else {
452 t.clone()
453 }
454 }
455 MessageContent::Json(v) => {
456 format!("JSON: {} bytes", v.to_string().len())
457 }
458 };
459 tracing::warn!(
460 "β
TOOL COMPLETED: {} in {:?} - Result: {}",
461 tool_name,
462 duration,
463 content_preview
464 );
465
466 self.append_history(tool_result_message);
468 }
470 Err(e) => {
471 tracing::error!(
472 "β TOOL FAILED: {} in {:?} - Error: {}",
473 tool_name,
474 duration,
475 e
476 );
477
478 let error_message = AgentMessage {
480 role: MessageRole::Tool,
481 content: MessageContent::Text(format!(
482 "Error executing {}: {}",
483 tool_name, e
484 )),
485 metadata: None,
486 };
487 self.append_history(error_message);
488 }
490 }
491 } else {
492 tracing::warn!("β οΈ Tool '{}' not found", tool_name);
494 let error_message = AgentMessage {
495 role: MessageRole::Tool,
496 content: MessageContent::Text(format!(
497 "Tool '{}' not found. Available tools: {}",
498 tool_name,
499 tools
500 .keys()
501 .map(|k| k.as_str())
502 .collect::<Vec<_>>()
503 .join(", ")
504 )),
505 metadata: None,
506 };
507 self.append_history(error_message);
508 }
510 }
511 PlannerAction::Terminate => {
512 tracing::debug!("π Agent terminated");
514 let message = AgentMessage {
515 role: MessageRole::Agent,
516 content: MessageContent::Text("Task completed.".into()),
517 metadata: None,
518 };
519 self.append_history(message.clone());
520 return Ok(message);
521 }
522 }
523 }
524 }
525}
526
527#[async_trait]
528impl AgentHandle for DeepAgent {
529 async fn describe(&self) -> AgentDescriptor {
530 self.descriptor.clone()
531 }
532
533 async fn handle_message(
534 &self,
535 input: AgentMessage,
536 _state: Arc<AgentStateSnapshot>,
537 ) -> anyhow::Result<AgentMessage> {
538 self.handle_message_internal(input, _state).await
539 }
540
541 async fn handle_message_stream(
542 &self,
543 input: AgentMessage,
544 _state: Arc<AgentStateSnapshot>,
545 ) -> anyhow::Result<agents_core::agent::AgentStream> {
546 use crate::planner::LlmBackedPlanner;
547 use agents_core::llm::{LlmRequest, StreamChunk};
548
549 self.append_history(input.clone());
551
552 let mut request = ModelRequest::new(&self.instructions, self.current_history());
554 let tools = self.collect_tools();
555
556 for middleware in &self.middlewares {
558 let mut ctx = MiddlewareContext::with_request(&mut request, self.state.clone());
559 middleware.modify_model_request(&mut ctx).await?;
560 }
561
562 let tool_schemas: Vec<_> = tools.values().map(|t| t.schema()).collect();
564 let llm_request = LlmRequest {
565 system_prompt: request.system_prompt.clone(),
566 messages: request.messages.clone(),
567 tools: tool_schemas,
568 };
569
570 let planner_any = self.planner.as_any();
572
573 if let Some(llm_planner) = planner_any.downcast_ref::<LlmBackedPlanner>() {
574 let model = llm_planner.model().clone();
576 let stream = model.generate_stream(llm_request).await?;
577 Ok(stream)
578 } else {
579 let response = self.handle_message_internal(input, _state).await?;
581 Ok(Box::pin(futures::stream::once(async move {
582 Ok(StreamChunk::Done { message: response })
583 })))
584 }
585 }
586
587 async fn current_interrupt(&self) -> anyhow::Result<Option<AgentInterrupt>> {
588 let state_guard = self
589 .state
590 .read()
591 .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on state"))?;
592 Ok(state_guard.pending_interrupts.first().cloned())
593 }
594
595 async fn resume_with_approval(
596 &self,
597 action: agents_core::hitl::HitlAction,
598 ) -> anyhow::Result<AgentMessage> {
599 self.resume_with_approval(action).await
600 }
601}
602
603pub fn create_deep_agent_from_config(config: DeepAgentConfig) -> DeepAgent {
608 let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
609 let history = Arc::new(RwLock::new(Vec::<AgentMessage>::new()));
610
611 let planning = Arc::new(PlanningMiddleware::new(state.clone()));
612 let filesystem = Arc::new(FilesystemMiddleware::new(state.clone()));
613
614 let mut registrations: Vec<SubAgentRegistration> = Vec::new();
616
617 for subagent_config in &config.subagent_configs {
619 let sub_planner = if let Some(ref model) = subagent_config.model {
621 Arc::new(LlmBackedPlanner::new(model.clone())) as Arc<dyn PlannerHandle>
623 } else {
624 config.planner.clone()
626 };
627
628 let mut sub_cfg = DeepAgentConfig::new(subagent_config.instructions.clone(), sub_planner);
630
631 if let Some(ref tools) = subagent_config.tools {
633 for tool in tools {
634 sub_cfg = sub_cfg.with_tool(tool.clone());
635 }
636 }
637
638 if let Some(ref builtin) = subagent_config.builtin_tools {
640 sub_cfg = sub_cfg.with_builtin_tools(builtin.iter().cloned());
641 }
642
643 sub_cfg = sub_cfg.with_auto_general_purpose(false);
645
646 sub_cfg = sub_cfg.with_prompt_caching(subagent_config.enable_prompt_caching);
648
649 let sub_agent = create_deep_agent_from_config(sub_cfg);
651
652 registrations.push(SubAgentRegistration {
654 descriptor: SubAgentDescriptor {
655 name: subagent_config.name.clone(),
656 description: subagent_config.description.clone(),
657 },
658 agent: Arc::new(sub_agent),
659 });
660 }
661
662 if config.auto_general_purpose {
664 let has_gp = registrations
665 .iter()
666 .any(|r| r.descriptor.name == "general-purpose");
667 if !has_gp {
668 let mut sub_cfg =
670 DeepAgentConfig::new(config.instructions.clone(), config.planner.clone())
671 .with_auto_general_purpose(false)
672 .with_prompt_caching(config.enable_prompt_caching);
673 if let Some(ref selected) = config.builtin_tools {
674 sub_cfg = sub_cfg.with_builtin_tools(selected.iter().cloned());
675 }
676 if let Some(ref sum) = config.summarization {
677 sub_cfg = sub_cfg.with_summarization(sum.clone());
678 }
679 for t in &config.tools {
680 sub_cfg = sub_cfg.with_tool(t.clone());
681 }
682
683 let gp = create_deep_agent_from_config(sub_cfg);
684 registrations.push(SubAgentRegistration {
685 descriptor: SubAgentDescriptor {
686 name: "general-purpose".into(),
687 description: "Default reasoning agent".into(),
688 },
689 agent: Arc::new(gp),
690 });
691 }
692 }
693
694 let subagent = Arc::new(SubAgentMiddleware::new(registrations));
695 let base_prompt = Arc::new(BaseSystemPromptMiddleware);
696 let deep_agent_prompt = Arc::new(DeepAgentPromptMiddleware::new(config.instructions.clone()));
697 let summarization = config.summarization.as_ref().map(|cfg| {
698 Arc::new(SummarizationMiddleware::new(
699 cfg.messages_to_keep,
700 cfg.summary_note.clone(),
701 ))
702 });
703 let hitl = if config.tool_interrupts.is_empty() {
704 None
705 } else {
706 if config.checkpointer.is_none() {
708 tracing::error!(
709 "β οΈ HITL middleware requires a checkpointer to persist interrupt state. \
710 HITL will be disabled. Please configure a checkpointer to enable HITL."
711 );
712 None
713 } else {
714 tracing::info!("π HITL enabled for {} tools", config.tool_interrupts.len());
715 Some(Arc::new(HumanInLoopMiddleware::new(
716 config.tool_interrupts.clone(),
717 )))
718 }
719 };
720
721 let mut middlewares: Vec<Arc<dyn AgentMiddleware>> = vec![
724 base_prompt,
725 deep_agent_prompt,
726 planning,
727 filesystem,
728 subagent,
729 ];
730 if let Some(ref summary) = summarization {
731 middlewares.push(summary.clone());
732 }
733 if config.enable_prompt_caching {
734 middlewares.push(Arc::new(AnthropicPromptCachingMiddleware::with_defaults()));
735 }
736 if let Some(ref hitl_mw) = hitl {
737 middlewares.push(hitl_mw.clone());
738 }
739
740 DeepAgent {
741 descriptor: AgentDescriptor {
742 name: "deep-agent".into(),
743 version: "0.0.1".into(),
744 description: Some("Rust deep agent".into()),
745 },
746 instructions: config.instructions,
747 planner: config.planner,
748 middlewares,
749 base_tools: config.tools,
750 state,
751 history,
752 _summarization: summarization,
753 _hitl: hitl,
754 builtin_tools: config.builtin_tools,
755 checkpointer: config.checkpointer,
756 }
757}