Skip to main content

oxi_agent/
agent.rs

1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10    CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26    /// Resolve a provider by name, returning an Arc handle.
27    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29    /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30    fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41        oxi_ai::get_provider(name).map(Arc::from)
42    }
43
44    fn resolve_model(&self, model_id: &str) -> Option<Model> {
45        crate::model_id::resolve_model_from_id(model_id)
46    }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50/// Mutable agent internals protected by a read-write lock.
51struct AgentInner {
52    config: AgentConfig,
53    provider: Arc<dyn Provider>,
54    /// Side-dispatch closures invoked for every `AgentEvent` emitted by
55    /// the agent run methods. Used by `oxi-sdk` to bridge observability
56    /// types (Tracer, CostTracker, ...) into the agent loop without
57    /// leaking SDK types into `oxi-agent`.
58    ///
59    /// Lock-mutex rather than `RwLock`: dispatch lists mutate rarely
60    /// (only on `add_observability_dispatch`), but reads happen on every
61    /// event (high frequency), so a `Mutex` with cheap poison-free
62    /// acquisition is the right shape.
63    observability_dispatch: parking_lot::Mutex<Vec<EventDispatchFn>>,
64}
65
66/// Type alias for an observability dispatch handler. Each entry is a
67/// closure registered via [`Agent::add_observability_dispatch`] and
68/// invoked on every emitted `AgentEvent`. Named to keep the
69/// [`AgentInner`] field readable without an inline `dyn` route.
70type EventDispatchFn = Arc<dyn Fn(AgentEvent) + Send + Sync>;
71
72impl Clone for AgentInner {
73    fn clone(&self) -> Self {
74        Self {
75            config: self.config.clone(),
76            provider: Arc::clone(&self.provider),
77            // The dispatch list is *not* cloned: each `Agent` instance has
78            // its own observers. Cloning the AgentInner (rare; happens in
79            // `run_with_channel_inner` when sharing config across loops)
80            // gives the new loop an empty observer set, which is correct:
81            // the *Agent* retains the original dispatch list, and the
82            // temporary inner clone is discarded after the run.
83            observability_dispatch: parking_lot::Mutex::new(Vec::new()),
84        }
85    }
86}
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation via [`continue_with`] and tokio-native
92/// event streaming via [`run_tokio_stream`].
93///
94/// [`continue_with`]: Agent::continue_with
95/// [`run_tokio_stream`]: Agent::run_tokio_stream
96/// Deferred model switch request, stored when the agent is running.
97struct PendingModelSwitch {
98    model_id: String,
99    provider: Arc<dyn Provider>,
100    /// Whether messages need cross-provider transformation.
101    needs_transform: bool,
102    old_api: oxi_ai::Api,
103    new_api: oxi_ai::Api,
104}
105
106/// Agent runtime.
107///
108/// Manages provider, tool registry, state, and compaction, providing an
109/// agentic loop for prompt execution, model switching, tool calls, and fallback.
110///
111/// Supports session continuation, tokio-native event streaming, and deferred
112/// model switching (changes are queued while a loop is running and applied
113/// after it completes).
114#[allow(missing_docs)]
115pub struct Agent {
116    inner: RwLock<AgentInner>,
117    tools: Arc<ToolRegistry>,
118    state: SharedState,
119    compaction_manager: CompactionManager,
120    hooks: parking_lot::RwLock<crate::config::AgentHooks>,
121    /// Guard: true while a run is in progress. Prevents concurrent runs.
122    is_running: Arc<AtomicBool>,
123    /// Provider/model resolver. Uses global functions by default,
124    /// or a custom resolver when created via `new_with_resolver()`.
125    resolver: Arc<dyn ProviderResolver>,
126    /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
127    /// propagated to AgentLoop's `external_stop` during each run.
128    cancel_flag: Arc<AtomicBool>,
129    /// Pending model switch — stored when the agent is running,
130    /// applied after the current loop completes.
131    pending_model_switch: RwLock<Option<PendingModelSwitch>>,
132}
133
134impl Agent {
135    /// Create a new agent with the given provider, config, and tool registry.
136    ///
137    /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
138    /// for model switching. For isolated instances, use [`new_with_resolver`].
139    ///
140    /// [`new_with_resolver`]: Agent::new_with_resolver
141    pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
142        let resolver = Arc::new(GlobalProviderResolver);
143        Self::build_inner(provider, config, tools, resolver)
144    }
145
146    /// Create an agent with a custom provider/model resolver.
147    ///
148    /// This is the preferred constructor for SDK usage where provider
149    /// and model registries must be isolated from global state.
150    pub fn new_with_resolver(
151        provider: Arc<dyn Provider>,
152        config: AgentConfig,
153        tools: Arc<ToolRegistry>,
154        resolver: Arc<dyn ProviderResolver>,
155    ) -> Self {
156        Self::build_inner(provider, config, tools, resolver)
157    }
158
159    /// Create an agent with an empty tool registry.
160    pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
161        Self::new(provider, config, Arc::new(ToolRegistry::new()))
162    }
163
164    /// Get the agent configuration (read guard)
165    fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
166        self.inner.read()
167    }
168
169    /// Get a write guard for the agent inner state
170    fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
171        self.inner.write()
172    }
173
174    /// Get the current model ID
175    pub fn model_id(&self) -> String {
176        self.config().config.model_id.clone()
177    }
178
179    /// Get the agent configuration (full clone)
180    pub fn get_config(&self) -> AgentConfig {
181        self.config().config.clone()
182    }
183
184    /// Internal constructor shared by `new()` and `new_with_resolver()`.
185    fn build_inner(
186        provider: Arc<dyn Provider>,
187        config: AgentConfig,
188        tools: Arc<ToolRegistry>,
189        resolver: Arc<dyn ProviderResolver>,
190    ) -> Self {
191        let mut compaction_manager =
192            CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
193
194        // Pre-initialize the LLM compactor if compaction is enabled
195        if config.compaction_strategy != CompactionStrategy::Disabled {
196            let model = resolver.resolve_model(&config.model_id);
197
198            if let Some(model) = model {
199                let llm_compactor =
200                    Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
201                compaction_manager.set_compactor(llm_compactor);
202            }
203        }
204
205        Self {
206            inner: RwLock::new(AgentInner {
207                config,
208                provider,
209                observability_dispatch: parking_lot::Mutex::new(Vec::new()),
210            }),
211            tools,
212            state: SharedState::new(),
213            compaction_manager,
214            hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
215            is_running: Arc::new(AtomicBool::new(false)),
216            resolver,
217            cancel_flag: Arc::new(AtomicBool::new(false)),
218            pending_model_switch: RwLock::new(None),
219        }
220    }
221
222    /// Get a reference to the provider resolver.
223    pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
224        &self.resolver
225    }
226
227    /// Switch the model used for future LLM calls.
228    ///
229    /// Switch model mid-conversation.
230    ///
231    /// If the agent is currently running, the switch is deferred: the new
232    /// model and provider are stored in `pending_model_switch` and applied
233    /// automatically when the current loop finishes. This ensures the
234    /// running loop completes with a consistent provider/model without
235    /// interruption.
236    ///
237    /// If the agent is idle, the switch takes effect immediately.
238    ///
239    /// If the new model uses a different provider API, the conversation
240    /// history is automatically transformed for cross-provider compatibility
241    /// (e.g. thinking blocks are converted to `<thinking>` tags).
242    ///
243    /// # Arguments
244    /// * `model_id` - New model ID in `provider/model` format
245    /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
246    ///
247    /// # Returns
248    /// `Ok(())` on success, or an error if the model/provider is unknown
249    pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
250        let new_model = self
251            .resolver
252            .resolve_model(model_id)
253            .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
254
255        // Create the new provider via resolver
256        let new_provider = self
257            .resolver
258            .resolve_provider(&new_model.provider)
259            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
260
261        // Detect API change
262        let (old_api, needs_transform) = {
263            let inner = self.config();
264            let old_api = self
265                .resolver
266                .resolve_model(&inner.config.model_id)
267                .map(|m| m.api)
268                .unwrap_or(oxi_ai::Api::AnthropicMessages);
269            (old_api, old_api != new_model.api)
270        };
271
272        // If the agent is currently running, defer the switch.
273        if self.is_running.load(Ordering::SeqCst) {
274            tracing::info!(
275                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
276                model_id
277            );
278            *self.pending_model_switch.write() = Some(PendingModelSwitch {
279                model_id: model_id.to_string(),
280                provider: new_provider,
281                needs_transform,
282                old_api,
283                new_api: new_model.api,
284            });
285            // Update config immediately so model_id() returns the new value,
286            // but leave provider unchanged so the running loop keeps its provider.
287            {
288                let mut inner = self.inner_mut();
289                inner.config.model_id = model_id.to_string();
290                inner.config.api_key = api_key;
291            }
292            return Ok(());
293        }
294
295        // Agent is idle — apply immediately.
296        if needs_transform {
297            let messages = self.state.get_state().messages.clone();
298            let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
299            self.state.update(|s| {
300                s.replace_messages(transformed);
301            });
302        }
303
304        // Update config and provider atomically
305        let mut inner = self.inner_mut();
306        inner.config.model_id = model_id.to_string();
307        inner.config.api_key = api_key;
308        inner.provider = new_provider;
309
310        Ok(())
311    }
312
313    /// Switch the model using a pre-resolved `Model` object.
314    ///
315    /// This is useful when the caller has already looked up the model
316    /// and optionally created the provider.
317    ///
318    /// Like [`switch_model`], if the agent is currently running, the switch
319    /// is deferred until the current loop completes.
320    ///
321    /// [`switch_model`]: Agent::switch_model
322    pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
323        let model_id = format!("{}/{}", model.provider, model.id);
324        let new_provider = self
325            .resolver
326            .resolve_provider(&model.provider)
327            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
328
329        // Detect API change
330        let (old_api, needs_transform) = {
331            let inner = self.config();
332            let old_api = self
333                .resolver
334                .resolve_model(&inner.config.model_id)
335                .map(|m| m.api)
336                .unwrap_or(oxi_ai::Api::AnthropicMessages);
337            (old_api, old_api != model.api)
338        };
339
340        // If the agent is currently running, defer the switch.
341        if self.is_running.load(Ordering::SeqCst) {
342            tracing::info!(
343                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
344                model_id
345            );
346            *self.pending_model_switch.write() = Some(PendingModelSwitch {
347                model_id: model_id.clone(),
348                provider: new_provider,
349                needs_transform,
350                old_api,
351                new_api: model.api,
352            });
353            let mut inner = self.inner_mut();
354            inner.config.model_id = model_id;
355            inner.config.api_key = api_key;
356            return Ok(());
357        }
358
359        // Agent is idle — apply immediately.
360        if needs_transform {
361            let messages = self.state.get_state().messages.clone();
362            let transformed = transform_for_provider(&messages, &old_api, &model.api);
363            self.state.update(|s| {
364                s.replace_messages(transformed);
365            });
366        }
367
368        let mut inner = self.inner_mut();
369        inner.config.model_id = model_id;
370        inner.config.api_key = api_key;
371        inner.provider = new_provider;
372
373        Ok(())
374    }
375
376    /// Refresh only the API key without changing model or provider.
377    /// Useful when the user stores a key after the session was already created.
378    pub fn refresh_api_key(&self, api_key: Option<String>) {
379        let mut inner = self.inner_mut();
380        inner.config.api_key = api_key;
381    }
382
383    /// Get a handle to the tool registry.
384    pub fn tools(&self) -> Arc<ToolRegistry> {
385        Arc::clone(&self.tools)
386    }
387
388    /// Get a snapshot of the current agent state.
389    pub fn state(&self) -> AgentState {
390        self.state.get_state()
391    }
392
393    /// Update agent state in-place. Used by compaction to replace messages.
394    pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
395        self.state.update(f);
396    }
397
398    /// Reset agent state for a new conversation
399    pub fn reset(&self) {
400        self.state.reset();
401    }
402
403    /// Register a tool that the agent can invoke during a run.
404    pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
405        self.tools.register(tool);
406    }
407
408    /// Update the system prompt for future interactions.
409    pub fn set_system_prompt(&self, prompt: String) {
410        self.inner_mut().config.system_prompt = Some(prompt);
411    }
412
413    /// Get the compaction manager
414    pub fn compaction_manager(&self) -> &CompactionManager {
415        &self.compaction_manager
416    }
417    /// Update the compaction strategy for future runs.
418    ///
419    /// The strategy is read fresh from the config at the start of each run
420    /// (see `run_with_channel_inner`), so this takes effect on the next
421    /// agent turn — never mid-run. Pair with `compaction_manager()` for
422    /// manual compaction, which is unaffected by the strategy.
423    pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
424        self.inner.write().config.compaction_strategy = strategy;
425    }
426    /// Get the compaction strategy that will be used on the next run.
427    ///
428    /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
429    /// **not** from the `compaction_manager` field (which retains its
430    /// construction-time strategy). The agent loop reads from config fresh
431    /// each run, so this is the authoritative value.
432    pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
433        self.inner.read().config.compaction_strategy.clone()
434    }
435
436    /// Run the agent with a prompt, collecting all events into a vector.
437    ///
438    /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
439    /// [`AgentEvent`] produced during the run.
440    pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
441        let mut events = Vec::new();
442        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
443        let result = self.run_with_channel(prompt, tx).await;
444        while let Ok(event) = rx.recv() {
445            events.push(event);
446        }
447        result.map(|r| (r, events))
448    }
449
450    /// Run the agent, delivering events through the provided channel.
451    ///
452    /// Delegates to the agent loop which implements the same 2-level agentic
453    /// loop matching pi-mono's architecture:
454    ///
455    /// ```text
456    /// AgentLoop.run_messages()
457    ///   Outer loop (follow-up messages):
458    ///     Inner loop (tool calls + steering):
459    ///       1. Inject pending messages (steering)
460    ///       2. Compaction check
461    ///       3. Stream LLM response (with accumulated partial messages)
462    ///       4. Execute tool calls if any
463    ///       5. Emit turn_end
464    ///       6. Check shouldStopAfterTurn
465    ///       7. Poll steering messages
466    ///     Check follow-up messages
467    ///     Exit
468    /// ```
469    pub async fn run_with_channel(
470        &self,
471        prompt: String,
472        tx: std::sync::mpsc::Sender<AgentEvent>,
473    ) -> Result<Response> {
474        // pi-mono: Agent.prompt() throws if activeRun exists.
475        // Prevent concurrent runs that would corrupt shared state.
476        if self
477            .is_running
478            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
479            .is_err()
480        {
481            return Err(Error::msg("Agent is already running"));
482        }
483
484        // Drop guard ensures is_running is cleared even on panic.
485        struct RunningGuard<'a>(&'a AtomicBool);
486        impl Drop for RunningGuard<'_> {
487            fn drop(&mut self) {
488                self.0.store(false, Ordering::SeqCst);
489            }
490        }
491        let _guard = RunningGuard(&self.is_running);
492        self.reset_cancel();
493
494        self.run_with_channel_inner(prompt, tx).await
495    }
496
497    /// Inner implementation of run_with_channel, called after the running guard is set.
498    async fn run_with_channel_inner(
499        &self,
500        prompt: String,
501        tx: std::sync::mpsc::Sender<AgentEvent>,
502    ) -> Result<Response> {
503        use crate::agent_loop::AgentLoop;
504
505        let (
506            provider,
507            system_prompt,
508            temperature,
509            max_tokens,
510            compaction_strategy,
511            context_window,
512            api_key,
513            workspace_dir,
514        ) = {
515            let inner = self.inner.read();
516            (
517                Arc::clone(&inner.provider) as Arc<dyn Provider>,
518                inner.config.system_prompt.clone(),
519                inner.config.temperature,
520                inner.config.max_tokens,
521                inner.config.compaction_strategy.clone(),
522                inner.config.context_window,
523                inner.config.api_key.clone(),
524                inner.config.workspace_dir.clone(),
525            )
526        }; // release read lock
527
528        // Build AgentLoopConfig from Agent's config
529        let loop_config = crate::agent_loop::config::AgentLoopConfig {
530            model_id: self.model_id(),
531            system_prompt,
532            temperature: temperature.unwrap_or(1.0) as f32,
533            max_tokens: max_tokens.unwrap_or(4096) as u32,
534            tool_execution: crate::config::ToolExecutionMode::Sequential,
535            compaction_strategy,
536            compaction_instruction: None,
537            context_window,
538            session_id: self.config().config.session_id.clone(),
539            transport: None,
540            compact_on_start: false,
541            max_retry_delay_ms: None,
542            auto_retry_enabled: true,
543            auto_retry_max_attempts: 3,
544            auto_retry_base_delay_ms: 1000,
545            api_key,
546            workspace_dir,
547            provider_options: self.config().config.provider_options.clone(),
548            on_compaction: None,
549            ttsr_engine: self.config().config.ttsr_engine.clone(),
550            memory: self.config().config.memory.clone(),
551            todo: self.config().config.todo.clone(),
552            agent_pool: self.config().config.agent_pool.clone(),
553            max_tool_result_bytes: self.config().config.max_tool_result_bytes,
554            subagent_runner: self.config().config.subagent_runner.clone(),
555            subagent_depth: self.config().config.subagent_depth,
556            ..Default::default()
557        };
558
559        // Create AgentLoop. We give it a NEW SharedState and sync back after.
560        // (SharedState is not Clone, so we create a fresh one from current state)
561        let fresh_state = crate::state::SharedState::new();
562        let current = self.state.get_state();
563        fresh_state.update(|s| {
564            *s = current;
565        });
566
567        let mut agent_loop = AgentLoop::new_with_resolver(
568            provider,
569            loop_config,
570            Arc::clone(&self.tools),
571            fresh_state,
572            Arc::clone(&self.resolver),
573        );
574
575        // Add the user prompt to Agent.state() AFTER fresh_state is created.
576        // fresh_state got a copy of the pre-prompt state, so run_loop will
577        // add the prompt to fresh_state independently via initial_prompts.
578        // But persist_session() reads Agent.state() (not fresh_state), so it
579        // needs the user prompt there to write it to the session file.
580        // Sync happens at AgentEnd (after run_loop completes), where
581        // Agent.state is overwritten with fresh_state (which has all messages).
582        self.state.update(|s| {
583            s.messages
584                .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
585                    prompt.clone(),
586                )));
587        });
588
589        // Pre-populate steering/follow-up from hooks
590        {
591            let hooks = self.hooks.read();
592            if let Some(ref get_steering) = hooks.get_steering_messages {
593                for msg_text in get_steering() {
594                    agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
595                }
596            }
597            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
598                for msg_text in get_follow_up() {
599                    agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
600                }
601            }
602
603            // Store hooks on AgentLoop so they can be polled each turn
604            // to pick up new messages injected during the run.
605            if let Some(ref get_steering) = hooks.get_steering_messages {
606                agent_loop.set_steering_hook(Arc::clone(get_steering));
607            }
608            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
609                agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
610            }
611        }
612        let mut al = agent_loop;
613
614        // Wire should_stop_after_turn hook: share AgentLoop's external_stop
615        // Arc with the emit callback. When the hook fires (Ctrl+C detected),
616        // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
617        // AND during streaming (streaming.rs checks external_stop each event).
618        //
619        // Arc<dyn Fn> can be cloned, so we read it without consuming.
620        let maybe_hook = {
621            let hooks_r = self.hooks.read();
622            hooks_r.should_stop_after_turn.clone()
623        };
624        let ext_stop = al.external_stop().clone();
625        let cancel_flag = self.cancel_flag.clone();
626
627        // Share cancel_flag with AgentLoop so the streaming loop can check
628        // it directly in the periodic timer — no emit callback required.
629        // This closes the gap where cancel() was ineffective when the
630        // provider stream produced no events.
631        al.set_cancel_signal(self.cancel_flag.clone());
632
633        // Create emit callback that sends through the channel.
634        // AgentLoop calls this synchronously. UnboundedSender::send() is
635        // non-blocking and never drops events (unlike try_send on bounded).
636        let tx_emit = tx.clone();
637
638        // Snapshot the observability_dispatch list once per run. This avoids
639        // holding an Agent lock on the emit-fn hot path while still letting
640        // SDK consumers register new dispatchers at any time (registers after
641        // this snapshot will fire on the next run).
642        let dispatch_handlers: Vec<EventDispatchFn> =
643            { self.inner.read().observability_dispatch.lock().clone() };
644        tracing::info!("[AGENT] Starting agent run with channel");
645        let result = al
646            .run(prompt.clone(), move |event: AgentEvent| {
647                // Forward event to channel (std::sync::mpsc — send from sync context)
648                tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
649                if let Err(e) = tx_emit.send(event.clone()) {
650                    tracing::error!(
651                        "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
652                        e
653                    );
654                } else {
655                    tracing::info!("[AGENT-EMIT] Successfully sent event");
656                }
657
658                // Propagate cancellation from Agent::cancel() → external_stop.
659                // This runs on every event, ensuring the streaming loop detects
660                // cancellation promptly.
661                if cancel_flag.load(Ordering::SeqCst) {
662                    ext_stop.store(true, Ordering::SeqCst);
663                }
664
665                // Fan out to SDK-side observability handlers (Tracer,
666                // CostTracker, ...). The dispatch list is snapshotted at
667                // run-start so we hold Arc clones, not a lock. This means
668                // handlers added mid-run do not fire until the next run.
669                for handler in dispatch_handlers.iter() {
670                    handler(event.clone());
671                }
672                // Propagate should_stop → external_stop on every event, not
673                // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
674                // so the context contents are irrelevant for non-TurnEnd events.
675                // This ensures streaming.rs detects cancellation immediately
676                // when the user presses Ctrl+C mid-stream.
677                if let Some(ref hook) = maybe_hook {
678                    let ctx = ShouldStopAfterTurnContext {
679                        message: match &event {
680                            AgentEvent::TurnEnd {
681                                assistant_message: oxi_ai::Message::Assistant(a),
682                                ..
683                            } => a.clone(),
684                            _ => oxi_ai::AssistantMessage::new(
685                                oxi_ai::Api::OpenAiCompletions,
686                                "agent",
687                                "agent-model",
688                            ),
689                        },
690                        tool_results: match &event {
691                            AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
692                            _ => Vec::new(),
693                        },
694                        iteration: 0,
695                    };
696                    if hook(&ctx) {
697                        ext_stop.store(true, Ordering::SeqCst);
698                    }
699                }
700            })
701            .await;
702
703        match result {
704            Ok(_events) => {
705                // Sync state back from AgentLoop
706                let loop_state = al.state().get_state();
707                self.state.update(|s| {
708                    *s = loop_state;
709                });
710
711                // Apply any pending model switch that was deferred during the run.
712                // This transforms messages (if cross-provider) and swaps the provider
713                // so the next run uses the new model.
714                self.apply_pending_model_switch();
715
716                // Extract final response text from state
717                let state = self.state.get_state();
718                let final_text = state
719                    .messages
720                    .iter()
721                    .rev()
722                    .find_map(|m| match m {
723                        oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
724                            oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
725                            _ => None,
726                        }),
727                        _ => None,
728                    })
729                    .unwrap_or_default();
730
731                let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
732
733                Ok(Response {
734                    content: final_text,
735                    stop_reason,
736                })
737            }
738            Err(e) => {
739                // Apply pending model switch even on error so the next run
740                // uses the new model.
741                self.apply_pending_model_switch();
742                Err(e)
743            }
744        }
745    }
746
747    // ── Helper methods for the agentic loop ────────────────────────
748
749    /// Set hooks for the agent loop.
750    pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
751        let mut h = self.hooks.write();
752        *h = hooks;
753    }
754
755    /// Register a side-dispatch closure called for every `AgentEvent`
756    /// emitted by `run`, `run_with_channel`, `run_streaming`,
757    /// `run_tokio_stream`, and `continue_with`.
758    ///
759    /// Multiple calls stack: every registered closure is invoked on
760    /// every event. Closures run synchronously on the agent-loop emit
761    /// thread, so they must be cheap and non-blocking. Long work
762    /// should be spawned off (e.g. `tokio::spawn`) by the closure
763    /// itself.
764    ///
765    /// Used by `oxi-sdk` to bridge observability types
766    /// (`Tracer`, `CostTracker`, `AuditLog`, `Authorizer` /
767    /// `AccessGate`) into the runtime without leaking those types
768    /// into `oxi-agent`.
769    ///
770    /// # Example
771    ///
772    /// ```ignore
773    /// agent.add_observability_dispatch(|event| match event {
774    ///     AgentEvent::TurnStart { turn_number } => {
775    ///         // open a span
776    ///     }
777    ///     AgentEvent::Usage { input_tokens, output_tokens } => {
778    ///         // record cost
779    ///     }
780    ///     _ => {}
781    /// });
782    /// ```
783    pub fn add_observability_dispatch(&self, f: impl Fn(AgentEvent) + Send + Sync + 'static) {
784        let guard = self.inner.write();
785        let mut slot = guard.observability_dispatch.lock();
786        slot.push(Arc::new(f));
787    }
788
789    /// Request cancellation of the current agent run.
790    ///
791    /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
792    /// `external_stop` on every event AND polled every ~500ms by the
793    /// streaming loop's periodic check. This ensures cancellation is
794    /// detected quickly even when the provider stream is completely hung
795    /// (no events arriving).
796    pub fn cancel(&self) {
797        self.cancel_flag.store(true, Ordering::SeqCst);
798    }
799
800    /// Reset the cancellation flag before starting a new run.
801    pub fn reset_cancel(&self) {
802        self.cancel_flag.store(false, Ordering::SeqCst);
803    }
804
805    /// Apply any pending model switch that was deferred during a running loop.
806    ///
807    /// Called after `run_with_channel_inner` completes (success or error).
808    /// Transforms messages for cross-provider switches and swaps the provider
809    /// so the next run uses the new model.
810    fn apply_pending_model_switch(&self) {
811        let pending = self.pending_model_switch.write().take();
812        if let Some(pending) = pending {
813            tracing::info!(
814                "[AGENT] Applying deferred model switch to '{}' (transform={})",
815                pending.model_id,
816                pending.needs_transform
817            );
818
819            // Transform messages if cross-provider
820            if pending.needs_transform {
821                let messages = self.state.get_state().messages.clone();
822                let transformed =
823                    transform_for_provider(&messages, &pending.old_api, &pending.new_api);
824                self.state.update(|s| {
825                    s.replace_messages(transformed);
826                });
827            }
828
829            // Swap the provider
830            let mut inner = self.inner_mut();
831            inner.provider = pending.provider;
832            // model_id was already updated in switch_model()
833        }
834    }
835
836    /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
837    ///
838    /// Blocking convenience wrapper suitable for callers that prefer a
839    /// callback-based API over a channel.
840    pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
841    where
842        F: FnMut(AgentEvent) + Send,
843    {
844        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
845        let result = self.run_with_channel(prompt, tx).await;
846        while let Ok(event) = rx.recv() {
847            on_event(event);
848        }
849        result
850    }
851
852    // ── Session persistence ────────────────────────────────────────
853
854    /// Export the agent state as a JSON value.
855    ///
856    /// The serialized state includes conversation messages, token counts,
857    /// iteration progress, and stop reason. Use [`import_state`] to restore.
858    ///
859    /// [`import_state`]: Agent::import_state
860    pub fn export_state(&self) -> Result<serde_json::Value> {
861        let state = self.state.get_state();
862        serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
863    }
864
865    /// Import agent state from a JSON value.
866    ///
867    /// Restores conversation history, token counts, and iteration progress.
868    /// Typically used together with [`export_state`] for session persistence.
869    ///
870    /// [`export_state`]: Agent::export_state
871    pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
872        let state: AgentState = serde_json::from_value(value)
873            .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
874        self.state.update(|s| *s = state);
875        Ok(())
876    }
877
878    // ── Session continuation ───────────────────────────────────────
879
880    /// Continue the current session with a new prompt.
881    ///
882    /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
883    /// preserves the existing conversation state and appends the new prompt.
884    /// This enables multi-turn interactions within the same session.
885    pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
886        let mut events = Vec::new();
887        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
888        let result = self.run_with_channel(prompt, tx).await;
889        while let Ok(event) = rx.recv() {
890            events.push(event);
891        }
892        result.map(|r| (r, events))
893    }
894
895    // ── Tokio-native streaming ─────────────────────────────────────
896
897    /// Run the agent with tokio-native event streaming.
898    ///
899    /// Returns a `tokio::sync::mpsc::Receiver` for events and a
900    /// `JoinHandle` for the response. This is the preferred API for
901    /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
902    ///
903    /// # Example
904    ///
905    /// ```ignore
906    /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
907    /// while let Some(event) = rx.recv().await {
908    ///     println!("Event: {:?}", event.type_name());
909    /// }
910    /// let response = handle.await??;
911    /// ```
912    pub async fn run_tokio_stream(
913        &self,
914        prompt: String,
915    ) -> Result<(
916        tokio::sync::mpsc::Receiver<AgentEvent>,
917        tokio::task::JoinHandle<Result<Response>>,
918    )> {
919        let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
920
921        if self
922            .is_running
923            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
924            .is_err()
925        {
926            return Err(Error::msg("Agent is already running"));
927        }
928
929        let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
930
931        let inner = self.inner.read().clone();
932        let tools = Arc::clone(&self.tools);
933        let resolver = Arc::clone(&self.resolver);
934
935        // Build AgentLoopConfig
936        let loop_config = crate::agent_loop::config::AgentLoopConfig {
937            model_id: inner.config.model_id.clone(),
938            system_prompt: inner.config.system_prompt.clone(),
939            temperature: inner.config.temperature.unwrap_or(1.0) as f32,
940            max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
941            tool_execution: crate::config::ToolExecutionMode::Sequential,
942            compaction_strategy: inner.config.compaction_strategy.clone(),
943            compaction_instruction: None,
944            context_window: inner.config.context_window,
945            session_id: inner.config.session_id.clone(),
946            transport: None,
947            compact_on_start: false,
948            max_retry_delay_ms: None,
949            auto_retry_enabled: true,
950            auto_retry_max_attempts: 3,
951            auto_retry_base_delay_ms: 1000,
952            api_key: inner.config.api_key.clone(),
953            workspace_dir: inner.config.workspace_dir.clone(),
954            provider_options: inner.config.provider_options.clone(),
955            on_compaction: None,
956            ttsr_engine: inner.config.ttsr_engine.clone(),
957            max_tool_result_bytes: inner.config.max_tool_result_bytes,
958            subagent_runner: inner.config.subagent_runner.clone(),
959            subagent_depth: inner.config.subagent_depth,
960            ..Default::default()
961        };
962
963        let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
964
965        // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
966        // agent loop so that state mutations inside the spawned task are
967        // visible through self.state() without an explicit sync step.
968        //
969        // Unlike run_with_channel_inner which creates a fresh SharedState
970        // and syncs back on completion, the tokio streaming API cannot
971        // access `self` inside the `'static` spawned task, so we share
972        // the underlying Arc instead.
973        //
974        // Pre-load current state into the shared Arc (in case it was
975        // modified by a previous run that used a different SharedState).
976        let shared_state = self.state.clone();
977
978        let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
979            provider,
980            loop_config,
981            tools,
982            shared_state.clone(),
983            resolver,
984        );
985
986        let maybe_hook = should_stop_hook;
987        let ext_stop = agent_loop.external_stop().clone();
988
989        // Clone the is_running Arc so the spawned task can clear it.
990        let is_running_flag = Arc::clone(&self.is_running);
991
992        // Snapshot the observability_dispatch list before the spawned
993        // task. The future is `'static` and cannot borrow `&self`,
994        // so we take the snapshot at run-start on the regular borrow
995        // stack and move the resulting Arc-clones into the task.
996        let dispatch_handlers: Vec<EventDispatchFn> = {
997            let guard = self.inner.read();
998            guard.observability_dispatch.lock().clone()
999        };
1000
1001        let handle = tokio::task::spawn(async move {
1002            let result = agent_loop
1003                .run(prompt, move |event: AgentEvent| {
1004                    // Forward to tokio channel (non-blocking)
1005                    let _ = tx.try_send(event.clone());
1006
1007                    // Fan out to SDK-side observability handlers
1008                    // (Tracer, CostTracker, ...).
1009                    for handler in dispatch_handlers.iter() {
1010                        handler(event.clone());
1011                    }
1012                    // Propagate should_stop → external_stop on every event,
1013                    // not just TurnEnd. See run_with_channel_inner for rationale.
1014                    if let Some(ref hook) = maybe_hook {
1015                        let ctx = ShouldStopAfterTurnContext {
1016                            message: match &event {
1017                                AgentEvent::TurnEnd {
1018                                    assistant_message: oxi_ai::Message::Assistant(a),
1019                                    ..
1020                                } => a.clone(),
1021                                _ => oxi_ai::AssistantMessage::new(
1022                                    oxi_ai::Api::OpenAiCompletions,
1023                                    "agent",
1024                                    "agent-model",
1025                                ),
1026                            },
1027                            tool_results: match &event {
1028                                AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
1029                                _ => Vec::new(),
1030                            },
1031                            iteration: 0,
1032                        };
1033                        if hook(&ctx) {
1034                            ext_stop.store(true, Ordering::SeqCst);
1035                        }
1036                    }
1037                })
1038                .await;
1039
1040            // Clear the Agent's running flag
1041            is_running_flag.store(false, Ordering::SeqCst);
1042
1043            match result {
1044                Ok(_events) => {
1045                    // State is already shared via the same SharedState Arc,
1046                    // so self.state() will reflect all mutations.
1047                    Ok(Response {
1048                        content: String::new(),
1049                        stop_reason: StopReason::Stop,
1050                    })
1051                }
1052                Err(e) => Err(e),
1053            }
1054        });
1055
1056        Ok((rx, handle))
1057    }
1058}