oxi_agent/agent.rs
1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10 CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26 /// Resolve a provider by name, returning an Arc handle.
27 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29 /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30 fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41 oxi_ai::get_provider(name).map(Arc::from)
42 }
43
44 fn resolve_model(&self, model_id: &str) -> Option<Model> {
45 crate::model_id::resolve_model_from_id(model_id)
46 }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50/// Mutable agent internals protected by a read-write lock.
51struct AgentInner {
52 config: AgentConfig,
53 provider: Arc<dyn Provider>,
54 /// Side-dispatch closures invoked for every `AgentEvent` emitted by
55 /// the agent run methods. Used by `oxi-sdk` to bridge observability
56 /// types (Tracer, CostTracker, ...) into the agent loop without
57 /// leaking SDK types into `oxi-agent`.
58 ///
59 /// Lock-mutex rather than `RwLock`: dispatch lists mutate rarely
60 /// (only on `add_observability_dispatch`), but reads happen on every
61 /// event (high frequency), so a `Mutex` with cheap poison-free
62 /// acquisition is the right shape.
63 observability_dispatch: parking_lot::Mutex<Vec<EventDispatchFn>>,
64}
65
66/// Type alias for an observability dispatch handler. Each entry is a
67/// closure registered via [`Agent::add_observability_dispatch`] and
68/// invoked on every emitted `AgentEvent`. Named to keep the
69/// [`AgentInner`] field readable without an inline `dyn` route.
70type EventDispatchFn = Arc<dyn Fn(AgentEvent) + Send + Sync>;
71
72impl Clone for AgentInner {
73 fn clone(&self) -> Self {
74 Self {
75 config: self.config.clone(),
76 provider: Arc::clone(&self.provider),
77 // The dispatch list is *not* cloned: each `Agent` instance has
78 // its own observers. Cloning the AgentInner (rare; happens in
79 // `run_with_channel_inner` when sharing config across loops)
80 // gives the new loop an empty observer set, which is correct:
81 // the *Agent* retains the original dispatch list, and the
82 // temporary inner clone is discarded after the run.
83 observability_dispatch: parking_lot::Mutex::new(Vec::new()),
84 }
85 }
86}
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation via [`continue_with`] and tokio-native
92/// event streaming via [`run_tokio_stream`].
93///
94/// [`continue_with`]: Agent::continue_with
95/// [`run_tokio_stream`]: Agent::run_tokio_stream
96/// Deferred model switch request, stored when the agent is running.
97struct PendingModelSwitch {
98 model_id: String,
99 provider: Arc<dyn Provider>,
100 /// Whether messages need cross-provider transformation.
101 needs_transform: bool,
102 old_api: oxi_ai::Api,
103 new_api: oxi_ai::Api,
104}
105
106/// Agent runtime.
107///
108/// Manages provider, tool registry, state, and compaction, providing an
109/// agentic loop for prompt execution, model switching, tool calls, and fallback.
110///
111/// Supports session continuation, tokio-native event streaming, and deferred
112/// model switching (changes are queued while a loop is running and applied
113/// after it completes).
114#[allow(missing_docs)]
115pub struct Agent {
116 inner: RwLock<AgentInner>,
117 tools: Arc<ToolRegistry>,
118 state: SharedState,
119 compaction_manager: CompactionManager,
120 hooks: parking_lot::RwLock<crate::config::AgentHooks>,
121 /// Guard: true while a run is in progress. Prevents concurrent runs.
122 is_running: Arc<AtomicBool>,
123 /// Provider/model resolver. Uses global functions by default,
124 /// or a custom resolver when created via `new_with_resolver()`.
125 resolver: Arc<dyn ProviderResolver>,
126 /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
127 /// propagated to AgentLoop's `external_stop` during each run.
128 cancel_flag: Arc<AtomicBool>,
129 /// Pending model switch — stored when the agent is running,
130 /// applied after the current loop completes.
131 pending_model_switch: RwLock<Option<PendingModelSwitch>>,
132}
133
134impl Agent {
135 /// Create a new agent with the given provider, config, and tool registry.
136 ///
137 /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
138 /// for model switching. For isolated instances, use [`new_with_resolver`].
139 ///
140 /// [`new_with_resolver`]: Agent::new_with_resolver
141 pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
142 let resolver = Arc::new(GlobalProviderResolver);
143 Self::build_inner(provider, config, tools, resolver)
144 }
145
146 /// Create an agent with a custom provider/model resolver.
147 ///
148 /// This is the preferred constructor for SDK usage where provider
149 /// and model registries must be isolated from global state.
150 pub fn new_with_resolver(
151 provider: Arc<dyn Provider>,
152 config: AgentConfig,
153 tools: Arc<ToolRegistry>,
154 resolver: Arc<dyn ProviderResolver>,
155 ) -> Self {
156 Self::build_inner(provider, config, tools, resolver)
157 }
158
159 /// Create an agent with an empty tool registry.
160 pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
161 Self::new(provider, config, Arc::new(ToolRegistry::new()))
162 }
163
164 /// Get the agent configuration (read guard)
165 fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
166 self.inner.read()
167 }
168
169 /// Get a write guard for the agent inner state
170 fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
171 self.inner.write()
172 }
173
174 /// Get the current model ID
175 pub fn model_id(&self) -> String {
176 self.config().config.model_id.clone()
177 }
178
179 /// Get the agent configuration (full clone)
180 pub fn get_config(&self) -> AgentConfig {
181 self.config().config.clone()
182 }
183
184 /// Internal constructor shared by `new()` and `new_with_resolver()`.
185 fn build_inner(
186 provider: Arc<dyn Provider>,
187 config: AgentConfig,
188 tools: Arc<ToolRegistry>,
189 resolver: Arc<dyn ProviderResolver>,
190 ) -> Self {
191 let mut compaction_manager =
192 CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
193
194 // Pre-initialize the LLM compactor if compaction is enabled
195 if config.compaction_strategy != CompactionStrategy::Disabled {
196 let model = resolver.resolve_model(&config.model_id);
197
198 if let Some(model) = model {
199 let llm_compactor =
200 Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
201 compaction_manager.set_compactor(llm_compactor);
202 }
203 }
204
205 Self {
206 inner: RwLock::new(AgentInner {
207 config,
208 provider,
209 observability_dispatch: parking_lot::Mutex::new(Vec::new()),
210 }),
211 tools,
212 state: SharedState::new(),
213 compaction_manager,
214 hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
215 is_running: Arc::new(AtomicBool::new(false)),
216 resolver,
217 cancel_flag: Arc::new(AtomicBool::new(false)),
218 pending_model_switch: RwLock::new(None),
219 }
220 }
221
222 /// Get a reference to the provider resolver.
223 pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
224 &self.resolver
225 }
226
227 /// Switch the model used for future LLM calls.
228 ///
229 /// Switch model mid-conversation.
230 ///
231 /// If the agent is currently running, the switch is deferred: the new
232 /// model and provider are stored in `pending_model_switch` and applied
233 /// automatically when the current loop finishes. This ensures the
234 /// running loop completes with a consistent provider/model without
235 /// interruption.
236 ///
237 /// If the agent is idle, the switch takes effect immediately.
238 ///
239 /// If the new model uses a different provider API, the conversation
240 /// history is automatically transformed for cross-provider compatibility
241 /// (e.g. thinking blocks are converted to `<thinking>` tags).
242 ///
243 /// # Arguments
244 /// * `model_id` - New model ID in `provider/model` format
245 /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
246 ///
247 /// # Returns
248 /// `Ok(())` on success, or an error if the model/provider is unknown
249 pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
250 let new_model = self
251 .resolver
252 .resolve_model(model_id)
253 .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
254
255 // Create the new provider via resolver
256 let new_provider = self
257 .resolver
258 .resolve_provider(&new_model.provider)
259 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
260
261 // Detect API change
262 let (old_api, needs_transform) = {
263 let inner = self.config();
264 let old_api = self
265 .resolver
266 .resolve_model(&inner.config.model_id)
267 .map(|m| m.api)
268 .unwrap_or(oxi_ai::Api::AnthropicMessages);
269 (old_api, old_api != new_model.api)
270 };
271
272 // If the agent is currently running, defer the switch.
273 if self.is_running.load(Ordering::SeqCst) {
274 tracing::info!(
275 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
276 model_id
277 );
278 *self.pending_model_switch.write() = Some(PendingModelSwitch {
279 model_id: model_id.to_string(),
280 provider: new_provider,
281 needs_transform,
282 old_api,
283 new_api: new_model.api,
284 });
285 // Update config immediately so model_id() returns the new value,
286 // but leave provider unchanged so the running loop keeps its provider.
287 {
288 let mut inner = self.inner_mut();
289 inner.config.model_id = model_id.to_string();
290 inner.config.api_key = api_key;
291 }
292 return Ok(());
293 }
294
295 // Agent is idle — apply immediately.
296 if needs_transform {
297 let messages = self.state.get_state().messages.clone();
298 let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
299 self.state.update(|s| {
300 s.replace_messages(transformed);
301 });
302 }
303
304 // Update config and provider atomically
305 let mut inner = self.inner_mut();
306 inner.config.model_id = model_id.to_string();
307 inner.config.api_key = api_key;
308 inner.provider = new_provider;
309
310 Ok(())
311 }
312
313 /// Switch the model using a pre-resolved `Model` object.
314 ///
315 /// This is useful when the caller has already looked up the model
316 /// and optionally created the provider.
317 ///
318 /// Like [`switch_model`], if the agent is currently running, the switch
319 /// is deferred until the current loop completes.
320 ///
321 /// [`switch_model`]: Agent::switch_model
322 pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
323 let model_id = format!("{}/{}", model.provider, model.id);
324 let new_provider = self
325 .resolver
326 .resolve_provider(&model.provider)
327 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
328
329 // Detect API change
330 let (old_api, needs_transform) = {
331 let inner = self.config();
332 let old_api = self
333 .resolver
334 .resolve_model(&inner.config.model_id)
335 .map(|m| m.api)
336 .unwrap_or(oxi_ai::Api::AnthropicMessages);
337 (old_api, old_api != model.api)
338 };
339
340 // If the agent is currently running, defer the switch.
341 if self.is_running.load(Ordering::SeqCst) {
342 tracing::info!(
343 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
344 model_id
345 );
346 *self.pending_model_switch.write() = Some(PendingModelSwitch {
347 model_id: model_id.clone(),
348 provider: new_provider,
349 needs_transform,
350 old_api,
351 new_api: model.api,
352 });
353 let mut inner = self.inner_mut();
354 inner.config.model_id = model_id;
355 inner.config.api_key = api_key;
356 return Ok(());
357 }
358
359 // Agent is idle — apply immediately.
360 if needs_transform {
361 let messages = self.state.get_state().messages.clone();
362 let transformed = transform_for_provider(&messages, &old_api, &model.api);
363 self.state.update(|s| {
364 s.replace_messages(transformed);
365 });
366 }
367
368 let mut inner = self.inner_mut();
369 inner.config.model_id = model_id;
370 inner.config.api_key = api_key;
371 inner.provider = new_provider;
372
373 Ok(())
374 }
375
376 /// Refresh only the API key without changing model or provider.
377 /// Useful when the user stores a key after the session was already created.
378 pub fn refresh_api_key(&self, api_key: Option<String>) {
379 let mut inner = self.inner_mut();
380 inner.config.api_key = api_key;
381 }
382
383 /// Get a handle to the tool registry.
384 pub fn tools(&self) -> Arc<ToolRegistry> {
385 Arc::clone(&self.tools)
386 }
387
388 /// Get a snapshot of the current agent state.
389 pub fn state(&self) -> AgentState {
390 self.state.get_state()
391 }
392
393 /// Update agent state in-place. Used by compaction to replace messages.
394 pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
395 self.state.update(f);
396 }
397
398 /// Reset agent state for a new conversation
399 pub fn reset(&self) {
400 self.state.reset();
401 }
402
403 /// Register a tool that the agent can invoke during a run.
404 pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
405 self.tools.register(tool);
406 }
407
408 /// Update the system prompt for future interactions.
409 pub fn set_system_prompt(&self, prompt: String) {
410 self.inner_mut().config.system_prompt = Some(prompt);
411 }
412
413 /// Get the compaction manager
414 pub fn compaction_manager(&self) -> &CompactionManager {
415 &self.compaction_manager
416 }
417 /// Update the compaction strategy for future runs.
418 ///
419 /// The strategy is read fresh from the config at the start of each run
420 /// (see `run_with_channel_inner`), so this takes effect on the next
421 /// agent turn — never mid-run. Pair with `compaction_manager()` for
422 /// manual compaction, which is unaffected by the strategy.
423 pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
424 self.inner.write().config.compaction_strategy = strategy;
425 }
426 /// Get the compaction strategy that will be used on the next run.
427 ///
428 /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
429 /// **not** from the `compaction_manager` field (which retains its
430 /// construction-time strategy). The agent loop reads from config fresh
431 /// each run, so this is the authoritative value.
432 pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
433 self.inner.read().config.compaction_strategy.clone()
434 }
435
436 /// Run the agent with a prompt, collecting all events into a vector.
437 ///
438 /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
439 /// [`AgentEvent`] produced during the run.
440 pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
441 let mut events = Vec::new();
442 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
443 let result = self.run_with_channel(prompt, tx).await;
444 while let Ok(event) = rx.recv() {
445 events.push(event);
446 }
447 result.map(|r| (r, events))
448 }
449
450 /// Run the agent, delivering events through the provided channel.
451 ///
452 /// Delegates to the agent loop which implements the same 2-level agentic
453 /// loop matching pi-mono's architecture:
454 ///
455 /// ```text
456 /// AgentLoop.run_messages()
457 /// Outer loop (follow-up messages):
458 /// Inner loop (tool calls + steering):
459 /// 1. Inject pending messages (steering)
460 /// 2. Compaction check
461 /// 3. Stream LLM response (with accumulated partial messages)
462 /// 4. Execute tool calls if any
463 /// 5. Emit turn_end
464 /// 6. Check shouldStopAfterTurn
465 /// 7. Poll steering messages
466 /// Check follow-up messages
467 /// Exit
468 /// ```
469 pub async fn run_with_channel(
470 &self,
471 prompt: String,
472 tx: std::sync::mpsc::Sender<AgentEvent>,
473 ) -> Result<Response> {
474 // pi-mono: Agent.prompt() throws if activeRun exists.
475 // Prevent concurrent runs that would corrupt shared state.
476 if self
477 .is_running
478 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
479 .is_err()
480 {
481 return Err(Error::msg("Agent is already running"));
482 }
483
484 // Drop guard ensures is_running is cleared even on panic.
485 struct RunningGuard<'a>(&'a AtomicBool);
486 impl Drop for RunningGuard<'_> {
487 fn drop(&mut self) {
488 self.0.store(false, Ordering::SeqCst);
489 }
490 }
491 let _guard = RunningGuard(&self.is_running);
492 self.reset_cancel();
493
494 self.run_with_channel_inner(prompt, tx).await
495 }
496
497 /// Inner implementation of run_with_channel, called after the running guard is set.
498 async fn run_with_channel_inner(
499 &self,
500 prompt: String,
501 tx: std::sync::mpsc::Sender<AgentEvent>,
502 ) -> Result<Response> {
503 use crate::agent_loop::AgentLoop;
504
505 let (
506 provider,
507 system_prompt,
508 temperature,
509 max_tokens,
510 compaction_strategy,
511 context_window,
512 api_key,
513 workspace_dir,
514 ) = {
515 let inner = self.inner.read();
516 (
517 Arc::clone(&inner.provider) as Arc<dyn Provider>,
518 inner.config.system_prompt.clone(),
519 inner.config.temperature,
520 inner.config.max_tokens,
521 inner.config.compaction_strategy.clone(),
522 inner.config.context_window,
523 inner.config.api_key.clone(),
524 inner.config.workspace_dir.clone(),
525 )
526 }; // release read lock
527
528 // Build AgentLoopConfig from Agent's config
529 let loop_config = crate::agent_loop::config::AgentLoopConfig {
530 model_id: self.model_id(),
531 system_prompt,
532 temperature: temperature.unwrap_or(1.0) as f32,
533 max_tokens: max_tokens.unwrap_or(4096) as u32,
534 tool_execution: crate::config::ToolExecutionMode::Sequential,
535 compaction_strategy,
536 compaction_instruction: None,
537 context_window,
538 session_id: self.config().config.session_id.clone(),
539 transport: None,
540 compact_on_start: false,
541 max_retry_delay_ms: None,
542 auto_retry_enabled: true,
543 auto_retry_max_attempts: 3,
544 auto_retry_base_delay_ms: 1000,
545 api_key,
546 workspace_dir,
547 provider_options: self.config().config.provider_options.clone(),
548 on_compaction: None,
549 ttsr_engine: self.config().config.ttsr_engine.clone(),
550 memory: self.config().config.memory.clone(),
551 todo: self.config().config.todo.clone(),
552 agent_pool: self.config().config.agent_pool.clone(),
553 max_tool_result_bytes: self.config().config.max_tool_result_bytes,
554 subagent_runner: self.config().config.subagent_runner.clone(),
555 subagent_depth: self.config().config.subagent_depth,
556 ..Default::default()
557 };
558
559 // Create AgentLoop. We give it a NEW SharedState and sync back after.
560 // (SharedState is not Clone, so we create a fresh one from current state)
561 let fresh_state = crate::state::SharedState::new();
562 let current = self.state.get_state();
563 fresh_state.update(|s| {
564 *s = current;
565 });
566
567 let mut agent_loop = AgentLoop::new_with_resolver(
568 provider,
569 loop_config,
570 Arc::clone(&self.tools),
571 fresh_state,
572 Arc::clone(&self.resolver),
573 );
574
575 // Add the user prompt to Agent.state() AFTER fresh_state is created.
576 // fresh_state got a copy of the pre-prompt state, so run_loop will
577 // add the prompt to fresh_state independently via initial_prompts.
578 // But persist_session() reads Agent.state() (not fresh_state), so it
579 // needs the user prompt there to write it to the session file.
580 // Sync happens at AgentEnd (after run_loop completes), where
581 // Agent.state is overwritten with fresh_state (which has all messages).
582 self.state.update(|s| {
583 s.messages
584 .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
585 prompt.clone(),
586 )));
587 });
588
589 // Pre-populate steering/follow-up from hooks
590 {
591 let hooks = self.hooks.read();
592 if let Some(ref get_steering) = hooks.get_steering_messages {
593 for msg_text in get_steering() {
594 agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
595 }
596 }
597 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
598 for msg_text in get_follow_up() {
599 agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
600 }
601 }
602
603 // Store hooks on AgentLoop so they can be polled each turn
604 // to pick up new messages injected during the run.
605 if let Some(ref get_steering) = hooks.get_steering_messages {
606 agent_loop.set_steering_hook(Arc::clone(get_steering));
607 }
608 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
609 agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
610 }
611 }
612 let mut al = agent_loop;
613
614 // Wire should_stop_after_turn hook: share AgentLoop's external_stop
615 // Arc with the emit callback. When the hook fires (Ctrl+C detected),
616 // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
617 // AND during streaming (streaming.rs checks external_stop each event).
618 //
619 // Arc<dyn Fn> can be cloned, so we read it without consuming.
620 let maybe_hook = {
621 let hooks_r = self.hooks.read();
622 hooks_r.should_stop_after_turn.clone()
623 };
624 let ext_stop = al.external_stop().clone();
625 let cancel_flag = self.cancel_flag.clone();
626
627 // Share cancel_flag with AgentLoop so the streaming loop can check
628 // it directly in the periodic timer — no emit callback required.
629 // This closes the gap where cancel() was ineffective when the
630 // provider stream produced no events.
631 al.set_cancel_signal(self.cancel_flag.clone());
632
633 // Create emit callback that sends through the channel.
634 // AgentLoop calls this synchronously. UnboundedSender::send() is
635 // non-blocking and never drops events (unlike try_send on bounded).
636 let tx_emit = tx.clone();
637
638 // Snapshot the observability_dispatch list once per run. This avoids
639 // holding an Agent lock on the emit-fn hot path while still letting
640 // SDK consumers register new dispatchers at any time (registers after
641 // this snapshot will fire on the next run).
642 let dispatch_handlers: Vec<EventDispatchFn> =
643 { self.inner.read().observability_dispatch.lock().clone() };
644 tracing::info!("[AGENT] Starting agent run with channel");
645 let result = al
646 .run(prompt.clone(), move |event: AgentEvent| {
647 // Forward event to channel (std::sync::mpsc — send from sync context)
648 tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
649 if let Err(e) = tx_emit.send(event.clone()) {
650 tracing::error!(
651 "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
652 e
653 );
654 } else {
655 tracing::info!("[AGENT-EMIT] Successfully sent event");
656 }
657
658 // Propagate cancellation from Agent::cancel() → external_stop.
659 // This runs on every event, ensuring the streaming loop detects
660 // cancellation promptly.
661 if cancel_flag.load(Ordering::SeqCst) {
662 ext_stop.store(true, Ordering::SeqCst);
663 }
664
665 // Fan out to SDK-side observability handlers (Tracer,
666 // CostTracker, ...). The dispatch list is snapshotted at
667 // run-start so we hold Arc clones, not a lock. This means
668 // handlers added mid-run do not fire until the next run.
669 for handler in dispatch_handlers.iter() {
670 handler(event.clone());
671 }
672 // Propagate should_stop → external_stop on every event, not
673 // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
674 // so the context contents are irrelevant for non-TurnEnd events.
675 // This ensures streaming.rs detects cancellation immediately
676 // when the user presses Ctrl+C mid-stream.
677 if let Some(ref hook) = maybe_hook {
678 let ctx = ShouldStopAfterTurnContext {
679 message: match &event {
680 AgentEvent::TurnEnd {
681 assistant_message: oxi_ai::Message::Assistant(a),
682 ..
683 } => a.clone(),
684 _ => oxi_ai::AssistantMessage::new(
685 oxi_ai::Api::OpenAiCompletions,
686 "agent",
687 "agent-model",
688 ),
689 },
690 tool_results: match &event {
691 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
692 _ => Vec::new(),
693 },
694 iteration: 0,
695 };
696 if hook(&ctx) {
697 ext_stop.store(true, Ordering::SeqCst);
698 }
699 }
700 })
701 .await;
702
703 match result {
704 Ok(_events) => {
705 // Sync state back from AgentLoop
706 let loop_state = al.state().get_state();
707 self.state.update(|s| {
708 *s = loop_state;
709 });
710
711 // Apply any pending model switch that was deferred during the run.
712 // This transforms messages (if cross-provider) and swaps the provider
713 // so the next run uses the new model.
714 self.apply_pending_model_switch();
715
716 // Extract final response text from state
717 let state = self.state.get_state();
718 let final_text = state
719 .messages
720 .iter()
721 .rev()
722 .find_map(|m| match m {
723 oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
724 oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
725 _ => None,
726 }),
727 _ => None,
728 })
729 .unwrap_or_default();
730
731 let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
732
733 Ok(Response {
734 content: final_text,
735 stop_reason,
736 })
737 }
738 Err(e) => {
739 // Apply pending model switch even on error so the next run
740 // uses the new model.
741 self.apply_pending_model_switch();
742 Err(e)
743 }
744 }
745 }
746
747 // ── Helper methods for the agentic loop ────────────────────────
748
749 /// Set hooks for the agent loop.
750 pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
751 let mut h = self.hooks.write();
752 *h = hooks;
753 }
754
755 /// Register a side-dispatch closure called for every `AgentEvent`
756 /// emitted by `run`, `run_with_channel`, `run_streaming`,
757 /// `run_tokio_stream`, and `continue_with`.
758 ///
759 /// Multiple calls stack: every registered closure is invoked on
760 /// every event. Closures run synchronously on the agent-loop emit
761 /// thread, so they must be cheap and non-blocking. Long work
762 /// should be spawned off (e.g. `tokio::spawn`) by the closure
763 /// itself.
764 ///
765 /// Used by `oxi-sdk` to bridge observability types
766 /// (`Tracer`, `CostTracker`, `AuditLog`, `Authorizer` /
767 /// `AccessGate`) into the runtime without leaking those types
768 /// into `oxi-agent`.
769 ///
770 /// # Example
771 ///
772 /// ```ignore
773 /// agent.add_observability_dispatch(|event| match event {
774 /// AgentEvent::TurnStart { turn_number } => {
775 /// // open a span
776 /// }
777 /// AgentEvent::Usage { input_tokens, output_tokens } => {
778 /// // record cost
779 /// }
780 /// _ => {}
781 /// });
782 /// ```
783 pub fn add_observability_dispatch(&self, f: impl Fn(AgentEvent) + Send + Sync + 'static) {
784 let guard = self.inner.write();
785 let mut slot = guard.observability_dispatch.lock();
786 slot.push(Arc::new(f));
787 }
788
789 /// Request cancellation of the current agent run.
790 ///
791 /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
792 /// `external_stop` on every event AND polled every ~500ms by the
793 /// streaming loop's periodic check. This ensures cancellation is
794 /// detected quickly even when the provider stream is completely hung
795 /// (no events arriving).
796 pub fn cancel(&self) {
797 self.cancel_flag.store(true, Ordering::SeqCst);
798 }
799
800 /// Reset the cancellation flag before starting a new run.
801 pub fn reset_cancel(&self) {
802 self.cancel_flag.store(false, Ordering::SeqCst);
803 }
804
805 /// Apply any pending model switch that was deferred during a running loop.
806 ///
807 /// Called after `run_with_channel_inner` completes (success or error).
808 /// Transforms messages for cross-provider switches and swaps the provider
809 /// so the next run uses the new model.
810 fn apply_pending_model_switch(&self) {
811 let pending = self.pending_model_switch.write().take();
812 if let Some(pending) = pending {
813 tracing::info!(
814 "[AGENT] Applying deferred model switch to '{}' (transform={})",
815 pending.model_id,
816 pending.needs_transform
817 );
818
819 // Transform messages if cross-provider
820 if pending.needs_transform {
821 let messages = self.state.get_state().messages.clone();
822 let transformed =
823 transform_for_provider(&messages, &pending.old_api, &pending.new_api);
824 self.state.update(|s| {
825 s.replace_messages(transformed);
826 });
827 }
828
829 // Swap the provider
830 let mut inner = self.inner_mut();
831 inner.provider = pending.provider;
832 // model_id was already updated in switch_model()
833 }
834 }
835
836 /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
837 ///
838 /// Blocking convenience wrapper suitable for callers that prefer a
839 /// callback-based API over a channel.
840 pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
841 where
842 F: FnMut(AgentEvent) + Send,
843 {
844 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
845 let result = self.run_with_channel(prompt, tx).await;
846 while let Ok(event) = rx.recv() {
847 on_event(event);
848 }
849 result
850 }
851
852 // ── Session persistence ────────────────────────────────────────
853
854 /// Export the agent state as a JSON value.
855 ///
856 /// The serialized state includes conversation messages, token counts,
857 /// iteration progress, and stop reason. Use [`import_state`] to restore.
858 ///
859 /// [`import_state`]: Agent::import_state
860 pub fn export_state(&self) -> Result<serde_json::Value> {
861 let state = self.state.get_state();
862 serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
863 }
864
865 /// Import agent state from a JSON value.
866 ///
867 /// Restores conversation history, token counts, and iteration progress.
868 /// Typically used together with [`export_state`] for session persistence.
869 ///
870 /// [`export_state`]: Agent::export_state
871 pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
872 let state: AgentState = serde_json::from_value(value)
873 .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
874 self.state.update(|s| *s = state);
875 Ok(())
876 }
877
878 // ── Session continuation ───────────────────────────────────────
879
880 /// Continue the current session with a new prompt.
881 ///
882 /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
883 /// preserves the existing conversation state and appends the new prompt.
884 /// This enables multi-turn interactions within the same session.
885 pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
886 let mut events = Vec::new();
887 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
888 let result = self.run_with_channel(prompt, tx).await;
889 while let Ok(event) = rx.recv() {
890 events.push(event);
891 }
892 result.map(|r| (r, events))
893 }
894
895 // ── Tokio-native streaming ─────────────────────────────────────
896
897 /// Run the agent with tokio-native event streaming.
898 ///
899 /// Returns a `tokio::sync::mpsc::Receiver` for events and a
900 /// `JoinHandle` for the response. This is the preferred API for
901 /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
902 ///
903 /// # Example
904 ///
905 /// ```ignore
906 /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
907 /// while let Some(event) = rx.recv().await {
908 /// println!("Event: {:?}", event.type_name());
909 /// }
910 /// let response = handle.await??;
911 /// ```
912 pub async fn run_tokio_stream(
913 &self,
914 prompt: String,
915 ) -> Result<(
916 tokio::sync::mpsc::Receiver<AgentEvent>,
917 tokio::task::JoinHandle<Result<Response>>,
918 )> {
919 let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
920
921 if self
922 .is_running
923 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
924 .is_err()
925 {
926 return Err(Error::msg("Agent is already running"));
927 }
928
929 let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
930
931 let inner = self.inner.read().clone();
932 let tools = Arc::clone(&self.tools);
933 let resolver = Arc::clone(&self.resolver);
934
935 // Build AgentLoopConfig
936 let loop_config = crate::agent_loop::config::AgentLoopConfig {
937 model_id: inner.config.model_id.clone(),
938 system_prompt: inner.config.system_prompt.clone(),
939 temperature: inner.config.temperature.unwrap_or(1.0) as f32,
940 max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
941 tool_execution: crate::config::ToolExecutionMode::Sequential,
942 compaction_strategy: inner.config.compaction_strategy.clone(),
943 compaction_instruction: None,
944 context_window: inner.config.context_window,
945 session_id: inner.config.session_id.clone(),
946 transport: None,
947 compact_on_start: false,
948 max_retry_delay_ms: None,
949 auto_retry_enabled: true,
950 auto_retry_max_attempts: 3,
951 auto_retry_base_delay_ms: 1000,
952 api_key: inner.config.api_key.clone(),
953 workspace_dir: inner.config.workspace_dir.clone(),
954 provider_options: inner.config.provider_options.clone(),
955 on_compaction: None,
956 ttsr_engine: inner.config.ttsr_engine.clone(),
957 max_tool_result_bytes: inner.config.max_tool_result_bytes,
958 subagent_runner: inner.config.subagent_runner.clone(),
959 subagent_depth: inner.config.subagent_depth,
960 ..Default::default()
961 };
962
963 let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
964
965 // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
966 // agent loop so that state mutations inside the spawned task are
967 // visible through self.state() without an explicit sync step.
968 //
969 // Unlike run_with_channel_inner which creates a fresh SharedState
970 // and syncs back on completion, the tokio streaming API cannot
971 // access `self` inside the `'static` spawned task, so we share
972 // the underlying Arc instead.
973 //
974 // Pre-load current state into the shared Arc (in case it was
975 // modified by a previous run that used a different SharedState).
976 let shared_state = self.state.clone();
977
978 let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
979 provider,
980 loop_config,
981 tools,
982 shared_state.clone(),
983 resolver,
984 );
985
986 let maybe_hook = should_stop_hook;
987 let ext_stop = agent_loop.external_stop().clone();
988
989 // Clone the is_running Arc so the spawned task can clear it.
990 let is_running_flag = Arc::clone(&self.is_running);
991
992 // Snapshot the observability_dispatch list before the spawned
993 // task. The future is `'static` and cannot borrow `&self`,
994 // so we take the snapshot at run-start on the regular borrow
995 // stack and move the resulting Arc-clones into the task.
996 let dispatch_handlers: Vec<EventDispatchFn> = {
997 let guard = self.inner.read();
998 guard.observability_dispatch.lock().clone()
999 };
1000
1001 let handle = tokio::task::spawn(async move {
1002 let result = agent_loop
1003 .run(prompt, move |event: AgentEvent| {
1004 // Forward to tokio channel (non-blocking)
1005 let _ = tx.try_send(event.clone());
1006
1007 // Fan out to SDK-side observability handlers
1008 // (Tracer, CostTracker, ...).
1009 for handler in dispatch_handlers.iter() {
1010 handler(event.clone());
1011 }
1012 // Propagate should_stop → external_stop on every event,
1013 // not just TurnEnd. See run_with_channel_inner for rationale.
1014 if let Some(ref hook) = maybe_hook {
1015 let ctx = ShouldStopAfterTurnContext {
1016 message: match &event {
1017 AgentEvent::TurnEnd {
1018 assistant_message: oxi_ai::Message::Assistant(a),
1019 ..
1020 } => a.clone(),
1021 _ => oxi_ai::AssistantMessage::new(
1022 oxi_ai::Api::OpenAiCompletions,
1023 "agent",
1024 "agent-model",
1025 ),
1026 },
1027 tool_results: match &event {
1028 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
1029 _ => Vec::new(),
1030 },
1031 iteration: 0,
1032 };
1033 if hook(&ctx) {
1034 ext_stop.store(true, Ordering::SeqCst);
1035 }
1036 }
1037 })
1038 .await;
1039
1040 // Clear the Agent's running flag
1041 is_running_flag.store(false, Ordering::SeqCst);
1042
1043 match result {
1044 Ok(_events) => {
1045 // State is already shared via the same SharedState Arc,
1046 // so self.state() will reflect all mutations.
1047 Ok(Response {
1048 content: String::new(),
1049 stop_reason: StopReason::Stop,
1050 })
1051 }
1052 Err(e) => Err(e),
1053 }
1054 });
1055
1056 Ok((rx, handle))
1057 }
1058}