Skip to main content

neuron_loop/
loop_impl.rs

1//! Core AgentLoop struct and run methods.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7
8use neuron_tool::ToolRegistry;
9use neuron_types::{
10    ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
11    ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
12    Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
13    ToolError, ToolOutput,
14};
15
16use crate::config::LoopConfig;
17
18// --- Type erasure for ObservabilityHook (RPITIT is not dyn-compatible) ---
19
20/// Type alias for a pinned, boxed, Send future returning a HookAction.
21type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
22
23/// Dyn-compatible wrapper for [`ObservabilityHook`].
24trait ErasedHook: Send + Sync {
25    fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
26}
27
28impl<H: ObservabilityHook> ErasedHook for H {
29    fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
30        Box::pin(self.on_event(event))
31    }
32}
33
34/// A type-erased observability hook for use in [`AgentLoop`].
35///
36/// Wraps any [`ObservabilityHook`] into a dyn-compatible form.
37pub struct BoxedHook(Arc<dyn ErasedHook>);
38
39impl BoxedHook {
40    /// Wrap any [`ObservabilityHook`] into a type-erased `BoxedHook`.
41    #[must_use]
42    pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
43        BoxedHook(Arc::new(hook))
44    }
45
46    /// Fire this hook with an event.
47    async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
48        self.0.erased_on_event(event).await
49    }
50}
51
52// --- Type erasure for DurableContext (RPITIT is not dyn-compatible) ---
53
54/// Type alias for durable LLM call future.
55type DurableLlmFuture<'a> =
56    Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
57
58/// Type alias for durable tool call future.
59type DurableToolFuture<'a> =
60    Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
61
62/// Dyn-compatible wrapper for [`DurableContext`].
63pub(crate) trait ErasedDurable: Send + Sync {
64    fn erased_execute_llm_call(
65        &self,
66        request: CompletionRequest,
67        options: ActivityOptions,
68    ) -> DurableLlmFuture<'_>;
69
70    fn erased_execute_tool<'a>(
71        &'a self,
72        tool_name: &'a str,
73        input: serde_json::Value,
74        ctx: &'a ToolContext,
75        options: ActivityOptions,
76    ) -> DurableToolFuture<'a>;
77}
78
79impl<D: DurableContext> ErasedDurable for D {
80    fn erased_execute_llm_call(
81        &self,
82        request: CompletionRequest,
83        options: ActivityOptions,
84    ) -> DurableLlmFuture<'_> {
85        Box::pin(self.execute_llm_call(request, options))
86    }
87
88    fn erased_execute_tool<'a>(
89        &'a self,
90        tool_name: &'a str,
91        input: serde_json::Value,
92        ctx: &'a ToolContext,
93        options: ActivityOptions,
94    ) -> DurableToolFuture<'a> {
95        Box::pin(self.execute_tool(tool_name, input, ctx, options))
96    }
97}
98
99/// A type-erased durable context for use in [`AgentLoop`].
100///
101/// Wraps any [`DurableContext`] into a dyn-compatible form.
102pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
103
104impl BoxedDurable {
105    /// Wrap any [`DurableContext`] into a type-erased `BoxedDurable`.
106    #[must_use]
107    pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
108        BoxedDurable(Arc::new(durable))
109    }
110}
111
112// --- AgentResult ---
113
114/// The result of a completed agent loop run.
115#[derive(Debug)]
116pub struct AgentResult {
117    /// The final text response from the model.
118    pub response: String,
119    /// All messages in the conversation (including tool calls/results).
120    pub messages: Vec<Message>,
121    /// Cumulative token usage across all turns.
122    pub usage: TokenUsage,
123    /// Number of turns completed.
124    pub turns: usize,
125}
126
127// --- AgentLoop ---
128
129/// Default activity timeout for durable execution.
130pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
131
132/// The agentic while loop: drives provider + tool + context interactions.
133///
134/// Generic over `P: Provider` (the LLM backend) and `C: ContextStrategy`
135/// (the compaction strategy). Hooks and durability are optional.
136pub struct AgentLoop<P: Provider, C: ContextStrategy> {
137    pub(crate) provider: P,
138    pub(crate) tools: ToolRegistry,
139    pub(crate) context: C,
140    pub(crate) hooks: Vec<BoxedHook>,
141    pub(crate) durability: Option<BoxedDurable>,
142    pub(crate) config: LoopConfig,
143    pub(crate) messages: Vec<Message>,
144}
145
146impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
147    /// Create a new `AgentLoop` with the given provider, tools, context strategy,
148    /// and configuration.
149    #[must_use]
150    pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
151        Self {
152            provider,
153            tools,
154            context,
155            hooks: Vec::new(),
156            durability: None,
157            config,
158            messages: Vec::new(),
159        }
160    }
161
162    /// Add an observability hook to the loop.
163    ///
164    /// Hooks are called in order of registration at each event point.
165    pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
166        self.hooks.push(BoxedHook::new(hook));
167        self
168    }
169
170    /// Set the durable context for crash-recoverable execution.
171    ///
172    /// When set, LLM calls and tool executions go through the durable context
173    /// so they can be journaled, replayed, and recovered by engines like
174    /// Temporal, Restate, or Inngest.
175    pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
176        self.durability = Some(BoxedDurable::new(durable));
177        self
178    }
179
180    /// Returns a reference to the current configuration.
181    #[must_use]
182    pub fn config(&self) -> &LoopConfig {
183        &self.config
184    }
185
186    /// Returns a reference to the current messages.
187    #[must_use]
188    pub fn messages(&self) -> &[Message] {
189        &self.messages
190    }
191
192    /// Returns a mutable reference to the tool registry.
193    #[must_use]
194    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
195        &mut self.tools
196    }
197
198    /// Run the agentic loop to completion.
199    ///
200    /// Appends the user message, then loops: call provider, execute tools if
201    /// needed, append results, repeat until the model returns a text-only
202    /// response or the turn limit is reached.
203    ///
204    /// When durability is set, LLM calls go through
205    /// [`DurableContext::execute_llm_call`] and tool calls go through
206    /// [`DurableContext::execute_tool`].
207    ///
208    /// Fires [`HookEvent`] at each step. If a hook returns
209    /// [`HookAction::Terminate`], the loop stops with
210    /// [`LoopError::HookTerminated`].
211    ///
212    /// # Errors
213    ///
214    /// Returns `LoopError::MaxTurns` if the turn limit is exceeded,
215    /// `LoopError::Provider` on provider failures, `LoopError::Tool`
216    /// on tool execution failures, or `LoopError::HookTerminated` if
217    /// a hook requests termination.
218    #[must_use = "this returns a Result that should be handled"]
219    pub async fn run(
220        &mut self,
221        user_message: Message,
222        tool_ctx: &ToolContext,
223    ) -> Result<AgentResult, LoopError> {
224        self.messages.push(user_message);
225
226        let mut total_usage = TokenUsage::default();
227        let mut turns: usize = 0;
228
229        loop {
230            // Check cancellation
231            if tool_ctx.cancellation_token.is_cancelled() {
232                return Err(LoopError::Cancelled);
233            }
234
235            // Check max turns
236            if let Some(max) = self.config.max_turns
237                && turns >= max
238            {
239                return Err(LoopError::MaxTurns(max));
240            }
241
242            // Fire LoopIteration hooks
243            if let Some(HookAction::Terminate { reason }) =
244                fire_loop_iteration_hooks(&self.hooks, turns).await?
245            {
246                return Err(LoopError::HookTerminated(reason));
247            }
248
249            // Check context compaction
250            let token_count = self.context.token_estimate(&self.messages);
251            if self.context.should_compact(&self.messages, token_count) {
252                let old_tokens = token_count;
253                self.messages = self.context.compact(self.messages.clone()).await?;
254                let new_tokens = self.context.token_estimate(&self.messages);
255
256                // Fire ContextCompaction hooks
257                if let Some(HookAction::Terminate { reason }) =
258                    fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
259                {
260                    return Err(LoopError::HookTerminated(reason));
261                }
262            }
263
264            // Build completion request
265            let request = CompletionRequest {
266                model: String::new(), // Provider decides the model
267                messages: self.messages.clone(),
268                system: Some(self.config.system_prompt.clone()),
269                tools: self.tools.definitions(),
270                ..Default::default()
271            };
272
273            // Fire PreLlmCall hooks
274            if let Some(HookAction::Terminate { reason }) =
275                fire_pre_llm_hooks(&self.hooks, &request).await?
276            {
277                return Err(LoopError::HookTerminated(reason));
278            }
279
280            // Call provider (via durability wrapper if present)
281            let response = if let Some(ref durable) = self.durability {
282                let options = ActivityOptions {
283                    start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
284                    heartbeat_timeout: None,
285                    retry_policy: None,
286                };
287                durable
288                    .0
289                    .erased_execute_llm_call(request, options)
290                    .await
291                    .map_err(|e| ProviderError::Other(Box::new(e)))?
292            } else {
293                self.provider.complete(request).await?
294            };
295
296            // Fire PostLlmCall hooks
297            if let Some(HookAction::Terminate { reason }) =
298                fire_post_llm_hooks(&self.hooks, &response).await?
299            {
300                return Err(LoopError::HookTerminated(reason));
301            }
302
303            // Accumulate usage
304            accumulate_usage(&mut total_usage, &response.usage);
305            turns += 1;
306
307            // Check for tool calls in the response
308            let tool_calls: Vec<_> = response
309                .message
310                .content
311                .iter()
312                .filter_map(|block| {
313                    if let ContentBlock::ToolUse { id, name, input } = block {
314                        Some((id.clone(), name.clone(), input.clone()))
315                    } else {
316                        None
317                    }
318                })
319                .collect();
320
321            // Append assistant message to conversation
322            self.messages.push(response.message.clone());
323
324            // Server-side compaction: the provider paused to compact context.
325            // Continue the loop so the next iteration picks up the compacted state.
326            if response.stop_reason == StopReason::Compaction {
327                continue;
328            }
329
330            if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
331                // No tool calls — extract text and return
332                let response_text = extract_text(&response.message);
333                return Ok(AgentResult {
334                    response: response_text,
335                    messages: self.messages.clone(),
336                    usage: total_usage,
337                    turns,
338                });
339            }
340
341            // Check cancellation before tool execution
342            if tool_ctx.cancellation_token.is_cancelled() {
343                return Err(LoopError::Cancelled);
344            }
345
346            // Execute tool calls and collect results
347            let tool_result_blocks = if self.config.parallel_tool_execution && tool_calls.len() > 1 {
348                let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
349                    self.execute_single_tool(call_id, tool_name, input, tool_ctx)
350                });
351                let results = futures::future::join_all(futs).await;
352                results.into_iter().collect::<Result<Vec<_>, _>>()?
353            } else {
354                let mut blocks = Vec::new();
355                for (call_id, tool_name, input) in &tool_calls {
356                    blocks.push(self.execute_single_tool(call_id, tool_name, input, tool_ctx).await?);
357                }
358                blocks
359            };
360
361            // Append tool results as a user message
362            self.messages.push(Message {
363                role: Role::User,
364                content: tool_result_blocks,
365            });
366        }
367    }
368
369    /// Convenience method to run the loop with a plain text message.
370    ///
371    /// Wraps `text` into a `Message { role: User, content: [Text(text)] }`
372    /// and calls [`run`](Self::run).
373    #[must_use = "this returns a Result that should be handled"]
374    pub async fn run_text(
375        &mut self,
376        text: &str,
377        tool_ctx: &ToolContext,
378    ) -> Result<AgentResult, LoopError> {
379        let message = Message {
380            role: Role::User,
381            content: vec![ContentBlock::Text(text.to_string())],
382        };
383        self.run(message, tool_ctx).await
384    }
385
386    /// Execute a single tool call, including pre/post hooks and durability routing.
387    ///
388    /// Returns the tool result as a [`ContentBlock::ToolResult`].
389    pub(crate) async fn execute_single_tool(
390        &self,
391        call_id: &str,
392        tool_name: &str,
393        input: &serde_json::Value,
394        tool_ctx: &ToolContext,
395    ) -> Result<ContentBlock, LoopError> {
396        // Fire PreToolExecution hooks
397        if let Some(action) = fire_pre_tool_hooks(&self.hooks, tool_name, input).await? {
398            match action {
399                HookAction::Terminate { reason } => {
400                    return Err(LoopError::HookTerminated(reason));
401                }
402                HookAction::Skip { reason } => {
403                    return Ok(ContentBlock::ToolResult {
404                        tool_use_id: call_id.to_string(),
405                        content: vec![ContentItem::Text(format!("Tool call skipped: {reason}"))],
406                        is_error: true,
407                    });
408                }
409                HookAction::Continue => {}
410            }
411        }
412
413        // Execute tool (via durability wrapper if present)
414        let result = if let Some(ref durable) = self.durability {
415            let options = ActivityOptions {
416                start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
417                heartbeat_timeout: None,
418                retry_policy: None,
419            };
420            durable
421                .0
422                .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
423                .await
424                .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
425        } else {
426            self.tools.execute(tool_name, input.clone(), tool_ctx).await?
427        };
428
429        // Fire PostToolExecution hooks
430        if let Some(HookAction::Terminate { reason }) =
431            fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
432        {
433            return Err(LoopError::HookTerminated(reason));
434        }
435
436        Ok(ContentBlock::ToolResult {
437            tool_use_id: call_id.to_string(),
438            content: result.content,
439            is_error: result.is_error,
440        })
441    }
442
443    /// Create a builder with the required provider and context strategy.
444    ///
445    /// All other options have sensible defaults:
446    /// - Empty tool registry
447    /// - Default loop config (no turn limit, empty system prompt)
448    /// - No hooks or durability
449    #[must_use]
450    pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
451        AgentLoopBuilder {
452            provider,
453            context,
454            tools: ToolRegistry::new(),
455            config: LoopConfig::default(),
456            hooks: Vec::new(),
457            durability: None,
458        }
459    }
460}
461
462/// Builder for constructing an [`AgentLoop`] with optional configuration.
463///
464/// Created via [`AgentLoop::builder`]. Only `provider` and `context` are required;
465/// everything else has sensible defaults.
466///
467/// # Example
468///
469/// ```ignore
470/// let agent = AgentLoop::builder(provider, context)
471///     .tools(tools)
472///     .system_prompt("You are a helpful assistant.")
473///     .max_turns(10)
474///     .build();
475/// ```
476pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
477    provider: P,
478    context: C,
479    tools: ToolRegistry,
480    config: LoopConfig,
481    hooks: Vec<BoxedHook>,
482    durability: Option<BoxedDurable>,
483}
484
485impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
486    /// Set the tool registry.
487    #[must_use]
488    pub fn tools(mut self, tools: ToolRegistry) -> Self {
489        self.tools = tools;
490        self
491    }
492
493    /// Set the full loop configuration.
494    #[must_use]
495    pub fn config(mut self, config: LoopConfig) -> Self {
496        self.config = config;
497        self
498    }
499
500    /// Set the system prompt (convenience for setting `config.system_prompt`).
501    #[must_use]
502    pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
503        self.config.system_prompt = prompt.into();
504        self
505    }
506
507    /// Set the maximum number of turns (convenience for setting `config.max_turns`).
508    #[must_use]
509    pub fn max_turns(mut self, max: usize) -> Self {
510        self.config.max_turns = Some(max);
511        self
512    }
513
514    /// Enable parallel tool execution (convenience for setting `config.parallel_tool_execution`).
515    #[must_use]
516    pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
517        self.config.parallel_tool_execution = parallel;
518        self
519    }
520
521    /// Add an observability hook.
522    #[must_use]
523    pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
524        self.hooks.push(BoxedHook::new(hook));
525        self
526    }
527
528    /// Set the durable context for crash-recoverable execution.
529    #[must_use]
530    pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
531        self.durability = Some(BoxedDurable::new(durable));
532        self
533    }
534
535    /// Build the [`AgentLoop`].
536    #[must_use]
537    pub fn build(self) -> AgentLoop<P, C> {
538        AgentLoop {
539            provider: self.provider,
540            tools: self.tools,
541            context: self.context,
542            hooks: self.hooks,
543            durability: self.durability,
544            config: self.config,
545            messages: Vec::new(),
546        }
547    }
548}
549
550// --- Hook firing helpers ---
551
552/// Fire all hooks for a PreLlmCall event, returning the first non-Continue action.
553pub(crate) async fn fire_pre_llm_hooks(
554    hooks: &[BoxedHook],
555    request: &CompletionRequest,
556) -> Result<Option<HookAction>, LoopError> {
557    for hook in hooks {
558        let action = hook
559            .fire(HookEvent::PreLlmCall { request })
560            .await
561            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
562        if !matches!(action, HookAction::Continue) {
563            return Ok(Some(action));
564        }
565    }
566    Ok(None)
567}
568
569/// Fire all hooks for a PostLlmCall event, returning the first non-Continue action.
570pub(crate) async fn fire_post_llm_hooks(
571    hooks: &[BoxedHook],
572    response: &CompletionResponse,
573) -> Result<Option<HookAction>, LoopError> {
574    for hook in hooks {
575        let action = hook
576            .fire(HookEvent::PostLlmCall { response })
577            .await
578            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
579        if !matches!(action, HookAction::Continue) {
580            return Ok(Some(action));
581        }
582    }
583    Ok(None)
584}
585
586/// Fire all hooks for a PreToolExecution event, returning the first non-Continue action.
587pub(crate) async fn fire_pre_tool_hooks(
588    hooks: &[BoxedHook],
589    tool_name: &str,
590    input: &serde_json::Value,
591) -> Result<Option<HookAction>, LoopError> {
592    for hook in hooks {
593        let action = hook
594            .fire(HookEvent::PreToolExecution { tool_name, input })
595            .await
596            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
597        if !matches!(action, HookAction::Continue) {
598            return Ok(Some(action));
599        }
600    }
601    Ok(None)
602}
603
604/// Fire all hooks for a PostToolExecution event, returning the first non-Continue action.
605pub(crate) async fn fire_post_tool_hooks(
606    hooks: &[BoxedHook],
607    tool_name: &str,
608    output: &ToolOutput,
609) -> Result<Option<HookAction>, LoopError> {
610    for hook in hooks {
611        let action = hook
612            .fire(HookEvent::PostToolExecution { tool_name, output })
613            .await
614            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
615        if !matches!(action, HookAction::Continue) {
616            return Ok(Some(action));
617        }
618    }
619    Ok(None)
620}
621
622/// Fire all hooks for a LoopIteration event, returning the first non-Continue action.
623pub(crate) async fn fire_loop_iteration_hooks(
624    hooks: &[BoxedHook],
625    turn: usize,
626) -> Result<Option<HookAction>, LoopError> {
627    for hook in hooks {
628        let action = hook
629            .fire(HookEvent::LoopIteration { turn })
630            .await
631            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
632        if !matches!(action, HookAction::Continue) {
633            return Ok(Some(action));
634        }
635    }
636    Ok(None)
637}
638
639/// Fire all hooks for a ContextCompaction event, returning the first non-Continue action.
640pub(crate) async fn fire_compaction_hooks(
641    hooks: &[BoxedHook],
642    old_tokens: usize,
643    new_tokens: usize,
644) -> Result<Option<HookAction>, LoopError> {
645    for hook in hooks {
646        let action = hook
647            .fire(HookEvent::ContextCompaction {
648                old_tokens,
649                new_tokens,
650            })
651            .await
652            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
653        if !matches!(action, HookAction::Continue) {
654            return Ok(Some(action));
655        }
656    }
657    Ok(None)
658}
659
660// --- Utility functions ---
661
662/// Extract text content from a message.
663pub(crate) fn extract_text(message: &Message) -> String {
664    message
665        .content
666        .iter()
667        .filter_map(|block| {
668            if let ContentBlock::Text(text) = block {
669                Some(text.as_str())
670            } else {
671                None
672            }
673        })
674        .collect::<Vec<_>>()
675        .join("")
676}
677
678/// Accumulate token usage from a response into the total.
679pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
680    total.input_tokens += delta.input_tokens;
681    total.output_tokens += delta.output_tokens;
682    if let Some(cache_read) = delta.cache_read_tokens {
683        *total.cache_read_tokens.get_or_insert(0) += cache_read;
684    }
685    if let Some(cache_creation) = delta.cache_creation_tokens {
686        *total.cache_creation_tokens.get_or_insert(0) += cache_creation;
687    }
688    if let Some(reasoning) = delta.reasoning_tokens {
689        *total.reasoning_tokens.get_or_insert(0) += reasoning;
690    }
691}