Skip to main content

neuron_loop/
loop_impl.rs

1//! Core AgentLoop struct and run methods.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7
8use neuron_tool::ToolRegistry;
9use neuron_types::{
10    ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
11    ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
12    Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
13    ToolError, ToolOutput,
14};
15
16use crate::config::LoopConfig;
17
18// --- Type erasure for ObservabilityHook (RPITIT is not dyn-compatible) ---
19
20/// Type alias for a pinned, boxed, Send future returning a HookAction.
21type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
22
23/// Dyn-compatible wrapper for [`ObservabilityHook`].
24trait ErasedHook: Send + Sync {
25    fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
26}
27
28impl<H: ObservabilityHook> ErasedHook for H {
29    fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
30        Box::pin(self.on_event(event))
31    }
32}
33
34/// A type-erased observability hook for use in [`AgentLoop`].
35///
36/// Wraps any [`ObservabilityHook`] into a dyn-compatible form.
37pub struct BoxedHook(Arc<dyn ErasedHook>);
38
39impl BoxedHook {
40    /// Wrap any [`ObservabilityHook`] into a type-erased `BoxedHook`.
41    #[must_use]
42    pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
43        BoxedHook(Arc::new(hook))
44    }
45
46    /// Fire this hook with an event.
47    async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
48        self.0.erased_on_event(event).await
49    }
50}
51
52// --- Type erasure for DurableContext (RPITIT is not dyn-compatible) ---
53
54/// Type alias for durable LLM call future.
55type DurableLlmFuture<'a> =
56    Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
57
58/// Type alias for durable tool call future.
59type DurableToolFuture<'a> =
60    Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
61
62/// Dyn-compatible wrapper for [`DurableContext`].
63pub(crate) trait ErasedDurable: Send + Sync {
64    fn erased_execute_llm_call(
65        &self,
66        request: CompletionRequest,
67        options: ActivityOptions,
68    ) -> DurableLlmFuture<'_>;
69
70    fn erased_execute_tool<'a>(
71        &'a self,
72        tool_name: &'a str,
73        input: serde_json::Value,
74        ctx: &'a ToolContext,
75        options: ActivityOptions,
76    ) -> DurableToolFuture<'a>;
77}
78
79impl<D: DurableContext> ErasedDurable for D {
80    fn erased_execute_llm_call(
81        &self,
82        request: CompletionRequest,
83        options: ActivityOptions,
84    ) -> DurableLlmFuture<'_> {
85        Box::pin(self.execute_llm_call(request, options))
86    }
87
88    fn erased_execute_tool<'a>(
89        &'a self,
90        tool_name: &'a str,
91        input: serde_json::Value,
92        ctx: &'a ToolContext,
93        options: ActivityOptions,
94    ) -> DurableToolFuture<'a> {
95        Box::pin(self.execute_tool(tool_name, input, ctx, options))
96    }
97}
98
99/// A type-erased durable context for use in [`AgentLoop`].
100///
101/// Wraps any [`DurableContext`] into a dyn-compatible form.
102pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
103
104impl BoxedDurable {
105    /// Wrap any [`DurableContext`] into a type-erased `BoxedDurable`.
106    #[must_use]
107    pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
108        BoxedDurable(Arc::new(durable))
109    }
110}
111
112// --- AgentResult ---
113
114/// The result of a completed agent loop run.
115#[derive(Debug)]
116pub struct AgentResult {
117    /// The final text response from the model.
118    pub response: String,
119    /// All messages in the conversation (including tool calls/results).
120    pub messages: Vec<Message>,
121    /// Cumulative token usage across all turns.
122    pub usage: TokenUsage,
123    /// Number of turns completed.
124    pub turns: usize,
125}
126
127// --- AgentLoop ---
128
129/// Default activity timeout for durable execution.
130pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
131
132/// The agentic while loop: drives provider + tool + context interactions.
133///
134/// Generic over `P: Provider` (the LLM backend) and `C: ContextStrategy`
135/// (the compaction strategy). Hooks and durability are optional.
136pub struct AgentLoop<P: Provider, C: ContextStrategy> {
137    pub(crate) provider: P,
138    pub(crate) tools: ToolRegistry,
139    pub(crate) context: C,
140    pub(crate) hooks: Vec<BoxedHook>,
141    pub(crate) durability: Option<BoxedDurable>,
142    pub(crate) config: LoopConfig,
143    pub(crate) messages: Vec<Message>,
144}
145
146impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
147    /// Create a new `AgentLoop` with the given provider, tools, context strategy,
148    /// and configuration.
149    #[must_use]
150    pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
151        Self {
152            provider,
153            tools,
154            context,
155            hooks: Vec::new(),
156            durability: None,
157            config,
158            messages: Vec::new(),
159        }
160    }
161
162    /// Add an observability hook to the loop.
163    ///
164    /// Hooks are called in order of registration at each event point.
165    pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
166        self.hooks.push(BoxedHook::new(hook));
167        self
168    }
169
170    /// Set the durable context for crash-recoverable execution.
171    ///
172    /// When set, LLM calls and tool executions go through the durable context
173    /// so they can be journaled, replayed, and recovered by engines like
174    /// Temporal, Restate, or Inngest.
175    pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
176        self.durability = Some(BoxedDurable::new(durable));
177        self
178    }
179
180    /// Returns a reference to the current configuration.
181    #[must_use]
182    pub fn config(&self) -> &LoopConfig {
183        &self.config
184    }
185
186    /// Returns a reference to the current messages.
187    #[must_use]
188    pub fn messages(&self) -> &[Message] {
189        &self.messages
190    }
191
192    /// Returns a mutable reference to the tool registry.
193    #[must_use]
194    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
195        &mut self.tools
196    }
197
198    /// Run the agentic loop to completion.
199    ///
200    /// Appends the user message, then loops: call provider, execute tools if
201    /// needed, append results, repeat until the model returns a text-only
202    /// response or the turn limit is reached.
203    ///
204    /// When durability is set, LLM calls go through
205    /// [`DurableContext::execute_llm_call`] and tool calls go through
206    /// [`DurableContext::execute_tool`].
207    ///
208    /// Fires [`HookEvent`] at each step. If a hook returns
209    /// [`HookAction::Terminate`], the loop stops with
210    /// [`LoopError::HookTerminated`].
211    ///
212    /// # Errors
213    ///
214    /// Returns `LoopError::MaxTurns` if the turn limit is exceeded,
215    /// `LoopError::Provider` on provider failures, `LoopError::Tool`
216    /// on tool execution failures, or `LoopError::HookTerminated` if
217    /// a hook requests termination.
218    #[must_use = "this returns a Result that should be handled"]
219    pub async fn run(
220        &mut self,
221        user_message: Message,
222        tool_ctx: &ToolContext,
223    ) -> Result<AgentResult, LoopError> {
224        self.messages.push(user_message);
225
226        let mut total_usage = TokenUsage::default();
227        let mut turns: usize = 0;
228
229        loop {
230            // Check max turns
231            if let Some(max) = self.config.max_turns
232                && turns >= max
233            {
234                return Err(LoopError::MaxTurns(max));
235            }
236
237            // Fire LoopIteration hooks
238            if let Some(HookAction::Terminate { reason }) =
239                fire_loop_iteration_hooks(&self.hooks, turns).await?
240            {
241                return Err(LoopError::HookTerminated(reason));
242            }
243
244            // Check context compaction
245            let token_count = self.context.token_estimate(&self.messages);
246            if self.context.should_compact(&self.messages, token_count) {
247                let old_tokens = token_count;
248                self.messages = self.context.compact(self.messages.clone()).await?;
249                let new_tokens = self.context.token_estimate(&self.messages);
250
251                // Fire ContextCompaction hooks
252                if let Some(HookAction::Terminate { reason }) =
253                    fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
254                {
255                    return Err(LoopError::HookTerminated(reason));
256                }
257            }
258
259            // Build completion request
260            let request = CompletionRequest {
261                model: String::new(), // Provider decides the model
262                messages: self.messages.clone(),
263                system: Some(self.config.system_prompt.clone()),
264                tools: self.tools.definitions(),
265                max_tokens: None,
266                temperature: None,
267                top_p: None,
268                stop_sequences: vec![],
269                tool_choice: None,
270                response_format: None,
271                thinking: None,
272                reasoning_effort: None,
273                extra: None,
274            };
275
276            // Fire PreLlmCall hooks
277            if let Some(HookAction::Terminate { reason }) =
278                fire_pre_llm_hooks(&self.hooks, &request).await?
279            {
280                return Err(LoopError::HookTerminated(reason));
281            }
282
283            // Call provider (via durability wrapper if present)
284            let response = if let Some(ref durable) = self.durability {
285                let options = ActivityOptions {
286                    start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
287                    heartbeat_timeout: None,
288                    retry_policy: None,
289                };
290                durable
291                    .0
292                    .erased_execute_llm_call(request, options)
293                    .await
294                    .map_err(|e| ProviderError::Other(Box::new(e)))?
295            } else {
296                self.provider.complete(request).await?
297            };
298
299            // Fire PostLlmCall hooks
300            if let Some(HookAction::Terminate { reason }) =
301                fire_post_llm_hooks(&self.hooks, &response).await?
302            {
303                return Err(LoopError::HookTerminated(reason));
304            }
305
306            // Accumulate usage
307            accumulate_usage(&mut total_usage, &response.usage);
308            turns += 1;
309
310            // Check for tool calls in the response
311            let tool_calls: Vec<_> = response
312                .message
313                .content
314                .iter()
315                .filter_map(|block| {
316                    if let ContentBlock::ToolUse { id, name, input } = block {
317                        Some((id.clone(), name.clone(), input.clone()))
318                    } else {
319                        None
320                    }
321                })
322                .collect();
323
324            // Append assistant message to conversation
325            self.messages.push(response.message.clone());
326
327            if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
328                // No tool calls — extract text and return
329                let response_text = extract_text(&response.message);
330                return Ok(AgentResult {
331                    response: response_text,
332                    messages: self.messages.clone(),
333                    usage: total_usage,
334                    turns,
335                });
336            }
337
338            // Execute tool calls and collect results
339            let mut tool_result_blocks = Vec::new();
340            for (call_id, tool_name, input) in &tool_calls {
341                // Fire PreToolExecution hooks
342                if let Some(action) =
343                    fire_pre_tool_hooks(&self.hooks, tool_name, input).await?
344                {
345                    match action {
346                        HookAction::Terminate { reason } => {
347                            return Err(LoopError::HookTerminated(reason));
348                        }
349                        HookAction::Skip { reason } => {
350                            // Skip the tool call and return a rejection message
351                            tool_result_blocks.push(ContentBlock::ToolResult {
352                                tool_use_id: call_id.clone(),
353                                content: vec![ContentItem::Text(format!(
354                                    "Tool call skipped: {reason}"
355                                ))],
356                                is_error: true,
357                            });
358                            continue;
359                        }
360                        HookAction::Continue => {}
361                    }
362                }
363
364                // Execute tool (via durability wrapper if present)
365                let result = if let Some(ref durable) = self.durability {
366                    let options = ActivityOptions {
367                        start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
368                        heartbeat_timeout: None,
369                        retry_policy: None,
370                    };
371                    durable
372                        .0
373                        .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
374                        .await
375                        .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
376                } else {
377                    self.tools.execute(tool_name, input.clone(), tool_ctx).await?
378                };
379
380                // Fire PostToolExecution hooks
381                if let Some(HookAction::Terminate { reason }) =
382                    fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
383                {
384                    return Err(LoopError::HookTerminated(reason));
385                }
386
387                tool_result_blocks.push(ContentBlock::ToolResult {
388                    tool_use_id: call_id.clone(),
389                    content: result.content,
390                    is_error: result.is_error,
391                });
392            }
393
394            // Append tool results as a user message
395            self.messages.push(Message {
396                role: Role::User,
397                content: tool_result_blocks,
398            });
399        }
400    }
401
402    /// Convenience method to run the loop with a plain text message.
403    ///
404    /// Wraps `text` into a `Message { role: User, content: [Text(text)] }`
405    /// and calls [`run`](Self::run).
406    #[must_use = "this returns a Result that should be handled"]
407    pub async fn run_text(
408        &mut self,
409        text: &str,
410        tool_ctx: &ToolContext,
411    ) -> Result<AgentResult, LoopError> {
412        let message = Message {
413            role: Role::User,
414            content: vec![ContentBlock::Text(text.to_string())],
415        };
416        self.run(message, tool_ctx).await
417    }
418
419    /// Create a builder with the required provider and context strategy.
420    ///
421    /// All other options have sensible defaults:
422    /// - Empty tool registry
423    /// - Default loop config (no turn limit, empty system prompt)
424    /// - No hooks or durability
425    #[must_use]
426    pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
427        AgentLoopBuilder {
428            provider,
429            context,
430            tools: ToolRegistry::new(),
431            config: LoopConfig::default(),
432            hooks: Vec::new(),
433            durability: None,
434        }
435    }
436}
437
438/// Builder for constructing an [`AgentLoop`] with optional configuration.
439///
440/// Created via [`AgentLoop::builder`]. Only `provider` and `context` are required;
441/// everything else has sensible defaults.
442///
443/// # Example
444///
445/// ```ignore
446/// let agent = AgentLoop::builder(provider, context)
447///     .tools(tools)
448///     .system_prompt("You are a helpful assistant.")
449///     .max_turns(10)
450///     .build();
451/// ```
452pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
453    provider: P,
454    context: C,
455    tools: ToolRegistry,
456    config: LoopConfig,
457    hooks: Vec<BoxedHook>,
458    durability: Option<BoxedDurable>,
459}
460
461impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
462    /// Set the tool registry.
463    #[must_use]
464    pub fn tools(mut self, tools: ToolRegistry) -> Self {
465        self.tools = tools;
466        self
467    }
468
469    /// Set the full loop configuration.
470    #[must_use]
471    pub fn config(mut self, config: LoopConfig) -> Self {
472        self.config = config;
473        self
474    }
475
476    /// Set the system prompt (convenience for setting `config.system_prompt`).
477    #[must_use]
478    pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
479        self.config.system_prompt = prompt.into();
480        self
481    }
482
483    /// Set the maximum number of turns (convenience for setting `config.max_turns`).
484    #[must_use]
485    pub fn max_turns(mut self, max: usize) -> Self {
486        self.config.max_turns = Some(max);
487        self
488    }
489
490    /// Enable parallel tool execution (convenience for setting `config.parallel_tool_execution`).
491    #[must_use]
492    pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
493        self.config.parallel_tool_execution = parallel;
494        self
495    }
496
497    /// Add an observability hook.
498    #[must_use]
499    pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
500        self.hooks.push(BoxedHook::new(hook));
501        self
502    }
503
504    /// Set the durable context for crash-recoverable execution.
505    #[must_use]
506    pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
507        self.durability = Some(BoxedDurable::new(durable));
508        self
509    }
510
511    /// Build the [`AgentLoop`].
512    #[must_use]
513    pub fn build(self) -> AgentLoop<P, C> {
514        AgentLoop {
515            provider: self.provider,
516            tools: self.tools,
517            context: self.context,
518            hooks: self.hooks,
519            durability: self.durability,
520            config: self.config,
521            messages: Vec::new(),
522        }
523    }
524}
525
526// --- Hook firing helpers ---
527
528/// Fire all hooks for a PreLlmCall event, returning the first non-Continue action.
529pub(crate) async fn fire_pre_llm_hooks(
530    hooks: &[BoxedHook],
531    request: &CompletionRequest,
532) -> Result<Option<HookAction>, LoopError> {
533    for hook in hooks {
534        let action = hook
535            .fire(HookEvent::PreLlmCall { request })
536            .await
537            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
538        if !matches!(action, HookAction::Continue) {
539            return Ok(Some(action));
540        }
541    }
542    Ok(None)
543}
544
545/// Fire all hooks for a PostLlmCall event, returning the first non-Continue action.
546pub(crate) async fn fire_post_llm_hooks(
547    hooks: &[BoxedHook],
548    response: &CompletionResponse,
549) -> Result<Option<HookAction>, LoopError> {
550    for hook in hooks {
551        let action = hook
552            .fire(HookEvent::PostLlmCall { response })
553            .await
554            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
555        if !matches!(action, HookAction::Continue) {
556            return Ok(Some(action));
557        }
558    }
559    Ok(None)
560}
561
562/// Fire all hooks for a PreToolExecution event, returning the first non-Continue action.
563pub(crate) async fn fire_pre_tool_hooks(
564    hooks: &[BoxedHook],
565    tool_name: &str,
566    input: &serde_json::Value,
567) -> Result<Option<HookAction>, LoopError> {
568    for hook in hooks {
569        let action = hook
570            .fire(HookEvent::PreToolExecution { tool_name, input })
571            .await
572            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
573        if !matches!(action, HookAction::Continue) {
574            return Ok(Some(action));
575        }
576    }
577    Ok(None)
578}
579
580/// Fire all hooks for a PostToolExecution event, returning the first non-Continue action.
581pub(crate) async fn fire_post_tool_hooks(
582    hooks: &[BoxedHook],
583    tool_name: &str,
584    output: &ToolOutput,
585) -> Result<Option<HookAction>, LoopError> {
586    for hook in hooks {
587        let action = hook
588            .fire(HookEvent::PostToolExecution { tool_name, output })
589            .await
590            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
591        if !matches!(action, HookAction::Continue) {
592            return Ok(Some(action));
593        }
594    }
595    Ok(None)
596}
597
598/// Fire all hooks for a LoopIteration event, returning the first non-Continue action.
599pub(crate) async fn fire_loop_iteration_hooks(
600    hooks: &[BoxedHook],
601    turn: usize,
602) -> Result<Option<HookAction>, LoopError> {
603    for hook in hooks {
604        let action = hook
605            .fire(HookEvent::LoopIteration { turn })
606            .await
607            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
608        if !matches!(action, HookAction::Continue) {
609            return Ok(Some(action));
610        }
611    }
612    Ok(None)
613}
614
615/// Fire all hooks for a ContextCompaction event, returning the first non-Continue action.
616pub(crate) async fn fire_compaction_hooks(
617    hooks: &[BoxedHook],
618    old_tokens: usize,
619    new_tokens: usize,
620) -> Result<Option<HookAction>, LoopError> {
621    for hook in hooks {
622        let action = hook
623            .fire(HookEvent::ContextCompaction {
624                old_tokens,
625                new_tokens,
626            })
627            .await
628            .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
629        if !matches!(action, HookAction::Continue) {
630            return Ok(Some(action));
631        }
632    }
633    Ok(None)
634}
635
636// --- Utility functions ---
637
638/// Extract text content from a message.
639pub(crate) fn extract_text(message: &Message) -> String {
640    message
641        .content
642        .iter()
643        .filter_map(|block| {
644            if let ContentBlock::Text(text) = block {
645                Some(text.as_str())
646            } else {
647                None
648            }
649        })
650        .collect::<Vec<_>>()
651        .join("")
652}
653
654/// Accumulate token usage from a response into the total.
655pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
656    total.input_tokens += delta.input_tokens;
657    total.output_tokens += delta.output_tokens;
658    if let Some(cache_read) = delta.cache_read_tokens {
659        *total.cache_read_tokens.get_or_insert(0) += cache_read;
660    }
661    if let Some(cache_creation) = delta.cache_creation_tokens {
662        *total.cache_creation_tokens.get_or_insert(0) += cache_creation;
663    }
664    if let Some(reasoning) = delta.reasoning_tokens {
665        *total.reasoning_tokens.get_or_insert(0) += reasoning;
666    }
667}