1use std::sync::Arc;
2use std::time::{SystemTime, UNIX_EPOCH};
3use tokio_util::sync::CancellationToken;
4
5use crate::{
6 event_stream::{EventStream, EventStreamConsumer, EventStreamProducer},
7 provider_types::{StreamContext, StreamFn, StreamOptions},
8 tool::AgentTool,
9 types::*,
10};
11
12pub struct AgentLoopConfig {
13 pub model: Model,
14 pub api_key: String,
15 pub system_prompt: String,
16 pub tools: Vec<Arc<dyn AgentTool>>,
17 pub thinking: ThinkingLevel,
18 pub max_tokens: Option<usize>,
19 pub stream_fn: StreamFn,
20 pub get_steering_messages: Option<Box<dyn Fn() -> Vec<Message> + Send + Sync>>,
21 pub get_follow_up_messages: Option<Box<dyn Fn() -> Vec<Message> + Send + Sync>>,
22 pub transform_messages:
23 Option<Box<dyn Fn(&[Message], &Model) -> (Vec<Message>, Option<crate::compaction::CompactionResult>) + Send + Sync>>,
24 pub post_tools_hooks: Vec<Arc<dyn crate::hooks::PostToolsHook>>,
25}
26
27pub fn agent_loop(
28 prompts: Vec<Message>,
29 config: AgentLoopConfig,
30 cancel: CancellationToken,
31) -> EventStreamConsumer<AgentEvent, Vec<Message>> {
32 let (producer, consumer) = EventStream::<AgentEvent, Vec<Message>>::new().split();
33
34 tokio::spawn(async move {
35 run_loop(prompts, config, cancel, producer).await;
36 });
37
38 consumer
39}
40
41async fn run_loop(
42 prompts: Vec<Message>,
43 config: AgentLoopConfig,
44 cancel: CancellationToken,
45 mut stream: EventStreamProducer<AgentEvent, Vec<Message>>,
46) {
47 let _ = stream.push(AgentEvent::AgentStart).await;
48 let mut messages = prompts;
49
50 'outer: loop {
51 if cancel.is_cancelled() {
52 break;
53 }
54
55 loop {
57 if cancel.is_cancelled() {
58 break 'outer;
59 }
60
61 let _ = stream.push(AgentEvent::TurnStart).await;
62
63 let tool_defs: Vec<ToolDef> = config
65 .tools
66 .iter()
67 .map(|t| ToolDef {
68 name: t.name().to_string(),
69 description: t.description(),
70 parameters: t.parameters_schema(),
71 })
72 .collect();
73
74 let effective_messages = if let Some(ref transform) = config.transform_messages {
76 let (transformed, compaction) = transform(&messages, &config.model);
77 if let Some(result) = compaction {
78 let _ = stream
79 .push(AgentEvent::ContextCompacted {
80 original_estimate: result.original_estimate,
81 compacted_estimate: result.compacted_estimate,
82 messages_pruned: result.messages_pruned,
83 })
84 .await;
85 }
86 transformed
87 } else {
88 messages.clone()
89 };
90
91 let assistant_msg =
92 stream_assistant_response(&config, &effective_messages, &tool_defs, &stream).await;
93
94 let _ = stream
95 .push(AgentEvent::MessageEnd {
96 message: assistant_msg.clone(),
97 })
98 .await;
99
100 let (tool_calls, stop_reason) = extract_tool_calls_and_stop(&assistant_msg);
102 messages.push(assistant_msg.clone());
103
104 if tool_calls.is_empty() {
105 let _ = stream
106 .push(AgentEvent::TurnEnd {
107 message: assistant_msg,
108 tool_results: vec![],
109 })
110 .await;
111
112 if matches!(stop_reason, StopReason::Error | StopReason::Aborted) {
113 break 'outer;
114 }
115 break; }
117
118 let mut tool_results = vec![];
120
121 for (i, (id, name, args)) in tool_calls.iter().enumerate() {
122 if cancel.is_cancelled() {
123 break 'outer;
124 }
125
126 if i > 0 {
128 if let Some(ref get_steering) = config.get_steering_messages {
129 let steering = get_steering();
130 if !steering.is_empty() {
131 for (skip_id, skip_name, _) in &tool_calls[i..] {
133 let skip_result = Message::ToolResult {
134 tool_call_id: skip_id.clone(),
135 tool_name: skip_name.clone(),
136 content: vec![Content::Text {
137 text: "Tool execution skipped due to steering".into(),
138 }],
139 is_error: true,
140 timestamp: now_ms(),
141 };
142 tool_results.push(skip_result.clone());
143 messages.push(skip_result);
144 }
145 messages.extend(steering);
146 break;
147 }
148 }
149 }
150
151 let _ = stream
153 .push(AgentEvent::ToolExecutionStart {
154 tool_call_id: id.clone(),
155 tool_name: name.clone(),
156 args: args.clone(),
157 })
158 .await;
159
160 let (result, is_error) =
161 execute_tool_call(&config.tools, id, name, args.clone(), cancel.clone()).await;
162
163 let tool_result_msg = Message::ToolResult {
164 tool_call_id: id.clone(),
165 tool_name: name.clone(),
166 content: result.content.clone(),
167 is_error,
168 timestamp: now_ms(),
169 };
170
171 tool_results.push(tool_result_msg.clone());
172 messages.push(tool_result_msg);
173
174 let _ = stream
175 .push(AgentEvent::ToolExecutionEnd {
176 tool_call_id: id.clone(),
177 tool_name: name.clone(),
178 result,
179 is_error,
180 })
181 .await;
182 }
183
184 if !tool_calls.is_empty() && !config.post_tools_hooks.is_empty() {
186 let tool_names: Vec<String> = tool_calls.iter().map(|(_, name, _)| name.clone()).collect();
187 for hook in &config.post_tools_hooks {
188 if cancel.is_cancelled() {
189 break 'outer;
190 }
191 let _ = stream.push(AgentEvent::PostToolsHookStart {
192 hook_name: hook.name().to_string(),
193 }).await;
194
195 let hook_result = tokio::time::timeout(
196 hook.timeout(),
197 hook.execute(&tool_names, cancel.clone()),
198 ).await;
199
200 let result = match hook_result {
201 Ok(r) => r,
202 Err(_) => crate::hooks::PostToolsHookResult {
203 steering_message: None,
204 success: false,
205 summary: format!("{}: timed out", hook.name()),
206 },
207 };
208
209 let _ = stream.push(AgentEvent::PostToolsHookEnd {
210 hook_name: hook.name().to_string(),
211 success: result.success,
212 summary: result.summary,
213 }).await;
214
215 if let Some(steering) = result.steering_message {
217 messages.push(Message::User {
218 content: UserContent::Text(steering),
219 timestamp: now_ms(),
220 });
221 }
222 }
223 }
224
225 let _ = stream
226 .push(AgentEvent::TurnEnd {
227 message: assistant_msg,
228 tool_results,
229 })
230 .await;
231
232 }
234
235 if let Some(ref get_follow_up) = config.get_follow_up_messages {
237 let follow_ups = get_follow_up();
238 if !follow_ups.is_empty() {
239 messages.extend(follow_ups);
240 continue 'outer;
241 }
242 }
243
244 break 'outer;
245 }
246
247 let _ = stream
248 .push(AgentEvent::AgentEnd {
249 messages: messages.clone(),
250 })
251 .await;
252 stream.end(Some(messages));
253}
254
255async fn stream_assistant_response(
256 config: &AgentLoopConfig,
257 messages: &[Message],
258 tool_defs: &[ToolDef],
259 stream: &EventStreamProducer<AgentEvent, Vec<Message>>,
260) -> Message {
261 let context = StreamContext {
262 messages: messages.to_vec(),
263 tools: tool_defs.to_vec(),
264 };
265 let options = StreamOptions {
266 api_key: config.api_key.clone(),
267 system_prompt: Some(config.system_prompt.clone()),
268 max_tokens: config.max_tokens,
269 thinking: config.thinking,
270 };
271
272 let mut assistant_stream = (config.stream_fn)(&config.model, context, options);
273
274 let mut content_blocks: Vec<Content> = vec![];
276 let model_str = config.model.id.clone();
277 let usage = Usage::default();
278 let mut stop_reason = StopReason::Stop;
279 let mut first_event = true;
280
281 while let Some(event) = assistant_stream.next().await {
282 match &event {
283 AssistantStreamEvent::TextStart { .. } => {
284 content_blocks.push(Content::Text {
285 text: String::new(),
286 });
287 }
288 AssistantStreamEvent::TextDelta { delta, .. } => {
289 if let Some(Content::Text { ref mut text }) = content_blocks.last_mut() {
290 text.push_str(delta);
291 }
292 }
293 AssistantStreamEvent::ThinkingStart { .. } => {
294 content_blocks.push(Content::Thinking {
295 thinking: String::new(),
296 });
297 }
298 AssistantStreamEvent::ThinkingDelta { delta, .. } => {
299 if let Some(Content::Thinking { ref mut thinking }) = content_blocks.last_mut() {
300 thinking.push_str(delta);
301 }
302 }
303 AssistantStreamEvent::ToolCallStart { .. } => {
304 content_blocks.push(Content::ToolCall {
305 id: String::new(),
306 name: String::new(),
307 arguments: serde_json::Value::Null,
308 });
309 }
310 AssistantStreamEvent::ToolCallEnd { tool_call, .. } => {
311 if let Some(last) = content_blocks.last_mut() {
312 *last = tool_call.clone();
313 }
314 }
315 AssistantStreamEvent::Done { stop_reason: sr } => {
316 stop_reason = sr.clone();
317 }
318 AssistantStreamEvent::Error { stop_reason: sr } => {
319 stop_reason = sr.clone();
320 }
321 _ => {}
322 }
323
324 let partial_msg = Message::Assistant {
325 content: content_blocks.clone(),
326 model: model_str.clone(),
327 usage: usage.clone(),
328 stop_reason: stop_reason.clone(),
329 timestamp: now_ms(),
330 };
331
332 if first_event {
333 let _ = stream
334 .push(AgentEvent::MessageStart {
335 message: partial_msg.clone(),
336 })
337 .await;
338 first_event = false;
339 }
340
341 let _ = stream
342 .push(AgentEvent::MessageUpdate {
343 message: partial_msg,
344 event,
345 })
346 .await;
347 }
348
349 if let Some(final_msg) = assistant_stream.result().await {
351 final_msg
352 } else {
353 Message::Assistant {
355 content: content_blocks,
356 model: model_str,
357 usage,
358 stop_reason,
359 timestamp: now_ms(),
360 }
361 }
362}
363
364fn extract_tool_calls_and_stop(
365 msg: &Message,
366) -> (Vec<(String, String, serde_json::Value)>, StopReason) {
367 match msg {
368 Message::Assistant {
369 content,
370 stop_reason,
371 ..
372 } => {
373 let calls: Vec<_> = content
374 .iter()
375 .filter_map(|c| {
376 if let Content::ToolCall {
377 id,
378 name,
379 arguments,
380 } = c
381 {
382 Some((id.clone(), name.clone(), arguments.clone()))
383 } else {
384 None
385 }
386 })
387 .collect();
388 (calls, stop_reason.clone())
389 }
390 _ => (vec![], StopReason::Stop),
391 }
392}
393
394async fn execute_tool_call(
395 tools: &[Arc<dyn AgentTool>],
396 id: &str,
397 name: &str,
398 args: serde_json::Value,
399 cancel: CancellationToken,
400) -> (ToolResult, bool) {
401 let tool = tools.iter().find(|t| t.name() == name);
402
403 match tool {
404 Some(tool) => match tool.execute(id, args, cancel).await {
405 Ok(result) => (result, false),
406 Err(e) => {
407 let error_result = ToolResult {
408 content: vec![Content::Text {
409 text: format!("Tool error: {}", e),
410 }],
411 details: serde_json::json!({}),
412 };
413 (error_result, true)
414 }
415 },
416 None => {
417 let error_result = ToolResult {
418 content: vec![Content::Text {
419 text: format!("Tool '{}' not found", name),
420 }],
421 details: serde_json::json!({}),
422 };
423 (error_result, true)
424 }
425 }
426}
427
428fn now_ms() -> u64 {
429 SystemTime::now()
430 .duration_since(UNIX_EPOCH)
431 .unwrap()
432 .as_millis() as u64
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::provider_types::{AssistantStream, StreamFn};
439 use crate::tool::ToolError;
440 use async_trait::async_trait;
441 use std::sync::atomic::Ordering;
442
443 fn test_model() -> Model {
444 Model {
445 id: "test-model".into(),
446 name: "Test Model".into(),
447 provider: "test".into(),
448 base_url: "http://localhost".into(),
449 reasoning: false,
450 context_window: 8192,
451 max_tokens: 4096,
452 }
453 }
454
455 fn mock_stream_fn_multi(
457 calls: Vec<(Vec<AssistantStreamEvent>, Message)>,
458 ) -> (Arc<std::sync::atomic::AtomicUsize>, StreamFn) {
459 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
460 let call_count_clone = call_count.clone();
461 let calls = Arc::new(calls);
462
463 let f: StreamFn = Arc::new(move |_model: &Model, _ctx: StreamContext, _opts: StreamOptions| {
464 let n = call_count_clone.fetch_add(1, Ordering::SeqCst);
465 let (events, final_msg) = (*calls)[n].clone();
466
467 let (mut producer, consumer) = AssistantStream::new().split();
468
469 tokio::spawn(async move {
470 for event in events {
471 let _ = producer.push(event).await;
472 }
473 producer.end(Some(final_msg));
474 });
475
476 consumer
477 });
478
479 (call_count, f)
480 }
481
482 fn mock_stream_fn(
484 events: Vec<AssistantStreamEvent>,
485 final_msg: Message,
486 ) -> StreamFn {
487 let (_, f) = mock_stream_fn_multi(vec![(events, final_msg)]);
488 f
489 }
490
491 #[tokio::test]
492 async fn test_text_only_response() {
493 let final_msg = Message::Assistant {
494 content: vec![Content::Text {
495 text: "Hello world".into(),
496 }],
497 model: "test-model".into(),
498 usage: Usage::default(),
499 stop_reason: StopReason::Stop,
500 timestamp: 100,
501 };
502
503 let events = vec![
504 AssistantStreamEvent::Start,
505 AssistantStreamEvent::TextStart { index: 0 },
506 AssistantStreamEvent::TextDelta {
507 index: 0,
508 delta: "Hello ".into(),
509 },
510 AssistantStreamEvent::TextDelta {
511 index: 0,
512 delta: "world".into(),
513 },
514 AssistantStreamEvent::TextEnd {
515 index: 0,
516 content: "Hello world".into(),
517 },
518 AssistantStreamEvent::Done {
519 stop_reason: StopReason::Stop,
520 },
521 ];
522
523 let config = AgentLoopConfig {
524 model: test_model(),
525 api_key: "test-key".into(),
526 system_prompt: "You are a test".into(),
527 tools: vec![],
528 thinking: ThinkingLevel::Off,
529 max_tokens: None,
530 stream_fn: mock_stream_fn(events, final_msg),
531 get_steering_messages: None,
532 get_follow_up_messages: None,
533 transform_messages: None,
534 post_tools_hooks: vec![],
535 };
536
537 let prompts = vec![Message::User {
538 content: UserContent::Text("hi".into()),
539 timestamp: 1,
540 }];
541
542 let cancel = CancellationToken::new();
543 let mut consumer = agent_loop(prompts, config, cancel);
544
545 let mut agent_events = vec![];
546 while let Some(event) = consumer.next().await {
547 agent_events.push(event);
548 }
549
550 assert!(matches!(agent_events[0], AgentEvent::AgentStart));
552 assert!(matches!(agent_events[1], AgentEvent::TurnStart));
553
554 let has_msg_start = agent_events
555 .iter()
556 .any(|e| matches!(e, AgentEvent::MessageStart { .. }));
557 assert!(has_msg_start);
558
559 let has_msg_end = agent_events
560 .iter()
561 .any(|e| matches!(e, AgentEvent::MessageEnd { .. }));
562 assert!(has_msg_end);
563
564 let has_turn_end = agent_events.iter().any(|e| {
565 matches!(e, AgentEvent::TurnEnd { tool_results, .. } if tool_results.is_empty())
566 });
567 assert!(has_turn_end);
568
569 assert!(matches!(
570 agent_events.last().unwrap(),
571 AgentEvent::AgentEnd { .. }
572 ));
573 }
574
575 struct MockTool {
577 tool_name: String,
578 result: ToolResult,
579 }
580
581 #[async_trait]
582 impl AgentTool for MockTool {
583 fn name(&self) -> &str {
584 &self.tool_name
585 }
586 fn label(&self) -> String {
587 self.tool_name.clone()
588 }
589 fn description(&self) -> String {
590 "A mock tool".into()
591 }
592 fn parameters_schema(&self) -> serde_json::Value {
593 serde_json::json!({"type": "object"})
594 }
595 async fn execute(
596 &self,
597 _tool_call_id: &str,
598 _params: serde_json::Value,
599 _cancel: CancellationToken,
600 ) -> Result<ToolResult, ToolError> {
601 Ok(self.result.clone())
602 }
603 }
604
605 #[tokio::test]
606 async fn test_tool_call_response() {
607 let tool_call_msg = Message::Assistant {
608 content: vec![Content::ToolCall {
609 id: "call_1".into(),
610 name: "mock_tool".into(),
611 arguments: serde_json::json!({"key": "value"}),
612 }],
613 model: "test-model".into(),
614 usage: Usage::default(),
615 stop_reason: StopReason::ToolUse,
616 timestamp: 100,
617 };
618
619 let text_msg = Message::Assistant {
620 content: vec![Content::Text {
621 text: "Done".into(),
622 }],
623 model: "test-model".into(),
624 usage: Usage::default(),
625 stop_reason: StopReason::Stop,
626 timestamp: 200,
627 };
628
629 let (call_count, stream_fn) = mock_stream_fn_multi(vec![
630 (
631 vec![
632 AssistantStreamEvent::Start,
633 AssistantStreamEvent::ToolCallStart { index: 0 },
634 AssistantStreamEvent::ToolCallEnd {
635 index: 0,
636 tool_call: Content::ToolCall {
637 id: "call_1".into(),
638 name: "mock_tool".into(),
639 arguments: serde_json::json!({"key": "value"}),
640 },
641 },
642 AssistantStreamEvent::Done {
643 stop_reason: StopReason::ToolUse,
644 },
645 ],
646 tool_call_msg,
647 ),
648 (
649 vec![
650 AssistantStreamEvent::Start,
651 AssistantStreamEvent::TextStart { index: 0 },
652 AssistantStreamEvent::TextDelta {
653 index: 0,
654 delta: "Done".into(),
655 },
656 AssistantStreamEvent::TextEnd {
657 index: 0,
658 content: "Done".into(),
659 },
660 AssistantStreamEvent::Done {
661 stop_reason: StopReason::Stop,
662 },
663 ],
664 text_msg,
665 ),
666 ]);
667
668 let mock_tool = Arc::new(MockTool {
669 tool_name: "mock_tool".into(),
670 result: ToolResult {
671 content: vec![Content::Text {
672 text: "tool output".into(),
673 }],
674 details: serde_json::json!({}),
675 },
676 });
677
678 let config = AgentLoopConfig {
679 model: test_model(),
680 api_key: "test-key".into(),
681 system_prompt: "You are a test".into(),
682 tools: vec![mock_tool],
683 thinking: ThinkingLevel::Off,
684 max_tokens: None,
685 stream_fn,
686 get_steering_messages: None,
687 get_follow_up_messages: None,
688 transform_messages: None,
689 post_tools_hooks: vec![],
690 };
691
692 let prompts = vec![Message::User {
693 content: UserContent::Text("use the tool".into()),
694 timestamp: 1,
695 }];
696
697 let cancel = CancellationToken::new();
698 let mut consumer = agent_loop(prompts, config, cancel);
699
700 let mut agent_events = vec![];
701 while let Some(event) = consumer.next().await {
702 agent_events.push(event);
703 }
704
705 let has_tool_start = agent_events
706 .iter()
707 .any(|e| matches!(e, AgentEvent::ToolExecutionStart { tool_name, .. } if tool_name == "mock_tool"));
708 assert!(has_tool_start);
709
710 let has_tool_end = agent_events.iter().any(
711 |e| matches!(e, AgentEvent::ToolExecutionEnd { tool_name, is_error, .. } if tool_name == "mock_tool" && !is_error),
712 );
713 assert!(has_tool_end);
714
715 assert_eq!(call_count.load(Ordering::SeqCst), 2);
716
717 assert!(matches!(
718 agent_events.last().unwrap(),
719 AgentEvent::AgentEnd { .. }
720 ));
721 }
722
723 #[tokio::test]
724 async fn test_tool_not_found() {
725 let tool_call_msg = Message::Assistant {
726 content: vec![Content::ToolCall {
727 id: "call_1".into(),
728 name: "nonexistent_tool".into(),
729 arguments: serde_json::json!({}),
730 }],
731 model: "test-model".into(),
732 usage: Usage::default(),
733 stop_reason: StopReason::ToolUse,
734 timestamp: 100,
735 };
736
737 let text_msg = Message::Assistant {
738 content: vec![Content::Text {
739 text: "Sorry".into(),
740 }],
741 model: "test-model".into(),
742 usage: Usage::default(),
743 stop_reason: StopReason::Stop,
744 timestamp: 200,
745 };
746
747 let (_call_count, stream_fn) = mock_stream_fn_multi(vec![
748 (
749 vec![
750 AssistantStreamEvent::Start,
751 AssistantStreamEvent::ToolCallStart { index: 0 },
752 AssistantStreamEvent::ToolCallEnd {
753 index: 0,
754 tool_call: Content::ToolCall {
755 id: "call_1".into(),
756 name: "nonexistent_tool".into(),
757 arguments: serde_json::json!({}),
758 },
759 },
760 AssistantStreamEvent::Done {
761 stop_reason: StopReason::ToolUse,
762 },
763 ],
764 tool_call_msg,
765 ),
766 (
767 vec![
768 AssistantStreamEvent::Start,
769 AssistantStreamEvent::TextStart { index: 0 },
770 AssistantStreamEvent::TextDelta {
771 index: 0,
772 delta: "Sorry".into(),
773 },
774 AssistantStreamEvent::TextEnd {
775 index: 0,
776 content: "Sorry".into(),
777 },
778 AssistantStreamEvent::Done {
779 stop_reason: StopReason::Stop,
780 },
781 ],
782 text_msg,
783 ),
784 ]);
785
786 let config = AgentLoopConfig {
787 model: test_model(),
788 api_key: "test-key".into(),
789 system_prompt: "You are a test".into(),
790 tools: vec![],
791 thinking: ThinkingLevel::Off,
792 max_tokens: None,
793 stream_fn,
794 get_steering_messages: None,
795 get_follow_up_messages: None,
796 transform_messages: None,
797 post_tools_hooks: vec![],
798 };
799
800 let prompts = vec![Message::User {
801 content: UserContent::Text("use the tool".into()),
802 timestamp: 1,
803 }];
804
805 let cancel = CancellationToken::new();
806 let mut consumer = agent_loop(prompts, config, cancel);
807
808 let mut agent_events = vec![];
809 while let Some(event) = consumer.next().await {
810 agent_events.push(event);
811 }
812
813 let has_error_tool = agent_events.iter().any(
814 |e| matches!(e, AgentEvent::ToolExecutionEnd { is_error, tool_name, .. } if *is_error && tool_name == "nonexistent_tool"),
815 );
816 assert!(has_error_tool);
817 }
818}