use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use neuron_tool::ToolRegistry;
use neuron_types::{
ActivityOptions, CompletionRequest, CompletionResponse, ContentBlock, ContentItem,
ContextStrategy, DurableContext, DurableError, HookAction, HookError, HookEvent, LoopError,
Message, ObservabilityHook, Provider, ProviderError, Role, StopReason, TokenUsage, ToolContext,
ToolError, ToolOutput, UsageLimits,
};
use crate::config::LoopConfig;
type HookFuture<'a> = Pin<Box<dyn Future<Output = Result<HookAction, HookError>> + Send + 'a>>;
trait ErasedHook: Send + Sync {
fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a>;
}
impl<H: ObservabilityHook> ErasedHook for H {
fn erased_on_event<'a>(&'a self, event: HookEvent<'a>) -> HookFuture<'a> {
Box::pin(self.on_event(event))
}
}
pub struct BoxedHook(Arc<dyn ErasedHook>);
impl BoxedHook {
#[must_use]
pub fn new<H: ObservabilityHook + 'static>(hook: H) -> Self {
BoxedHook(Arc::new(hook))
}
async fn fire(&self, event: HookEvent<'_>) -> Result<HookAction, HookError> {
self.0.erased_on_event(event).await
}
}
type DurableLlmFuture<'a> =
Pin<Box<dyn Future<Output = Result<CompletionResponse, DurableError>> + Send + 'a>>;
type DurableToolFuture<'a> =
Pin<Box<dyn Future<Output = Result<ToolOutput, DurableError>> + Send + 'a>>;
pub(crate) trait ErasedDurable: Send + Sync {
fn erased_execute_llm_call(
&self,
request: CompletionRequest,
options: ActivityOptions,
) -> DurableLlmFuture<'_>;
fn erased_execute_tool<'a>(
&'a self,
tool_name: &'a str,
input: serde_json::Value,
ctx: &'a ToolContext,
options: ActivityOptions,
) -> DurableToolFuture<'a>;
}
impl<D: DurableContext> ErasedDurable for D {
fn erased_execute_llm_call(
&self,
request: CompletionRequest,
options: ActivityOptions,
) -> DurableLlmFuture<'_> {
Box::pin(self.execute_llm_call(request, options))
}
fn erased_execute_tool<'a>(
&'a self,
tool_name: &'a str,
input: serde_json::Value,
ctx: &'a ToolContext,
options: ActivityOptions,
) -> DurableToolFuture<'a> {
Box::pin(self.execute_tool(tool_name, input, ctx, options))
}
}
pub struct BoxedDurable(pub(crate) Arc<dyn ErasedDurable>);
impl BoxedDurable {
#[must_use]
pub fn new<D: DurableContext + 'static>(durable: D) -> Self {
BoxedDurable(Arc::new(durable))
}
}
#[derive(Debug)]
pub struct AgentResult {
pub response: String,
pub messages: Vec<Message>,
pub usage: TokenUsage,
pub turns: usize,
}
pub(crate) const DEFAULT_ACTIVITY_TIMEOUT: Duration = Duration::from_secs(120);
pub struct AgentLoop<P: Provider, C: ContextStrategy> {
pub(crate) provider: P,
pub(crate) tools: ToolRegistry,
pub(crate) context: C,
pub(crate) hooks: Vec<BoxedHook>,
pub(crate) durability: Option<BoxedDurable>,
pub(crate) config: LoopConfig,
pub(crate) messages: Vec<Message>,
}
impl<P: Provider, C: ContextStrategy> AgentLoop<P, C> {
#[must_use]
pub fn new(provider: P, tools: ToolRegistry, context: C, config: LoopConfig) -> Self {
Self {
provider,
tools,
context,
hooks: Vec::new(),
durability: None,
config,
messages: Vec::new(),
}
}
pub fn add_hook<H: ObservabilityHook + 'static>(&mut self, hook: H) -> &mut Self {
self.hooks.push(BoxedHook::new(hook));
self
}
pub fn set_durability<D: DurableContext + 'static>(&mut self, durable: D) -> &mut Self {
self.durability = Some(BoxedDurable::new(durable));
self
}
#[must_use]
pub fn config(&self) -> &LoopConfig {
&self.config
}
#[must_use]
pub fn messages(&self) -> &[Message] {
&self.messages
}
#[must_use]
pub fn tools_mut(&mut self) -> &mut ToolRegistry {
&mut self.tools
}
#[must_use = "this returns a Result that should be handled"]
pub async fn run(
&mut self,
user_message: Message,
tool_ctx: &ToolContext,
) -> Result<AgentResult, LoopError> {
self.messages.push(user_message);
let mut total_usage = TokenUsage::default();
let mut turns: usize = 0;
let mut request_count: usize = 0;
let mut tool_call_count: usize = 0;
loop {
if tool_ctx.cancellation_token.is_cancelled() {
return Err(LoopError::Cancelled);
}
if let Some(max) = self.config.max_turns
&& turns >= max
{
return Err(LoopError::MaxTurns(max));
}
if let Some(ref limits) = self.config.usage_limits {
check_request_limit(limits, request_count)?;
}
if let Some(HookAction::Terminate { reason }) =
fire_loop_iteration_hooks(&self.hooks, turns).await?
{
return Err(LoopError::HookTerminated(reason));
}
let token_count = self.context.token_estimate(&self.messages);
if self.context.should_compact(&self.messages, token_count) {
let old_tokens = token_count;
self.messages = self.context.compact(self.messages.clone()).await?;
let new_tokens = self.context.token_estimate(&self.messages);
if let Some(HookAction::Terminate { reason }) =
fire_compaction_hooks(&self.hooks, old_tokens, new_tokens).await?
{
return Err(LoopError::HookTerminated(reason));
}
}
let request = CompletionRequest {
model: String::new(), messages: self.messages.clone(),
system: Some(self.config.system_prompt.clone()),
tools: self.tools.definitions(),
..Default::default()
};
if let Some(HookAction::Terminate { reason }) =
fire_pre_llm_hooks(&self.hooks, &request).await?
{
return Err(LoopError::HookTerminated(reason));
}
let response = if let Some(ref durable) = self.durability {
let options = ActivityOptions {
start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
heartbeat_timeout: None,
retry_policy: None,
};
durable
.0
.erased_execute_llm_call(request, options)
.await
.map_err(|e| ProviderError::Other(Box::new(e)))?
} else {
self.provider.complete(request).await?
};
if let Some(HookAction::Terminate { reason }) =
fire_post_llm_hooks(&self.hooks, &response).await?
{
return Err(LoopError::HookTerminated(reason));
}
accumulate_usage(&mut total_usage, &response.usage);
request_count += 1;
turns += 1;
if let Some(ref limits) = self.config.usage_limits {
check_token_limits(limits, &total_usage)?;
}
let tool_calls: Vec<_> = response
.message
.content
.iter()
.filter_map(|block| {
if let ContentBlock::ToolUse { id, name, input } = block {
Some((id.clone(), name.clone(), input.clone()))
} else {
None
}
})
.collect();
self.messages.push(response.message.clone());
if response.stop_reason == StopReason::Compaction {
continue;
}
if tool_calls.is_empty() || response.stop_reason == StopReason::EndTurn {
let response_text = extract_text(&response.message);
return Ok(AgentResult {
response: response_text,
messages: self.messages.clone(),
usage: total_usage,
turns,
});
}
if tool_ctx.cancellation_token.is_cancelled() {
return Err(LoopError::Cancelled);
}
if let Some(ref limits) = self.config.usage_limits {
check_tool_call_limit(limits, tool_call_count, tool_calls.len())?;
}
tool_call_count += tool_calls.len();
let tool_result_blocks = if self.config.parallel_tool_execution && tool_calls.len() > 1
{
let futs = tool_calls.iter().map(|(call_id, tool_name, input)| {
self.execute_single_tool(call_id, tool_name, input, tool_ctx)
});
let results = futures::future::join_all(futs).await;
results.into_iter().collect::<Result<Vec<_>, _>>()?
} else {
let mut blocks = Vec::new();
for (call_id, tool_name, input) in &tool_calls {
blocks.push(
self.execute_single_tool(call_id, tool_name, input, tool_ctx)
.await?,
);
}
blocks
};
self.messages.push(Message {
role: Role::User,
content: tool_result_blocks,
});
}
}
#[must_use = "this returns a Result that should be handled"]
pub async fn run_text(
&mut self,
text: &str,
tool_ctx: &ToolContext,
) -> Result<AgentResult, LoopError> {
let message = Message {
role: Role::User,
content: vec![ContentBlock::Text(text.to_string())],
};
self.run(message, tool_ctx).await
}
pub(crate) async fn execute_single_tool(
&self,
call_id: &str,
tool_name: &str,
input: &serde_json::Value,
tool_ctx: &ToolContext,
) -> Result<ContentBlock, LoopError> {
if let Some(action) = fire_pre_tool_hooks(&self.hooks, tool_name, input).await? {
match action {
HookAction::Terminate { reason } => {
return Err(LoopError::HookTerminated(reason));
}
HookAction::Skip { reason } => {
return Ok(ContentBlock::ToolResult {
tool_use_id: call_id.to_string(),
content: vec![ContentItem::Text(format!("Tool call skipped: {reason}"))],
is_error: true,
});
}
HookAction::Continue => {}
}
}
let result = if let Some(ref durable) = self.durability {
let options = ActivityOptions {
start_to_close_timeout: DEFAULT_ACTIVITY_TIMEOUT,
heartbeat_timeout: None,
retry_policy: None,
};
durable
.0
.erased_execute_tool(tool_name, input.clone(), tool_ctx, options)
.await
.map_err(|e| ToolError::ExecutionFailed(Box::new(e)))?
} else {
self.tools
.execute(tool_name, input.clone(), tool_ctx)
.await?
};
if let Some(HookAction::Terminate { reason }) =
fire_post_tool_hooks(&self.hooks, tool_name, &result).await?
{
return Err(LoopError::HookTerminated(reason));
}
Ok(ContentBlock::ToolResult {
tool_use_id: call_id.to_string(),
content: result.content,
is_error: result.is_error,
})
}
#[must_use]
pub fn builder(provider: P, context: C) -> AgentLoopBuilder<P, C> {
AgentLoopBuilder {
provider,
context,
tools: ToolRegistry::new(),
config: LoopConfig::default(),
hooks: Vec::new(),
durability: None,
}
}
}
pub struct AgentLoopBuilder<P: Provider, C: ContextStrategy> {
provider: P,
context: C,
tools: ToolRegistry,
config: LoopConfig,
hooks: Vec<BoxedHook>,
durability: Option<BoxedDurable>,
}
impl<P: Provider, C: ContextStrategy> AgentLoopBuilder<P, C> {
#[must_use]
pub fn tools(mut self, tools: ToolRegistry) -> Self {
self.tools = tools;
self
}
#[must_use]
pub fn config(mut self, config: LoopConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn system_prompt(mut self, prompt: impl Into<neuron_types::SystemPrompt>) -> Self {
self.config.system_prompt = prompt.into();
self
}
#[must_use]
pub fn max_turns(mut self, max: usize) -> Self {
self.config.max_turns = Some(max);
self
}
#[must_use]
pub fn parallel_tool_execution(mut self, parallel: bool) -> Self {
self.config.parallel_tool_execution = parallel;
self
}
#[must_use]
pub fn usage_limits(mut self, limits: UsageLimits) -> Self {
self.config.usage_limits = Some(limits);
self
}
#[must_use]
pub fn hook<H: ObservabilityHook + 'static>(mut self, hook: H) -> Self {
self.hooks.push(BoxedHook::new(hook));
self
}
#[must_use]
pub fn durability<D: DurableContext + 'static>(mut self, durable: D) -> Self {
self.durability = Some(BoxedDurable::new(durable));
self
}
#[must_use]
pub fn build(self) -> AgentLoop<P, C> {
AgentLoop {
provider: self.provider,
tools: self.tools,
context: self.context,
hooks: self.hooks,
durability: self.durability,
config: self.config,
messages: Vec::new(),
}
}
}
pub(crate) async fn fire_pre_llm_hooks(
hooks: &[BoxedHook],
request: &CompletionRequest,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::PreLlmCall { request })
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) async fn fire_post_llm_hooks(
hooks: &[BoxedHook],
response: &CompletionResponse,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::PostLlmCall { response })
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) async fn fire_pre_tool_hooks(
hooks: &[BoxedHook],
tool_name: &str,
input: &serde_json::Value,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::PreToolExecution { tool_name, input })
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) async fn fire_post_tool_hooks(
hooks: &[BoxedHook],
tool_name: &str,
output: &ToolOutput,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::PostToolExecution { tool_name, output })
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) async fn fire_loop_iteration_hooks(
hooks: &[BoxedHook],
turn: usize,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::LoopIteration { turn })
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) async fn fire_compaction_hooks(
hooks: &[BoxedHook],
old_tokens: usize,
new_tokens: usize,
) -> Result<Option<HookAction>, LoopError> {
for hook in hooks {
let action = hook
.fire(HookEvent::ContextCompaction {
old_tokens,
new_tokens,
})
.await
.map_err(|e| LoopError::HookTerminated(e.to_string()))?;
if !matches!(action, HookAction::Continue) {
return Ok(Some(action));
}
}
Ok(None)
}
pub(crate) fn check_request_limit(
limits: &UsageLimits,
request_count: usize,
) -> Result<(), LoopError> {
if let Some(max) = limits.request_limit
&& request_count >= max
{
return Err(LoopError::UsageLimitExceeded(format!(
"request limit of {max} reached"
)));
}
Ok(())
}
pub(crate) fn check_token_limits(
limits: &UsageLimits,
usage: &TokenUsage,
) -> Result<(), LoopError> {
if let Some(max) = limits.input_tokens_limit
&& usage.input_tokens > max
{
return Err(LoopError::UsageLimitExceeded(format!(
"input token limit of {max} exceeded (used {})",
usage.input_tokens
)));
}
if let Some(max) = limits.output_tokens_limit
&& usage.output_tokens > max
{
return Err(LoopError::UsageLimitExceeded(format!(
"output token limit of {max} exceeded (used {})",
usage.output_tokens
)));
}
if let Some(max) = limits.total_tokens_limit {
let total = usage.input_tokens + usage.output_tokens;
if total > max {
return Err(LoopError::UsageLimitExceeded(format!(
"total token limit of {max} exceeded (used {total})"
)));
}
}
Ok(())
}
pub(crate) fn check_tool_call_limit(
limits: &UsageLimits,
current_count: usize,
new_calls: usize,
) -> Result<(), LoopError> {
if let Some(max) = limits.tool_calls_limit
&& current_count + new_calls > max
{
return Err(LoopError::UsageLimitExceeded(format!(
"tool call limit of {max} would be exceeded ({} + {new_calls} calls)",
current_count
)));
}
Ok(())
}
pub(crate) fn extract_text(message: &Message) -> String {
message
.content
.iter()
.filter_map(|block| {
if let ContentBlock::Text(text) = block {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("")
}
pub(crate) fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
total.input_tokens += delta.input_tokens;
total.output_tokens += delta.output_tokens;
if let Some(cache_read) = delta.cache_read_tokens {
*total.cache_read_tokens.get_or_insert(0) += cache_read;
}
if let Some(cache_creation) = delta.cache_creation_tokens {
*total.cache_creation_tokens.get_or_insert(0) += cache_creation;
}
if let Some(reasoning) = delta.reasoning_tokens {
*total.reasoning_tokens.get_or_insert(0) += reasoning;
}
}