1use neuron_tool::ToolRegistry;
8use futures::StreamExt;
9
10use neuron_types::{
11 CompletionRequest, ContentBlock, ContentItem, ContextStrategy, HookAction, LoopError, Message,
12 Provider, Role, StopReason, StreamError, StreamEvent, TokenUsage, ToolContext, ToolOutput,
13};
14
15use crate::loop_impl::{
16 accumulate_usage, extract_text, fire_compaction_hooks, fire_loop_iteration_hooks,
17 fire_post_llm_hooks, fire_post_tool_hooks, fire_pre_llm_hooks, fire_pre_tool_hooks,
18 AgentLoop, AgentResult, DEFAULT_ACTIVITY_TIMEOUT,
19};
20
21#[derive(Debug)]
23pub enum TurnResult {
24 ToolsExecuted {
26 calls: Vec<(String, String, serde_json::Value)>,
28 results: Vec<ToolOutput>,
30 },
31 FinalResponse(AgentResult),
33 CompactionOccurred {
35 old_tokens: usize,
37 new_tokens: usize,
39 },
40 MaxTurnsReached,
42 Error(LoopError),
44}
45
46pub struct StepIterator<'a, P: Provider, C: ContextStrategy> {
54 loop_ref: &'a mut AgentLoop<P, C>,
55 tool_ctx: &'a ToolContext,
56 total_usage: TokenUsage,
57 turns: usize,
58 finished: bool,
59}
60
61impl<'a, P: Provider, C: ContextStrategy> StepIterator<'a, P, C> {
62 pub async fn next(&mut self) -> Option<TurnResult> {
67 if self.finished {
68 return None;
69 }
70
71 if self.tool_ctx.cancellation_token.is_cancelled() {
73 self.finished = true;
74 return Some(TurnResult::Error(LoopError::Cancelled));
75 }
76
77 if let Some(max) = self.loop_ref.config.max_turns
79 && self.turns >= max
80 {
81 self.finished = true;
82 return Some(TurnResult::MaxTurnsReached);
83 }
84
85 match fire_loop_iteration_hooks(&self.loop_ref.hooks, self.turns).await {
87 Ok(Some(HookAction::Terminate { reason })) => {
88 self.finished = true;
89 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
90 }
91 Err(e) => {
92 self.finished = true;
93 return Some(TurnResult::Error(e));
94 }
95 _ => {}
96 }
97
98 let token_count = self.loop_ref.context.token_estimate(&self.loop_ref.messages);
100 if self
101 .loop_ref
102 .context
103 .should_compact(&self.loop_ref.messages, token_count)
104 {
105 let old_tokens = token_count;
106 match self
107 .loop_ref
108 .context
109 .compact(self.loop_ref.messages.clone())
110 .await
111 {
112 Ok(compacted) => {
113 self.loop_ref.messages = compacted;
114 let new_tokens =
115 self.loop_ref.context.token_estimate(&self.loop_ref.messages);
116
117 match fire_compaction_hooks(&self.loop_ref.hooks, old_tokens, new_tokens).await
119 {
120 Ok(Some(HookAction::Terminate { reason })) => {
121 self.finished = true;
122 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
123 }
124 Err(e) => {
125 self.finished = true;
126 return Some(TurnResult::Error(e));
127 }
128 _ => {}
129 }
130
131 return Some(TurnResult::CompactionOccurred {
132 old_tokens,
133 new_tokens,
134 });
135 }
136 Err(e) => {
137 self.finished = true;
138 return Some(TurnResult::Error(e.into()));
139 }
140 }
141 }
142
143 let request = CompletionRequest {
145 model: String::new(),
146 messages: self.loop_ref.messages.clone(),
147 system: Some(self.loop_ref.config.system_prompt.clone()),
148 tools: self.loop_ref.tools.definitions(),
149 ..Default::default()
150 };
151
152 match fire_pre_llm_hooks(&self.loop_ref.hooks, &request).await {
154 Ok(Some(HookAction::Terminate { reason })) => {
155 self.finished = true;
156 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
157 }
158 Err(e) => {
159 self.finished = true;
160 return Some(TurnResult::Error(e));
161 }
162 _ => {}
163 }
164
165 let response = if let Some(ref durable) = self.loop_ref.durability {
167 let options = neuron_types::ActivityOptions {
168 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
169 heartbeat_timeout: None,
170 retry_policy: None,
171 };
172 match durable.0.erased_execute_llm_call(request, options).await {
173 Ok(r) => r,
174 Err(e) => {
175 self.finished = true;
176 return Some(TurnResult::Error(
177 neuron_types::ProviderError::Other(Box::new(e)).into(),
178 ));
179 }
180 }
181 } else {
182 match self.loop_ref.provider.complete(request).await {
183 Ok(r) => r,
184 Err(e) => {
185 self.finished = true;
186 return Some(TurnResult::Error(e.into()));
187 }
188 }
189 };
190
191 match fire_post_llm_hooks(&self.loop_ref.hooks, &response).await {
193 Ok(Some(HookAction::Terminate { reason })) => {
194 self.finished = true;
195 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
196 }
197 Err(e) => {
198 self.finished = true;
199 return Some(TurnResult::Error(e));
200 }
201 _ => {}
202 }
203
204 accumulate_usage(&mut self.total_usage, &response.usage);
206 self.turns += 1;
207
208 let tool_calls: Vec<_> = response
210 .message
211 .content
212 .iter()
213 .filter_map(|block| {
214 if let ContentBlock::ToolUse { id, name, input } = block {
215 Some((id.clone(), name.clone(), input.clone()))
216 } else {
217 None
218 }
219 })
220 .collect();
221
222 self.loop_ref.messages.push(response.message.clone());
224
225 if response.stop_reason == StopReason::Compaction {
228 return Some(TurnResult::CompactionOccurred {
229 old_tokens: 0,
230 new_tokens: 0,
231 });
232 }
233
234 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
235 self.finished = true;
236 let response_text = extract_text(&response.message);
237 return Some(TurnResult::FinalResponse(AgentResult {
238 response: response_text,
239 messages: self.loop_ref.messages.clone(),
240 usage: self.total_usage.clone(),
241 turns: self.turns,
242 }));
243 }
244
245 if self.tool_ctx.cancellation_token.is_cancelled() {
247 self.finished = true;
248 return Some(TurnResult::Error(LoopError::Cancelled));
249 }
250
251 let mut tool_result_blocks = Vec::new();
253 let mut tool_outputs = Vec::new();
254
255 if self.loop_ref.config.parallel_tool_execution && tool_calls.len() > 1 {
256 let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
257 self.loop_ref.execute_single_tool(call_id, tool_name, input, self.tool_ctx)
258 });
259 let results = futures::future::join_all(futs).await;
260 for result in results {
261 match result {
262 Ok(block) => {
263 if let ContentBlock::ToolResult { content, is_error, .. } = &block {
265 tool_outputs.push(ToolOutput {
266 content: content.clone(),
267 structured_content: None,
268 is_error: *is_error,
269 });
270 }
271 tool_result_blocks.push(block);
272 }
273 Err(e) => {
274 self.finished = true;
275 return Some(TurnResult::Error(e));
276 }
277 }
278 }
279 } else {
280 for (call_id, tool_name, input) in &tool_calls {
281 match self.loop_ref.execute_single_tool(call_id, tool_name, input, self.tool_ctx).await {
282 Ok(block) => {
283 if let ContentBlock::ToolResult { content, is_error, .. } = &block {
284 tool_outputs.push(ToolOutput {
285 content: content.clone(),
286 structured_content: None,
287 is_error: *is_error,
288 });
289 }
290 tool_result_blocks.push(block);
291 }
292 Err(e) => {
293 self.finished = true;
294 return Some(TurnResult::Error(e));
295 }
296 }
297 }
298 }
299
300 self.loop_ref.messages.push(Message {
302 role: Role::User,
303 content: tool_result_blocks,
304 });
305
306 Some(TurnResult::ToolsExecuted {
307 calls: tool_calls,
308 results: tool_outputs,
309 })
310 }
311
312 #[must_use]
314 pub fn messages(&self) -> &[Message] {
315 &self.loop_ref.messages
316 }
317
318 pub fn inject_message(&mut self, message: Message) {
320 self.loop_ref.messages.push(message);
321 }
322
323 #[must_use]
325 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
326 &mut self.loop_ref.tools
327 }
328}
329
330impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
331 #[must_use]
340 pub fn run_step<'a>(
341 &'a mut self,
342 user_message: Message,
343 tool_ctx: &'a ToolContext,
344 ) -> StepIterator<'a, P, C> {
345 self.messages.push(user_message);
346 StepIterator {
347 loop_ref: self,
348 tool_ctx,
349 total_usage: TokenUsage::default(),
350 turns: 0,
351 finished: false,
352 }
353 }
354
355 pub async fn run_stream(
374 &mut self,
375 user_message: Message,
376 tool_ctx: &ToolContext,
377 ) -> tokio::sync::mpsc::Receiver<StreamEvent> {
378 let (tx, rx) = tokio::sync::mpsc::channel(64);
379 self.messages.push(user_message);
380
381 let mut turns: usize = 0;
382
383 loop {
384 if tool_ctx.cancellation_token.is_cancelled() {
386 let _ = tx
387 .send(StreamEvent::Error(StreamError::non_retryable(
388 "cancelled",
389 )))
390 .await;
391 break;
392 }
393
394 if let Some(max) = self.config.max_turns
396 && turns >= max
397 {
398 let _ = tx
399 .send(StreamEvent::Error(StreamError::non_retryable(format!(
400 "max turns reached ({max})"
401 ))))
402 .await;
403 break;
404 }
405
406 match fire_loop_iteration_hooks(&self.hooks, turns).await {
408 Ok(Some(HookAction::Terminate { reason })) => {
409 let _ = tx
410 .send(StreamEvent::Error(StreamError::non_retryable(format!(
411 "hook terminated: {reason}"
412 ))))
413 .await;
414 break;
415 }
416 Err(e) => {
417 let _ = tx
418 .send(StreamEvent::Error(StreamError::non_retryable(format!(
419 "hook error: {e}"
420 ))))
421 .await;
422 break;
423 }
424 _ => {}
425 }
426
427 let token_count = self.context.token_estimate(&self.messages);
429 if self.context.should_compact(&self.messages, token_count) {
430 let old_tokens = token_count;
431 match self.context.compact(self.messages.clone()).await {
432 Ok(compacted) => {
433 self.messages = compacted;
434 let new_tokens = self.context.token_estimate(&self.messages);
435
436 match fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await {
438 Ok(Some(HookAction::Terminate { reason })) => {
439 let _ = tx
440 .send(StreamEvent::Error(StreamError::non_retryable(format!(
441 "hook terminated: {reason}"
442 ))))
443 .await;
444 break;
445 }
446 Err(e) => {
447 let _ = tx
448 .send(StreamEvent::Error(StreamError::non_retryable(format!(
449 "hook error: {e}"
450 ))))
451 .await;
452 break;
453 }
454 _ => {}
455 }
456 }
457 Err(e) => {
458 let _ = tx
459 .send(StreamEvent::Error(StreamError::non_retryable(format!(
460 "compaction error: {e}"
461 ))))
462 .await;
463 break;
464 }
465 }
466 }
467
468 let request = CompletionRequest {
470 model: String::new(),
471 messages: self.messages.clone(),
472 system: Some(self.config.system_prompt.clone()),
473 tools: self.tools.definitions(),
474 ..Default::default()
475 };
476
477 match fire_pre_llm_hooks(&self.hooks, &request).await {
479 Ok(Some(HookAction::Terminate { reason })) => {
480 let _ = tx
481 .send(StreamEvent::Error(StreamError::non_retryable(format!(
482 "hook terminated: {reason}"
483 ))))
484 .await;
485 break;
486 }
487 Err(e) => {
488 let _ = tx
489 .send(StreamEvent::Error(StreamError::non_retryable(format!(
490 "hook error: {e}"
491 ))))
492 .await;
493 break;
494 }
495 _ => {}
496 }
497
498 let message = if let Some(ref durable) = self.durability {
501 let options = neuron_types::ActivityOptions {
503 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
504 heartbeat_timeout: None,
505 retry_policy: None,
506 };
507 let response = match durable.0.erased_execute_llm_call(request, options).await {
508 Ok(r) => r,
509 Err(e) => {
510 let _ = tx
511 .send(StreamEvent::Error(StreamError::non_retryable(format!(
512 "durable error: {e}"
513 ))))
514 .await;
515 break;
516 }
517 };
518
519 for block in &response.message.content {
521 if let ContentBlock::Text(text) = block
522 && tx.send(StreamEvent::TextDelta(text.clone())).await.is_err()
523 {
524 return rx;
525 }
526 }
527 if tx.send(StreamEvent::Usage(response.usage.clone())).await.is_err() {
528 return rx;
529 }
530 if tx
531 .send(StreamEvent::MessageComplete(response.message.clone()))
532 .await
533 .is_err()
534 {
535 return rx;
536 }
537
538 match fire_post_llm_hooks(&self.hooks, &response).await {
540 Ok(Some(HookAction::Terminate { reason })) => {
541 let _ = tx
542 .send(StreamEvent::Error(StreamError::non_retryable(format!(
543 "hook terminated: {reason}"
544 ))))
545 .await;
546 break;
547 }
548 Err(e) => {
549 let _ = tx
550 .send(StreamEvent::Error(StreamError::non_retryable(format!(
551 "hook error: {e}"
552 ))))
553 .await;
554 break;
555 }
556 _ => {}
557 }
558
559 response.message
560 } else {
561 let stream_handle = match self.provider.complete_stream(request).await {
563 Ok(h) => h,
564 Err(e) => {
565 let _ = tx
566 .send(StreamEvent::Error(StreamError::non_retryable(format!(
567 "provider error: {e}"
568 ))))
569 .await;
570 break;
571 }
572 };
573
574 let mut assembled_message: Option<Message> = None;
576 let mut assembled_response: Option<neuron_types::CompletionResponse> = None;
577 let mut stream = stream_handle.receiver;
578
579 while let Some(event) = stream.next().await {
580 match &event {
581 StreamEvent::MessageComplete(msg) => {
582 assembled_message = Some(msg.clone());
583 }
584 StreamEvent::Usage(u) => {
585 assembled_response = Some(neuron_types::CompletionResponse {
587 id: String::new(),
588 model: String::new(),
589 message: assembled_message.clone().unwrap_or(Message {
590 role: Role::Assistant,
591 content: vec![],
592 }),
593 usage: u.clone(),
594 stop_reason: StopReason::EndTurn,
595 });
596 }
597 _ => {}
598 }
599 if tx.send(event).await.is_err() {
601 return rx;
603 }
604 }
605
606 let msg = match assembled_message {
608 Some(m) => m,
609 None => {
610 let _ = tx
611 .send(StreamEvent::Error(StreamError::non_retryable(
612 "stream ended without MessageComplete",
613 )))
614 .await;
615 break;
616 }
617 };
618
619 if let Some(mut resp) = assembled_response {
621 resp.message = msg.clone();
622 match fire_post_llm_hooks(&self.hooks, &resp).await {
623 Ok(Some(HookAction::Terminate { reason })) => {
624 let _ = tx
625 .send(StreamEvent::Error(StreamError::non_retryable(format!(
626 "hook terminated: {reason}"
627 ))))
628 .await;
629 break;
630 }
631 Err(e) => {
632 let _ = tx
633 .send(StreamEvent::Error(StreamError::non_retryable(format!(
634 "hook error: {e}"
635 ))))
636 .await;
637 break;
638 }
639 _ => {}
640 }
641 } else {
642 let resp = neuron_types::CompletionResponse {
644 id: String::new(),
645 model: String::new(),
646 message: msg.clone(),
647 usage: TokenUsage::default(),
648 stop_reason: StopReason::EndTurn,
649 };
650 match fire_post_llm_hooks(&self.hooks, &resp).await {
651 Ok(Some(HookAction::Terminate { reason })) => {
652 let _ = tx
653 .send(StreamEvent::Error(StreamError::non_retryable(format!(
654 "hook terminated: {reason}"
655 ))))
656 .await;
657 break;
658 }
659 Err(e) => {
660 let _ = tx
661 .send(StreamEvent::Error(StreamError::non_retryable(format!(
662 "hook error: {e}"
663 ))))
664 .await;
665 break;
666 }
667 _ => {}
668 }
669 }
670
671 msg
672 };
673
674 turns += 1;
675
676 let tool_calls: Vec<_> = message
678 .content
679 .iter()
680 .filter_map(|block| {
681 if let ContentBlock::ToolUse { id, name, input } = block {
682 Some((id.clone(), name.clone(), input.clone()))
683 } else {
684 None
685 }
686 })
687 .collect();
688
689 self.messages.push(message.clone());
690
691 if tool_calls.is_empty() {
697 break;
699 }
700
701 if tool_ctx.cancellation_token.is_cancelled() {
703 let _ = tx
704 .send(StreamEvent::Error(StreamError::non_retryable(
705 "cancelled",
706 )))
707 .await;
708 break;
709 }
710
711 let mut tool_result_blocks = Vec::new();
713 for (call_id, tool_name, input) in &tool_calls {
714 match fire_pre_tool_hooks(&self.hooks, tool_name, input).await {
716 Ok(Some(HookAction::Terminate { reason })) => {
717 let _ = tx
718 .send(StreamEvent::Error(StreamError::non_retryable(format!(
719 "hook terminated: {reason}"
720 ))))
721 .await;
722 return rx;
723 }
724 Ok(Some(HookAction::Skip { reason })) => {
725 tool_result_blocks.push(ContentBlock::ToolResult {
726 tool_use_id: call_id.clone(),
727 content: vec![ContentItem::Text(format!(
728 "Tool call skipped: {reason}"
729 ))],
730 is_error: true,
731 });
732 continue;
733 }
734 Err(e) => {
735 let _ = tx
736 .send(StreamEvent::Error(StreamError::non_retryable(format!(
737 "hook error: {e}"
738 ))))
739 .await;
740 return rx;
741 }
742 _ => {}
743 }
744
745 let result = if let Some(ref durable) = self.durability {
747 let options = neuron_types::ActivityOptions {
748 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
749 heartbeat_timeout: None,
750 retry_policy: None,
751 };
752 match durable
753 .0
754 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
755 .await
756 {
757 Ok(r) => r,
758 Err(e) => {
759 let _ = tx
760 .send(StreamEvent::Error(StreamError::non_retryable(format!(
761 "durable tool error: {e}"
762 ))))
763 .await;
764 return rx;
765 }
766 }
767 } else {
768 match self.tools.execute(tool_name, input.clone(), tool_ctx).await {
769 Ok(r) => r,
770 Err(neuron_types::ToolError::ModelRetry(hint)) => {
771 ToolOutput {
774 content: vec![ContentItem::Text(hint)],
775 structured_content: None,
776 is_error: true,
777 }
778 }
779 Err(e) => {
780 let _ = tx
781 .send(StreamEvent::Error(StreamError::non_retryable(format!(
782 "tool error: {e}"
783 ))))
784 .await;
785 return rx;
786 }
787 }
788 };
789
790 match fire_post_tool_hooks(&self.hooks, tool_name, &result).await {
792 Ok(Some(HookAction::Terminate { reason })) => {
793 let _ = tx
794 .send(StreamEvent::Error(StreamError::non_retryable(format!(
795 "hook terminated: {reason}"
796 ))))
797 .await;
798 return rx;
799 }
800 Err(e) => {
801 let _ = tx
802 .send(StreamEvent::Error(StreamError::non_retryable(format!(
803 "hook error: {e}"
804 ))))
805 .await;
806 return rx;
807 }
808 _ => {}
809 }
810
811 tool_result_blocks.push(ContentBlock::ToolResult {
812 tool_use_id: call_id.clone(),
813 content: result.content,
814 is_error: result.is_error,
815 });
816 }
817
818 self.messages.push(Message {
819 role: Role::User,
820 content: tool_result_blocks,
821 });
822 }
823
824 rx
825 }
826}