Skip to main content

oxi_agent/
agent.rs

1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10    CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26    /// Resolve a provider by name, returning an Arc handle.
27    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29    /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30    fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41        oxi_ai::get_provider(name).map(Arc::from)
42    }
43
44    fn resolve_model(&self, model_id: &str) -> Option<Model> {
45        crate::model_id::resolve_model_from_id(model_id)
46    }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50/// Mutable agent internals protected by a read-write lock.
51struct AgentInner {
52    config: AgentConfig,
53    provider: Arc<dyn Provider>,
54    /// Side-dispatch closures invoked for every `AgentEvent` emitted by
55    /// the agent run methods. Used by `oxi-sdk` to bridge observability
56    /// types (Tracer, CostTracker, ...) into the agent loop without
57    /// leaking SDK types into `oxi-agent`.
58    ///
59    /// Lock-mutex rather than `RwLock`: dispatch lists mutate rarely
60    /// (only on `add_observability_dispatch`), but reads happen on every
61    /// event (high frequency), so a `Mutex` with cheap poison-free
62    /// acquisition is the right shape.
63    observability_dispatch: parking_lot::Mutex<Vec<EventDispatchFn>>,
64}
65
66/// Type alias for an observability dispatch handler. Each entry is a
67/// closure registered via [`Agent::add_observability_dispatch`] and
68/// invoked on every emitted `AgentEvent`. Named to keep the
69/// [`AgentInner`] field readable without an inline `dyn` route.
70type EventDispatchFn = Arc<dyn Fn(AgentEvent) + Send + Sync>;
71
72impl Clone for AgentInner {
73    fn clone(&self) -> Self {
74        Self {
75            config: self.config.clone(),
76            provider: Arc::clone(&self.provider),
77            // The dispatch list is *not* cloned: each `Agent` instance has
78            // its own observers. Cloning the AgentInner (rare; happens in
79            // `run_with_channel_inner` when sharing config across loops)
80            // gives the new loop an empty observer set, which is correct:
81            // the *Agent* retains the original dispatch list, and the
82            // temporary inner clone is discarded after the run.
83            observability_dispatch: parking_lot::Mutex::new(Vec::new()),
84        }
85    }
86}
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation via [`continue_with`] and tokio-native
92/// event streaming via [`run_tokio_stream`].
93///
94/// [`continue_with`]: Agent::continue_with
95/// [`run_tokio_stream`]: Agent::run_tokio_stream
96/// Deferred model switch request, stored when the agent is running.
97struct PendingModelSwitch {
98    model_id: String,
99    provider: Arc<dyn Provider>,
100    /// Whether messages need cross-provider transformation.
101    needs_transform: bool,
102    old_api: oxi_ai::Api,
103    new_api: oxi_ai::Api,
104}
105
106/// Agent runtime.
107///
108/// Manages provider, tool registry, state, and compaction, providing an
109/// agentic loop for prompt execution, model switching, tool calls, and fallback.
110///
111/// Supports session continuation, tokio-native event streaming, and deferred
112/// model switching (changes are queued while a loop is running and applied
113/// after it completes).
114#[allow(missing_docs)]
115pub struct Agent {
116    inner: RwLock<AgentInner>,
117    tools: Arc<ToolRegistry>,
118    state: SharedState,
119    compaction_manager: CompactionManager,
120    hooks: parking_lot::RwLock<crate::config::AgentHooks>,
121    /// Guard: true while a run is in progress. Prevents concurrent runs.
122    is_running: Arc<AtomicBool>,
123    /// Provider/model resolver. Uses global functions by default,
124    /// or a custom resolver when created via `new_with_resolver()`.
125    resolver: Arc<dyn ProviderResolver>,
126    /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
127    /// propagated to AgentLoop's `external_stop` during each run.
128    cancel_flag: Arc<AtomicBool>,
129    /// Pending model switch — stored when the agent is running,
130    /// applied after the current loop completes.
131    pending_model_switch: RwLock<Option<PendingModelSwitch>>,
132}
133
134impl Agent {
135    /// Create a new agent with the given provider, config, and tool registry.
136    ///
137    /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
138    /// for model switching. For isolated instances, use [`new_with_resolver`].
139    ///
140    /// [`new_with_resolver`]: Agent::new_with_resolver
141    pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
142        let resolver = Arc::new(GlobalProviderResolver);
143        Self::build_inner(provider, config, tools, resolver)
144    }
145
146    /// Create an agent with a custom provider/model resolver.
147    ///
148    /// This is the preferred constructor for SDK usage where provider
149    /// and model registries must be isolated from global state.
150    pub fn new_with_resolver(
151        provider: Arc<dyn Provider>,
152        config: AgentConfig,
153        tools: Arc<ToolRegistry>,
154        resolver: Arc<dyn ProviderResolver>,
155    ) -> Self {
156        Self::build_inner(provider, config, tools, resolver)
157    }
158
159    /// Create an agent with an empty tool registry.
160    pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
161        Self::new(provider, config, Arc::new(ToolRegistry::new()))
162    }
163
164    /// Get the agent configuration (read guard)
165    fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
166        self.inner.read()
167    }
168
169    /// Get a write guard for the agent inner state
170    fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
171        self.inner.write()
172    }
173
174    /// Get the current model ID
175    pub fn model_id(&self) -> String {
176        self.config().config.model_id.clone()
177    }
178
179    /// Get the agent configuration (full clone)
180    pub fn get_config(&self) -> AgentConfig {
181        self.config().config.clone()
182    }
183
184    /// Internal constructor shared by `new()` and `new_with_resolver()`.
185    fn build_inner(
186        provider: Arc<dyn Provider>,
187        config: AgentConfig,
188        tools: Arc<ToolRegistry>,
189        resolver: Arc<dyn ProviderResolver>,
190    ) -> Self {
191        let mut compaction_manager =
192            CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
193
194        // Pre-initialize the LLM compactor if compaction is enabled
195        if config.compaction_strategy != CompactionStrategy::Disabled {
196            let model = resolver.resolve_model(&config.model_id);
197
198            if let Some(model) = model {
199                let llm_compactor =
200                    Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
201                compaction_manager.set_compactor(llm_compactor);
202            }
203        }
204
205        Self {
206            inner: RwLock::new(AgentInner {
207                config,
208                provider,
209                observability_dispatch: parking_lot::Mutex::new(Vec::new()),
210            }),
211            tools,
212            state: SharedState::new(),
213            compaction_manager,
214            hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
215            is_running: Arc::new(AtomicBool::new(false)),
216            resolver,
217            cancel_flag: Arc::new(AtomicBool::new(false)),
218            pending_model_switch: RwLock::new(None),
219        }
220    }
221
222    /// Get a reference to the provider resolver.
223    pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
224        &self.resolver
225    }
226
227    /// Switch the model used for future LLM calls.
228    ///
229    /// Switch model mid-conversation.
230    ///
231    /// If the agent is currently running, the switch is deferred: the new
232    /// model and provider are stored in `pending_model_switch` and applied
233    /// automatically when the current loop finishes. This ensures the
234    /// running loop completes with a consistent provider/model without
235    /// interruption.
236    ///
237    /// If the agent is idle, the switch takes effect immediately.
238    ///
239    /// If the new model uses a different provider API, the conversation
240    /// history is automatically transformed for cross-provider compatibility
241    /// (e.g. thinking blocks are converted to `<thinking>` tags).
242    ///
243    /// # Arguments
244    /// * `model_id` - New model ID in `provider/model` format
245    /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
246    ///
247    /// # Returns
248    /// `Ok(())` on success, or an error if the model/provider is unknown
249    pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
250        let new_model = self
251            .resolver
252            .resolve_model(model_id)
253            .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
254
255        // Create the new provider via resolver
256        let new_provider = self
257            .resolver
258            .resolve_provider(&new_model.provider)
259            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
260
261        // Detect API change
262        let (old_api, needs_transform) = {
263            let inner = self.config();
264            let old_api = self
265                .resolver
266                .resolve_model(&inner.config.model_id)
267                .map(|m| m.api)
268                .unwrap_or(oxi_ai::Api::AnthropicMessages);
269            (old_api, old_api != new_model.api)
270        };
271
272        // If the agent is currently running, defer the switch.
273        if self.is_running.load(Ordering::SeqCst) {
274            tracing::info!(
275                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
276                model_id
277            );
278            *self.pending_model_switch.write() = Some(PendingModelSwitch {
279                model_id: model_id.to_string(),
280                provider: new_provider,
281                needs_transform,
282                old_api,
283                new_api: new_model.api,
284            });
285            // Update config immediately so model_id() returns the new value,
286            // but leave provider unchanged so the running loop keeps its provider.
287            {
288                let mut inner = self.inner_mut();
289                inner.config.model_id = model_id.to_string();
290                inner.config.api_key = api_key;
291            }
292            return Ok(());
293        }
294
295        // Agent is idle — apply immediately.
296        if needs_transform {
297            let messages = self.state.get_state().messages.clone();
298            let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
299            self.state.update(|s| {
300                s.replace_messages(transformed);
301            });
302        }
303
304        // Update config and provider atomically
305        let mut inner = self.inner_mut();
306        inner.config.model_id = model_id.to_string();
307        inner.config.api_key = api_key;
308        inner.provider = new_provider;
309
310        Ok(())
311    }
312
313    /// Switch the model using a pre-resolved `Model` object.
314    ///
315    /// This is useful when the caller has already looked up the model
316    /// and optionally created the provider.
317    ///
318    /// Like [`switch_model`], if the agent is currently running, the switch
319    /// is deferred until the current loop completes.
320    ///
321    /// [`switch_model`]: Agent::switch_model
322    pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
323        let model_id = format!("{}/{}", model.provider, model.id);
324        let new_provider = self
325            .resolver
326            .resolve_provider(&model.provider)
327            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
328
329        // Detect API change
330        let (old_api, needs_transform) = {
331            let inner = self.config();
332            let old_api = self
333                .resolver
334                .resolve_model(&inner.config.model_id)
335                .map(|m| m.api)
336                .unwrap_or(oxi_ai::Api::AnthropicMessages);
337            (old_api, old_api != model.api)
338        };
339
340        // If the agent is currently running, defer the switch.
341        if self.is_running.load(Ordering::SeqCst) {
342            tracing::info!(
343                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
344                model_id
345            );
346            *self.pending_model_switch.write() = Some(PendingModelSwitch {
347                model_id: model_id.clone(),
348                provider: new_provider,
349                needs_transform,
350                old_api,
351                new_api: model.api,
352            });
353            let mut inner = self.inner_mut();
354            inner.config.model_id = model_id;
355            inner.config.api_key = api_key;
356            return Ok(());
357        }
358
359        // Agent is idle — apply immediately.
360        if needs_transform {
361            let messages = self.state.get_state().messages.clone();
362            let transformed = transform_for_provider(&messages, &old_api, &model.api);
363            self.state.update(|s| {
364                s.replace_messages(transformed);
365            });
366        }
367
368        let mut inner = self.inner_mut();
369        inner.config.model_id = model_id;
370        inner.config.api_key = api_key;
371        inner.provider = new_provider;
372
373        Ok(())
374    }
375
376    /// Refresh only the API key without changing model or provider.
377    /// Useful when the user stores a key after the session was already created.
378    pub fn refresh_api_key(&self, api_key: Option<String>) {
379        let mut inner = self.inner_mut();
380        inner.config.api_key = api_key;
381    }
382
383    /// Get a handle to the tool registry.
384    pub fn tools(&self) -> Arc<ToolRegistry> {
385        Arc::clone(&self.tools)
386    }
387
388    /// Get a snapshot of the current agent state.
389    pub fn state(&self) -> AgentState {
390        self.state.get_state()
391    }
392
393    /// Update agent state in-place. Used by compaction to replace messages.
394    pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
395        self.state.update(f);
396    }
397
398    /// Reset agent state for a new conversation
399    pub fn reset(&self) {
400        self.state.reset();
401    }
402
403    /// Register a tool that the agent can invoke during a run.
404    pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
405        self.tools.register(tool);
406    }
407
408    /// Update the system prompt for future interactions.
409    pub fn set_system_prompt(&self, prompt: String) {
410        self.inner_mut().config.system_prompt = Some(prompt);
411    }
412
413    /// Get the compaction manager
414    pub fn compaction_manager(&self) -> &CompactionManager {
415        &self.compaction_manager
416    }
417    /// Update the compaction strategy for future runs.
418    ///
419    /// The strategy is read fresh from the config at the start of each run
420    /// (see `run_with_channel_inner`), so this takes effect on the next
421    /// agent turn — never mid-run. Pair with `compaction_manager()` for
422    /// manual compaction, which is unaffected by the strategy.
423    pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
424        self.inner.write().config.compaction_strategy = strategy;
425    }
426    /// Get the compaction strategy that will be used on the next run.
427    ///
428    /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
429    /// **not** from the `compaction_manager` field (which retains its
430    /// construction-time strategy). The agent loop reads from config fresh
431    /// each run, so this is the authoritative value.
432    pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
433        self.inner.read().config.compaction_strategy.clone()
434    }
435
436    /// Run the agent with a prompt, collecting all events into a vector.
437    ///
438    /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
439    /// [`AgentEvent`] produced during the run.
440    pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
441        let mut events = Vec::new();
442        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
443        let result = self.run_with_channel(prompt, tx).await;
444        while let Ok(event) = rx.recv() {
445            events.push(event);
446        }
447        result.map(|r| (r, events))
448    }
449
450    /// Run the agent, delivering events through the provided channel.
451    ///
452    /// Delegates to the agent loop which implements the same 2-level agentic
453    /// loop matching pi-mono's architecture:
454    ///
455    /// ```text
456    /// AgentLoop.run_messages()
457    ///   Outer loop (follow-up messages):
458    ///     Inner loop (tool calls + steering):
459    ///       1. Inject pending messages (steering)
460    ///       2. Compaction check
461    ///       3. Stream LLM response (with accumulated partial messages)
462    ///       4. Execute tool calls if any
463    ///       5. Emit turn_end
464    ///       6. Check shouldStopAfterTurn
465    ///       7. Poll steering messages
466    ///     Check follow-up messages
467    ///     Exit
468    /// ```
469    pub async fn run_with_channel(
470        &self,
471        prompt: String,
472        tx: std::sync::mpsc::Sender<AgentEvent>,
473    ) -> Result<Response> {
474        // pi-mono: Agent.prompt() throws if activeRun exists.
475        // Prevent concurrent runs that would corrupt shared state.
476        if self
477            .is_running
478            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
479            .is_err()
480        {
481            return Err(Error::msg("Agent is already running"));
482        }
483
484        // Drop guard ensures is_running is cleared even on panic.
485        struct RunningGuard<'a>(&'a AtomicBool);
486        impl Drop for RunningGuard<'_> {
487            fn drop(&mut self) {
488                self.0.store(false, Ordering::SeqCst);
489            }
490        }
491        let _guard = RunningGuard(&self.is_running);
492        self.reset_cancel();
493
494        self.run_with_channel_inner(prompt, tx).await
495    }
496
497    /// Inner implementation of run_with_channel, called after the running guard is set.
498    async fn run_with_channel_inner(
499        &self,
500        prompt: String,
501        tx: std::sync::mpsc::Sender<AgentEvent>,
502    ) -> Result<Response> {
503        use crate::agent_loop::AgentLoop;
504
505        let (
506            provider,
507            system_prompt,
508            temperature,
509            max_tokens,
510            compaction_strategy,
511            context_window,
512            api_key,
513            workspace_dir,
514        ) = {
515            let inner = self.inner.read();
516            (
517                Arc::clone(&inner.provider) as Arc<dyn Provider>,
518                inner.config.system_prompt.clone(),
519                inner.config.temperature,
520                inner.config.max_tokens,
521                inner.config.compaction_strategy.clone(),
522                inner.config.context_window,
523                inner.config.api_key.clone(),
524                inner.config.workspace_dir.clone(),
525            )
526        }; // release read lock
527
528        // Build AgentLoopConfig from Agent's config
529        let loop_config = crate::agent_loop::config::AgentLoopConfig {
530            model_id: self.model_id(),
531            system_prompt,
532            temperature: temperature.unwrap_or(1.0) as f32,
533            max_tokens: max_tokens.unwrap_or(4096) as u32,
534            tool_execution: crate::config::ToolExecutionMode::Sequential,
535            compaction_strategy,
536            compaction_instruction: None,
537            context_window,
538            session_id: self.config().config.session_id.clone(),
539            transport: None,
540            compact_on_start: false,
541            max_retry_delay_ms: None,
542            auto_retry_enabled: true,
543            auto_retry_max_attempts: 3,
544            auto_retry_base_delay_ms: 1000,
545            api_key,
546            workspace_dir,
547            provider_options: self.config().config.provider_options.clone(),
548            on_compaction: None,
549            ttsr_engine: self.config().config.ttsr_engine.clone(),
550            memory: self.config().config.memory.clone(),
551            todo: self.config().config.todo.clone(),
552            agent_pool: self.config().config.agent_pool.clone(),
553            ..Default::default()
554        };
555
556        // Create AgentLoop. We give it a NEW SharedState and sync back after.
557        // (SharedState is not Clone, so we create a fresh one from current state)
558        let fresh_state = crate::state::SharedState::new();
559        let current = self.state.get_state();
560        fresh_state.update(|s| {
561            *s = current;
562        });
563
564        let mut agent_loop = AgentLoop::new_with_resolver(
565            provider,
566            loop_config,
567            Arc::clone(&self.tools),
568            fresh_state,
569            Arc::clone(&self.resolver),
570        );
571
572        // Add the user prompt to Agent.state() AFTER fresh_state is created.
573        // fresh_state got a copy of the pre-prompt state, so run_loop will
574        // add the prompt to fresh_state independently via initial_prompts.
575        // But persist_session() reads Agent.state() (not fresh_state), so it
576        // needs the user prompt there to write it to the session file.
577        // Sync happens at AgentEnd (after run_loop completes), where
578        // Agent.state is overwritten with fresh_state (which has all messages).
579        self.state.update(|s| {
580            s.messages
581                .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
582                    prompt.clone(),
583                )));
584        });
585
586        // Pre-populate steering/follow-up from hooks
587        {
588            let hooks = self.hooks.read();
589            if let Some(ref get_steering) = hooks.get_steering_messages {
590                for msg_text in get_steering() {
591                    agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
592                }
593            }
594            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
595                for msg_text in get_follow_up() {
596                    agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
597                }
598            }
599
600            // Store hooks on AgentLoop so they can be polled each turn
601            // to pick up new messages injected during the run.
602            if let Some(ref get_steering) = hooks.get_steering_messages {
603                agent_loop.set_steering_hook(Arc::clone(get_steering));
604            }
605            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
606                agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
607            }
608        }
609        let mut al = agent_loop;
610
611        // Wire should_stop_after_turn hook: share AgentLoop's external_stop
612        // Arc with the emit callback. When the hook fires (Ctrl+C detected),
613        // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
614        // AND during streaming (streaming.rs checks external_stop each event).
615        //
616        // Arc<dyn Fn> can be cloned, so we read it without consuming.
617        let maybe_hook = {
618            let hooks_r = self.hooks.read();
619            hooks_r.should_stop_after_turn.clone()
620        };
621        let ext_stop = al.external_stop().clone();
622        let cancel_flag = self.cancel_flag.clone();
623
624        // Share cancel_flag with AgentLoop so the streaming loop can check
625        // it directly in the periodic timer — no emit callback required.
626        // This closes the gap where cancel() was ineffective when the
627        // provider stream produced no events.
628        al.set_cancel_signal(self.cancel_flag.clone());
629
630        // Create emit callback that sends through the channel.
631        // AgentLoop calls this synchronously. UnboundedSender::send() is
632        // non-blocking and never drops events (unlike try_send on bounded).
633        let tx_emit = tx.clone();
634
635        // Snapshot the observability_dispatch list once per run. This avoids
636        // holding an Agent lock on the emit-fn hot path while still letting
637        // SDK consumers register new dispatchers at any time (registers after
638        // this snapshot will fire on the next run).
639        let dispatch_handlers: Vec<EventDispatchFn> =
640            { self.inner.read().observability_dispatch.lock().clone() };
641        tracing::info!("[AGENT] Starting agent run with channel");
642        let result = al
643            .run(prompt.clone(), move |event: AgentEvent| {
644                // Forward event to channel (std::sync::mpsc — send from sync context)
645                tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
646                if let Err(e) = tx_emit.send(event.clone()) {
647                    tracing::error!(
648                        "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
649                        e
650                    );
651                } else {
652                    tracing::info!("[AGENT-EMIT] Successfully sent event");
653                }
654
655                // Propagate cancellation from Agent::cancel() → external_stop.
656                // This runs on every event, ensuring the streaming loop detects
657                // cancellation promptly.
658                if cancel_flag.load(Ordering::SeqCst) {
659                    ext_stop.store(true, Ordering::SeqCst);
660                }
661
662                // Fan out to SDK-side observability handlers (Tracer,
663                // CostTracker, ...). The dispatch list is snapshotted at
664                // run-start so we hold Arc clones, not a lock. This means
665                // handlers added mid-run do not fire until the next run.
666                for handler in dispatch_handlers.iter() {
667                    handler(event.clone());
668                }
669                // Propagate should_stop → external_stop on every event, not
670                // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
671                // so the context contents are irrelevant for non-TurnEnd events.
672                // This ensures streaming.rs detects cancellation immediately
673                // when the user presses Ctrl+C mid-stream.
674                if let Some(ref hook) = maybe_hook {
675                    let ctx = ShouldStopAfterTurnContext {
676                        message: match &event {
677                            AgentEvent::TurnEnd {
678                                assistant_message: oxi_ai::Message::Assistant(a),
679                                ..
680                            } => a.clone(),
681                            _ => oxi_ai::AssistantMessage::new(
682                                oxi_ai::Api::OpenAiCompletions,
683                                "agent",
684                                "agent-model",
685                            ),
686                        },
687                        tool_results: match &event {
688                            AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
689                            _ => Vec::new(),
690                        },
691                        iteration: 0,
692                    };
693                    if hook(&ctx) {
694                        ext_stop.store(true, Ordering::SeqCst);
695                    }
696                }
697            })
698            .await;
699
700        match result {
701            Ok(_events) => {
702                // Sync state back from AgentLoop
703                let loop_state = al.state().get_state();
704                self.state.update(|s| {
705                    *s = loop_state;
706                });
707
708                // Apply any pending model switch that was deferred during the run.
709                // This transforms messages (if cross-provider) and swaps the provider
710                // so the next run uses the new model.
711                self.apply_pending_model_switch();
712
713                // Extract final response text from state
714                let state = self.state.get_state();
715                let final_text = state
716                    .messages
717                    .iter()
718                    .rev()
719                    .find_map(|m| match m {
720                        oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
721                            oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
722                            _ => None,
723                        }),
724                        _ => None,
725                    })
726                    .unwrap_or_default();
727
728                let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
729
730                Ok(Response {
731                    content: final_text,
732                    stop_reason,
733                })
734            }
735            Err(e) => {
736                // Apply pending model switch even on error so the next run
737                // uses the new model.
738                self.apply_pending_model_switch();
739                Err(e)
740            }
741        }
742    }
743
744    // ── Helper methods for the agentic loop ────────────────────────
745
746    /// Set hooks for the agent loop.
747    pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
748        let mut h = self.hooks.write();
749        *h = hooks;
750    }
751
752    /// Register a side-dispatch closure called for every `AgentEvent`
753    /// emitted by `run`, `run_with_channel`, `run_streaming`,
754    /// `run_tokio_stream`, and `continue_with`.
755    ///
756    /// Multiple calls stack: every registered closure is invoked on
757    /// every event. Closures run synchronously on the agent-loop emit
758    /// thread, so they must be cheap and non-blocking. Long work
759    /// should be spawned off (e.g. `tokio::spawn`) by the closure
760    /// itself.
761    ///
762    /// Used by `oxi-sdk` to bridge observability types
763    /// (`Tracer`, `CostTracker`, `AuditLog`, `Authorizer` /
764    /// `AccessGate`) into the runtime without leaking those types
765    /// into `oxi-agent`.
766    ///
767    /// # Example
768    ///
769    /// ```ignore
770    /// agent.add_observability_dispatch(|event| match event {
771    ///     AgentEvent::TurnStart { turn_number } => {
772    ///         // open a span
773    ///     }
774    ///     AgentEvent::Usage { input_tokens, output_tokens } => {
775    ///         // record cost
776    ///     }
777    ///     _ => {}
778    /// });
779    /// ```
780    pub fn add_observability_dispatch(&self, f: impl Fn(AgentEvent) + Send + Sync + 'static) {
781        let guard = self.inner.write();
782        let mut slot = guard.observability_dispatch.lock();
783        slot.push(Arc::new(f));
784    }
785
786    /// Request cancellation of the current agent run.
787    ///
788    /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
789    /// `external_stop` on every event AND polled every ~500ms by the
790    /// streaming loop's periodic check. This ensures cancellation is
791    /// detected quickly even when the provider stream is completely hung
792    /// (no events arriving).
793    pub fn cancel(&self) {
794        self.cancel_flag.store(true, Ordering::SeqCst);
795    }
796
797    /// Reset the cancellation flag before starting a new run.
798    pub fn reset_cancel(&self) {
799        self.cancel_flag.store(false, Ordering::SeqCst);
800    }
801
802    /// Apply any pending model switch that was deferred during a running loop.
803    ///
804    /// Called after `run_with_channel_inner` completes (success or error).
805    /// Transforms messages for cross-provider switches and swaps the provider
806    /// so the next run uses the new model.
807    fn apply_pending_model_switch(&self) {
808        let pending = self.pending_model_switch.write().take();
809        if let Some(pending) = pending {
810            tracing::info!(
811                "[AGENT] Applying deferred model switch to '{}' (transform={})",
812                pending.model_id,
813                pending.needs_transform
814            );
815
816            // Transform messages if cross-provider
817            if pending.needs_transform {
818                let messages = self.state.get_state().messages.clone();
819                let transformed =
820                    transform_for_provider(&messages, &pending.old_api, &pending.new_api);
821                self.state.update(|s| {
822                    s.replace_messages(transformed);
823                });
824            }
825
826            // Swap the provider
827            let mut inner = self.inner_mut();
828            inner.provider = pending.provider;
829            // model_id was already updated in switch_model()
830        }
831    }
832
833    /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
834    ///
835    /// Blocking convenience wrapper suitable for callers that prefer a
836    /// callback-based API over a channel.
837    pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
838    where
839        F: FnMut(AgentEvent) + Send,
840    {
841        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
842        let result = self.run_with_channel(prompt, tx).await;
843        while let Ok(event) = rx.recv() {
844            on_event(event);
845        }
846        result
847    }
848
849    // ── Session persistence ────────────────────────────────────────
850
851    /// Export the agent state as a JSON value.
852    ///
853    /// The serialized state includes conversation messages, token counts,
854    /// iteration progress, and stop reason. Use [`import_state`] to restore.
855    ///
856    /// [`import_state`]: Agent::import_state
857    pub fn export_state(&self) -> Result<serde_json::Value> {
858        let state = self.state.get_state();
859        serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
860    }
861
862    /// Import agent state from a JSON value.
863    ///
864    /// Restores conversation history, token counts, and iteration progress.
865    /// Typically used together with [`export_state`] for session persistence.
866    ///
867    /// [`export_state`]: Agent::export_state
868    pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
869        let state: AgentState = serde_json::from_value(value)
870            .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
871        self.state.update(|s| *s = state);
872        Ok(())
873    }
874
875    // ── Session continuation ───────────────────────────────────────
876
877    /// Continue the current session with a new prompt.
878    ///
879    /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
880    /// preserves the existing conversation state and appends the new prompt.
881    /// This enables multi-turn interactions within the same session.
882    pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
883        let mut events = Vec::new();
884        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
885        let result = self.run_with_channel(prompt, tx).await;
886        while let Ok(event) = rx.recv() {
887            events.push(event);
888        }
889        result.map(|r| (r, events))
890    }
891
892    // ── Tokio-native streaming ─────────────────────────────────────
893
894    /// Run the agent with tokio-native event streaming.
895    ///
896    /// Returns a `tokio::sync::mpsc::Receiver` for events and a
897    /// `JoinHandle` for the response. This is the preferred API for
898    /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
899    ///
900    /// # Example
901    ///
902    /// ```ignore
903    /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
904    /// while let Some(event) = rx.recv().await {
905    ///     println!("Event: {:?}", event.type_name());
906    /// }
907    /// let response = handle.await??;
908    /// ```
909    pub async fn run_tokio_stream(
910        &self,
911        prompt: String,
912    ) -> Result<(
913        tokio::sync::mpsc::Receiver<AgentEvent>,
914        tokio::task::JoinHandle<Result<Response>>,
915    )> {
916        let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
917
918        if self
919            .is_running
920            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
921            .is_err()
922        {
923            return Err(Error::msg("Agent is already running"));
924        }
925
926        let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
927
928        let inner = self.inner.read().clone();
929        let tools = Arc::clone(&self.tools);
930        let resolver = Arc::clone(&self.resolver);
931
932        // Build AgentLoopConfig
933        let loop_config = crate::agent_loop::config::AgentLoopConfig {
934            model_id: inner.config.model_id.clone(),
935            system_prompt: inner.config.system_prompt.clone(),
936            temperature: inner.config.temperature.unwrap_or(1.0) as f32,
937            max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
938            tool_execution: crate::config::ToolExecutionMode::Sequential,
939            compaction_strategy: inner.config.compaction_strategy.clone(),
940            compaction_instruction: None,
941            context_window: inner.config.context_window,
942            session_id: inner.config.session_id.clone(),
943            transport: None,
944            compact_on_start: false,
945            max_retry_delay_ms: None,
946            auto_retry_enabled: true,
947            auto_retry_max_attempts: 3,
948            auto_retry_base_delay_ms: 1000,
949            api_key: inner.config.api_key.clone(),
950            workspace_dir: inner.config.workspace_dir.clone(),
951            provider_options: inner.config.provider_options.clone(),
952            on_compaction: None,
953            ttsr_engine: inner.config.ttsr_engine.clone(),
954            ..Default::default()
955        };
956
957        let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
958
959        // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
960        // agent loop so that state mutations inside the spawned task are
961        // visible through self.state() without an explicit sync step.
962        //
963        // Unlike run_with_channel_inner which creates a fresh SharedState
964        // and syncs back on completion, the tokio streaming API cannot
965        // access `self` inside the `'static` spawned task, so we share
966        // the underlying Arc instead.
967        //
968        // Pre-load current state into the shared Arc (in case it was
969        // modified by a previous run that used a different SharedState).
970        let shared_state = self.state.clone();
971
972        let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
973            provider,
974            loop_config,
975            tools,
976            shared_state.clone(),
977            resolver,
978        );
979
980        let maybe_hook = should_stop_hook;
981        let ext_stop = agent_loop.external_stop().clone();
982
983        // Clone the is_running Arc so the spawned task can clear it.
984        let is_running_flag = Arc::clone(&self.is_running);
985
986        // Snapshot the observability_dispatch list before the spawned
987        // task. The future is `'static` and cannot borrow `&self`,
988        // so we take the snapshot at run-start on the regular borrow
989        // stack and move the resulting Arc-clones into the task.
990        let dispatch_handlers: Vec<EventDispatchFn> = {
991            let guard = self.inner.read();
992            guard.observability_dispatch.lock().clone()
993        };
994
995        let handle = tokio::task::spawn(async move {
996            let result = agent_loop
997                .run(prompt, move |event: AgentEvent| {
998                    // Forward to tokio channel (non-blocking)
999                    let _ = tx.try_send(event.clone());
1000
1001                    // Fan out to SDK-side observability handlers
1002                    // (Tracer, CostTracker, ...).
1003                    for handler in dispatch_handlers.iter() {
1004                        handler(event.clone());
1005                    }
1006                    // Propagate should_stop → external_stop on every event,
1007                    // not just TurnEnd. See run_with_channel_inner for rationale.
1008                    if let Some(ref hook) = maybe_hook {
1009                        let ctx = ShouldStopAfterTurnContext {
1010                            message: match &event {
1011                                AgentEvent::TurnEnd {
1012                                    assistant_message: oxi_ai::Message::Assistant(a),
1013                                    ..
1014                                } => a.clone(),
1015                                _ => oxi_ai::AssistantMessage::new(
1016                                    oxi_ai::Api::OpenAiCompletions,
1017                                    "agent",
1018                                    "agent-model",
1019                                ),
1020                            },
1021                            tool_results: match &event {
1022                                AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
1023                                _ => Vec::new(),
1024                            },
1025                            iteration: 0,
1026                        };
1027                        if hook(&ctx) {
1028                            ext_stop.store(true, Ordering::SeqCst);
1029                        }
1030                    }
1031                })
1032                .await;
1033
1034            // Clear the Agent's running flag
1035            is_running_flag.store(false, Ordering::SeqCst);
1036
1037            match result {
1038                Ok(_events) => {
1039                    // State is already shared via the same SharedState Arc,
1040                    // so self.state() will reflect all mutations.
1041                    Ok(Response {
1042                        content: String::new(),
1043                        stop_reason: StopReason::Stop,
1044                    })
1045                }
1046                Err(e) => Err(e),
1047            }
1048        });
1049
1050        Ok((rx, handle))
1051    }
1052}