1use crate::agent::Context;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::memory_policy::{MemoryAdapter, MemoryPolicy};
4use crate::agent::executor::tool_processor::ToolProcessor;
5use crate::agent::hooks::AgentHooks;
6use crate::agent::task::Task;
7use crate::channel::{Sender, channel};
8use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
9use crate::utils::{receiver_into_stream, spawn_future};
10use autoagents_llm::ToolCall;
11use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, StreamResponse};
12use autoagents_llm::error::LLMError;
13use autoagents_protocol::{Event, SubmissionId};
14use futures::{Stream, StreamExt};
15use serde_json::Value;
16use std::collections::HashSet;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc;
23
24#[cfg(target_arch = "wasm32")]
25use futures::channel::mpsc;
26
27#[derive(Debug, Clone, Copy)]
29pub enum ToolMode {
30 Enabled,
31 Disabled,
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum StreamMode {
37 Structured,
38 Tool,
39}
40
41#[derive(Debug, Clone)]
43pub struct TurnEngineConfig {
44 pub max_turns: usize,
45 pub tool_mode: ToolMode,
46 pub stream_mode: StreamMode,
47 pub memory_policy: MemoryPolicy,
48}
49
50impl TurnEngineConfig {
51 pub fn basic(max_turns: usize) -> Self {
52 Self {
53 max_turns,
54 tool_mode: ToolMode::Disabled,
55 stream_mode: StreamMode::Structured,
56 memory_policy: MemoryPolicy::basic(),
57 }
58 }
59
60 pub fn react(max_turns: usize) -> Self {
61 Self {
62 max_turns,
63 tool_mode: ToolMode::Enabled,
64 stream_mode: StreamMode::Tool,
65 memory_policy: MemoryPolicy::react(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct TurnEngineOutput {
73 pub response: String,
74 pub tool_calls: Vec<ToolCallResult>,
75}
76
77#[derive(Debug)]
79pub enum TurnDelta {
80 Text(String),
81 ToolResults(Vec<ToolCallResult>),
82 Done(crate::agent::executor::TurnResult<TurnEngineOutput>),
83}
84
85#[derive(Error, Debug)]
86pub enum TurnEngineError {
87 #[error("LLM error: {0}")]
88 LLMError(String),
89
90 #[error("Run aborted by hook")]
91 Aborted,
92
93 #[error("Other error: {0}")]
94 Other(String),
95}
96
97#[derive(Clone)]
99pub struct TurnState {
100 memory: MemoryAdapter,
101 stored_user: bool,
102}
103
104impl TurnState {
105 pub fn new(context: &Context, policy: MemoryPolicy) -> Self {
106 Self {
107 memory: MemoryAdapter::new(context.memory(), policy),
108 stored_user: false,
109 }
110 }
111
112 pub fn memory(&self) -> &MemoryAdapter {
113 &self.memory
114 }
115
116 pub fn stored_user(&self) -> bool {
117 self.stored_user
118 }
119
120 fn mark_user_stored(&mut self) {
121 self.stored_user = true;
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct TurnEngine {
128 config: TurnEngineConfig,
129}
130
131impl TurnEngine {
132 pub fn new(config: TurnEngineConfig) -> Self {
133 Self { config }
134 }
135
136 pub fn turn_state(&self, context: &Context) -> TurnState {
137 TurnState::new(context, self.config.memory_policy.clone())
138 }
139
140 pub async fn run_turn<H: AgentHooks>(
141 &self,
142 hooks: &H,
143 task: &Task,
144 context: &Context,
145 turn_state: &mut TurnState,
146 turn_index: usize,
147 max_turns: usize,
148 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
149 let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
150 let tx_event = context.tx().ok();
151 EventHelper::send_turn_started(
152 &tx_event,
153 task.submission_id,
154 context.config().id,
155 turn_index,
156 max_turns,
157 )
158 .await;
159
160 hooks.on_turn_start(turn_index, context).await;
161
162 let include_user_prompt =
163 should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
164 let messages = self
165 .build_messages(context, task, turn_state.memory(), include_user_prompt)
166 .await;
167
168 if should_store_user(turn_state) {
169 turn_state.memory.store_user(task).await;
170 turn_state.mark_user_stored();
171 }
172
173 let tools = context.tools();
174 let response = self.get_llm_response(context, &messages, tools).await?;
175 let response_text = response.text().unwrap_or_default();
176
177 let tool_calls = if matches!(self.config.tool_mode, ToolMode::Enabled) {
178 response.tool_calls().unwrap_or_default()
179 } else {
180 Vec::new()
181 };
182
183 if !tool_calls.is_empty() {
184 let tool_results = process_tool_calls_with_hooks(
185 hooks,
186 context,
187 task.submission_id,
188 tools,
189 &tool_calls,
190 &tx_event,
191 )
192 .await;
193
194 turn_state
195 .memory
196 .store_tool_interaction(&tool_calls, &tool_results, &response_text)
197 .await;
198 record_tool_calls_state(context, &tool_results);
199
200 EventHelper::send_turn_completed(
201 &tx_event,
202 task.submission_id,
203 context.config().id,
204 turn_index,
205 false,
206 )
207 .await;
208 hooks.on_turn_complete(turn_index, context).await;
209
210 return Ok(crate::agent::executor::TurnResult::Continue(Some(
211 TurnEngineOutput {
212 response: response_text,
213 tool_calls: tool_results,
214 },
215 )));
216 }
217
218 if !response_text.is_empty() {
219 turn_state.memory.store_assistant(&response_text).await;
220 }
221
222 EventHelper::send_turn_completed(
223 &tx_event,
224 task.submission_id,
225 context.config().id,
226 turn_index,
227 true,
228 )
229 .await;
230 hooks.on_turn_complete(turn_index, context).await;
231
232 Ok(crate::agent::executor::TurnResult::Complete(
233 TurnEngineOutput {
234 response: response_text,
235 tool_calls: Vec::new(),
236 },
237 ))
238 }
239
240 pub async fn run_turn_stream<H>(
241 &self,
242 hooks: H,
243 task: &Task,
244 context: Arc<Context>,
245 turn_state: &mut TurnState,
246 turn_index: usize,
247 max_turns: usize,
248 ) -> Result<
249 Pin<Box<dyn Stream<Item = Result<TurnDelta, TurnEngineError>> + Send>>,
250 TurnEngineError,
251 >
252 where
253 H: AgentHooks + Clone + Send + Sync + 'static,
254 {
255 let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
256 let include_user_prompt =
257 should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
258 let messages = self
259 .build_messages(&context, task, turn_state.memory(), include_user_prompt)
260 .await;
261
262 if should_store_user(turn_state) {
263 turn_state.memory.store_user(task).await;
264 turn_state.mark_user_stored();
265 }
266
267 let (mut tx, rx) = channel::<Result<TurnDelta, TurnEngineError>>(100);
268 let engine = self.clone();
269 let context_clone = context.clone();
270 let task = task.clone();
271 let hooks = hooks.clone();
272 let memory = turn_state.memory.clone();
273 let messages = messages.clone();
274
275 spawn_future(async move {
276 let tx_event = context_clone.tx().ok();
277 EventHelper::send_turn_started(
278 &tx_event,
279 task.submission_id,
280 context_clone.config().id,
281 turn_index,
282 max_turns,
283 )
284 .await;
285 hooks.on_turn_start(turn_index, &context_clone).await;
286
287 let result = match engine.config.stream_mode {
288 StreamMode::Structured => {
289 engine
290 .stream_structured(&context_clone, &task, &memory, &mut tx, &messages)
291 .await
292 }
293 StreamMode::Tool => {
294 engine
295 .stream_with_tools(
296 &hooks,
297 &context_clone,
298 &task,
299 context_clone.tools(),
300 &memory,
301 &mut tx,
302 &messages,
303 )
304 .await
305 }
306 };
307
308 match result {
309 Ok(turn_result) => {
310 let final_turn =
311 matches!(turn_result, crate::agent::executor::TurnResult::Complete(_));
312 EventHelper::send_turn_completed(
313 &tx_event,
314 task.submission_id,
315 context_clone.config().id,
316 turn_index,
317 final_turn,
318 )
319 .await;
320 hooks.on_turn_complete(turn_index, &context_clone).await;
321 let _ = tx.send(Ok(TurnDelta::Done(turn_result))).await;
322 }
323 Err(err) => {
324 let _ = tx.send(Err(err)).await;
325 }
326 }
327 });
328
329 Ok(receiver_into_stream(rx))
330 }
331
332 async fn stream_structured(
333 &self,
334 context: &Context,
335 task: &Task,
336 memory: &MemoryAdapter,
337 tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
338 messages: &[ChatMessage],
339 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
340 let mut stream = self.get_structured_stream(context, messages).await?;
341 let mut response_text = String::default();
342
343 while let Some(chunk_result) = stream.next().await {
344 let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
345 let content = chunk
346 .choices
347 .first()
348 .and_then(|choice| choice.delta.content.as_ref())
349 .map_or("", |value| value)
350 .to_string();
351
352 if content.is_empty() {
353 continue;
354 }
355
356 response_text.push_str(&content);
357
358 let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
359
360 let tx_event = context.tx().ok();
361 EventHelper::send_stream_chunk(
362 &tx_event,
363 task.submission_id,
364 StreamChunk::Text(content),
365 )
366 .await;
367 }
368
369 if !response_text.is_empty() {
370 memory.store_assistant(&response_text).await;
371 }
372
373 Ok(crate::agent::executor::TurnResult::Complete(
374 TurnEngineOutput {
375 response: response_text,
376 tool_calls: Vec::default(),
377 },
378 ))
379 }
380
381 #[allow(clippy::too_many_arguments)]
382 async fn stream_with_tools<H: AgentHooks>(
383 &self,
384 hooks: &H,
385 context: &Context,
386 task: &Task,
387 tools: &[Box<dyn ToolT>],
388 memory: &MemoryAdapter,
389 tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
390 messages: &[ChatMessage],
391 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
392 let mut stream = self.get_tool_stream(context, messages, tools).await?;
393 let mut response_text = String::default();
394 let mut tool_calls = Vec::default();
395 let mut tool_call_ids: HashSet<String> = HashSet::default();
396
397 while let Some(chunk_result) = stream.next().await {
398 let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
399 let chunk_clone = chunk.clone();
400
401 match chunk {
402 StreamChunk::Text(content) => {
403 response_text.push_str(&content);
404 let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
405 }
406 StreamChunk::ToolUseComplete { tool_call, .. } => {
407 if tool_call_ids.insert(tool_call.id.clone()) {
408 tool_calls.push(tool_call.clone());
409 let tx_event = context.tx().ok();
410 EventHelper::send_stream_tool_call(
411 &tx_event,
412 task.submission_id,
413 serde_json::to_value(tool_call).unwrap_or(Value::Null),
414 )
415 .await;
416 }
417 }
418 StreamChunk::Usage(_) => {}
419 _ => {}
420 }
421
422 let tx_event = context.tx().ok();
423 EventHelper::send_stream_chunk(&tx_event, task.submission_id, chunk_clone).await;
424 }
425
426 if tool_calls.is_empty() {
427 if !response_text.is_empty() {
428 memory.store_assistant(&response_text).await;
429 }
430 return Ok(crate::agent::executor::TurnResult::Complete(
431 TurnEngineOutput {
432 response: response_text,
433 tool_calls: Vec::new(),
434 },
435 ));
436 }
437
438 let tx_event = context.tx().ok();
439 let tool_results = process_tool_calls_with_hooks(
440 hooks,
441 context,
442 task.submission_id,
443 tools,
444 &tool_calls,
445 &tx_event,
446 )
447 .await;
448
449 memory
450 .store_tool_interaction(&tool_calls, &tool_results, &response_text)
451 .await;
452 record_tool_calls_state(context, &tool_results);
453
454 let _ = tx
455 .send(Ok(TurnDelta::ToolResults(tool_results.clone())))
456 .await;
457
458 Ok(crate::agent::executor::TurnResult::Continue(Some(
459 TurnEngineOutput {
460 response: response_text,
461 tool_calls: tool_results,
462 },
463 )))
464 }
465
466 async fn get_llm_response(
467 &self,
468 context: &Context,
469 messages: &[ChatMessage],
470 tools: &[Box<dyn ToolT>],
471 ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, TurnEngineError> {
472 let llm = context.llm();
473 let output_schema = context.config().output_schema.clone();
474
475 if matches!(self.config.tool_mode, ToolMode::Enabled) && !tools.is_empty() {
476 let cached = context.serialized_tools();
477 let tools_serialized = if let Some(cached) = cached {
478 cached
479 } else {
480 Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
481 };
482 llm.chat_with_tools(messages, Some(&tools_serialized), output_schema)
483 .await
484 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
485 } else {
486 llm.chat(messages, output_schema)
487 .await
488 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
489 }
490 }
491
492 async fn get_structured_stream(
493 &self,
494 context: &Context,
495 messages: &[ChatMessage],
496 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, TurnEngineError>
497 {
498 context
499 .llm()
500 .chat_stream_struct(messages, None, context.config().output_schema.clone())
501 .await
502 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
503 }
504
505 async fn get_tool_stream(
506 &self,
507 context: &Context,
508 messages: &[ChatMessage],
509 tools: &[Box<dyn ToolT>],
510 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, TurnEngineError>
511 {
512 let cached = context.serialized_tools();
513 let tools_serialized = if let Some(cached) = cached {
514 cached
515 } else {
516 Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
517 };
518 context
519 .llm()
520 .chat_stream_with_tools(
521 messages,
522 if tools_serialized.is_empty() {
523 None
524 } else {
525 Some(&tools_serialized)
526 },
527 context.config().output_schema.clone(),
528 )
529 .await
530 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
531 }
532
533 async fn build_messages(
534 &self,
535 context: &Context,
536 task: &Task,
537 memory: &MemoryAdapter,
538 include_user_prompt: bool,
539 ) -> Vec<ChatMessage> {
540 let system_prompt = task
541 .system_prompt
542 .as_deref()
543 .unwrap_or_else(|| &context.config().description);
544 let mut messages = vec![ChatMessage {
545 role: ChatRole::System,
546 message_type: MessageType::Text,
547 content: system_prompt.to_string(),
548 }];
549
550 let recalled = memory.recall_messages(task).await;
551 messages.extend(recalled);
552
553 if include_user_prompt {
554 messages.push(user_message(task));
555 }
556
557 messages
558 }
559}
560
561pub fn record_task_state(context: &Context, task: &Task) {
562 let state = context.state();
563 #[cfg(not(target_arch = "wasm32"))]
564 if let Ok(mut guard) = state.try_lock() {
565 guard.record_task(task.clone());
566 };
567 #[cfg(target_arch = "wasm32")]
568 if let Some(mut guard) = state.try_lock() {
569 guard.record_task(task.clone());
570 };
571}
572
573fn user_message(task: &Task) -> ChatMessage {
574 if let Some((mime, image_data)) = &task.image {
575 ChatMessage {
576 role: ChatRole::User,
577 message_type: MessageType::Image(((*mime).into(), image_data.clone())),
578 content: task.prompt.clone(),
579 }
580 } else {
581 ChatMessage {
582 role: ChatRole::User,
583 message_type: MessageType::Text,
584 content: task.prompt.clone(),
585 }
586 }
587}
588
589fn should_include_user_prompt(memory: &MemoryAdapter, stored_user: bool) -> bool {
590 if !memory.is_enabled() {
591 return true;
592 }
593 if !memory.policy().recall {
594 return true;
595 }
596 if !memory.policy().store_user {
597 return true;
598 }
599 !stored_user
600}
601
602fn should_store_user(turn_state: &TurnState) -> bool {
603 if !turn_state.memory.is_enabled() {
604 return false;
605 }
606 if !turn_state.memory.policy().store_user {
607 return false;
608 }
609 !turn_state.stored_user
610}
611
612fn normalize_max_turns(max_turns: usize, fallback: usize) -> usize {
613 if max_turns == 0 {
614 return fallback.max(1);
615 }
616 max_turns
617}
618
619fn record_tool_calls_state(context: &Context, tool_results: &[ToolCallResult]) {
620 if tool_results.is_empty() {
621 return;
622 }
623 let state = context.state();
624 #[cfg(not(target_arch = "wasm32"))]
625 if let Ok(mut guard) = state.try_lock() {
626 for result in tool_results {
627 guard.record_tool_call(result.clone());
628 }
629 };
630 #[cfg(target_arch = "wasm32")]
631 if let Some(mut guard) = state.try_lock() {
632 for result in tool_results {
633 guard.record_tool_call(result.clone());
634 }
635 };
636}
637
638async fn process_tool_calls_with_hooks<H: AgentHooks>(
639 hooks: &H,
640 context: &Context,
641 submission_id: SubmissionId,
642 tools: &[Box<dyn ToolT>],
643 tool_calls: &[ToolCall],
644 tx_event: &Option<mpsc::Sender<Event>>,
645) -> Vec<ToolCallResult> {
646 let mut results = Vec::new();
647 for call in tool_calls {
648 if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
649 hooks,
650 context,
651 submission_id,
652 tools,
653 call,
654 tx_event,
655 )
656 .await
657 {
658 results.push(result);
659 }
660 }
661 results
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use crate::agent::task::Task;
668 use crate::agent::{AgentConfig, Context};
669 use crate::tests::{ConfigurableLLMProvider, StaticChatResponse};
670 use async_trait::async_trait;
671 use autoagents_llm::ToolCall;
672 use autoagents_llm::chat::{StreamChoice, StreamChunk, StreamDelta, StreamResponse};
673 use autoagents_protocol::ActorID;
674 use futures::StreamExt;
675
676 #[derive(Debug)]
677 struct LocalTool {
678 name: String,
679 output: serde_json::Value,
680 }
681
682 impl LocalTool {
683 fn new(name: &str, output: serde_json::Value) -> Self {
684 Self {
685 name: name.to_string(),
686 output,
687 }
688 }
689 }
690
691 impl crate::tool::ToolT for LocalTool {
692 fn name(&self) -> &str {
693 &self.name
694 }
695
696 fn description(&self) -> &str {
697 "local tool"
698 }
699
700 fn args_schema(&self) -> serde_json::Value {
701 serde_json::json!({"type": "object"})
702 }
703 }
704
705 #[async_trait]
706 impl crate::tool::ToolRuntime for LocalTool {
707 async fn execute(
708 &self,
709 _args: serde_json::Value,
710 ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
711 Ok(self.output.clone())
712 }
713 }
714
715 #[test]
716 fn test_turn_engine_config_basic() {
717 let config = TurnEngineConfig::basic(5);
718 assert_eq!(config.max_turns, 5);
719 assert!(matches!(config.tool_mode, ToolMode::Disabled));
720 assert!(matches!(config.stream_mode, StreamMode::Structured));
721 assert!(config.memory_policy.recall);
722 }
723
724 #[test]
725 fn test_turn_engine_config_react() {
726 let config = TurnEngineConfig::react(10);
727 assert_eq!(config.max_turns, 10);
728 assert!(matches!(config.tool_mode, ToolMode::Enabled));
729 assert!(matches!(config.stream_mode, StreamMode::Tool));
730 assert!(config.memory_policy.recall);
731 }
732
733 #[test]
734 fn test_normalize_max_turns_nonzero() {
735 assert_eq!(normalize_max_turns(5, 10), 5);
736 }
737
738 #[test]
739 fn test_normalize_max_turns_zero_uses_fallback() {
740 assert_eq!(normalize_max_turns(0, 10), 10);
741 }
742
743 #[test]
744 fn test_normalize_max_turns_zero_fallback_zero() {
745 assert_eq!(normalize_max_turns(0, 0), 1);
746 }
747
748 #[test]
749 fn test_should_include_user_prompt_no_memory() {
750 let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
751 assert!(should_include_user_prompt(&adapter, false));
752 }
753
754 #[test]
755 fn test_should_include_user_prompt_recall_disabled() {
756 let mut policy = MemoryPolicy::basic();
757 policy.recall = false;
758 let mem: Box<dyn crate::agent::memory::MemoryProvider> =
759 Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
760 let adapter = MemoryAdapter::new(
761 Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
762 policy,
763 );
764 assert!(should_include_user_prompt(&adapter, false));
765 }
766
767 #[test]
768 fn test_should_include_user_prompt_store_user_disabled() {
769 let mut policy = MemoryPolicy::basic();
770 policy.store_user = false;
771 let mem: Box<dyn crate::agent::memory::MemoryProvider> =
772 Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
773 let adapter = MemoryAdapter::new(
774 Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
775 policy,
776 );
777 assert!(should_include_user_prompt(&adapter, false));
778 }
779
780 #[test]
781 fn test_should_include_user_prompt_already_stored() {
782 let mem: Box<dyn crate::agent::memory::MemoryProvider> =
783 Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
784 let adapter = MemoryAdapter::new(
785 Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
786 MemoryPolicy::basic(),
787 );
788 assert!(!should_include_user_prompt(&adapter, true));
790 }
791
792 #[test]
793 fn test_should_store_user_no_memory() {
794 let state = TurnState {
795 memory: MemoryAdapter::new(None, MemoryPolicy::basic()),
796 stored_user: false,
797 };
798 assert!(!should_store_user(&state));
799 }
800
801 #[test]
802 fn test_should_store_user_already_stored() {
803 let mem: Box<dyn crate::agent::memory::MemoryProvider> =
804 Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
805 let state = TurnState {
806 memory: MemoryAdapter::new(
807 Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
808 MemoryPolicy::basic(),
809 ),
810 stored_user: true,
811 };
812 assert!(!should_store_user(&state));
813 }
814
815 #[test]
816 fn test_user_message_text() {
817 let task = Task::new("hello");
818 let msg = user_message(&task);
819 assert!(matches!(msg.role, ChatRole::User));
820 assert!(matches!(msg.message_type, MessageType::Text));
821 assert_eq!(msg.content, "hello");
822 }
823
824 #[test]
825 fn test_user_message_image() {
826 let mut task = Task::new("describe");
827 task.image = Some((autoagents_protocol::ImageMime::PNG, vec![1, 2, 3]));
828 let msg = user_message(&task);
829 assert!(matches!(msg.role, ChatRole::User));
830 assert!(matches!(msg.message_type, MessageType::Image(_)));
831 }
832
833 #[test]
834 fn test_turn_state_new_and_mark_user_stored() {
835 let config = AgentConfig {
836 id: ActorID::new_v4(),
837 name: "test".to_string(),
838 description: "test".to_string(),
839 output_schema: None,
840 };
841 let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
842 let context = Context::new(llm, None).with_config(config);
843
844 let mut state = TurnState::new(&context, MemoryPolicy::basic());
845 assert!(!state.stored_user());
846 state.mark_user_stored();
847 assert!(state.stored_user());
848 }
849
850 #[tokio::test]
851 async fn test_build_messages_with_system_prompt() {
852 let config = AgentConfig {
853 id: ActorID::new_v4(),
854 name: "test".to_string(),
855 description: "default desc".to_string(),
856 output_schema: None,
857 };
858 let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
859 let context = Context::new(llm, None).with_config(config);
860
861 let engine = TurnEngine::new(TurnEngineConfig::basic(1));
862 let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
863 let mut task = Task::new("user input");
864 task.system_prompt = Some("custom system".to_string());
865
866 let messages = engine.build_messages(&context, &task, &adapter, true).await;
867 assert_eq!(messages.len(), 2);
869 assert_eq!(messages[0].content, "custom system");
870 assert_eq!(messages[0].role, ChatRole::System);
871 assert_eq!(messages[1].content, "user input");
872 }
873
874 #[tokio::test]
875 async fn test_build_messages_without_user_prompt() {
876 let config = AgentConfig {
877 id: ActorID::new_v4(),
878 name: "test".to_string(),
879 description: "desc".to_string(),
880 output_schema: None,
881 };
882 let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
883 let context = Context::new(llm, None).with_config(config);
884
885 let engine = TurnEngine::new(TurnEngineConfig::basic(1));
886 let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
887 let task = Task::new("user input");
888
889 let messages = engine
890 .build_messages(&context, &task, &adapter, false)
891 .await;
892 assert_eq!(messages.len(), 1);
894 assert_eq!(messages[0].role, ChatRole::System);
895 }
896
897 #[tokio::test]
898 async fn test_run_turn_no_tools_single_turn() {
899 use crate::tests::MockAgentImpl;
900 let config = AgentConfig {
901 id: ActorID::new_v4(),
902 name: "test".to_string(),
903 description: "test desc".to_string(),
904 output_schema: None,
905 };
906 let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
907 let context = Context::new(llm, None).with_config(config);
908
909 let engine = TurnEngine::new(TurnEngineConfig::basic(1));
910 let mut turn_state = engine.turn_state(&context);
911 let task = Task::new("test prompt");
912 let hooks = MockAgentImpl::new("test", "test");
913
914 let result = engine
915 .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
916 .await;
917 assert!(result.is_ok());
918 let turn_result = result.unwrap();
919 assert!(matches!(
920 turn_result,
921 crate::agent::executor::TurnResult::Complete(_)
922 ));
923 if let crate::agent::executor::TurnResult::Complete(output) = turn_result {
924 assert_eq!(output.response, "Mock response");
925 }
926 }
927
928 #[tokio::test]
929 async fn test_run_turn_with_tool_calls_continues() {
930 use crate::tests::MockAgentImpl;
931 let tool_call = ToolCall {
932 id: "call_1".to_string(),
933 call_type: "function".to_string(),
934 function: autoagents_llm::FunctionCall {
935 name: "tool_a".to_string(),
936 arguments: r#"{"value":1}"#.to_string(),
937 },
938 };
939
940 let llm = Arc::new(ConfigurableLLMProvider {
941 chat_response: StaticChatResponse {
942 text: Some("Use tool".to_string()),
943 tool_calls: Some(vec![tool_call.clone()]),
944 usage: None,
945 thinking: None,
946 },
947 ..ConfigurableLLMProvider::default()
948 });
949
950 let config = AgentConfig {
951 id: ActorID::new_v4(),
952 name: "tool_agent".to_string(),
953 description: "desc".to_string(),
954 output_schema: None,
955 };
956 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
957 let context = Context::new(llm, None)
958 .with_config(config)
959 .with_tools(vec![Box::new(tool)]);
960
961 let engine = TurnEngine::new(TurnEngineConfig {
962 max_turns: 2,
963 tool_mode: ToolMode::Enabled,
964 stream_mode: StreamMode::Structured,
965 memory_policy: MemoryPolicy::basic(),
966 });
967 let mut turn_state = engine.turn_state(&context);
968 let task = Task::new("prompt");
969 let hooks = MockAgentImpl::new("test", "test");
970
971 let result = engine
972 .run_turn(&hooks, &task, &context, &mut turn_state, 0, 2)
973 .await
974 .unwrap();
975
976 match result {
977 crate::agent::executor::TurnResult::Continue(Some(output)) => {
978 assert_eq!(output.response, "Use tool");
979 assert_eq!(output.tool_calls.len(), 1);
980 assert!(output.tool_calls[0].success);
981 }
982 _ => panic!("expected Continue(Some)"),
983 }
984
985 #[cfg(not(target_arch = "wasm32"))]
986 if let Ok(state) = context.state().try_lock() {
987 assert_eq!(state.tool_calls.len(), 1);
988 }
989 }
990
991 #[tokio::test]
992 async fn test_run_turn_tool_mode_disabled_ignores_tool_calls() {
993 use crate::tests::MockAgentImpl;
994 let tool_call = ToolCall {
995 id: "call_1".to_string(),
996 call_type: "function".to_string(),
997 function: autoagents_llm::FunctionCall {
998 name: "tool_a".to_string(),
999 arguments: r#"{"value":1}"#.to_string(),
1000 },
1001 };
1002
1003 let llm = Arc::new(ConfigurableLLMProvider {
1004 chat_response: StaticChatResponse {
1005 text: Some("No tools".to_string()),
1006 tool_calls: Some(vec![tool_call]),
1007 usage: None,
1008 thinking: None,
1009 },
1010 ..ConfigurableLLMProvider::default()
1011 });
1012
1013 let config = AgentConfig {
1014 id: ActorID::new_v4(),
1015 name: "tool_agent".to_string(),
1016 description: "desc".to_string(),
1017 output_schema: None,
1018 };
1019 let context = Context::new(llm, None).with_config(config);
1020
1021 let engine = TurnEngine::new(TurnEngineConfig {
1022 max_turns: 1,
1023 tool_mode: ToolMode::Disabled,
1024 stream_mode: StreamMode::Structured,
1025 memory_policy: MemoryPolicy::basic(),
1026 });
1027 let mut turn_state = engine.turn_state(&context);
1028 let task = Task::new("prompt");
1029 let hooks = MockAgentImpl::new("test", "test");
1030
1031 let result = engine
1032 .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
1033 .await
1034 .unwrap();
1035
1036 match result {
1037 crate::agent::executor::TurnResult::Complete(output) => {
1038 assert_eq!(output.response, "No tools");
1039 assert!(output.tool_calls.is_empty());
1040 }
1041 _ => panic!("expected Complete"),
1042 }
1043 }
1044
1045 #[tokio::test]
1046 async fn test_run_turn_stream_structured_aggregates_text() {
1047 use crate::tests::MockAgentImpl;
1048 let llm = Arc::new(ConfigurableLLMProvider {
1049 structured_stream: vec![
1050 StreamResponse {
1051 choices: vec![StreamChoice {
1052 delta: StreamDelta {
1053 content: Some("Hello ".to_string()),
1054 tool_calls: None,
1055 },
1056 }],
1057 usage: None,
1058 },
1059 StreamResponse {
1060 choices: vec![StreamChoice {
1061 delta: StreamDelta {
1062 content: Some("world".to_string()),
1063 tool_calls: None,
1064 },
1065 }],
1066 usage: None,
1067 },
1068 ],
1069 ..ConfigurableLLMProvider::default()
1070 });
1071
1072 let config = AgentConfig {
1073 id: ActorID::new_v4(),
1074 name: "stream_agent".to_string(),
1075 description: "desc".to_string(),
1076 output_schema: None,
1077 };
1078 let context = Arc::new(Context::new(llm, None).with_config(config));
1079 let engine = TurnEngine::new(TurnEngineConfig {
1080 max_turns: 1,
1081 tool_mode: ToolMode::Disabled,
1082 stream_mode: StreamMode::Structured,
1083 memory_policy: MemoryPolicy::basic(),
1084 });
1085 let mut turn_state = engine.turn_state(&context);
1086 let task = Task::new("prompt");
1087 let hooks = MockAgentImpl::new("test", "test");
1088
1089 let mut stream = engine
1090 .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1091 .await
1092 .unwrap();
1093
1094 let mut final_text = String::default();
1095 while let Some(delta) = stream.next().await {
1096 if let Ok(TurnDelta::Done(result)) = delta {
1097 final_text = match result {
1098 crate::agent::executor::TurnResult::Complete(output) => output.response,
1099 crate::agent::executor::TurnResult::Continue(Some(output)) => output.response,
1100 crate::agent::executor::TurnResult::Continue(None) => String::default(),
1101 };
1102 break;
1103 }
1104 }
1105
1106 assert_eq!(final_text, "Hello world");
1107 }
1108
1109 #[tokio::test]
1110 async fn test_run_turn_stream_with_tools_executes_tools() {
1111 use crate::tests::MockAgentImpl;
1112 let tool_call = ToolCall {
1113 id: "call_1".to_string(),
1114 call_type: "function".to_string(),
1115 function: autoagents_llm::FunctionCall {
1116 name: "tool_a".to_string(),
1117 arguments: r#"{"value":1}"#.to_string(),
1118 },
1119 };
1120
1121 let llm = Arc::new(ConfigurableLLMProvider {
1122 stream_chunks: vec![
1123 StreamChunk::Text("thinking".to_string()),
1124 StreamChunk::ToolUseComplete {
1125 index: 0,
1126 tool_call: tool_call.clone(),
1127 },
1128 StreamChunk::Done {
1129 stop_reason: "tool_use".to_string(),
1130 },
1131 ],
1132 ..ConfigurableLLMProvider::default()
1133 });
1134
1135 let config = AgentConfig {
1136 id: ActorID::new_v4(),
1137 name: "tool_stream_agent".to_string(),
1138 description: "desc".to_string(),
1139 output_schema: None,
1140 };
1141 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
1142 let context = Arc::new(
1143 Context::new(llm, None)
1144 .with_config(config)
1145 .with_tools(vec![Box::new(tool)]),
1146 );
1147 let engine = TurnEngine::new(TurnEngineConfig {
1148 max_turns: 1,
1149 tool_mode: ToolMode::Enabled,
1150 stream_mode: StreamMode::Tool,
1151 memory_policy: MemoryPolicy::basic(),
1152 });
1153 let mut turn_state = engine.turn_state(&context);
1154 let task = Task::new("prompt");
1155 let hooks = MockAgentImpl::new("test", "test");
1156
1157 let mut stream = engine
1158 .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1159 .await
1160 .unwrap();
1161
1162 let mut final_result = None;
1163 while let Some(delta) = stream.next().await {
1164 if let Ok(TurnDelta::Done(result)) = delta {
1165 final_result = Some(result);
1166 break;
1167 }
1168 }
1169
1170 match final_result.expect("done") {
1171 crate::agent::executor::TurnResult::Continue(Some(output)) => {
1172 assert_eq!(output.tool_calls.len(), 1);
1173 assert!(output.tool_calls[0].success);
1174 }
1175 _ => panic!("expected Continue(Some)"),
1176 }
1177 }
1178}