1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::Duration;
7
8use neuron_tool::ToolRegistry;
9use neuron_types::{
10 ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
11 ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
12 Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
13 ToolError, ToolOutput,
14};
15
16use crate::config::LoopConfig;
17
18type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
22
23trait ErasedHook: Send + Sync {
25 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
26}
27
28impl<H: ObservabilityHook> ErasedHook for H {
29 fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
30 Box::pin(self.on_event(event))
31 }
32}
33
34pub struct BoxedHook(Arc<dyn ErasedHook>);
38
39impl BoxedHook {
40 #[must_use]
42 pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
43 BoxedHook(Arc::new(hook))
44 }
45
46 async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
48 self.0.erased_on_event(event).await
49 }
50}
51
52type DurableLlmFuture<'a> =
56 Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
57
58type DurableToolFuture<'a> =
60 Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
61
62pub(crate) trait ErasedDurable: Send + Sync {
64 fn erased_execute_llm_call(
65 &self,
66 request: CompletionRequest,
67 options: ActivityOptions,
68 ) -> DurableLlmFuture<'_>;
69
70 fn erased_execute_tool<'a>(
71 &'a self,
72 tool_name: &'a str,
73 input: serde_json::Value,
74 ctx: &'a ToolContext,
75 options: ActivityOptions,
76 ) -> DurableToolFuture<'a>;
77}
78
79impl<D: DurableContext> ErasedDurable for D {
80 fn erased_execute_llm_call(
81 &self,
82 request: CompletionRequest,
83 options: ActivityOptions,
84 ) -> DurableLlmFuture<'_> {
85 Box::pin(self.execute_llm_call(request, options))
86 }
87
88 fn erased_execute_tool<'a>(
89 &'a self,
90 tool_name: &'a str,
91 input: serde_json::Value,
92 ctx: &'a ToolContext,
93 options: ActivityOptions,
94 ) -> DurableToolFuture<'a> {
95 Box::pin(self.execute_tool(tool_name, input, ctx, options))
96 }
97}
98
99pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
103
104impl BoxedDurable {
105 #[must_use]
107 pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
108 BoxedDurable(Arc::new(durable))
109 }
110}
111
112#[derive(Debug)]
116pub struct AgentResult {
117 pub response: String,
119 pub messages: Vec<Message>,
121 pub usage: TokenUsage,
123 pub turns: usize,
125}
126
127pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
131
132pub struct AgentLoop<P: Provider, C: ContextStrategy> {
137 pub(crate) provider: P,
138 pub(crate) tools: ToolRegistry,
139 pub(crate) context: C,
140 pub(crate) hooks: Vec<BoxedHook>,
141 pub(crate) durability: Option<BoxedDurable>,
142 pub(crate) config: LoopConfig,
143 pub(crate) messages: Vec<Message>,
144}
145
146impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
147 #[must_use]
150 pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
151 Self {
152 provider,
153 tools,
154 context,
155 hooks: Vec::new(),
156 durability: None,
157 config,
158 messages: Vec::new(),
159 }
160 }
161
162 pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
166 self.hooks.push(BoxedHook::new(hook));
167 self
168 }
169
170 pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
176 self.durability = Some(BoxedDurable::new(durable));
177 self
178 }
179
180 #[must_use]
182 pub fn config(&self) -> &LoopConfig {
183 &self.config
184 }
185
186 #[must_use]
188 pub fn messages(&self) -> &[Message] {
189 &self.messages
190 }
191
192 #[must_use]
194 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
195 &mut self.tools
196 }
197
198 #[must_use = "this returns a Result that should be handled"]
219 pub async fn run(
220 &mut self,
221 user_message: Message,
222 tool_ctx: &ToolContext,
223 ) -> Result<AgentResult, LoopError> {
224 self.messages.push(user_message);
225
226 let mut total_usage = TokenUsage::default();
227 let mut turns: usize = 0;
228
229 loop {
230 if let Some(max) = self.config.max_turns
232 && turns >= max
233 {
234 return Err(LoopError::MaxTurns(max));
235 }
236
237 if let Some(HookAction::Terminate { reason }) =
239 fire_loop_iteration_hooks(&self.hooks, turns).await?
240 {
241 return Err(LoopError::HookTerminated(reason));
242 }
243
244 let token_count = self.context.token_estimate(&self.messages);
246 if self.context.should_compact(&self.messages, token_count) {
247 let old_tokens = token_count;
248 self.messages = self.context.compact(self.messages.clone()).await?;
249 let new_tokens = self.context.token_estimate(&self.messages);
250
251 if let Some(HookAction::Terminate { reason }) =
253 fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
254 {
255 return Err(LoopError::HookTerminated(reason));
256 }
257 }
258
259 let request = CompletionRequest {
261 model: String::new(), messages: self.messages.clone(),
263 system: Some(self.config.system_prompt.clone()),
264 tools: self.tools.definitions(),
265 max_tokens: None,
266 temperature: None,
267 top_p: None,
268 stop_sequences: vec![],
269 tool_choice: None,
270 response_format: None,
271 thinking: None,
272 reasoning_effort: None,
273 extra: None,
274 };
275
276 if let Some(HookAction::Terminate { reason }) =
278 fire_pre_llm_hooks(&self.hooks, &request).await?
279 {
280 return Err(LoopError::HookTerminated(reason));
281 }
282
283 let response = if let Some(ref durable) = self.durability {
285 let options = ActivityOptions {
286 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
287 heartbeat_timeout: None,
288 retry_policy: None,
289 };
290 durable
291 .0
292 .erased_execute_llm_call(request, options)
293 .await
294 .map_err(|e| ProviderError::Other(Box::new(e)))?
295 } else {
296 self.provider.complete(request).await?
297 };
298
299 if let Some(HookAction::Terminate { reason }) =
301 fire_post_llm_hooks(&self.hooks, &response).await?
302 {
303 return Err(LoopError::HookTerminated(reason));
304 }
305
306 accumulate_usage(&mut total_usage, &response.usage);
308 turns += 1;
309
310 let tool_calls: Vec<_> = response
312 .message
313 .content
314 .iter()
315 .filter_map(|block| {
316 if let ContentBlock::ToolUse { id, name, input } = block {
317 Some((id.clone(), name.clone(), input.clone()))
318 } else {
319 None
320 }
321 })
322 .collect();
323
324 self.messages.push(response.message.clone());
326
327 if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
328 let response_text = extract_text(&response.message);
330 return Ok(AgentResult {
331 response: response_text,
332 messages: self.messages.clone(),
333 usage: total_usage,
334 turns,
335 });
336 }
337
338 let mut tool_result_blocks = Vec::new();
340 for (call_id, tool_name, input) in &tool_calls {
341 if let Some(action) =
343 fire_pre_tool_hooks(&self.hooks, tool_name, input).await?
344 {
345 match action {
346 HookAction::Terminate { reason } => {
347 return Err(LoopError::HookTerminated(reason));
348 }
349 HookAction::Skip { reason } => {
350 tool_result_blocks.push(ContentBlock::ToolResult {
352 tool_use_id: call_id.clone(),
353 content: vec![ContentItem::Text(format!(
354 "Tool call skipped: {reason}"
355 ))],
356 is_error: true,
357 });
358 continue;
359 }
360 HookAction::Continue => {}
361 }
362 }
363
364 let result = if let Some(ref durable) = self.durability {
366 let options = ActivityOptions {
367 start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
368 heartbeat_timeout: None,
369 retry_policy: None,
370 };
371 durable
372 .0
373 .erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
374 .await
375 .map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
376 } else {
377 self.tools.execute(tool_name, input.clone(), tool_ctx).await?
378 };
379
380 if let Some(HookAction::Terminate { reason }) =
382 fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
383 {
384 return Err(LoopError::HookTerminated(reason));
385 }
386
387 tool_result_blocks.push(ContentBlock::ToolResult {
388 tool_use_id: call_id.clone(),
389 content: result.content,
390 is_error: result.is_error,
391 });
392 }
393
394 self.messages.push(Message {
396 role: Role::User,
397 content: tool_result_blocks,
398 });
399 }
400 }
401
402 #[must_use = "this returns a Result that should be handled"]
407 pub async fn run_text(
408 &mut self,
409 text: &str,
410 tool_ctx: &ToolContext,
411 ) -> Result<AgentResult, LoopError> {
412 let message = Message {
413 role: Role::User,
414 content: vec![ContentBlock::Text(text.to_string())],
415 };
416 self.run(message, tool_ctx).await
417 }
418
419 #[must_use]
426 pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
427 AgentLoopBuilder {
428 provider,
429 context,
430 tools: ToolRegistry::new(),
431 config: LoopConfig::default(),
432 hooks: Vec::new(),
433 durability: None,
434 }
435 }
436}
437
438pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
453 provider: P,
454 context: C,
455 tools: ToolRegistry,
456 config: LoopConfig,
457 hooks: Vec<BoxedHook>,
458 durability: Option<BoxedDurable>,
459}
460
461impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
462 #[must_use]
464 pub fn tools(mut self, tools: ToolRegistry) -> Self {
465 self.tools = tools;
466 self
467 }
468
469 #[must_use]
471 pub fn config(mut self, config: LoopConfig) -> Self {
472 self.config = config;
473 self
474 }
475
476 #[must_use]
478 pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
479 self.config.system_prompt = prompt.into();
480 self
481 }
482
483 #[must_use]
485 pub fn max_turns(mut self, max: usize) -> Self {
486 self.config.max_turns = Some(max);
487 self
488 }
489
490 #[must_use]
492 pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
493 self.config.parallel_tool_execution = parallel;
494 self
495 }
496
497 #[must_use]
499 pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
500 self.hooks.push(BoxedHook::new(hook));
501 self
502 }
503
504 #[must_use]
506 pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
507 self.durability = Some(BoxedDurable::new(durable));
508 self
509 }
510
511 #[must_use]
513 pub fn build(self) -> AgentLoop<P, C> {
514 AgentLoop {
515 provider: self.provider,
516 tools: self.tools,
517 context: self.context,
518 hooks: self.hooks,
519 durability: self.durability,
520 config: self.config,
521 messages: Vec::new(),
522 }
523 }
524}
525
526pub(crate) async fn fire_pre_llm_hooks(
530 hooks: &[BoxedHook],
531 request: &CompletionRequest,
532) -> Result<Option<HookAction>, LoopError> {
533 for hook in hooks {
534 let action = hook
535 .fire(HookEvent::PreLlmCall { request })
536 .await
537 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
538 if !matches!(action, HookAction::Continue) {
539 return Ok(Some(action));
540 }
541 }
542 Ok(None)
543}
544
545pub(crate) async fn fire_post_llm_hooks(
547 hooks: &[BoxedHook],
548 response: &CompletionResponse,
549) -> Result<Option<HookAction>, LoopError> {
550 for hook in hooks {
551 let action = hook
552 .fire(HookEvent::PostLlmCall { response })
553 .await
554 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
555 if !matches!(action, HookAction::Continue) {
556 return Ok(Some(action));
557 }
558 }
559 Ok(None)
560}
561
562pub(crate) async fn fire_pre_tool_hooks(
564 hooks: &[BoxedHook],
565 tool_name: &str,
566 input: &serde_json::Value,
567) -> Result<Option<HookAction>, LoopError> {
568 for hook in hooks {
569 let action = hook
570 .fire(HookEvent::PreToolExecution { tool_name, input })
571 .await
572 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
573 if !matches!(action, HookAction::Continue) {
574 return Ok(Some(action));
575 }
576 }
577 Ok(None)
578}
579
580pub(crate) async fn fire_post_tool_hooks(
582 hooks: &[BoxedHook],
583 tool_name: &str,
584 output: &ToolOutput,
585) -> Result<Option<HookAction>, LoopError> {
586 for hook in hooks {
587 let action = hook
588 .fire(HookEvent::PostToolExecution { tool_name, output })
589 .await
590 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
591 if !matches!(action, HookAction::Continue) {
592 return Ok(Some(action));
593 }
594 }
595 Ok(None)
596}
597
598pub(crate) async fn fire_loop_iteration_hooks(
600 hooks: &[BoxedHook],
601 turn: usize,
602) -> Result<Option<HookAction>, LoopError> {
603 for hook in hooks {
604 let action = hook
605 .fire(HookEvent::LoopIteration { turn })
606 .await
607 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
608 if !matches!(action, HookAction::Continue) {
609 return Ok(Some(action));
610 }
611 }
612 Ok(None)
613}
614
615pub(crate) async fn fire_compaction_hooks(
617 hooks: &[BoxedHook],
618 old_tokens: usize,
619 new_tokens: usize,
620) -> Result<Option<HookAction>, LoopError> {
621 for hook in hooks {
622 let action = hook
623 .fire(HookEvent::ContextCompaction {
624 old_tokens,
625 new_tokens,
626 })
627 .await
628 .map_err(|e| LoopError::HookTerminated(e.to_string()))?;
629 if !matches!(action, HookAction::Continue) {
630 return Ok(Some(action));
631 }
632 }
633 Ok(None)
634}
635
636pub(crate) fn extract_text(message: &Message) -> String {
640 message
641 .content
642 .iter()
643 .filter_map(|block| {
644 if let ContentBlock::Text(text) = block {
645 Some(text.as_str())
646 } else {
647 None
648 }
649 })
650 .collect::<Vec<_>>()
651 .join("")
652}
653
654pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
656 total.input_tokens += delta.input_tokens;
657 total.output_tokens += delta.output_tokens;
658 if let Some(cache_read) = delta.cache_read_tokens {
659 *total.cache_read_tokens.get_or_insert(0) += cache_read;
660 }
661 if let Some(cache_creation) = delta.cache_creation_tokens {
662 *total.cache_creation_tokens.get_or_insert(0) += cache_creation;
663 }
664 if let Some(reasoning) = delta.reasoning_tokens {
665 *total.reasoning_tokens.get_or_insert(0) += reasoning;
666 }
667}