1use futures::StreamExt;
8use neuron_tool::ToolRegistry;
9
10use neuron_types::{
11 CompletionRequest, ContentBlock, ContentItem, ContextStrategy, HookAction, LoopError, Message,
12 Provider, Role, StopReason, StreamError, StreamEvent, TokenUsage, ToolContext, ToolOutput,
13};
14
15use crate::loop_impl::{
16 AgentLoop, AgentResult, DEFAULT_ACTIVITY_TIMEOUT, accumulate_usage, check_request_limit,
17 check_token_limits, check_tool_call_limit, extract_text, fire_compaction_hooks,
18 fire_loop_iteration_hooks, fire_post_llm_hooks, fire_post_tool_hooks, fire_pre_llm_hooks,
19 fire_pre_tool_hooks,
20};
21
22#[derive(Debug)]
24pub enum TurnResult {
25 ToolsExecuted {
27 calls: Vec<(String, String, serde_json::Value)>,
29 results: Vec<ToolOutput>,
31 },
32 FinalResponse(AgentResult),
34 CompactionOccurred {
36 old_tokens: usize,
38 new_tokens: usize,
40 },
41 MaxTurnsReached,
43 Error(LoopError),
45}
46
47pub struct StepIterator<'a, P: Provider, C: ContextStrategy> {
55 loop_ref: &'a mut AgentLoop<P, C>,
56 tool_ctx: &'a ToolContext,
57 total_usage: TokenUsage,
58 turns: usize,
59 request_count: usize,
60 tool_call_count: usize,
61 finished: bool,
62}
63
64impl<'a, P: Provider, C: ContextStrategy> StepIterator<'a, P, C> {
65 pub async fn next(&mut self) -> Option<TurnResult> {
70 if self.finished {
71 return None;
72 }
73
74 if self.tool_ctx.cancellation_token.is_cancelled() {
76 self.finished = true;
77 return Some(TurnResult::Error(LoopError::Cancelled));
78 }
79
80 if let Some(max) = self.loop_ref.config.max_turns
82 && self.turns >= max
83 {
84 self.finished = true;
85 return Some(TurnResult::MaxTurnsReached);
86 }
87
88 if let Some(ref limits) = self.loop_ref.config.usage_limits
90 && let Err(e) = check_request_limit(limits, self.request_count)
91 {
92 self.finished = true;
93 return Some(TurnResult::Error(e));
94 }
95
96 match fire_loop_iteration_hooks(&self.loop_ref.hooks, self.turns).await {
98 Ok(Some(HookAction::Terminate { reason })) => {
99 self.finished = true;
100 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
101 }
102 Err(e) => {
103 self.finished = true;
104 return Some(TurnResult::Error(e));
105 }
106 _ => {}
107 }
108
109 let token_count = self
111 .loop_ref
112 .context
113 .token_estimate(&self.loop_ref.messages);
114 if self
115 .loop_ref
116 .context
117 .should_compact(&self.loop_ref.messages, token_count)
118 {
119 let old_tokens = token_count;
120 match self
121 .loop_ref
122 .context
123 .compact(self.loop_ref.messages.clone())
124 .await
125 {
126 Ok(compacted) => {
127 self.loop_ref.messages = compacted;
128 let new_tokens = self
129 .loop_ref
130 .context
131 .token_estimate(&self.loop_ref.messages);
132
133 match fire_compaction_hooks(&self.loop_ref.hooks, old_tokens, new_tokens).await
135 {
136 Ok(Some(HookAction::Terminate { reason })) => {
137 self.finished = true;
138 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
139 }
140 Err(e) => {
141 self.finished = true;
142 return Some(TurnResult::Error(e));
143 }
144 _ => {}
145 }
146
147 return Some(TurnResult::CompactionOccurred {
148 old_tokens,
149 new_tokens,
150 });
151 }
152 Err(e) => {
153 self.finished = true;
154 return Some(TurnResult::Error(e.into()));
155 }
156 }
157 }
158
159 let request = CompletionRequest {
161 model: String::new(),
162 messages: self.loop_ref.messages.clone(),
163 system: Some(self.loop_ref.config.system_prompt.clone()),
164 tools: self.loop_ref.tools.definitions(),
165 ..Default::default()
166 };
167
168 match fire_pre_llm_hooks(&self.loop_ref.hooks, &request).await {
170 Ok(Some(HookAction::Terminate { reason })) => {
171 self.finished = true;
172 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
173 }
174 Err(e) => {
175 self.finished = true;
176 return Some(TurnResult::Error(e));
177 }
178 _ => {}
179 }
180
181 let response = if let Some(ref durable) = self.loop_ref.durability {
183 let options = neuron_types::ActivityOptions {
184 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
185 heartbeat_timeout: None,
186 retry_policy: None,
187 };
188 match durable.0.erased_execute_llm_call(request, options).await {
189 Ok(r) => r,
190 Err(e) => {
191 self.finished = true;
192 return Some(TurnResult::Error(
193 neuron_types::ProviderError::Other(Box::new(e)).into(),
194 ));
195 }
196 }
197 } else {
198 match self.loop_ref.provider.complete(request).await {
199 Ok(r) => r,
200 Err(e) => {
201 self.finished = true;
202 return Some(TurnResult::Error(e.into()));
203 }
204 }
205 };
206
207 match fire_post_llm_hooks(&self.loop_ref.hooks, &response).await {
209 Ok(Some(HookAction::Terminate { reason })) => {
210 self.finished = true;
211 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
212 }
213 Err(e) => {
214 self.finished = true;
215 return Some(TurnResult::Error(e));
216 }
217 _ => {}
218 }
219
220 accumulate_usage(&mut self.total_usage, &response.usage);
222 self.request_count += 1;
223 self.turns += 1;
224
225 if let Some(ref limits) = self.loop_ref.config.usage_limits
227 && let Err(e) = check_token_limits(limits, &self.total_usage)
228 {
229 self.finished = true;
230 return Some(TurnResult::Error(e));
231 }
232
233 let tool_calls: Vec<_> = response
235 .message
236 .content
237 .iter()
238 .filter_map(|block| {
239 if let ContentBlock::ToolUse { id, name, input } = block {
240 Some((id.clone(), name.clone(), input.clone()))
241 } else {
242 None
243 }
244 })
245 .collect();
246
247 self.loop_ref.messages.push(response.message.clone());
249
250 if response.stop_reason == StopReason::Compaction {
253 return Some(TurnResult::CompactionOccurred {
254 old_tokens: 0,
255 new_tokens: 0,
256 });
257 }
258
259 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
260 self.finished = true;
261 let response_text = extract_text(&response.message);
262 return Some(TurnResult::FinalResponse(AgentResult {
263 response: response_text,
264 messages: self.loop_ref.messages.clone(),
265 usage: self.total_usage.clone(),
266 turns: self.turns,
267 }));
268 }
269
270 if self.tool_ctx.cancellation_token.is_cancelled() {
272 self.finished = true;
273 return Some(TurnResult::Error(LoopError::Cancelled));
274 }
275
276 if let Some(ref limits) = self.loop_ref.config.usage_limits
278 && let Err(e) = check_tool_call_limit(limits, self.tool_call_count, tool_calls.len())
279 {
280 self.finished = true;
281 return Some(TurnResult::Error(e));
282 }
283 self.tool_call_count += tool_calls.len();
284
285 let mut tool_result_blocks = Vec::new();
287 let mut tool_outputs = Vec::new();
288
289 if self.loop_ref.config.parallel_tool_execution && tool_calls.len() > 1 {
290 let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
291 self.loop_ref
292 .execute_single_tool(call_id, tool_name, input, self.tool_ctx)
293 });
294 let results = futures::future::join_all(futs).await;
295 for result in results {
296 match result {
297 Ok(block) => {
298 if let ContentBlock::ToolResult {
300 content, is_error, ..
301 } = &block
302 {
303 tool_outputs.push(ToolOutput {
304 content: content.clone(),
305 structured_content: None,
306 is_error: *is_error,
307 });
308 }
309 tool_result_blocks.push(block);
310 }
311 Err(e) => {
312 self.finished = true;
313 return Some(TurnResult::Error(e));
314 }
315 }
316 }
317 } else {
318 for (call_id, tool_name, input) in &tool_calls {
319 match self
320 .loop_ref
321 .execute_single_tool(call_id, tool_name, input, self.tool_ctx)
322 .await
323 {
324 Ok(block) => {
325 if let ContentBlock::ToolResult {
326 content, is_error, ..
327 } = &block
328 {
329 tool_outputs.push(ToolOutput {
330 content: content.clone(),
331 structured_content: None,
332 is_error: *is_error,
333 });
334 }
335 tool_result_blocks.push(block);
336 }
337 Err(e) => {
338 self.finished = true;
339 return Some(TurnResult::Error(e));
340 }
341 }
342 }
343 }
344
345 self.loop_ref.messages.push(Message {
347 role: Role::User,
348 content: tool_result_blocks,
349 });
350
351 Some(TurnResult::ToolsExecuted {
352 calls: tool_calls,
353 results: tool_outputs,
354 })
355 }
356
357 #[must_use]
359 pub fn messages(&self) -> &[Message] {
360 &self.loop_ref.messages
361 }
362
363 pub fn inject_message(&mut self, message: Message) {
365 self.loop_ref.messages.push(message);
366 }
367
368 #[must_use]
370 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
371 &mut self.loop_ref.tools
372 }
373}
374
375impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
376 #[must_use]
385 pub fn run_step<'a>(
386 &'a mut self,
387 user_message: Message,
388 tool_ctx: &'a ToolContext,
389 ) -> StepIterator<'a, P, C> {
390 self.messages.push(user_message);
391 StepIterator {
392 loop_ref: self,
393 tool_ctx,
394 total_usage: TokenUsage::default(),
395 turns: 0,
396 request_count: 0,
397 tool_call_count: 0,
398 finished: false,
399 }
400 }
401
402 pub async fn run_stream(
421 &mut self,
422 user_message: Message,
423 tool_ctx: &ToolContext,
424 ) -> tokio::sync::mpsc::Receiver<StreamEvent> {
425 let (tx, rx) = tokio::sync::mpsc::channel(64);
426 self.messages.push(user_message);
427
428 let mut turns: usize = 0;
429 let mut request_count: usize = 0;
430 let mut tool_call_count: usize = 0;
431 let mut total_usage = TokenUsage::default();
432
433 loop {
434 if tool_ctx.cancellation_token.is_cancelled() {
436 let _ = tx
437 .send(StreamEvent::Error(StreamError::non_retryable("cancelled")))
438 .await;
439 break;
440 }
441
442 if let Some(max) = self.config.max_turns
444 && turns >= max
445 {
446 let _ = tx
447 .send(StreamEvent::Error(StreamError::non_retryable(format!(
448 "max turns reached ({max})"
449 ))))
450 .await;
451 break;
452 }
453
454 if let Some(ref limits) = self.config.usage_limits
456 && let Err(e) = check_request_limit(limits, request_count)
457 {
458 let _ = tx
459 .send(StreamEvent::Error(StreamError::non_retryable(format!(
460 "{e}"
461 ))))
462 .await;
463 break;
464 }
465
466 match fire_loop_iteration_hooks(&self.hooks, turns).await {
468 Ok(Some(HookAction::Terminate { reason })) => {
469 let _ = tx
470 .send(StreamEvent::Error(StreamError::non_retryable(format!(
471 "hook terminated: {reason}"
472 ))))
473 .await;
474 break;
475 }
476 Err(e) => {
477 let _ = tx
478 .send(StreamEvent::Error(StreamError::non_retryable(format!(
479 "hook error: {e}"
480 ))))
481 .await;
482 break;
483 }
484 _ => {}
485 }
486
487 let token_count = self.context.token_estimate(&self.messages);
489 if self.context.should_compact(&self.messages, token_count) {
490 let old_tokens = token_count;
491 match self.context.compact(self.messages.clone()).await {
492 Ok(compacted) => {
493 self.messages = compacted;
494 let new_tokens = self.context.token_estimate(&self.messages);
495
496 match fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await {
498 Ok(Some(HookAction::Terminate { reason })) => {
499 let _ = tx
500 .send(StreamEvent::Error(StreamError::non_retryable(format!(
501 "hook terminated: {reason}"
502 ))))
503 .await;
504 break;
505 }
506 Err(e) => {
507 let _ = tx
508 .send(StreamEvent::Error(StreamError::non_retryable(format!(
509 "hook error: {e}"
510 ))))
511 .await;
512 break;
513 }
514 _ => {}
515 }
516 }
517 Err(e) => {
518 let _ = tx
519 .send(StreamEvent::Error(StreamError::non_retryable(format!(
520 "compaction error: {e}"
521 ))))
522 .await;
523 break;
524 }
525 }
526 }
527
528 let request = CompletionRequest {
530 model: String::new(),
531 messages: self.messages.clone(),
532 system: Some(self.config.system_prompt.clone()),
533 tools: self.tools.definitions(),
534 ..Default::default()
535 };
536
537 match fire_pre_llm_hooks(&self.hooks, &request).await {
539 Ok(Some(HookAction::Terminate { reason })) => {
540 let _ = tx
541 .send(StreamEvent::Error(StreamError::non_retryable(format!(
542 "hook terminated: {reason}"
543 ))))
544 .await;
545 break;
546 }
547 Err(e) => {
548 let _ = tx
549 .send(StreamEvent::Error(StreamError::non_retryable(format!(
550 "hook error: {e}"
551 ))))
552 .await;
553 break;
554 }
555 _ => {}
556 }
557
558 let (message, turn_usage) = if let Some(ref durable) = self.durability {
562 let options = neuron_types::ActivityOptions {
564 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
565 heartbeat_timeout: None,
566 retry_policy: None,
567 };
568 let response = match durable.0.erased_execute_llm_call(request, options).await {
569 Ok(r) => r,
570 Err(e) => {
571 let _ = tx
572 .send(StreamEvent::Error(StreamError::non_retryable(format!(
573 "durable error: {e}"
574 ))))
575 .await;
576 break;
577 }
578 };
579
580 for block in &response.message.content {
582 if let ContentBlock::Text(text) = block
583 && tx.send(StreamEvent::TextDelta(text.clone())).await.is_err()
584 {
585 return rx;
586 }
587 }
588 if tx
589 .send(StreamEvent::Usage(response.usage.clone()))
590 .await
591 .is_err()
592 {
593 return rx;
594 }
595 if tx
596 .send(StreamEvent::MessageComplete(response.message.clone()))
597 .await
598 .is_err()
599 {
600 return rx;
601 }
602
603 match fire_post_llm_hooks(&self.hooks, &response).await {
605 Ok(Some(HookAction::Terminate { reason })) => {
606 let _ = tx
607 .send(StreamEvent::Error(StreamError::non_retryable(format!(
608 "hook terminated: {reason}"
609 ))))
610 .await;
611 break;
612 }
613 Err(e) => {
614 let _ = tx
615 .send(StreamEvent::Error(StreamError::non_retryable(format!(
616 "hook error: {e}"
617 ))))
618 .await;
619 break;
620 }
621 _ => {}
622 }
623
624 (response.message, response.usage)
625 } else {
626 let stream_handle = match self.provider.complete_stream(request).await {
628 Ok(h) => h,
629 Err(e) => {
630 let _ = tx
631 .send(StreamEvent::Error(StreamError::non_retryable(format!(
632 "provider error: {e}"
633 ))))
634 .await;
635 break;
636 }
637 };
638
639 let mut assembled_message: Option<Message> = None;
641 let mut assembled_response: Option<neuron_types::CompletionResponse> = None;
642 let mut stream_usage = TokenUsage::default();
643 let mut stream = stream_handle.receiver;
644
645 while let Some(event) = stream.next().await {
646 match &event {
647 StreamEvent::MessageComplete(msg) => {
648 assembled_message = Some(msg.clone());
649 }
650 StreamEvent::Usage(u) => {
651 stream_usage = u.clone();
652 assembled_response = Some(neuron_types::CompletionResponse {
654 id: String::new(),
655 model: String::new(),
656 message: assembled_message.clone().unwrap_or(Message {
657 role: Role::Assistant,
658 content: vec![],
659 }),
660 usage: u.clone(),
661 stop_reason: StopReason::EndTurn,
662 });
663 }
664 _ => {}
665 }
666 if tx.send(event).await.is_err() {
668 return rx;
670 }
671 }
672
673 let msg = match assembled_message {
675 Some(m) => m,
676 None => {
677 let _ = tx
678 .send(StreamEvent::Error(StreamError::non_retryable(
679 "stream ended without MessageComplete",
680 )))
681 .await;
682 break;
683 }
684 };
685
686 if let Some(mut resp) = assembled_response {
688 resp.message = msg.clone();
689 match fire_post_llm_hooks(&self.hooks, &resp).await {
690 Ok(Some(HookAction::Terminate { reason })) => {
691 let _ = tx
692 .send(StreamEvent::Error(StreamError::non_retryable(format!(
693 "hook terminated: {reason}"
694 ))))
695 .await;
696 break;
697 }
698 Err(e) => {
699 let _ = tx
700 .send(StreamEvent::Error(StreamError::non_retryable(format!(
701 "hook error: {e}"
702 ))))
703 .await;
704 break;
705 }
706 _ => {}
707 }
708 } else {
709 let resp = neuron_types::CompletionResponse {
711 id: String::new(),
712 model: String::new(),
713 message: msg.clone(),
714 usage: TokenUsage::default(),
715 stop_reason: StopReason::EndTurn,
716 };
717 match fire_post_llm_hooks(&self.hooks, &resp).await {
718 Ok(Some(HookAction::Terminate { reason })) => {
719 let _ = tx
720 .send(StreamEvent::Error(StreamError::non_retryable(format!(
721 "hook terminated: {reason}"
722 ))))
723 .await;
724 break;
725 }
726 Err(e) => {
727 let _ = tx
728 .send(StreamEvent::Error(StreamError::non_retryable(format!(
729 "hook error: {e}"
730 ))))
731 .await;
732 break;
733 }
734 _ => {}
735 }
736 }
737
738 (msg, stream_usage)
739 };
740
741 request_count += 1;
742 turns += 1;
743 accumulate_usage(&mut total_usage, &turn_usage);
744
745 if let Some(ref limits) = self.config.usage_limits
747 && let Err(e) = check_token_limits(limits, &total_usage)
748 {
749 let _ = tx
750 .send(StreamEvent::Error(StreamError::non_retryable(format!(
751 "{e}"
752 ))))
753 .await;
754 break;
755 }
756
757 let tool_calls: Vec<_> = message
759 .content
760 .iter()
761 .filter_map(|block| {
762 if let ContentBlock::ToolUse { id, name, input } = block {
763 Some((id.clone(), name.clone(), input.clone()))
764 } else {
765 None
766 }
767 })
768 .collect();
769
770 self.messages.push(message.clone());
771
772 if tool_calls.is_empty() {
778 break;
780 }
781
782 if tool_ctx.cancellation_token.is_cancelled() {
784 let _ = tx
785 .send(StreamEvent::Error(StreamError::non_retryable("cancelled")))
786 .await;
787 break;
788 }
789
790 if let Some(ref limits) = self.config.usage_limits
792 && let Err(e) = check_tool_call_limit(limits, tool_call_count, tool_calls.len())
793 {
794 let _ = tx
795 .send(StreamEvent::Error(StreamError::non_retryable(format!(
796 "{e}"
797 ))))
798 .await;
799 break;
800 }
801 tool_call_count += tool_calls.len();
802
803 let mut tool_result_blocks = Vec::new();
805 for (call_id, tool_name, input) in &tool_calls {
806 match fire_pre_tool_hooks(&self.hooks, tool_name, input).await {
808 Ok(Some(HookAction::Terminate { reason })) => {
809 let _ = tx
810 .send(StreamEvent::Error(StreamError::non_retryable(format!(
811 "hook terminated: {reason}"
812 ))))
813 .await;
814 return rx;
815 }
816 Ok(Some(HookAction::Skip { reason })) => {
817 tool_result_blocks.push(ContentBlock::ToolResult {
818 tool_use_id: call_id.clone(),
819 content: vec![ContentItem::Text(format!(
820 "Tool call skipped: {reason}"
821 ))],
822 is_error: true,
823 });
824 continue;
825 }
826 Err(e) => {
827 let _ = tx
828 .send(StreamEvent::Error(StreamError::non_retryable(format!(
829 "hook error: {e}"
830 ))))
831 .await;
832 return rx;
833 }
834 _ => {}
835 }
836
837 let result = if let Some(ref durable) = self.durability {
839 let options = neuron_types::ActivityOptions {
840 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
841 heartbeat_timeout: None,
842 retry_policy: None,
843 };
844 match durable
845 .0
846 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
847 .await
848 {
849 Ok(r) => r,
850 Err(e) => {
851 let _ = tx
852 .send(StreamEvent::Error(StreamError::non_retryable(format!(
853 "durable tool error: {e}"
854 ))))
855 .await;
856 return rx;
857 }
858 }
859 } else {
860 match self.tools.execute(tool_name, input.clone(), tool_ctx).await {
861 Ok(r) => r,
862 Err(neuron_types::ToolError::ModelRetry(hint)) => {
863 ToolOutput {
866 content: vec![ContentItem::Text(hint)],
867 structured_content: None,
868 is_error: true,
869 }
870 }
871 Err(e) => {
872 let _ = tx
873 .send(StreamEvent::Error(StreamError::non_retryable(format!(
874 "tool error: {e}"
875 ))))
876 .await;
877 return rx;
878 }
879 }
880 };
881
882 match fire_post_tool_hooks(&self.hooks, tool_name, &result).await {
884 Ok(Some(HookAction::Terminate { reason })) => {
885 let _ = tx
886 .send(StreamEvent::Error(StreamError::non_retryable(format!(
887 "hook terminated: {reason}"
888 ))))
889 .await;
890 return rx;
891 }
892 Err(e) => {
893 let _ = tx
894 .send(StreamEvent::Error(StreamError::non_retryable(format!(
895 "hook error: {e}"
896 ))))
897 .await;
898 return rx;
899 }
900 _ => {}
901 }
902
903 tool_result_blocks.push(ContentBlock::ToolResult {
904 tool_use_id: call_id.clone(),
905 content: result.content,
906 is_error: result.is_error,
907 });
908 }
909
910 self.messages.push(Message {
911 role: Role::User,
912 content: tool_result_blocks,
913 });
914 }
915
916 rx
917 }
918}