1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7
8use neuron_tool::ToolRegistry;
9use neuron_types::{
10 ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
11 ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
12 Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
13 ToolError, ToolOutput, UsageLimits,
14};
15
16use crate::config::LoopConfig;
17
18type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
22
23trait ErasedHook: Send + Sync {
25 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
26}
27
28impl<H: ObservabilityHook> ErasedHook for H {
29 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
30 Box::pin(self.on_event(event))
31 }
32}
33
34pub struct BoxedHook(Arc<dyn ErasedHook>);
38
39impl BoxedHook {
40 #[must_use]
42 pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
43 BoxedHook(Arc::new(hook))
44 }
45
46 async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
48 self.0.erased_on_event(event).await
49 }
50}
51
52type DurableLlmFuture<'a> =
56 Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
57
58type DurableToolFuture<'a> =
60 Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
61
62pub(crate) trait ErasedDurable: Send + Sync {
64 fn erased_execute_llm_call(
65 &self,
66 request: CompletionRequest,
67 options: ActivityOptions,
68 ) -> DurableLlmFuture<'_>;
69
70 fn erased_execute_tool<'a>(
71 &'a self,
72 tool_name: &'a str,
73 input: serde_json::Value,
74 ctx: &'a ToolContext,
75 options: ActivityOptions,
76 ) -> DurableToolFuture<'a>;
77}
78
79impl<D: DurableContext> ErasedDurable for D {
80 fn erased_execute_llm_call(
81 &self,
82 request: CompletionRequest,
83 options: ActivityOptions,
84 ) -> DurableLlmFuture<'_> {
85 Box::pin(self.execute_llm_call(request, options))
86 }
87
88 fn erased_execute_tool<'a>(
89 &'a self,
90 tool_name: &'a str,
91 input: serde_json::Value,
92 ctx: &'a ToolContext,
93 options: ActivityOptions,
94 ) -> DurableToolFuture<'a> {
95 Box::pin(self.execute_tool(tool_name, input, ctx, options))
96 }
97}
98
99pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
103
104impl BoxedDurable {
105 #[must_use]
107 pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
108 BoxedDurable(Arc::new(durable))
109 }
110}
111
112#[derive(Debug)]
116pub struct AgentResult {
117 pub response: String,
119 pub messages: Vec<Message>,
121 pub usage: TokenUsage,
123 pub turns: usize,
125}
126
127pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
131
132pub struct AgentLoop<P: Provider, C: ContextStrategy> {
137 pub(crate) provider: P,
138 pub(crate) tools: ToolRegistry,
139 pub(crate) context: C,
140 pub(crate) hooks: Vec<BoxedHook>,
141 pub(crate) durability: Option<BoxedDurable>,
142 pub(crate) config: LoopConfig,
143 pub(crate) messages: Vec<Message>,
144}
145
146impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
147 #[must_use]
150 pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
151 Self {
152 provider,
153 tools,
154 context,
155 hooks: Vec::new(),
156 durability: None,
157 config,
158 messages: Vec::new(),
159 }
160 }
161
162 pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
166 self.hooks.push(BoxedHook::new(hook));
167 self
168 }
169
170 pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
176 self.durability = Some(BoxedDurable::new(durable));
177 self
178 }
179
180 #[must_use]
182 pub fn config(&self) -> &LoopConfig {
183 &self.config
184 }
185
186 #[must_use]
188 pub fn messages(&self) -> &[Message] {
189 &self.messages
190 }
191
192 #[must_use]
194 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
195 &mut self.tools
196 }
197
198 #[must_use = "this returns a Result that should be handled"]
219 pub async fn run(
220 &mut self,
221 user_message: Message,
222 tool_ctx: &ToolContext,
223 ) -> Result<AgentResult, LoopError> {
224 self.messages.push(user_message);
225
226 let mut total_usage = TokenUsage::default();
227 let mut turns: usize = 0;
228 let mut request_count: usize = 0;
229 let mut tool_call_count: usize = 0;
230
231 loop {
232 if tool_ctx.cancellation_token.is_cancelled() {
234 return Err(LoopError::Cancelled);
235 }
236
237 if let Some(max) = self.config.max_turns
239 && turns >= max
240 {
241 return Err(LoopError::MaxTurns(max));
242 }
243
244 if let Some(ref limits) = self.config.usage_limits {
246 check_request_limit(limits, request_count)?;
247 }
248
249 if let Some(HookAction::Terminate { reason }) =
251 fire_loop_iteration_hooks(&self.hooks, turns).await?
252 {
253 return Err(LoopError::HookTerminated(reason));
254 }
255
256 let token_count = self.context.token_estimate(&self.messages);
258 if self.context.should_compact(&self.messages, token_count) {
259 let old_tokens = token_count;
260 self.messages = self.context.compact(self.messages.clone()).await?;
261 let new_tokens = self.context.token_estimate(&self.messages);
262
263 if let Some(HookAction::Terminate { reason }) =
265 fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
266 {
267 return Err(LoopError::HookTerminated(reason));
268 }
269 }
270
271 let request = CompletionRequest {
273 model: String::new(), messages: self.messages.clone(),
275 system: Some(self.config.system_prompt.clone()),
276 tools: self.tools.definitions(),
277 ..Default::default()
278 };
279
280 if let Some(HookAction::Terminate { reason }) =
282 fire_pre_llm_hooks(&self.hooks, &request).await?
283 {
284 return Err(LoopError::HookTerminated(reason));
285 }
286
287 let response = if let Some(ref durable) = self.durability {
289 let options = ActivityOptions {
290 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
291 heartbeat_timeout: None,
292 retry_policy: None,
293 };
294 durable
295 .0
296 .erased_execute_llm_call(request, options)
297 .await
298 .map_err(|e| ProviderError::Other(Box::new(e)))?
299 } else {
300 self.provider.complete(request).await?
301 };
302
303 if let Some(HookAction::Terminate { reason }) =
305 fire_post_llm_hooks(&self.hooks, &response).await?
306 {
307 return Err(LoopError::HookTerminated(reason));
308 }
309
310 accumulate_usage(&mut total_usage, &response.usage);
312 request_count += 1;
313 turns += 1;
314
315 if let Some(ref limits) = self.config.usage_limits {
317 check_token_limits(limits, &total_usage)?;
318 }
319
320 let tool_calls: Vec<_> = response
322 .message
323 .content
324 .iter()
325 .filter_map(|block| {
326 if let ContentBlock::ToolUse { id, name, input } = block {
327 Some((id.clone(), name.clone(), input.clone()))
328 } else {
329 None
330 }
331 })
332 .collect();
333
334 self.messages.push(response.message.clone());
336
337 if response.stop_reason == StopReason::Compaction {
340 continue;
341 }
342
343 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
344 let response_text = extract_text(&response.message);
346 return Ok(AgentResult {
347 response: response_text,
348 messages: self.messages.clone(),
349 usage: total_usage,
350 turns,
351 });
352 }
353
354 if tool_ctx.cancellation_token.is_cancelled() {
356 return Err(LoopError::Cancelled);
357 }
358
359 if let Some(ref limits) = self.config.usage_limits {
361 check_tool_call_limit(limits, tool_call_count, tool_calls.len())?;
362 }
363 tool_call_count += tool_calls.len();
364
365 let tool_result_blocks = if self.config.parallel_tool_execution && tool_calls.len() > 1
367 {
368 let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
369 self.execute_single_tool(call_id, tool_name, input, tool_ctx)
370 });
371 let results = futures::future::join_all(futs).await;
372 results.into_iter().collect::<Result<Vec<_>, _>>()?
373 } else {
374 let mut blocks = Vec::new();
375 for (call_id, tool_name, input) in &tool_calls {
376 blocks.push(
377 self.execute_single_tool(call_id, tool_name, input, tool_ctx)
378 .await?,
379 );
380 }
381 blocks
382 };
383
384 self.messages.push(Message {
386 role: Role::User,
387 content: tool_result_blocks,
388 });
389 }
390 }
391
392 #[must_use = "this returns a Result that should be handled"]
397 pub async fn run_text(
398 &mut self,
399 text: &str,
400 tool_ctx: &ToolContext,
401 ) -> Result<AgentResult, LoopError> {
402 let message = Message {
403 role: Role::User,
404 content: vec![ContentBlock::Text(text.to_string())],
405 };
406 self.run(message, tool_ctx).await
407 }
408
409 pub(crate) async fn execute_single_tool(
413 &self,
414 call_id: &str,
415 tool_name: &str,
416 input: &serde_json::Value,
417 tool_ctx: &ToolContext,
418 ) -> Result<ContentBlock, LoopError> {
419 if let Some(action) = fire_pre_tool_hooks(&self.hooks, tool_name, input).await? {
421 match action {
422 HookAction::Terminate { reason } => {
423 return Err(LoopError::HookTerminated(reason));
424 }
425 HookAction::Skip { reason } => {
426 return Ok(ContentBlock::ToolResult {
427 tool_use_id: call_id.to_string(),
428 content: vec![ContentItem::Text(format!("Tool call skipped: {reason}"))],
429 is_error: true,
430 });
431 }
432 HookAction::Continue => {}
433 }
434 }
435
436 let result = if let Some(ref durable) = self.durability {
438 let options = ActivityOptions {
439 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
440 heartbeat_timeout: None,
441 retry_policy: None,
442 };
443 durable
444 .0
445 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
446 .await
447 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
448 } else {
449 self.tools
450 .execute(tool_name, input.clone(), tool_ctx)
451 .await?
452 };
453
454 if let Some(HookAction::Terminate { reason }) =
456 fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
457 {
458 return Err(LoopError::HookTerminated(reason));
459 }
460
461 Ok(ContentBlock::ToolResult {
462 tool_use_id: call_id.to_string(),
463 content: result.content,
464 is_error: result.is_error,
465 })
466 }
467
468 #[must_use]
475 pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
476 AgentLoopBuilder {
477 provider,
478 context,
479 tools: ToolRegistry::new(),
480 config: LoopConfig::default(),
481 hooks: Vec::new(),
482 durability: None,
483 }
484 }
485}
486
487pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
502 provider: P,
503 context: C,
504 tools: ToolRegistry,
505 config: LoopConfig,
506 hooks: Vec<BoxedHook>,
507 durability: Option<BoxedDurable>,
508}
509
510impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
511 #[must_use]
513 pub fn tools(mut self, tools: ToolRegistry) -> Self {
514 self.tools = tools;
515 self
516 }
517
518 #[must_use]
520 pub fn config(mut self, config: LoopConfig) -> Self {
521 self.config = config;
522 self
523 }
524
525 #[must_use]
527 pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
528 self.config.system_prompt = prompt.into();
529 self
530 }
531
532 #[must_use]
534 pub fn max_turns(mut self, max: usize) -> Self {
535 self.config.max_turns = Some(max);
536 self
537 }
538
539 #[must_use]
541 pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
542 self.config.parallel_tool_execution = parallel;
543 self
544 }
545
546 #[must_use]
548 pub fn usage_limits(mut self, limits: UsageLimits) -> Self {
549 self.config.usage_limits = Some(limits);
550 self
551 }
552
553 #[must_use]
555 pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
556 self.hooks.push(BoxedHook::new(hook));
557 self
558 }
559
560 #[must_use]
562 pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
563 self.durability = Some(BoxedDurable::new(durable));
564 self
565 }
566
567 #[must_use]
569 pub fn build(self) -> AgentLoop<P, C> {
570 AgentLoop {
571 provider: self.provider,
572 tools: self.tools,
573 context: self.context,
574 hooks: self.hooks,
575 durability: self.durability,
576 config: self.config,
577 messages: Vec::new(),
578 }
579 }
580}
581
582pub(crate) async fn fire_pre_llm_hooks(
586 hooks: &[BoxedHook],
587 request: &CompletionRequest,
588) -> Result<Option<HookAction>, LoopError> {
589 for hook in hooks {
590 let action = hook
591 .fire(HookEvent::PreLlmCall { request })
592 .await
593 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
594 if !matches!(action, HookAction::Continue) {
595 return Ok(Some(action));
596 }
597 }
598 Ok(None)
599}
600
601pub(crate) async fn fire_post_llm_hooks(
603 hooks: &[BoxedHook],
604 response: &CompletionResponse,
605) -> Result<Option<HookAction>, LoopError> {
606 for hook in hooks {
607 let action = hook
608 .fire(HookEvent::PostLlmCall { response })
609 .await
610 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
611 if !matches!(action, HookAction::Continue) {
612 return Ok(Some(action));
613 }
614 }
615 Ok(None)
616}
617
618pub(crate) async fn fire_pre_tool_hooks(
620 hooks: &[BoxedHook],
621 tool_name: &str,
622 input: &serde_json::Value,
623) -> Result<Option<HookAction>, LoopError> {
624 for hook in hooks {
625 let action = hook
626 .fire(HookEvent::PreToolExecution { tool_name, input })
627 .await
628 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
629 if !matches!(action, HookAction::Continue) {
630 return Ok(Some(action));
631 }
632 }
633 Ok(None)
634}
635
636pub(crate) async fn fire_post_tool_hooks(
638 hooks: &[BoxedHook],
639 tool_name: &str,
640 output: &ToolOutput,
641) -> Result<Option<HookAction>, LoopError> {
642 for hook in hooks {
643 let action = hook
644 .fire(HookEvent::PostToolExecution { tool_name, output })
645 .await
646 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
647 if !matches!(action, HookAction::Continue) {
648 return Ok(Some(action));
649 }
650 }
651 Ok(None)
652}
653
654pub(crate) async fn fire_loop_iteration_hooks(
656 hooks: &[BoxedHook],
657 turn: usize,
658) -> Result<Option<HookAction>, LoopError> {
659 for hook in hooks {
660 let action = hook
661 .fire(HookEvent::LoopIteration { turn })
662 .await
663 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
664 if !matches!(action, HookAction::Continue) {
665 return Ok(Some(action));
666 }
667 }
668 Ok(None)
669}
670
671pub(crate) async fn fire_compaction_hooks(
673 hooks: &[BoxedHook],
674 old_tokens: usize,
675 new_tokens: usize,
676) -> Result<Option<HookAction>, LoopError> {
677 for hook in hooks {
678 let action = hook
679 .fire(HookEvent::ContextCompaction {
680 old_tokens,
681 new_tokens,
682 })
683 .await
684 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
685 if !matches!(action, HookAction::Continue) {
686 return Ok(Some(action));
687 }
688 }
689 Ok(None)
690}
691
692pub(crate) fn check_request_limit(
696 limits: &UsageLimits,
697 request_count: usize,
698) -> Result<(), LoopError> {
699 if let Some(max) = limits.request_limit
700 && request_count >= max
701 {
702 return Err(LoopError::UsageLimitExceeded(format!(
703 "request limit of {max} reached"
704 )));
705 }
706 Ok(())
707}
708
709pub(crate) fn check_token_limits(
711 limits: &UsageLimits,
712 usage: &TokenUsage,
713) -> Result<(), LoopError> {
714 if let Some(max) = limits.input_tokens_limit
715 && usage.input_tokens > max
716 {
717 return Err(LoopError::UsageLimitExceeded(format!(
718 "input token limit of {max} exceeded (used {})",
719 usage.input_tokens
720 )));
721 }
722 if let Some(max) = limits.output_tokens_limit
723 && usage.output_tokens > max
724 {
725 return Err(LoopError::UsageLimitExceeded(format!(
726 "output token limit of {max} exceeded (used {})",
727 usage.output_tokens
728 )));
729 }
730 if let Some(max) = limits.total_tokens_limit {
731 let total = usage.input_tokens + usage.output_tokens;
732 if total > max {
733 return Err(LoopError::UsageLimitExceeded(format!(
734 "total token limit of {max} exceeded (used {total})"
735 )));
736 }
737 }
738 Ok(())
739}
740
741pub(crate) fn check_tool_call_limit(
743 limits: &UsageLimits,
744 current_count: usize,
745 new_calls: usize,
746) -> Result<(), LoopError> {
747 if let Some(max) = limits.tool_calls_limit
748 && current_count + new_calls > max
749 {
750 return Err(LoopError::UsageLimitExceeded(format!(
751 "tool call limit of {max} would be exceeded ({} + {new_calls} calls)",
752 current_count
753 )));
754 }
755 Ok(())
756}
757
758pub(crate) fn extract_text(message: &Message) -> String {
762 message
763 .content
764 .iter()
765 .filter_map(|block| {
766 if let ContentBlock::Text(text) = block {
767 Some(text.as_str())
768 } else {
769 None
770 }
771 })
772 .collect::<Vec<_>>()
773 .join("")
774}
775
776pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
778 total.input_tokens += delta.input_tokens;
779 total.output_tokens += delta.output_tokens;
780 if let Some(cache_read) = delta.cache_read_tokens {
781 *total.cache_read_tokens.get_or_insert(0) += cache_read;
782 }
783 if let Some(cache_creation) = delta.cache_creation_tokens {
784 *total.cache_creation_tokens.get_or_insert(0) += cache_creation;
785 }
786 if let Some(reasoning) = delta.reasoning_tokens {
787 *total.reasoning_tokens.get_or_insert(0) += reasoning;
788 }
789}