1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7
8use neuron_tool::ToolRegistry;
9use neuron_types::{
10 ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
11 ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
12 Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
13 ToolError, ToolOutput,
14};
15
16use crate::config::LoopConfig;
17
18type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
22
23trait ErasedHook: Send + Sync {
25 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
26}
27
28impl<H: ObservabilityHook> ErasedHook for H {
29 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
30 Box::pin(self.on_event(event))
31 }
32}
33
34pub struct BoxedHook(Arc<dyn ErasedHook>);
38
39impl BoxedHook {
40 #[must_use]
42 pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
43 BoxedHook(Arc::new(hook))
44 }
45
46 async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
48 self.0.erased_on_event(event).await
49 }
50}
51
52type DurableLlmFuture<'a> =
56 Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
57
58type DurableToolFuture<'a> =
60 Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
61
62pub(crate) trait ErasedDurable: Send + Sync {
64 fn erased_execute_llm_call(
65 &self,
66 request: CompletionRequest,
67 options: ActivityOptions,
68 ) -> DurableLlmFuture<'_>;
69
70 fn erased_execute_tool<'a>(
71 &'a self,
72 tool_name: &'a str,
73 input: serde_json::Value,
74 ctx: &'a ToolContext,
75 options: ActivityOptions,
76 ) -> DurableToolFuture<'a>;
77}
78
79impl<D: DurableContext> ErasedDurable for D {
80 fn erased_execute_llm_call(
81 &self,
82 request: CompletionRequest,
83 options: ActivityOptions,
84 ) -> DurableLlmFuture<'_> {
85 Box::pin(self.execute_llm_call(request, options))
86 }
87
88 fn erased_execute_tool<'a>(
89 &'a self,
90 tool_name: &'a str,
91 input: serde_json::Value,
92 ctx: &'a ToolContext,
93 options: ActivityOptions,
94 ) -> DurableToolFuture<'a> {
95 Box::pin(self.execute_tool(tool_name, input, ctx, options))
96 }
97}
98
99pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
103
104impl BoxedDurable {
105 #[must_use]
107 pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
108 BoxedDurable(Arc::new(durable))
109 }
110}
111
112#[derive(Debug)]
116pub struct AgentResult {
117 pub response: String,
119 pub messages: Vec<Message>,
121 pub usage: TokenUsage,
123 pub turns: usize,
125}
126
127pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
131
132pub struct AgentLoop<P: Provider, C: ContextStrategy> {
137 pub(crate) provider: P,
138 pub(crate) tools: ToolRegistry,
139 pub(crate) context: C,
140 pub(crate) hooks: Vec<BoxedHook>,
141 pub(crate) durability: Option<BoxedDurable>,
142 pub(crate) config: LoopConfig,
143 pub(crate) messages: Vec<Message>,
144}
145
146impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
147 #[must_use]
150 pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
151 Self {
152 provider,
153 tools,
154 context,
155 hooks: Vec::new(),
156 durability: None,
157 config,
158 messages: Vec::new(),
159 }
160 }
161
162 pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
166 self.hooks.push(BoxedHook::new(hook));
167 self
168 }
169
170 pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
176 self.durability = Some(BoxedDurable::new(durable));
177 self
178 }
179
180 #[must_use]
182 pub fn config(&self) -> &LoopConfig {
183 &self.config
184 }
185
186 #[must_use]
188 pub fn messages(&self) -> &[Message] {
189 &self.messages
190 }
191
192 #[must_use]
194 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
195 &mut self.tools
196 }
197
198 #[must_use = "this returns a Result that should be handled"]
219 pub async fn run(
220 &mut self,
221 user_message: Message,
222 tool_ctx: &ToolContext,
223 ) -> Result<AgentResult, LoopError> {
224 self.messages.push(user_message);
225
226 let mut total_usage = TokenUsage::default();
227 let mut turns: usize = 0;
228
229 loop {
230 if tool_ctx.cancellation_token.is_cancelled() {
232 return Err(LoopError::Cancelled);
233 }
234
235 if let Some(max) = self.config.max_turns
237 && turns >= max
238 {
239 return Err(LoopError::MaxTurns(max));
240 }
241
242 if let Some(HookAction::Terminate { reason }) =
244 fire_loop_iteration_hooks(&self.hooks, turns).await?
245 {
246 return Err(LoopError::HookTerminated(reason));
247 }
248
249 let token_count = self.context.token_estimate(&self.messages);
251 if self.context.should_compact(&self.messages, token_count) {
252 let old_tokens = token_count;
253 self.messages = self.context.compact(self.messages.clone()).await?;
254 let new_tokens = self.context.token_estimate(&self.messages);
255
256 if let Some(HookAction::Terminate { reason }) =
258 fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
259 {
260 return Err(LoopError::HookTerminated(reason));
261 }
262 }
263
264 let request = CompletionRequest {
266 model: String::new(), messages: self.messages.clone(),
268 system: Some(self.config.system_prompt.clone()),
269 tools: self.tools.definitions(),
270 ..Default::default()
271 };
272
273 if let Some(HookAction::Terminate { reason }) =
275 fire_pre_llm_hooks(&self.hooks, &request).await?
276 {
277 return Err(LoopError::HookTerminated(reason));
278 }
279
280 let response = if let Some(ref durable) = self.durability {
282 let options = ActivityOptions {
283 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
284 heartbeat_timeout: None,
285 retry_policy: None,
286 };
287 durable
288 .0
289 .erased_execute_llm_call(request, options)
290 .await
291 .map_err(|e| ProviderError::Other(Box::new(e)))?
292 } else {
293 self.provider.complete(request).await?
294 };
295
296 if let Some(HookAction::Terminate { reason }) =
298 fire_post_llm_hooks(&self.hooks, &response).await?
299 {
300 return Err(LoopError::HookTerminated(reason));
301 }
302
303 accumulate_usage(&mut total_usage, &response.usage);
305 turns += 1;
306
307 let tool_calls: Vec<_> = response
309 .message
310 .content
311 .iter()
312 .filter_map(|block| {
313 if let ContentBlock::ToolUse { id, name, input } = block {
314 Some((id.clone(), name.clone(), input.clone()))
315 } else {
316 None
317 }
318 })
319 .collect();
320
321 self.messages.push(response.message.clone());
323
324 if response.stop_reason == StopReason::Compaction {
327 continue;
328 }
329
330 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
331 let response_text = extract_text(&response.message);
333 return Ok(AgentResult {
334 response: response_text,
335 messages: self.messages.clone(),
336 usage: total_usage,
337 turns,
338 });
339 }
340
341 if tool_ctx.cancellation_token.is_cancelled() {
343 return Err(LoopError::Cancelled);
344 }
345
346 let tool_result_blocks = if self.config.parallel_tool_execution && tool_calls.len() > 1 {
348 let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
349 self.execute_single_tool(call_id, tool_name, input, tool_ctx)
350 });
351 let results = futures::future::join_all(futs).await;
352 results.into_iter().collect::<Result<Vec<_>, _>>()?
353 } else {
354 let mut blocks = Vec::new();
355 for (call_id, tool_name, input) in &tool_calls {
356 blocks.push(self.execute_single_tool(call_id, tool_name, input, tool_ctx).await?);
357 }
358 blocks
359 };
360
361 self.messages.push(Message {
363 role: Role::User,
364 content: tool_result_blocks,
365 });
366 }
367 }
368
369 #[must_use = "this returns a Result that should be handled"]
374 pub async fn run_text(
375 &mut self,
376 text: &str,
377 tool_ctx: &ToolContext,
378 ) -> Result<AgentResult, LoopError> {
379 let message = Message {
380 role: Role::User,
381 content: vec![ContentBlock::Text(text.to_string())],
382 };
383 self.run(message, tool_ctx).await
384 }
385
386 pub(crate) async fn execute_single_tool(
390 &self,
391 call_id: &str,
392 tool_name: &str,
393 input: &serde_json::Value,
394 tool_ctx: &ToolContext,
395 ) -> Result<ContentBlock, LoopError> {
396 if let Some(action) = fire_pre_tool_hooks(&self.hooks, tool_name, input).await? {
398 match action {
399 HookAction::Terminate { reason } => {
400 return Err(LoopError::HookTerminated(reason));
401 }
402 HookAction::Skip { reason } => {
403 return Ok(ContentBlock::ToolResult {
404 tool_use_id: call_id.to_string(),
405 content: vec![ContentItem::Text(format!("Tool call skipped: {reason}"))],
406 is_error: true,
407 });
408 }
409 HookAction::Continue => {}
410 }
411 }
412
413 let result = if let Some(ref durable) = self.durability {
415 let options = ActivityOptions {
416 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
417 heartbeat_timeout: None,
418 retry_policy: None,
419 };
420 durable
421 .0
422 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
423 .await
424 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
425 } else {
426 self.tools.execute(tool_name, input.clone(), tool_ctx).await?
427 };
428
429 if let Some(HookAction::Terminate { reason }) =
431 fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
432 {
433 return Err(LoopError::HookTerminated(reason));
434 }
435
436 Ok(ContentBlock::ToolResult {
437 tool_use_id: call_id.to_string(),
438 content: result.content,
439 is_error: result.is_error,
440 })
441 }
442
443 #[must_use]
450 pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
451 AgentLoopBuilder {
452 provider,
453 context,
454 tools: ToolRegistry::new(),
455 config: LoopConfig::default(),
456 hooks: Vec::new(),
457 durability: None,
458 }
459 }
460}
461
462pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
477 provider: P,
478 context: C,
479 tools: ToolRegistry,
480 config: LoopConfig,
481 hooks: Vec<BoxedHook>,
482 durability: Option<BoxedDurable>,
483}
484
485impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
486 #[must_use]
488 pub fn tools(mut self, tools: ToolRegistry) -> Self {
489 self.tools = tools;
490 self
491 }
492
493 #[must_use]
495 pub fn config(mut self, config: LoopConfig) -> Self {
496 self.config = config;
497 self
498 }
499
500 #[must_use]
502 pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
503 self.config.system_prompt = prompt.into();
504 self
505 }
506
507 #[must_use]
509 pub fn max_turns(mut self, max: usize) -> Self {
510 self.config.max_turns = Some(max);
511 self
512 }
513
514 #[must_use]
516 pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
517 self.config.parallel_tool_execution = parallel;
518 self
519 }
520
521 #[must_use]
523 pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
524 self.hooks.push(BoxedHook::new(hook));
525 self
526 }
527
528 #[must_use]
530 pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
531 self.durability = Some(BoxedDurable::new(durable));
532 self
533 }
534
535 #[must_use]
537 pub fn build(self) -> AgentLoop<P, C> {
538 AgentLoop {
539 provider: self.provider,
540 tools: self.tools,
541 context: self.context,
542 hooks: self.hooks,
543 durability: self.durability,
544 config: self.config,
545 messages: Vec::new(),
546 }
547 }
548}
549
550pub(crate) async fn fire_pre_llm_hooks(
554 hooks: &[BoxedHook],
555 request: &CompletionRequest,
556) -> Result<Option<HookAction>, LoopError> {
557 for hook in hooks {
558 let action = hook
559 .fire(HookEvent::PreLlmCall { request })
560 .await
561 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
562 if !matches!(action, HookAction::Continue) {
563 return Ok(Some(action));
564 }
565 }
566 Ok(None)
567}
568
569pub(crate) async fn fire_post_llm_hooks(
571 hooks: &[BoxedHook],
572 response: &CompletionResponse,
573) -> Result<Option<HookAction>, LoopError> {
574 for hook in hooks {
575 let action = hook
576 .fire(HookEvent::PostLlmCall { response })
577 .await
578 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
579 if !matches!(action, HookAction::Continue) {
580 return Ok(Some(action));
581 }
582 }
583 Ok(None)
584}
585
586pub(crate) async fn fire_pre_tool_hooks(
588 hooks: &[BoxedHook],
589 tool_name: &str,
590 input: &serde_json::Value,
591) -> Result<Option<HookAction>, LoopError> {
592 for hook in hooks {
593 let action = hook
594 .fire(HookEvent::PreToolExecution { tool_name, input })
595 .await
596 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
597 if !matches!(action, HookAction::Continue) {
598 return Ok(Some(action));
599 }
600 }
601 Ok(None)
602}
603
604pub(crate) async fn fire_post_tool_hooks(
606 hooks: &[BoxedHook],
607 tool_name: &str,
608 output: &ToolOutput,
609) -> Result<Option<HookAction>, LoopError> {
610 for hook in hooks {
611 let action = hook
612 .fire(HookEvent::PostToolExecution { tool_name, output })
613 .await
614 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
615 if !matches!(action, HookAction::Continue) {
616 return Ok(Some(action));
617 }
618 }
619 Ok(None)
620}
621
622pub(crate) async fn fire_loop_iteration_hooks(
624 hooks: &[BoxedHook],
625 turn: usize,
626) -> Result<Option<HookAction>, LoopError> {
627 for hook in hooks {
628 let action = hook
629 .fire(HookEvent::LoopIteration { turn })
630 .await
631 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
632 if !matches!(action, HookAction::Continue) {
633 return Ok(Some(action));
634 }
635 }
636 Ok(None)
637}
638
639pub(crate) async fn fire_compaction_hooks(
641 hooks: &[BoxedHook],
642 old_tokens: usize,
643 new_tokens: usize,
644) -> Result<Option<HookAction>, LoopError> {
645 for hook in hooks {
646 let action = hook
647 .fire(HookEvent::ContextCompaction {
648 old_tokens,
649 new_tokens,
650 })
651 .await
652 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
653 if !matches!(action, HookAction::Continue) {
654 return Ok(Some(action));
655 }
656 }
657 Ok(None)
658}
659
660pub(crate) fn extract_text(message: &Message) -> String {
664 message
665 .content
666 .iter()
667 .filter_map(|block| {
668 if let ContentBlock::Text(text) = block {
669 Some(text.as_str())
670 } else {
671 None
672 }
673 })
674 .collect::<Vec<_>>()
675 .join("")
676}
677
678pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
680 total.input_tokens += delta.input_tokens;
681 total.output_tokens += delta.output_tokens;
682 if let Some(cache_read) = delta.cache_read_tokens {
683 *total.cache_read_tokens.get_or_insert(0) += cache_read;
684 }
685 if let Some(cache_creation) = delta.cache_creation_tokens {
686 *total.cache_creation_tokens.get_or_insert(0) += cache_creation;
687 }
688 if let Some(reasoning) = delta.reasoning_tokens {
689 *total.reasoning_tokens.get_or_insert(0) += reasoning;
690 }
691}