Skip to main content

agent_base/engine/runtime/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use crate::llm::LlmClient;
6use crate::skill::{Skill, SkillPrompter};
7use crate::tool::{ToolPolicy, ToolRegistry};
8use crate::types::{AgentConfig, MessageRole, AgentError, CheckpointData, CheckpointStep};
9use tokio::sync::broadcast;
10use tracing::Span;
11
12use crate::types::{AgentResult, AgentEvent, SessionId, RunOutcome};
13use super::approval::ApprovalHandler;
14use super::context::ContextWindowManager;
15use super::middleware::{MiddlewareRef, UserMessageCtx, PreLlmCtx, PostLlmCtx};
16use super::recovery::{ToolErrorAction, ToolErrorRecovery};
17use super::session_store::SessionStore;
18use super::AgentSession;
19
20mod approval_flow;
21mod llm;
22mod tool_exec;
23
24use tool_exec::ToolCallResult;
25
26const DEFAULT_MAX_TURNS: u32 = 50;
27
28pub struct AgentRuntime {
29    pub(crate) client: Arc<dyn LlmClient>,
30    pub(crate) config: AgentConfig,
31    pub(crate) tools: ToolRegistry,
32    pub(crate) approval_handler: Option<Arc<dyn ApprovalHandler>>,
33    pub(crate) tool_policy: Option<Arc<dyn ToolPolicy>>,
34    pub(crate) middlewares: Vec<MiddlewareRef>,
35    pub(crate) event_bus: broadcast::Sender<AgentEvent>,
36    pub(crate) next_session_id: AtomicU64,
37    pub(crate) sessions: HashMap<SessionId, AgentSession>,
38    pub(crate) context_manager: Option<ContextWindowManager>,
39    pub(crate) session_store: Arc<dyn SessionStore>,
40    pub(crate) skills: Vec<Arc<dyn Skill>>,
41    #[allow(dead_code)]
42    pub(crate) skill_prompter: Arc<dyn SkillPrompter>,
43    pub(crate) error_recovery: Arc<dyn ToolErrorRecovery>,
44}
45
46impl AgentRuntime {
47    pub fn create_session(&mut self) -> SessionId {
48        let id = SessionId {
49            id: self.next_session_id.fetch_add(1, Ordering::Relaxed),
50            external_id: None,
51        };
52        let mut session = AgentSession::new(id.clone());
53        if let Some(system_prompt) = self.config.system_prompt.as_deref() {
54            session.push_message(MessageRole::System, system_prompt);
55        }
56        self.sessions.insert(id.clone(), session);
57        id
58    }
59
60    /// Restore an existing session from persistence into runtime memory
61    ///
62    /// On success, the session can be used for continued execution.
63    /// Returns None if not found in persistence layer.
64    pub async fn restore_session(&mut self, session_id: &SessionId) -> Option<&AgentSession> {
65        if self.sessions.contains_key(session_id) {
66            return self.sessions.get(session_id);
67        }
68        match self.session_store.load(session_id).await {
69            Ok(Some(session)) => {
70                self.sessions.insert(session_id.clone(), session);
71                self.sessions.get(session_id)
72            }
73            _ => None,
74        }
75    }
76
77    pub fn session(&self, session_id: &SessionId) -> Option<&AgentSession> {
78        self.sessions.get(session_id)
79    }
80
81    pub fn tools(&self) -> &ToolRegistry {
82        &self.tools
83    }
84
85    pub fn client(&self) -> &Arc<dyn LlmClient> {
86        &self.client
87    }
88
89    pub fn approval_handler(&self) -> Option<&Arc<dyn ApprovalHandler>> {
90        self.approval_handler.as_ref()
91    }
92
93    pub fn tool_policy(&self) -> Option<&Arc<dyn ToolPolicy>> {
94        self.tool_policy.as_ref()
95    }
96
97    pub fn subscribe_events(&self) -> broadcast::Receiver<AgentEvent> {
98        self.event_bus.subscribe()
99    }
100
101    pub fn session_store(&self) -> &Arc<dyn SessionStore> {
102        &self.session_store
103    }
104
105    pub fn skills(&self) -> &[Arc<dyn Skill>] {
106        &self.skills
107    }
108
109    fn cached_approval(&self, session_id: &SessionId, action_key: &str) -> bool {
110        self.sessions
111            .get(session_id)
112            .is_some_and(|session| session.is_action_allowed(action_key))
113    }
114
115    fn cache_approval(&mut self, session_id: &SessionId, action_key: String) {
116        if let Some(session) = self.sessions.get_mut(session_id) {
117            session.allow_action(action_key);
118        }
119    }
120
121    fn emit_event(&self, event: AgentEvent) {
122        let _ = self.event_bus.send(event);
123    }
124
125    fn session_or_err(&self, session_id: &SessionId) -> AgentResult<&AgentSession> {
126        self.sessions
127            .get(session_id)
128            .ok_or_else(|| AgentError::session_not_found(session_id.id))
129    }
130
131    fn session_mut_or_err(&mut self, session_id: &SessionId) -> AgentResult<&mut AgentSession> {
132        self.sessions
133            .get_mut(session_id)
134            .ok_or_else(|| AgentError::session_not_found(session_id.id))
135    }
136
137    fn drain_async_events<F>(
138        event_rx: &mut broadcast::Receiver<AgentEvent>,
139        on_event: &mut F,
140    ) -> AgentResult<()>
141    where
142        F: FnMut(AgentEvent) -> AgentResult<()>,
143    {
144        loop {
145            match event_rx.try_recv() {
146                Ok(event) => on_event(event)?,
147                Err(broadcast::error::TryRecvError::Empty) => break,
148                Err(broadcast::error::TryRecvError::Lagged(_)) => continue,
149                Err(broadcast::error::TryRecvError::Closed) => break,
150            }
151        }
152        Ok(())
153    }
154
155    pub async fn run_turn_with_handler<F>(
156        &mut self,
157        session_id: SessionId,
158        user_input: &str,
159        mut on_event: F,
160    ) -> AgentResult<RunOutcome>
161    where
162        F: FnMut(AgentEvent) -> AgentResult<()>,
163    {
164        let span = Span::current();
165        let _guard = span.enter();
166        tracing::info!(session_id = session_id.id, user_input = %user_input, "agent turn start");
167        drop(_guard);
168
169        let mut event_rx = self.subscribe_events();
170        let tool_definitions = self.tools.definitions();
171
172        let mut user_input_owned = user_input.to_string();
173
174        {
175            let mut ctx = UserMessageCtx {
176                session_id: session_id.clone(),
177                user_input: user_input_owned.clone(),
178                event_bus: self.event_bus.clone(),
179            };
180            for mw in &self.middlewares {
181                mw.on_user_message(&mut ctx).await?;
182            }
183            user_input_owned = ctx.user_input;
184        }
185
186        {
187            let session = self.session_mut_or_err(&session_id)?;
188            session.push_message(MessageRole::User, &user_input_owned);
189        }
190
191        self.emit_event(AgentEvent::Checkpoint {
192            session_id: session_id.clone(),
193            checkpoint: CheckpointData {
194                session_id: session_id.clone(),
195                user_input: user_input_owned.clone(),
196                step: CheckpointStep::AfterUserInput,
197                turn_count: 0,
198            },
199        });
200
201        let max_turns = self.config.max_turns.unwrap_or(DEFAULT_MAX_TURNS);
202        let mut turn_count: u32 = 0;
203
204        loop {
205            turn_count += 1;
206
207            if turn_count > max_turns {
208                self.emit_event(AgentEvent::RunFinished {
209                    session_id: session_id.clone(),
210                });
211                Self::drain_async_events(&mut event_rx, &mut on_event)?;
212                break;
213            }
214
215            Self::drain_async_events(&mut event_rx, &mut on_event)?;
216
217            let turn_span = tracing::info_span!("turn", session_id = session_id.id, turn = turn_count);
218            let _turn_guard = turn_span.enter();
219
220            let mut messages: Vec<_> = self.session_or_err(&session_id)?.chat_messages().to_vec();
221            let mut tools_for_turn = tool_definitions.clone();
222
223            if let Some(ref ctx_mgr) = self.context_manager {
224                ctx_mgr.trim(&mut messages);
225            }
226
227            {
228                let mut ctx = PreLlmCtx {
229                    session_id: session_id.clone(),
230                    messages: messages.clone(),
231                    tools: tools_for_turn.clone(),
232                    event_bus: self.event_bus.clone(),
233                };
234                for mw in &self.middlewares {
235                    mw.on_pre_llm(&mut ctx).await?;
236                }
237                messages = ctx.messages;
238                tools_for_turn = ctx.tools;
239            }
240
241            self.emit_event(AgentEvent::Checkpoint {
242                session_id: session_id.clone(),
243                checkpoint: CheckpointData {
244                    session_id: session_id.clone(),
245                    user_input: user_input_owned.clone(),
246                    step: CheckpointStep::BeforeLlm {
247                        messages: messages.clone(),
248                        tools: tools_for_turn.clone(),
249                    },
250                    turn_count,
251                },
252            });
253
254            let aggregator = self
255                .execute_llm_turn(&session_id, &messages, &tools_for_turn, &mut event_rx, &mut on_event)
256                .await?;
257
258            let (mut full_text, mut is_tool_call, mut tool_calls) = aggregator.into_parts();
259
260            {
261                let mut ctx = PostLlmCtx {
262                    session_id: session_id.clone(),
263                    full_text: full_text.clone(),
264                    is_tool_call,
265                    tool_calls: tool_calls.clone(),
266                    event_bus: self.event_bus.clone(),
267                };
268                for mw in &self.middlewares {
269                    mw.on_post_llm(&mut ctx).await?;
270                }
271                full_text = ctx.full_text;
272                is_tool_call = ctx.is_tool_call;
273                tool_calls = ctx.tool_calls;
274            }
275
276            if full_text.is_empty() && !is_tool_call {
277                continue;
278            }
279
280            if !full_text.is_empty() {
281                let session = self.session_mut_or_err(&session_id)?;
282                session.push_message(MessageRole::Assistant, full_text);
283            }
284
285            if is_tool_call && !tool_calls.is_empty() {
286                self.emit_event(AgentEvent::Checkpoint {
287                    session_id: session_id.clone(),
288                    checkpoint: CheckpointData {
289                        session_id: session_id.clone(),
290                        user_input: user_input_owned.clone(),
291                        step: CheckpointStep::BeforeToolCalls {
292                            tool_calls: tool_calls.clone(),
293                        },
294                        turn_count,
295                    },
296                });
297
298                match self
299                    .handle_tool_calls(
300                        &session_id,
301                        &tool_calls,
302                        &mut event_rx,
303                        &mut on_event,
304                    )
305                    .await
306                {
307                    Ok(ToolCallResult::Continue) => {
308                        self.emit_event(AgentEvent::Checkpoint {
309                            session_id: session_id.clone(),
310                            checkpoint: CheckpointData {
311                                session_id: session_id.clone(),
312                                user_input: user_input_owned.clone(),
313                                step: CheckpointStep::AfterToolCalls {
314                                    tool_calls: tool_calls.clone(),
315                                    results: Vec::new(),
316                                },
317                                turn_count,
318                            },
319                        });
320                        continue;
321                    }
322                    Ok(ToolCallResult::Break) => {
323                        self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
324                        Self::drain_async_events(&mut event_rx, &mut on_event)?;
325                        break;
326                    }
327                    Err(e) => {
328                        if e.is_cancelled() {
329                            return Err(e);
330                        }
331                        let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
332                        let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
333                        match action {
334                            ToolErrorAction::Stop => {
335                                self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
336                                Self::drain_async_events(&mut event_rx, &mut on_event)?;
337                                let session = self.session_or_err(&session_id)?;
338                                let _ = self.session_store.save(session).await;
339                                return Ok(RunOutcome::Failed {
340                                    error: format!("Tool execution failed: {}", e),
341                                });
342                            }
343                            ToolErrorAction::Retry => {
344                                let session = self.session_mut_or_err(&session_id)?;
345                                session.push_message(
346                                    MessageRole::Assistant,
347                                    format!("(Failed to call tools: {})", names.join(", ")),
348                                );
349                                session.push_message(
350                                    MessageRole::User,
351                                    "Tool calls failed. Please simplify your plan and retry.",
352                                );
353                                continue;
354                            }
355                        }
356                    }
357                }
358            }
359
360            self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
361            Self::drain_async_events(&mut event_rx, &mut on_event)?;
362            break;
363        }
364
365        let outcome = if turn_count > max_turns {
366            RunOutcome::Failed { error: format!("Max turns ({max_turns}) reached, stopping forcibly") }
367        } else {
368            RunOutcome::Completed
369        };
370
371        let session = self.session_or_err(&session_id)?;
372        let _ = self.session_store.save(session).await;
373
374        tracing::info!(session_id = session_id.id, turn_count, "agent turn completed");
375        Ok(outcome)
376    }
377
378    pub async fn run_turn_stream(
379        &mut self,
380        session_id: SessionId,
381        user_input: &str,
382    ) -> AgentResult<(Vec<AgentEvent>, RunOutcome)> {
383        let mut events = Vec::new();
384        let outcome = self.run_turn_with_handler(session_id, user_input, |event| {
385            events.push(event);
386            Ok(())
387        })
388        .await?;
389        Ok((events, outcome))
390    }
391
392    pub async fn resume_from_checkpoint<F>(
393        &mut self,
394        checkpoint: CheckpointData,
395        mut on_event: F,
396    ) -> AgentResult<RunOutcome>
397    where
398        F: FnMut(AgentEvent) -> AgentResult<()>,
399    {
400        let session_id = checkpoint.session_id;
401        let user_input = checkpoint.user_input;
402        let turn_count = checkpoint.turn_count;
403
404        tracing::info!(session_id = session_id.id, turn_count, step = ?checkpoint.step, "resuming from checkpoint");
405
406        let mut event_rx = self.subscribe_events();
407        let tool_definitions = self.tools.definitions();
408        let max_turns = self.config.max_turns.unwrap_or(DEFAULT_MAX_TURNS);
409        let mut turn_count = turn_count;
410
411        match checkpoint.step {
412            CheckpointStep::BeforeToolCalls { tool_calls } => {
413                match self
414                    .handle_tool_calls(
415                        &session_id,
416                        &tool_calls,
417                        &mut event_rx,
418                        &mut on_event,
419                    )
420                    .await
421                {
422                    Ok(ToolCallResult::Continue) => {}
423                    Ok(ToolCallResult::Break) => {
424                        self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
425                        Self::drain_async_events(&mut event_rx, &mut on_event)?;
426                        return Ok(RunOutcome::Completed);
427                    }
428                    Err(e) => {
429                        if e.is_cancelled() {
430                            return Err(e);
431                        }
432                        let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
433                        let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
434                        match action {
435                            ToolErrorAction::Stop => {
436                                self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
437                                Self::drain_async_events(&mut event_rx, &mut on_event)?;
438                                let session = self.session_or_err(&session_id)?;
439                                let _ = self.session_store.save(session).await;
440                                return Ok(RunOutcome::Failed {
441                                    error: format!("Tool execution failed: {}", e),
442                                });
443                            }
444                            ToolErrorAction::Retry => {
445                                let session = self.session_mut_or_err(&session_id)?;
446                                session.push_message(
447                                    MessageRole::Assistant,
448                                    format!("(Failed to call tools: {})", names.join(", ")),
449                                );
450                                session.push_message(
451                                    MessageRole::User,
452                                    "Tool calls failed. Please simplify your plan and retry.",
453                                );
454                            }
455                        }
456                    }
457                }
458            }
459            _ => {}
460        }
461
462        loop {
463            turn_count += 1;
464
465            if turn_count > max_turns {
466                self.emit_event(AgentEvent::RunFinished {
467                    session_id: session_id.clone(),
468                });
469                Self::drain_async_events(&mut event_rx, &mut on_event)?;
470                break;
471            }
472
473            Self::drain_async_events(&mut event_rx, &mut on_event)?;
474
475            let mut messages: Vec<_> = self.session_or_err(&session_id)?.chat_messages().to_vec();
476            let mut tools_for_turn = tool_definitions.clone();
477
478            if let Some(ref ctx_mgr) = self.context_manager {
479                ctx_mgr.trim(&mut messages);
480            }
481
482            {
483                let mut ctx = PreLlmCtx {
484                    session_id: session_id.clone(),
485                    messages: messages.clone(),
486                    tools: tools_for_turn.clone(),
487                    event_bus: self.event_bus.clone(),
488                };
489                for mw in &self.middlewares {
490                    mw.on_pre_llm(&mut ctx).await?;
491                }
492                messages = ctx.messages;
493                tools_for_turn = ctx.tools;
494            }
495
496            self.emit_event(AgentEvent::Checkpoint {
497                session_id: session_id.clone(),
498                checkpoint: CheckpointData {
499                    session_id: session_id.clone(),
500                    user_input: user_input.clone(),
501                    step: CheckpointStep::BeforeLlm {
502                        messages: messages.clone(),
503                        tools: tools_for_turn.clone(),
504                    },
505                    turn_count,
506                },
507            });
508
509            let aggregator = self
510                .execute_llm_turn(&session_id, &messages, &tools_for_turn, &mut event_rx, &mut on_event)
511                .await?;
512
513            let (mut full_text, mut is_tool_call, mut tool_calls) = aggregator.into_parts();
514
515            {
516                let mut ctx = PostLlmCtx {
517                    session_id: session_id.clone(),
518                    full_text: full_text.clone(),
519                    is_tool_call,
520                    tool_calls: tool_calls.clone(),
521                    event_bus: self.event_bus.clone(),
522                };
523                for mw in &self.middlewares {
524                    mw.on_post_llm(&mut ctx).await?;
525                }
526                full_text = ctx.full_text;
527                is_tool_call = ctx.is_tool_call;
528                tool_calls = ctx.tool_calls;
529            }
530
531            if full_text.is_empty() && !is_tool_call {
532                continue;
533            }
534
535            if !full_text.is_empty() {
536                let session = self.session_mut_or_err(&session_id)?;
537                session.push_message(MessageRole::Assistant, full_text);
538            }
539
540            if is_tool_call && !tool_calls.is_empty() {
541                self.emit_event(AgentEvent::Checkpoint {
542                    session_id: session_id.clone(),
543                    checkpoint: CheckpointData {
544                        session_id: session_id.clone(),
545                        user_input: user_input.clone(),
546                        step: CheckpointStep::BeforeToolCalls {
547                            tool_calls: tool_calls.clone(),
548                        },
549                        turn_count,
550                    },
551                });
552
553                match self
554                    .handle_tool_calls(
555                        &session_id,
556                        &tool_calls,
557                        &mut event_rx,
558                        &mut on_event,
559                    )
560                    .await
561                {
562                    Ok(ToolCallResult::Continue) => {
563                        self.emit_event(AgentEvent::Checkpoint {
564                            session_id: session_id.clone(),
565                            checkpoint: CheckpointData {
566                                session_id: session_id.clone(),
567                                user_input: user_input.clone(),
568                                step: CheckpointStep::AfterToolCalls {
569                                    tool_calls: tool_calls.clone(),
570                                    results: Vec::new(),
571                                },
572                                turn_count,
573                            },
574                        });
575                        continue;
576                    }
577                    Ok(ToolCallResult::Break) => {
578                        self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
579                        Self::drain_async_events(&mut event_rx, &mut on_event)?;
580                        break;
581                    }
582                    Err(e) => {
583                        if e.is_cancelled() {
584                            return Err(e);
585                        }
586                        let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
587                        let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
588                        match action {
589                            ToolErrorAction::Stop => {
590                                self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
591                                Self::drain_async_events(&mut event_rx, &mut on_event)?;
592                                let session = self.session_or_err(&session_id)?;
593                                let _ = self.session_store.save(session).await;
594                                return Ok(RunOutcome::Failed {
595                                    error: format!("Tool execution failed: {}", e),
596                                });
597                            }
598                            ToolErrorAction::Retry => {
599                                let session = self.session_mut_or_err(&session_id)?;
600                                session.push_message(
601                                    MessageRole::Assistant,
602                                    format!("(Failed to call tools: {})", names.join(", ")),
603                                );
604                                session.push_message(
605                                    MessageRole::User,
606                                    "Tool calls failed. Please simplify your plan and retry.",
607                                );
608                                continue;
609                            }
610                        }
611                    }
612                }
613            }
614
615            self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
616            Self::drain_async_events(&mut event_rx, &mut on_event)?;
617            break;
618        }
619
620        let outcome = if turn_count > max_turns {
621            RunOutcome::Failed { error: format!("Max turns ({max_turns}) reached, stopping forcibly") }
622        } else {
623            RunOutcome::Completed
624        };
625
626        let session = self.session_or_err(&session_id)?;
627        let _ = self.session_store.save(session).await;
628
629        tracing::info!(session_id = session_id.id, turn_count, "agent resume completed");
630        Ok(outcome)
631    }
632}