1use neuron_tool::ToolRegistry;
8use futures::StreamExt;
9
10use neuron_types::{
11 CompletionRequest, ContentBlock, ContentItem, ContextStrategy, HookAction, LoopError, Message,
12 Provider, Role, StopReason, StreamError, StreamEvent, TokenUsage, ToolContext, ToolOutput,
13};
14
15use crate::loop_impl::{
16 accumulate_usage, extract_text, fire_compaction_hooks, fire_loop_iteration_hooks,
17 fire_post_llm_hooks, fire_post_tool_hooks, fire_pre_llm_hooks, fire_pre_tool_hooks,
18 AgentLoop, AgentResult, DEFAULT_ACTIVITY_TIMEOUT,
19};
20
21#[derive(Debug)]
23pub enum TurnResult {
24 ToolsExecuted {
26 calls: Vec<(String, String, serde_json::Value)>,
28 results: Vec<ToolOutput>,
30 },
31 FinalResponse(AgentResult),
33 CompactionOccurred {
35 old_tokens: usize,
37 new_tokens: usize,
39 },
40 MaxTurnsReached,
42 Error(LoopError),
44}
45
46pub struct StepIterator<'a, P: Provider, C: ContextStrategy> {
54 loop_ref: &'a mut AgentLoop<P, C>,
55 tool_ctx: &'a ToolContext,
56 total_usage: TokenUsage,
57 turns: usize,
58 finished: bool,
59}
60
61impl<'a, P: Provider, C: ContextStrategy> StepIterator<'a, P, C> {
62 pub async fn next(&mut self) -> Option<TurnResult> {
67 if self.finished {
68 return None;
69 }
70
71 if let Some(max) = self.loop_ref.config.max_turns
73 && self.turns >= max
74 {
75 self.finished = true;
76 return Some(TurnResult::MaxTurnsReached);
77 }
78
79 match fire_loop_iteration_hooks(&self.loop_ref.hooks, self.turns).await {
81 Ok(Some(HookAction::Terminate { reason })) => {
82 self.finished = true;
83 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
84 }
85 Err(e) => {
86 self.finished = true;
87 return Some(TurnResult::Error(e));
88 }
89 _ => {}
90 }
91
92 let token_count = self.loop_ref.context.token_estimate(&self.loop_ref.messages);
94 if self
95 .loop_ref
96 .context
97 .should_compact(&self.loop_ref.messages, token_count)
98 {
99 let old_tokens = token_count;
100 match self
101 .loop_ref
102 .context
103 .compact(self.loop_ref.messages.clone())
104 .await
105 {
106 Ok(compacted) => {
107 self.loop_ref.messages = compacted;
108 let new_tokens =
109 self.loop_ref.context.token_estimate(&self.loop_ref.messages);
110
111 match fire_compaction_hooks(&self.loop_ref.hooks, old_tokens, new_tokens).await
113 {
114 Ok(Some(HookAction::Terminate { reason })) => {
115 self.finished = true;
116 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
117 }
118 Err(e) => {
119 self.finished = true;
120 return Some(TurnResult::Error(e));
121 }
122 _ => {}
123 }
124
125 return Some(TurnResult::CompactionOccurred {
126 old_tokens,
127 new_tokens,
128 });
129 }
130 Err(e) => {
131 self.finished = true;
132 return Some(TurnResult::Error(e.into()));
133 }
134 }
135 }
136
137 let request = CompletionRequest {
139 model: String::new(),
140 messages: self.loop_ref.messages.clone(),
141 system: Some(self.loop_ref.config.system_prompt.clone()),
142 tools: self.loop_ref.tools.definitions(),
143 max_tokens: None,
144 temperature: None,
145 top_p: None,
146 stop_sequences: vec![],
147 tool_choice: None,
148 response_format: None,
149 thinking: None,
150 reasoning_effort: None,
151 extra: None,
152 };
153
154 match fire_pre_llm_hooks(&self.loop_ref.hooks, &request).await {
156 Ok(Some(HookAction::Terminate { reason })) => {
157 self.finished = true;
158 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
159 }
160 Err(e) => {
161 self.finished = true;
162 return Some(TurnResult::Error(e));
163 }
164 _ => {}
165 }
166
167 let response = if let Some(ref durable) = self.loop_ref.durability {
169 let options = neuron_types::ActivityOptions {
170 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
171 heartbeat_timeout: None,
172 retry_policy: None,
173 };
174 match durable.0.erased_execute_llm_call(request, options).await {
175 Ok(r) => r,
176 Err(e) => {
177 self.finished = true;
178 return Some(TurnResult::Error(
179 neuron_types::ProviderError::Other(Box::new(e)).into(),
180 ));
181 }
182 }
183 } else {
184 match self.loop_ref.provider.complete(request).await {
185 Ok(r) => r,
186 Err(e) => {
187 self.finished = true;
188 return Some(TurnResult::Error(e.into()));
189 }
190 }
191 };
192
193 match fire_post_llm_hooks(&self.loop_ref.hooks, &response).await {
195 Ok(Some(HookAction::Terminate { reason })) => {
196 self.finished = true;
197 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
198 }
199 Err(e) => {
200 self.finished = true;
201 return Some(TurnResult::Error(e));
202 }
203 _ => {}
204 }
205
206 accumulate_usage(&mut self.total_usage, &response.usage);
208 self.turns += 1;
209
210 let tool_calls: Vec<_> = response
212 .message
213 .content
214 .iter()
215 .filter_map(|block| {
216 if let ContentBlock::ToolUse { id, name, input } = block {
217 Some((id.clone(), name.clone(), input.clone()))
218 } else {
219 None
220 }
221 })
222 .collect();
223
224 self.loop_ref.messages.push(response.message.clone());
226
227 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
228 self.finished = true;
229 let response_text = extract_text(&response.message);
230 return Some(TurnResult::FinalResponse(AgentResult {
231 response: response_text,
232 messages: self.loop_ref.messages.clone(),
233 usage: self.total_usage.clone(),
234 turns: self.turns,
235 }));
236 }
237
238 let mut tool_result_blocks = Vec::new();
240 let mut tool_outputs = Vec::new();
241 for (call_id, tool_name, input) in &tool_calls {
242 match fire_pre_tool_hooks(&self.loop_ref.hooks, tool_name, input).await {
244 Ok(Some(HookAction::Terminate { reason })) => {
245 self.finished = true;
246 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
247 }
248 Ok(Some(HookAction::Skip { reason })) => {
249 let output = ToolOutput {
250 content: vec![ContentItem::Text(format!(
251 "Tool call skipped: {reason}"
252 ))],
253 structured_content: None,
254 is_error: true,
255 };
256 tool_result_blocks.push(ContentBlock::ToolResult {
257 tool_use_id: call_id.clone(),
258 content: output.content.clone(),
259 is_error: true,
260 });
261 tool_outputs.push(output);
262 continue;
263 }
264 Err(e) => {
265 self.finished = true;
266 return Some(TurnResult::Error(e));
267 }
268 _ => {}
269 }
270
271 let result = if let Some(ref durable) = self.loop_ref.durability {
273 let options = neuron_types::ActivityOptions {
274 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
275 heartbeat_timeout: None,
276 retry_policy: None,
277 };
278 match durable
279 .0
280 .erased_execute_tool(tool_name, input.clone(), self.tool_ctx, options)
281 .await
282 {
283 Ok(r) => r,
284 Err(e) => {
285 self.finished = true;
286 return Some(TurnResult::Error(
287 neuron_types::ToolError::ExecutionFailed(Box::new(e)).into(),
288 ));
289 }
290 }
291 } else {
292 match self
293 .loop_ref
294 .tools
295 .execute(tool_name, input.clone(), self.tool_ctx)
296 .await
297 {
298 Ok(r) => r,
299 Err(e) => {
300 self.finished = true;
301 return Some(TurnResult::Error(e.into()));
302 }
303 }
304 };
305
306 match fire_post_tool_hooks(&self.loop_ref.hooks, tool_name, &result).await {
308 Ok(Some(HookAction::Terminate { reason })) => {
309 self.finished = true;
310 return Some(TurnResult::Error(LoopError::HookTerminated(reason)));
311 }
312 Err(e) => {
313 self.finished = true;
314 return Some(TurnResult::Error(e));
315 }
316 _ => {}
317 }
318
319 tool_result_blocks.push(ContentBlock::ToolResult {
320 tool_use_id: call_id.clone(),
321 content: result.content.clone(),
322 is_error: result.is_error,
323 });
324 tool_outputs.push(result);
325 }
326
327 self.loop_ref.messages.push(Message {
329 role: Role::User,
330 content: tool_result_blocks,
331 });
332
333 Some(TurnResult::ToolsExecuted {
334 calls: tool_calls,
335 results: tool_outputs,
336 })
337 }
338
339 #[must_use]
341 pub fn messages(&self) -> &[Message] {
342 &self.loop_ref.messages
343 }
344
345 pub fn inject_message(&mut self, message: Message) {
347 self.loop_ref.messages.push(message);
348 }
349
350 #[must_use]
352 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
353 &mut self.loop_ref.tools
354 }
355}
356
357impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
358 #[must_use]
367 pub fn run_step<'a>(
368 &'a mut self,
369 user_message: Message,
370 tool_ctx: &'a ToolContext,
371 ) -> StepIterator<'a, P, C> {
372 self.messages.push(user_message);
373 StepIterator {
374 loop_ref: self,
375 tool_ctx,
376 total_usage: TokenUsage::default(),
377 turns: 0,
378 finished: false,
379 }
380 }
381
382 pub async fn run_stream(
401 &mut self,
402 user_message: Message,
403 tool_ctx: &ToolContext,
404 ) -> tokio::sync::mpsc::Receiver<StreamEvent> {
405 let (tx, rx) = tokio::sync::mpsc::channel(64);
406 self.messages.push(user_message);
407
408 let mut turns: usize = 0;
409
410 loop {
411 if let Some(max) = self.config.max_turns
413 && turns >= max
414 {
415 let _ = tx
416 .send(StreamEvent::Error(StreamError::non_retryable(format!(
417 "max turns reached ({max})"
418 ))))
419 .await;
420 break;
421 }
422
423 match fire_loop_iteration_hooks(&self.hooks, turns).await {
425 Ok(Some(HookAction::Terminate { reason })) => {
426 let _ = tx
427 .send(StreamEvent::Error(StreamError::non_retryable(format!(
428 "hook terminated: {reason}"
429 ))))
430 .await;
431 break;
432 }
433 Err(e) => {
434 let _ = tx
435 .send(StreamEvent::Error(StreamError::non_retryable(format!(
436 "hook error: {e}"
437 ))))
438 .await;
439 break;
440 }
441 _ => {}
442 }
443
444 let token_count = self.context.token_estimate(&self.messages);
446 if self.context.should_compact(&self.messages, token_count) {
447 let old_tokens = token_count;
448 match self.context.compact(self.messages.clone()).await {
449 Ok(compacted) => {
450 self.messages = compacted;
451 let new_tokens = self.context.token_estimate(&self.messages);
452
453 match fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await {
455 Ok(Some(HookAction::Terminate { reason })) => {
456 let _ = tx
457 .send(StreamEvent::Error(StreamError::non_retryable(format!(
458 "hook terminated: {reason}"
459 ))))
460 .await;
461 break;
462 }
463 Err(e) => {
464 let _ = tx
465 .send(StreamEvent::Error(StreamError::non_retryable(format!(
466 "hook error: {e}"
467 ))))
468 .await;
469 break;
470 }
471 _ => {}
472 }
473 }
474 Err(e) => {
475 let _ = tx
476 .send(StreamEvent::Error(StreamError::non_retryable(format!(
477 "compaction error: {e}"
478 ))))
479 .await;
480 break;
481 }
482 }
483 }
484
485 let request = CompletionRequest {
487 model: String::new(),
488 messages: self.messages.clone(),
489 system: Some(self.config.system_prompt.clone()),
490 tools: self.tools.definitions(),
491 max_tokens: None,
492 temperature: None,
493 top_p: None,
494 stop_sequences: vec![],
495 tool_choice: None,
496 response_format: None,
497 thinking: None,
498 reasoning_effort: None,
499 extra: None,
500 };
501
502 match fire_pre_llm_hooks(&self.hooks, &request).await {
504 Ok(Some(HookAction::Terminate { reason })) => {
505 let _ = tx
506 .send(StreamEvent::Error(StreamError::non_retryable(format!(
507 "hook terminated: {reason}"
508 ))))
509 .await;
510 break;
511 }
512 Err(e) => {
513 let _ = tx
514 .send(StreamEvent::Error(StreamError::non_retryable(format!(
515 "hook error: {e}"
516 ))))
517 .await;
518 break;
519 }
520 _ => {}
521 }
522
523 let message = if let Some(ref durable) = self.durability {
526 let options = neuron_types::ActivityOptions {
528 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
529 heartbeat_timeout: None,
530 retry_policy: None,
531 };
532 let response = match durable.0.erased_execute_llm_call(request, options).await {
533 Ok(r) => r,
534 Err(e) => {
535 let _ = tx
536 .send(StreamEvent::Error(StreamError::non_retryable(format!(
537 "durable error: {e}"
538 ))))
539 .await;
540 break;
541 }
542 };
543
544 for block in &response.message.content {
546 if let ContentBlock::Text(text) = block
547 && tx.send(StreamEvent::TextDelta(text.clone())).await.is_err()
548 {
549 return rx;
550 }
551 }
552 if tx.send(StreamEvent::Usage(response.usage.clone())).await.is_err() {
553 return rx;
554 }
555 if tx
556 .send(StreamEvent::MessageComplete(response.message.clone()))
557 .await
558 .is_err()
559 {
560 return rx;
561 }
562
563 match fire_post_llm_hooks(&self.hooks, &response).await {
565 Ok(Some(HookAction::Terminate { reason })) => {
566 let _ = tx
567 .send(StreamEvent::Error(StreamError::non_retryable(format!(
568 "hook terminated: {reason}"
569 ))))
570 .await;
571 break;
572 }
573 Err(e) => {
574 let _ = tx
575 .send(StreamEvent::Error(StreamError::non_retryable(format!(
576 "hook error: {e}"
577 ))))
578 .await;
579 break;
580 }
581 _ => {}
582 }
583
584 response.message
585 } else {
586 let stream_handle = match self.provider.complete_stream(request).await {
588 Ok(h) => h,
589 Err(e) => {
590 let _ = tx
591 .send(StreamEvent::Error(StreamError::non_retryable(format!(
592 "provider error: {e}"
593 ))))
594 .await;
595 break;
596 }
597 };
598
599 let mut assembled_message: Option<Message> = None;
601 let mut assembled_response: Option<neuron_types::CompletionResponse> = None;
602 let mut stream = stream_handle.receiver;
603
604 while let Some(event) = stream.next().await {
605 match &event {
606 StreamEvent::MessageComplete(msg) => {
607 assembled_message = Some(msg.clone());
608 }
609 StreamEvent::Usage(u) => {
610 assembled_response = Some(neuron_types::CompletionResponse {
612 id: String::new(),
613 model: String::new(),
614 message: assembled_message.clone().unwrap_or(Message {
615 role: Role::Assistant,
616 content: vec![],
617 }),
618 usage: u.clone(),
619 stop_reason: StopReason::EndTurn,
620 });
621 }
622 _ => {}
623 }
624 if tx.send(event).await.is_err() {
626 return rx;
628 }
629 }
630
631 let msg = match assembled_message {
633 Some(m) => m,
634 None => {
635 let _ = tx
636 .send(StreamEvent::Error(StreamError::non_retryable(
637 "stream ended without MessageComplete",
638 )))
639 .await;
640 break;
641 }
642 };
643
644 if let Some(mut resp) = assembled_response {
646 resp.message = msg.clone();
647 match fire_post_llm_hooks(&self.hooks, &resp).await {
648 Ok(Some(HookAction::Terminate { reason })) => {
649 let _ = tx
650 .send(StreamEvent::Error(StreamError::non_retryable(format!(
651 "hook terminated: {reason}"
652 ))))
653 .await;
654 break;
655 }
656 Err(e) => {
657 let _ = tx
658 .send(StreamEvent::Error(StreamError::non_retryable(format!(
659 "hook error: {e}"
660 ))))
661 .await;
662 break;
663 }
664 _ => {}
665 }
666 } else {
667 let resp = neuron_types::CompletionResponse {
669 id: String::new(),
670 model: String::new(),
671 message: msg.clone(),
672 usage: TokenUsage::default(),
673 stop_reason: StopReason::EndTurn,
674 };
675 match fire_post_llm_hooks(&self.hooks, &resp).await {
676 Ok(Some(HookAction::Terminate { reason })) => {
677 let _ = tx
678 .send(StreamEvent::Error(StreamError::non_retryable(format!(
679 "hook terminated: {reason}"
680 ))))
681 .await;
682 break;
683 }
684 Err(e) => {
685 let _ = tx
686 .send(StreamEvent::Error(StreamError::non_retryable(format!(
687 "hook error: {e}"
688 ))))
689 .await;
690 break;
691 }
692 _ => {}
693 }
694 }
695
696 msg
697 };
698
699 turns += 1;
700
701 let tool_calls: Vec<_> = message
703 .content
704 .iter()
705 .filter_map(|block| {
706 if let ContentBlock::ToolUse { id, name, input } = block {
707 Some((id.clone(), name.clone(), input.clone()))
708 } else {
709 None
710 }
711 })
712 .collect();
713
714 self.messages.push(message.clone());
715
716 if tool_calls.is_empty() {
717 break;
719 }
720
721 let mut tool_result_blocks = Vec::new();
723 for (call_id, tool_name, input) in &tool_calls {
724 match fire_pre_tool_hooks(&self.hooks, tool_name, input).await {
726 Ok(Some(HookAction::Terminate { reason })) => {
727 let _ = tx
728 .send(StreamEvent::Error(StreamError::non_retryable(format!(
729 "hook terminated: {reason}"
730 ))))
731 .await;
732 return rx;
733 }
734 Ok(Some(HookAction::Skip { reason })) => {
735 tool_result_blocks.push(ContentBlock::ToolResult {
736 tool_use_id: call_id.clone(),
737 content: vec![ContentItem::Text(format!(
738 "Tool call skipped: {reason}"
739 ))],
740 is_error: true,
741 });
742 continue;
743 }
744 Err(e) => {
745 let _ = tx
746 .send(StreamEvent::Error(StreamError::non_retryable(format!(
747 "hook error: {e}"
748 ))))
749 .await;
750 return rx;
751 }
752 _ => {}
753 }
754
755 let result = if let Some(ref durable) = self.durability {
757 let options = neuron_types::ActivityOptions {
758 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
759 heartbeat_timeout: None,
760 retry_policy: None,
761 };
762 match durable
763 .0
764 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
765 .await
766 {
767 Ok(r) => r,
768 Err(e) => {
769 let _ = tx
770 .send(StreamEvent::Error(StreamError::non_retryable(format!(
771 "durable tool error: {e}"
772 ))))
773 .await;
774 return rx;
775 }
776 }
777 } else {
778 match self.tools.execute(tool_name, input.clone(), tool_ctx).await {
779 Ok(r) => r,
780 Err(e) => {
781 let _ = tx
782 .send(StreamEvent::Error(StreamError::non_retryable(format!(
783 "tool error: {e}"
784 ))))
785 .await;
786 return rx;
787 }
788 }
789 };
790
791 match fire_post_tool_hooks(&self.hooks, tool_name, &result).await {
793 Ok(Some(HookAction::Terminate { reason })) => {
794 let _ = tx
795 .send(StreamEvent::Error(StreamError::non_retryable(format!(
796 "hook terminated: {reason}"
797 ))))
798 .await;
799 return rx;
800 }
801 Err(e) => {
802 let _ = tx
803 .send(StreamEvent::Error(StreamError::non_retryable(format!(
804 "hook error: {e}"
805 ))))
806 .await;
807 return rx;
808 }
809 _ => {}
810 }
811
812 tool_result_blocks.push(ContentBlock::ToolResult {
813 tool_use_id: call_id.clone(),
814 content: result.content,
815 is_error: result.is_error,
816 });
817 }
818
819 self.messages.push(Message {
820 role: Role::User,
821 content: tool_result_blocks,
822 });
823 }
824
825 rx
826 }
827}