1use crate::agent::Context;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::memory_policy::{MemoryAdapter, MemoryPolicy};
4use crate::agent::executor::tool_processor::ToolProcessor;
5use crate::agent::hooks::AgentHooks;
6use crate::agent::task::Task;
7use crate::channel::{Sender, channel};
8use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
9use crate::utils::{receiver_into_stream, spawn_future};
10use autoagents_llm::ToolCall;
11use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, StreamResponse, Tool};
12use autoagents_llm::error::LLMError;
13use autoagents_protocol::{Event, SubmissionId};
14use futures::{Stream, StreamExt};
15use serde_json::Value;
16use std::collections::HashSet;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc;
23
24#[cfg(target_arch = "wasm32")]
25use futures::channel::mpsc;
26
27#[derive(Debug, Clone, Copy)]
29pub enum ToolMode {
30 Enabled,
31 Disabled,
32}
33
34#[derive(Debug, Clone, Copy)]
36pub enum StreamMode {
37 Structured,
38 Tool,
39}
40
41#[derive(Debug, Clone)]
43pub struct TurnEngineConfig {
44 pub max_turns: usize,
45 pub tool_mode: ToolMode,
46 pub stream_mode: StreamMode,
47 pub memory_policy: MemoryPolicy,
48}
49
50impl TurnEngineConfig {
51 pub fn basic(max_turns: usize) -> Self {
52 Self {
53 max_turns,
54 tool_mode: ToolMode::Disabled,
55 stream_mode: StreamMode::Structured,
56 memory_policy: MemoryPolicy::basic(),
57 }
58 }
59
60 pub fn react(max_turns: usize) -> Self {
61 Self {
62 max_turns,
63 tool_mode: ToolMode::Enabled,
64 stream_mode: StreamMode::Tool,
65 memory_policy: MemoryPolicy::react(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct TurnEngineOutput {
73 pub response: String,
74 pub tool_calls: Vec<ToolCallResult>,
75}
76
77#[derive(Debug)]
79pub enum TurnDelta {
80 Text(String),
81 ToolResults(Vec<ToolCallResult>),
82 Done(crate::agent::executor::TurnResult<TurnEngineOutput>),
83}
84
85#[derive(Error, Debug)]
86pub enum TurnEngineError {
87 #[error("LLM error: {0}")]
88 LLMError(String),
89
90 #[error("Run aborted by hook")]
91 Aborted,
92
93 #[error("Other error: {0}")]
94 Other(String),
95}
96
97#[derive(Clone)]
99pub struct TurnState {
100 memory: MemoryAdapter,
101 stored_user: bool,
102}
103
104impl TurnState {
105 pub fn new(context: &Context, policy: MemoryPolicy) -> Self {
106 Self {
107 memory: MemoryAdapter::new(context.memory(), policy),
108 stored_user: false,
109 }
110 }
111
112 pub fn memory(&self) -> &MemoryAdapter {
113 &self.memory
114 }
115
116 pub fn stored_user(&self) -> bool {
117 self.stored_user
118 }
119
120 fn mark_user_stored(&mut self) {
121 self.stored_user = true;
122 }
123}
124
125#[derive(Debug, Clone)]
127pub struct TurnEngine {
128 config: TurnEngineConfig,
129}
130
131impl TurnEngine {
132 pub fn new(config: TurnEngineConfig) -> Self {
133 Self { config }
134 }
135
136 pub fn turn_state(&self, context: &Context) -> TurnState {
137 TurnState::new(context, self.config.memory_policy.clone())
138 }
139
140 pub async fn run_turn<H: AgentHooks>(
141 &self,
142 hooks: &H,
143 task: &Task,
144 context: &Context,
145 turn_state: &mut TurnState,
146 turn_index: usize,
147 max_turns: usize,
148 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
149 let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
150 let tx_event = context.tx().ok();
151 EventHelper::send_turn_started(
152 &tx_event,
153 task.submission_id,
154 context.config().id,
155 turn_index,
156 max_turns,
157 )
158 .await;
159
160 hooks.on_turn_start(turn_index, context).await;
161
162 let include_user_prompt =
163 should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
164 let messages = self
165 .build_messages(context, task, turn_state.memory(), include_user_prompt)
166 .await;
167
168 if should_store_user(turn_state) {
169 turn_state.memory.store_user(task).await;
170 turn_state.mark_user_stored();
171 }
172
173 let tools = context.tools();
174 let response = self.get_llm_response(context, &messages, tools).await?;
175 let response_text = response.text().unwrap_or_default();
176
177 let tool_calls = if matches!(self.config.tool_mode, ToolMode::Enabled) {
178 response.tool_calls().unwrap_or_default()
179 } else {
180 Vec::new()
181 };
182
183 if !tool_calls.is_empty() {
184 let tool_results = process_tool_calls_with_hooks(
185 hooks,
186 context,
187 task.submission_id,
188 tools,
189 &tool_calls,
190 &tx_event,
191 )
192 .await;
193
194 turn_state
195 .memory
196 .store_tool_interaction(&tool_calls, &tool_results, &response_text)
197 .await;
198 record_tool_calls_state(context, &tool_results);
199
200 EventHelper::send_turn_completed(
201 &tx_event,
202 task.submission_id,
203 context.config().id,
204 turn_index,
205 false,
206 )
207 .await;
208 hooks.on_turn_complete(turn_index, context).await;
209
210 return Ok(crate::agent::executor::TurnResult::Continue(Some(
211 TurnEngineOutput {
212 response: response_text,
213 tool_calls: tool_results,
214 },
215 )));
216 }
217
218 if !response_text.is_empty() {
219 turn_state.memory.store_assistant(&response_text).await;
220 }
221
222 EventHelper::send_turn_completed(
223 &tx_event,
224 task.submission_id,
225 context.config().id,
226 turn_index,
227 true,
228 )
229 .await;
230 hooks.on_turn_complete(turn_index, context).await;
231
232 Ok(crate::agent::executor::TurnResult::Complete(
233 TurnEngineOutput {
234 response: response_text,
235 tool_calls: Vec::new(),
236 },
237 ))
238 }
239
240 pub async fn run_turn_stream<H>(
241 &self,
242 hooks: H,
243 task: &Task,
244 context: Arc<Context>,
245 turn_state: &mut TurnState,
246 turn_index: usize,
247 max_turns: usize,
248 ) -> Result<
249 Pin<Box<dyn Stream<Item = Result<TurnDelta, TurnEngineError>> + Send>>,
250 TurnEngineError,
251 >
252 where
253 H: AgentHooks + Clone + Send + Sync + 'static,
254 {
255 let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
256 let include_user_prompt =
257 should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
258 let messages = self
259 .build_messages(&context, task, turn_state.memory(), include_user_prompt)
260 .await;
261
262 if should_store_user(turn_state) {
263 turn_state.memory.store_user(task).await;
264 turn_state.mark_user_stored();
265 }
266
267 let (mut tx, rx) = channel::<Result<TurnDelta, TurnEngineError>>(100);
268 let engine = self.clone();
269 let context_clone = context.clone();
270 let task = task.clone();
271 let hooks = hooks.clone();
272 let memory = turn_state.memory.clone();
273 let messages = messages.clone();
274
275 spawn_future(async move {
276 let tx_event = context_clone.tx().ok();
277 EventHelper::send_turn_started(
278 &tx_event,
279 task.submission_id,
280 context_clone.config().id,
281 turn_index,
282 max_turns,
283 )
284 .await;
285 hooks.on_turn_start(turn_index, &context_clone).await;
286
287 let result = match engine.config.stream_mode {
288 StreamMode::Structured => {
289 engine
290 .stream_structured(&context_clone, &task, &memory, &mut tx, &messages)
291 .await
292 }
293 StreamMode::Tool => {
294 engine
295 .stream_with_tools(
296 &hooks,
297 &context_clone,
298 &task,
299 context_clone.tools(),
300 &memory,
301 &mut tx,
302 &messages,
303 )
304 .await
305 }
306 };
307
308 match result {
309 Ok(turn_result) => {
310 let final_turn =
311 matches!(turn_result, crate::agent::executor::TurnResult::Complete(_));
312 EventHelper::send_turn_completed(
313 &tx_event,
314 task.submission_id,
315 context_clone.config().id,
316 turn_index,
317 final_turn,
318 )
319 .await;
320 hooks.on_turn_complete(turn_index, &context_clone).await;
321 let _ = tx.send(Ok(TurnDelta::Done(turn_result))).await;
322 }
323 Err(err) => {
324 let _ = tx.send(Err(err)).await;
325 }
326 }
327 });
328
329 Ok(receiver_into_stream(rx))
330 }
331
332 async fn stream_structured(
333 &self,
334 context: &Context,
335 task: &Task,
336 memory: &MemoryAdapter,
337 tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
338 messages: &[ChatMessage],
339 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
340 let mut stream = self.get_structured_stream(context, messages).await?;
341 let mut response_text = String::new();
342
343 while let Some(chunk_result) = stream.next().await {
344 let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
345 let content = chunk
346 .choices
347 .first()
348 .and_then(|choice| choice.delta.content.as_ref())
349 .map_or("", |value| value)
350 .to_string();
351
352 if content.is_empty() {
353 continue;
354 }
355
356 response_text.push_str(&content);
357
358 let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
359
360 let tx_event = context.tx().ok();
361 EventHelper::send_stream_chunk(
362 &tx_event,
363 task.submission_id,
364 StreamChunk::Text(content),
365 )
366 .await;
367 }
368
369 if !response_text.is_empty() {
370 memory.store_assistant(&response_text).await;
371 }
372
373 Ok(crate::agent::executor::TurnResult::Complete(
374 TurnEngineOutput {
375 response: response_text,
376 tool_calls: Vec::new(),
377 },
378 ))
379 }
380
381 #[allow(clippy::too_many_arguments)]
382 async fn stream_with_tools<H: AgentHooks>(
383 &self,
384 hooks: &H,
385 context: &Context,
386 task: &Task,
387 tools: &[Box<dyn ToolT>],
388 memory: &MemoryAdapter,
389 tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
390 messages: &[ChatMessage],
391 ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
392 let mut stream = self.get_tool_stream(context, messages, tools).await?;
393 let mut response_text = String::new();
394 let mut tool_calls = Vec::new();
395 let mut tool_call_ids = HashSet::new();
396
397 while let Some(chunk_result) = stream.next().await {
398 let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
399 let chunk_clone = chunk.clone();
400
401 match chunk {
402 StreamChunk::Text(content) => {
403 response_text.push_str(&content);
404 let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
405 }
406 StreamChunk::ToolUseComplete {
407 index: _,
408 tool_call,
409 } => {
410 if tool_call_ids.insert(tool_call.id.clone()) {
411 tool_calls.push(tool_call.clone());
412 let tx_event = context.tx().ok();
413 EventHelper::send_stream_tool_call(
414 &tx_event,
415 task.submission_id,
416 serde_json::to_value(tool_call).unwrap_or(Value::Null),
417 )
418 .await;
419 }
420 }
421 StreamChunk::Usage(_) => {}
422 _ => {}
423 }
424
425 let tx_event = context.tx().ok();
426 EventHelper::send_stream_chunk(&tx_event, task.submission_id, chunk_clone).await;
427 }
428
429 if tool_calls.is_empty() {
430 if !response_text.is_empty() {
431 memory.store_assistant(&response_text).await;
432 }
433 return Ok(crate::agent::executor::TurnResult::Complete(
434 TurnEngineOutput {
435 response: response_text,
436 tool_calls: Vec::new(),
437 },
438 ));
439 }
440
441 let tx_event = context.tx().ok();
442 let tool_results = process_tool_calls_with_hooks(
443 hooks,
444 context,
445 task.submission_id,
446 tools,
447 &tool_calls,
448 &tx_event,
449 )
450 .await;
451
452 memory
453 .store_tool_interaction(&tool_calls, &tool_results, &response_text)
454 .await;
455 record_tool_calls_state(context, &tool_results);
456
457 let _ = tx
458 .send(Ok(TurnDelta::ToolResults(tool_results.clone())))
459 .await;
460
461 Ok(crate::agent::executor::TurnResult::Continue(Some(
462 TurnEngineOutput {
463 response: response_text,
464 tool_calls: tool_results,
465 },
466 )))
467 }
468
469 async fn get_llm_response(
470 &self,
471 context: &Context,
472 messages: &[ChatMessage],
473 tools: &[Box<dyn ToolT>],
474 ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, TurnEngineError> {
475 let llm = context.llm();
476 let output_schema = context.config().output_schema.clone();
477
478 if matches!(self.config.tool_mode, ToolMode::Enabled) && !tools.is_empty() {
479 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
480 llm.chat_with_tools(messages, Some(&tools_serialized), output_schema)
481 .await
482 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
483 } else {
484 llm.chat(messages, output_schema)
485 .await
486 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
487 }
488 }
489
490 async fn get_structured_stream(
491 &self,
492 context: &Context,
493 messages: &[ChatMessage],
494 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, TurnEngineError>
495 {
496 context
497 .llm()
498 .chat_stream_struct(messages, None, context.config().output_schema.clone())
499 .await
500 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
501 }
502
503 async fn get_tool_stream(
504 &self,
505 context: &Context,
506 messages: &[ChatMessage],
507 tools: &[Box<dyn ToolT>],
508 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, TurnEngineError>
509 {
510 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
511 context
512 .llm()
513 .chat_stream_with_tools(
514 messages,
515 if tools_serialized.is_empty() {
516 None
517 } else {
518 Some(&tools_serialized)
519 },
520 context.config().output_schema.clone(),
521 )
522 .await
523 .map_err(|e| TurnEngineError::LLMError(e.to_string()))
524 }
525
526 async fn build_messages(
527 &self,
528 context: &Context,
529 task: &Task,
530 memory: &MemoryAdapter,
531 include_user_prompt: bool,
532 ) -> Vec<ChatMessage> {
533 let system_prompt = task
534 .system_prompt
535 .as_deref()
536 .unwrap_or(&context.config().description);
537 let mut messages = vec![ChatMessage {
538 role: ChatRole::System,
539 message_type: MessageType::Text,
540 content: system_prompt.to_string(),
541 }];
542
543 let recalled = memory.recall_messages(task).await;
544 messages.extend(recalled);
545
546 if include_user_prompt {
547 messages.push(user_message(task));
548 }
549
550 messages
551 }
552}
553
554pub fn record_task_state(context: &Context, task: &Task) {
555 let state = context.state();
556 #[cfg(not(target_arch = "wasm32"))]
557 if let Ok(mut guard) = state.try_lock() {
558 guard.record_task(task.clone());
559 };
560 #[cfg(target_arch = "wasm32")]
561 if let Some(mut guard) = state.try_lock() {
562 guard.record_task(task.clone());
563 };
564}
565
566fn user_message(task: &Task) -> ChatMessage {
567 if let Some((mime, image_data)) = &task.image {
568 ChatMessage {
569 role: ChatRole::User,
570 message_type: MessageType::Image(((*mime).into(), image_data.clone())),
571 content: task.prompt.clone(),
572 }
573 } else {
574 ChatMessage {
575 role: ChatRole::User,
576 message_type: MessageType::Text,
577 content: task.prompt.clone(),
578 }
579 }
580}
581
582fn should_include_user_prompt(memory: &MemoryAdapter, stored_user: bool) -> bool {
583 if !memory.is_enabled() {
584 return true;
585 }
586 if !memory.policy().recall {
587 return true;
588 }
589 if !memory.policy().store_user {
590 return true;
591 }
592 !stored_user
593}
594
595fn should_store_user(turn_state: &TurnState) -> bool {
596 if !turn_state.memory.is_enabled() {
597 return false;
598 }
599 if !turn_state.memory.policy().store_user {
600 return false;
601 }
602 !turn_state.stored_user
603}
604
605fn normalize_max_turns(max_turns: usize, fallback: usize) -> usize {
606 if max_turns == 0 {
607 return fallback.max(1);
608 }
609 max_turns
610}
611
612fn record_tool_calls_state(context: &Context, tool_results: &[ToolCallResult]) {
613 if tool_results.is_empty() {
614 return;
615 }
616 let state = context.state();
617 #[cfg(not(target_arch = "wasm32"))]
618 if let Ok(mut guard) = state.try_lock() {
619 for result in tool_results {
620 guard.record_tool_call(result.clone());
621 }
622 };
623 #[cfg(target_arch = "wasm32")]
624 if let Some(mut guard) = state.try_lock() {
625 for result in tool_results {
626 guard.record_tool_call(result.clone());
627 }
628 };
629}
630
631async fn process_tool_calls_with_hooks<H: AgentHooks>(
632 hooks: &H,
633 context: &Context,
634 submission_id: SubmissionId,
635 tools: &[Box<dyn ToolT>],
636 tool_calls: &[ToolCall],
637 tx_event: &Option<mpsc::Sender<Event>>,
638) -> Vec<ToolCallResult> {
639 let mut results = Vec::new();
640 for call in tool_calls {
641 if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
642 hooks,
643 context,
644 submission_id,
645 tools,
646 call,
647 tx_event,
648 )
649 .await
650 {
651 results.push(result);
652 }
653 }
654 results
655}