1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use crate::llm::LlmClient;
6use crate::skill::{Skill, SkillPrompter};
7use crate::tool::{ToolPolicy, ToolRegistry};
8use crate::types::{AgentConfig, MessageRole, AgentError, CheckpointData, CheckpointStep};
9use tokio::sync::broadcast;
10use tracing::Span;
11
12use crate::types::{AgentResult, AgentEvent, SessionId, RunOutcome};
13use super::approval::ApprovalHandler;
14use super::context::ContextWindowManager;
15use super::middleware::{MiddlewareRef, UserMessageCtx, PreLlmCtx, PostLlmCtx};
16use super::recovery::{ToolErrorAction, ToolErrorRecovery};
17use super::session_store::SessionStore;
18use super::AgentSession;
19
20mod approval_flow;
21mod llm;
22mod tool_exec;
23
24use tool_exec::ToolCallResult;
25
26const DEFAULT_MAX_TURNS: u32 = 50;
27
28pub struct AgentRuntime {
29 pub(crate) client: Arc<dyn LlmClient>,
30 pub(crate) config: AgentConfig,
31 pub(crate) tools: ToolRegistry,
32 pub(crate) approval_handler: Option<Arc<dyn ApprovalHandler>>,
33 pub(crate) tool_policy: Option<Arc<dyn ToolPolicy>>,
34 pub(crate) middlewares: Vec<MiddlewareRef>,
35 pub(crate) event_bus: broadcast::Sender<AgentEvent>,
36 pub(crate) next_session_id: AtomicU64,
37 pub(crate) sessions: HashMap<SessionId, AgentSession>,
38 pub(crate) context_manager: Option<ContextWindowManager>,
39 pub(crate) session_store: Arc<dyn SessionStore>,
40 pub(crate) skills: Vec<Arc<dyn Skill>>,
41 #[allow(dead_code)]
42 pub(crate) skill_prompter: Arc<dyn SkillPrompter>,
43 pub(crate) error_recovery: Arc<dyn ToolErrorRecovery>,
44}
45
46impl AgentRuntime {
47 pub fn create_session(&mut self) -> SessionId {
48 let id = SessionId {
49 id: self.next_session_id.fetch_add(1, Ordering::Relaxed),
50 external_id: None,
51 };
52 let mut session = AgentSession::new(id.clone());
53 if let Some(system_prompt) = self.config.system_prompt.as_deref() {
54 session.push_message(MessageRole::System, system_prompt);
55 }
56 self.sessions.insert(id.clone(), session);
57 id
58 }
59
60 pub async fn restore_session(&mut self, session_id: &SessionId) -> Option<&AgentSession> {
65 if self.sessions.contains_key(session_id) {
66 return self.sessions.get(session_id);
67 }
68 match self.session_store.load(session_id).await {
69 Ok(Some(session)) => {
70 self.sessions.insert(session_id.clone(), session);
71 self.sessions.get(session_id)
72 }
73 _ => None,
74 }
75 }
76
77 pub fn session(&self, session_id: &SessionId) -> Option<&AgentSession> {
78 self.sessions.get(session_id)
79 }
80
81 pub fn tools(&self) -> &ToolRegistry {
82 &self.tools
83 }
84
85 pub fn client(&self) -> &Arc<dyn LlmClient> {
86 &self.client
87 }
88
89 pub fn approval_handler(&self) -> Option<&Arc<dyn ApprovalHandler>> {
90 self.approval_handler.as_ref()
91 }
92
93 pub fn tool_policy(&self) -> Option<&Arc<dyn ToolPolicy>> {
94 self.tool_policy.as_ref()
95 }
96
97 pub fn subscribe_events(&self) -> broadcast::Receiver<AgentEvent> {
98 self.event_bus.subscribe()
99 }
100
101 pub fn session_store(&self) -> &Arc<dyn SessionStore> {
102 &self.session_store
103 }
104
105 pub fn skills(&self) -> &[Arc<dyn Skill>] {
106 &self.skills
107 }
108
109 fn cached_approval(&self, session_id: &SessionId, action_key: &str) -> bool {
110 self.sessions
111 .get(session_id)
112 .is_some_and(|session| session.is_action_allowed(action_key))
113 }
114
115 fn cache_approval(&mut self, session_id: &SessionId, action_key: String) {
116 if let Some(session) = self.sessions.get_mut(session_id) {
117 session.allow_action(action_key);
118 }
119 }
120
121 fn emit_event(&self, event: AgentEvent) {
122 let _ = self.event_bus.send(event);
123 }
124
125 fn session_or_err(&self, session_id: &SessionId) -> AgentResult<&AgentSession> {
126 self.sessions
127 .get(session_id)
128 .ok_or_else(|| AgentError::session_not_found(session_id.id))
129 }
130
131 fn session_mut_or_err(&mut self, session_id: &SessionId) -> AgentResult<&mut AgentSession> {
132 self.sessions
133 .get_mut(session_id)
134 .ok_or_else(|| AgentError::session_not_found(session_id.id))
135 }
136
137 fn drain_async_events<F>(
138 event_rx: &mut broadcast::Receiver<AgentEvent>,
139 on_event: &mut F,
140 ) -> AgentResult<()>
141 where
142 F: FnMut(AgentEvent) -> AgentResult<()>,
143 {
144 loop {
145 match event_rx.try_recv() {
146 Ok(event) => on_event(event)?,
147 Err(broadcast::error::TryRecvError::Empty) => break,
148 Err(broadcast::error::TryRecvError::Lagged(_)) => continue,
149 Err(broadcast::error::TryRecvError::Closed) => break,
150 }
151 }
152 Ok(())
153 }
154
155 pub async fn run_turn_with_handler<F>(
156 &mut self,
157 session_id: SessionId,
158 user_input: &str,
159 mut on_event: F,
160 ) -> AgentResult<RunOutcome>
161 where
162 F: FnMut(AgentEvent) -> AgentResult<()>,
163 {
164 let span = Span::current();
165 let _guard = span.enter();
166 tracing::info!(session_id = session_id.id, user_input = %user_input, "agent turn start");
167 drop(_guard);
168
169 let mut event_rx = self.subscribe_events();
170 let tool_definitions = self.tools.definitions();
171
172 let mut user_input_owned = user_input.to_string();
173
174 {
175 let mut ctx = UserMessageCtx {
176 session_id: session_id.clone(),
177 user_input: user_input_owned.clone(),
178 event_bus: self.event_bus.clone(),
179 };
180 for mw in &self.middlewares {
181 mw.on_user_message(&mut ctx).await?;
182 }
183 user_input_owned = ctx.user_input;
184 }
185
186 {
187 let session = self.session_mut_or_err(&session_id)?;
188 session.push_message(MessageRole::User, &user_input_owned);
189 }
190
191 self.emit_event(AgentEvent::Checkpoint {
192 session_id: session_id.clone(),
193 checkpoint: CheckpointData {
194 session_id: session_id.clone(),
195 user_input: user_input_owned.clone(),
196 step: CheckpointStep::AfterUserInput,
197 turn_count: 0,
198 },
199 });
200
201 let max_turns = self.config.max_turns.unwrap_or(DEFAULT_MAX_TURNS);
202 let mut turn_count: u32 = 0;
203
204 loop {
205 turn_count += 1;
206
207 if turn_count > max_turns {
208 self.emit_event(AgentEvent::RunFinished {
209 session_id: session_id.clone(),
210 });
211 Self::drain_async_events(&mut event_rx, &mut on_event)?;
212 break;
213 }
214
215 Self::drain_async_events(&mut event_rx, &mut on_event)?;
216
217 let turn_span = tracing::info_span!("turn", session_id = session_id.id, turn = turn_count);
218 let _turn_guard = turn_span.enter();
219
220 let mut messages: Vec<_> = self.session_or_err(&session_id)?.chat_messages().to_vec();
221 let mut tools_for_turn = tool_definitions.clone();
222
223 if let Some(ref ctx_mgr) = self.context_manager {
224 ctx_mgr.trim(&mut messages);
225 }
226
227 {
228 let mut ctx = PreLlmCtx {
229 session_id: session_id.clone(),
230 messages: messages.clone(),
231 tools: tools_for_turn.clone(),
232 event_bus: self.event_bus.clone(),
233 };
234 for mw in &self.middlewares {
235 mw.on_pre_llm(&mut ctx).await?;
236 }
237 messages = ctx.messages;
238 tools_for_turn = ctx.tools;
239 }
240
241 self.emit_event(AgentEvent::Checkpoint {
242 session_id: session_id.clone(),
243 checkpoint: CheckpointData {
244 session_id: session_id.clone(),
245 user_input: user_input_owned.clone(),
246 step: CheckpointStep::BeforeLlm {
247 messages: messages.clone(),
248 tools: tools_for_turn.clone(),
249 },
250 turn_count,
251 },
252 });
253
254 let aggregator = self
255 .execute_llm_turn(&session_id, &messages, &tools_for_turn, &mut event_rx, &mut on_event)
256 .await?;
257
258 let (mut full_text, mut is_tool_call, mut tool_calls) = aggregator.into_parts();
259
260 {
261 let mut ctx = PostLlmCtx {
262 session_id: session_id.clone(),
263 full_text: full_text.clone(),
264 is_tool_call,
265 tool_calls: tool_calls.clone(),
266 event_bus: self.event_bus.clone(),
267 };
268 for mw in &self.middlewares {
269 mw.on_post_llm(&mut ctx).await?;
270 }
271 full_text = ctx.full_text;
272 is_tool_call = ctx.is_tool_call;
273 tool_calls = ctx.tool_calls;
274 }
275
276 if full_text.is_empty() && !is_tool_call {
277 continue;
278 }
279
280 if !full_text.is_empty() {
281 let session = self.session_mut_or_err(&session_id)?;
282 session.push_message(MessageRole::Assistant, full_text);
283 }
284
285 if is_tool_call && !tool_calls.is_empty() {
286 self.emit_event(AgentEvent::Checkpoint {
287 session_id: session_id.clone(),
288 checkpoint: CheckpointData {
289 session_id: session_id.clone(),
290 user_input: user_input_owned.clone(),
291 step: CheckpointStep::BeforeToolCalls {
292 tool_calls: tool_calls.clone(),
293 },
294 turn_count,
295 },
296 });
297
298 match self
299 .handle_tool_calls(
300 &session_id,
301 &tool_calls,
302 &mut event_rx,
303 &mut on_event,
304 )
305 .await
306 {
307 Ok(ToolCallResult::Continue) => {
308 self.emit_event(AgentEvent::Checkpoint {
309 session_id: session_id.clone(),
310 checkpoint: CheckpointData {
311 session_id: session_id.clone(),
312 user_input: user_input_owned.clone(),
313 step: CheckpointStep::AfterToolCalls {
314 tool_calls: tool_calls.clone(),
315 results: Vec::new(),
316 },
317 turn_count,
318 },
319 });
320 continue;
321 }
322 Ok(ToolCallResult::Break) => {
323 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
324 Self::drain_async_events(&mut event_rx, &mut on_event)?;
325 break;
326 }
327 Err(e) => {
328 if e.is_cancelled() {
329 return Err(e);
330 }
331 let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
332 let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
333 match action {
334 ToolErrorAction::Stop => {
335 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
336 Self::drain_async_events(&mut event_rx, &mut on_event)?;
337 let session = self.session_or_err(&session_id)?;
338 let _ = self.session_store.save(session).await;
339 return Ok(RunOutcome::Failed {
340 error: format!("Tool execution failed: {}", e),
341 });
342 }
343 ToolErrorAction::Retry => {
344 let session = self.session_mut_or_err(&session_id)?;
345 session.push_message(
346 MessageRole::Assistant,
347 format!("(Failed to call tools: {})", names.join(", ")),
348 );
349 session.push_message(
350 MessageRole::User,
351 "Tool calls failed. Please simplify your plan and retry.",
352 );
353 continue;
354 }
355 }
356 }
357 }
358 }
359
360 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
361 Self::drain_async_events(&mut event_rx, &mut on_event)?;
362 break;
363 }
364
365 let outcome = if turn_count > max_turns {
366 RunOutcome::Failed { error: format!("Max turns ({max_turns}) reached, stopping forcibly") }
367 } else {
368 RunOutcome::Completed
369 };
370
371 let session = self.session_or_err(&session_id)?;
372 let _ = self.session_store.save(session).await;
373
374 tracing::info!(session_id = session_id.id, turn_count, "agent turn completed");
375 Ok(outcome)
376 }
377
378 pub async fn run_turn_stream(
379 &mut self,
380 session_id: SessionId,
381 user_input: &str,
382 ) -> AgentResult<(Vec<AgentEvent>, RunOutcome)> {
383 let mut events = Vec::new();
384 let outcome = self.run_turn_with_handler(session_id, user_input, |event| {
385 events.push(event);
386 Ok(())
387 })
388 .await?;
389 Ok((events, outcome))
390 }
391
392 pub async fn resume_from_checkpoint<F>(
393 &mut self,
394 checkpoint: CheckpointData,
395 mut on_event: F,
396 ) -> AgentResult<RunOutcome>
397 where
398 F: FnMut(AgentEvent) -> AgentResult<()>,
399 {
400 let session_id = checkpoint.session_id;
401 let user_input = checkpoint.user_input;
402 let turn_count = checkpoint.turn_count;
403
404 tracing::info!(session_id = session_id.id, turn_count, step = ?checkpoint.step, "resuming from checkpoint");
405
406 let mut event_rx = self.subscribe_events();
407 let tool_definitions = self.tools.definitions();
408 let max_turns = self.config.max_turns.unwrap_or(DEFAULT_MAX_TURNS);
409 let mut turn_count = turn_count;
410
411 match checkpoint.step {
412 CheckpointStep::BeforeToolCalls { tool_calls } => {
413 match self
414 .handle_tool_calls(
415 &session_id,
416 &tool_calls,
417 &mut event_rx,
418 &mut on_event,
419 )
420 .await
421 {
422 Ok(ToolCallResult::Continue) => {}
423 Ok(ToolCallResult::Break) => {
424 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
425 Self::drain_async_events(&mut event_rx, &mut on_event)?;
426 return Ok(RunOutcome::Completed);
427 }
428 Err(e) => {
429 if e.is_cancelled() {
430 return Err(e);
431 }
432 let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
433 let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
434 match action {
435 ToolErrorAction::Stop => {
436 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
437 Self::drain_async_events(&mut event_rx, &mut on_event)?;
438 let session = self.session_or_err(&session_id)?;
439 let _ = self.session_store.save(session).await;
440 return Ok(RunOutcome::Failed {
441 error: format!("Tool execution failed: {}", e),
442 });
443 }
444 ToolErrorAction::Retry => {
445 let session = self.session_mut_or_err(&session_id)?;
446 session.push_message(
447 MessageRole::Assistant,
448 format!("(Failed to call tools: {})", names.join(", ")),
449 );
450 session.push_message(
451 MessageRole::User,
452 "Tool calls failed. Please simplify your plan and retry.",
453 );
454 }
455 }
456 }
457 }
458 }
459 _ => {}
460 }
461
462 loop {
463 turn_count += 1;
464
465 if turn_count > max_turns {
466 self.emit_event(AgentEvent::RunFinished {
467 session_id: session_id.clone(),
468 });
469 Self::drain_async_events(&mut event_rx, &mut on_event)?;
470 break;
471 }
472
473 Self::drain_async_events(&mut event_rx, &mut on_event)?;
474
475 let mut messages: Vec<_> = self.session_or_err(&session_id)?.chat_messages().to_vec();
476 let mut tools_for_turn = tool_definitions.clone();
477
478 if let Some(ref ctx_mgr) = self.context_manager {
479 ctx_mgr.trim(&mut messages);
480 }
481
482 {
483 let mut ctx = PreLlmCtx {
484 session_id: session_id.clone(),
485 messages: messages.clone(),
486 tools: tools_for_turn.clone(),
487 event_bus: self.event_bus.clone(),
488 };
489 for mw in &self.middlewares {
490 mw.on_pre_llm(&mut ctx).await?;
491 }
492 messages = ctx.messages;
493 tools_for_turn = ctx.tools;
494 }
495
496 self.emit_event(AgentEvent::Checkpoint {
497 session_id: session_id.clone(),
498 checkpoint: CheckpointData {
499 session_id: session_id.clone(),
500 user_input: user_input.clone(),
501 step: CheckpointStep::BeforeLlm {
502 messages: messages.clone(),
503 tools: tools_for_turn.clone(),
504 },
505 turn_count,
506 },
507 });
508
509 let aggregator = self
510 .execute_llm_turn(&session_id, &messages, &tools_for_turn, &mut event_rx, &mut on_event)
511 .await?;
512
513 let (mut full_text, mut is_tool_call, mut tool_calls) = aggregator.into_parts();
514
515 {
516 let mut ctx = PostLlmCtx {
517 session_id: session_id.clone(),
518 full_text: full_text.clone(),
519 is_tool_call,
520 tool_calls: tool_calls.clone(),
521 event_bus: self.event_bus.clone(),
522 };
523 for mw in &self.middlewares {
524 mw.on_post_llm(&mut ctx).await?;
525 }
526 full_text = ctx.full_text;
527 is_tool_call = ctx.is_tool_call;
528 tool_calls = ctx.tool_calls;
529 }
530
531 if full_text.is_empty() && !is_tool_call {
532 continue;
533 }
534
535 if !full_text.is_empty() {
536 let session = self.session_mut_or_err(&session_id)?;
537 session.push_message(MessageRole::Assistant, full_text);
538 }
539
540 if is_tool_call && !tool_calls.is_empty() {
541 self.emit_event(AgentEvent::Checkpoint {
542 session_id: session_id.clone(),
543 checkpoint: CheckpointData {
544 session_id: session_id.clone(),
545 user_input: user_input.clone(),
546 step: CheckpointStep::BeforeToolCalls {
547 tool_calls: tool_calls.clone(),
548 },
549 turn_count,
550 },
551 });
552
553 match self
554 .handle_tool_calls(
555 &session_id,
556 &tool_calls,
557 &mut event_rx,
558 &mut on_event,
559 )
560 .await
561 {
562 Ok(ToolCallResult::Continue) => {
563 self.emit_event(AgentEvent::Checkpoint {
564 session_id: session_id.clone(),
565 checkpoint: CheckpointData {
566 session_id: session_id.clone(),
567 user_input: user_input.clone(),
568 step: CheckpointStep::AfterToolCalls {
569 tool_calls: tool_calls.clone(),
570 results: Vec::new(),
571 },
572 turn_count,
573 },
574 });
575 continue;
576 }
577 Ok(ToolCallResult::Break) => {
578 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
579 Self::drain_async_events(&mut event_rx, &mut on_event)?;
580 break;
581 }
582 Err(e) => {
583 if e.is_cancelled() {
584 return Err(e);
585 }
586 let names: Vec<String> = tool_calls.iter().map(|(_, n, _)| n.clone()).collect();
587 let action = self.error_recovery.on_error(&session_id, &names, &e).await?;
588 match action {
589 ToolErrorAction::Stop => {
590 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
591 Self::drain_async_events(&mut event_rx, &mut on_event)?;
592 let session = self.session_or_err(&session_id)?;
593 let _ = self.session_store.save(session).await;
594 return Ok(RunOutcome::Failed {
595 error: format!("Tool execution failed: {}", e),
596 });
597 }
598 ToolErrorAction::Retry => {
599 let session = self.session_mut_or_err(&session_id)?;
600 session.push_message(
601 MessageRole::Assistant,
602 format!("(Failed to call tools: {})", names.join(", ")),
603 );
604 session.push_message(
605 MessageRole::User,
606 "Tool calls failed. Please simplify your plan and retry.",
607 );
608 continue;
609 }
610 }
611 }
612 }
613 }
614
615 self.emit_event(AgentEvent::RunFinished { session_id: session_id.clone() });
616 Self::drain_async_events(&mut event_rx, &mut on_event)?;
617 break;
618 }
619
620 let outcome = if turn_count > max_turns {
621 RunOutcome::Failed { error: format!("Max turns ({max_turns}) reached, stopping forcibly") }
622 } else {
623 RunOutcome::Completed
624 };
625
626 let session = self.session_or_err(&session_id)?;
627 let _ = self.session_store.save(session).await;
628
629 tracing::info!(session_id = session_id.id, turn_count, "agent resume completed");
630 Ok(outcome)
631 }
632}