Skip to main content

oxi_agent/
agent.rs

1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10    CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26    /// Resolve a provider by name, returning an Arc handle.
27    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29    /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30    fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41        oxi_ai::get_provider(name).map(Arc::from)
42    }
43
44    fn resolve_model(&self, model_id: &str) -> Option<Model> {
45        crate::model_id::resolve_model_from_id(model_id)
46    }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50
51/// Mutable agent internals protected by a read-write lock.
52struct AgentInner {
53    config: AgentConfig,
54    provider: Arc<dyn Provider>,
55}
56
57impl Clone for AgentInner {
58    fn clone(&self) -> Self {
59        Self {
60            config: self.config.clone(),
61            provider: Arc::clone(&self.provider),
62        }
63    }
64}
65
66/// Agent runtime.
67///
68/// Manages provider, tool registry, state, and compaction, providing an
69/// agentic loop for prompt execution, model switching, tool calls, and fallback.
70///
71/// Supports session continuation via [`continue_with`] and tokio-native
72/// event streaming via [`run_tokio_stream`].
73///
74/// [`continue_with`]: Agent::continue_with
75/// [`run_tokio_stream`]: Agent::run_tokio_stream
76/// Deferred model switch request, stored when the agent is running.
77struct PendingModelSwitch {
78    model_id: String,
79    provider: Arc<dyn Provider>,
80    /// Whether messages need cross-provider transformation.
81    needs_transform: bool,
82    old_api: oxi_ai::Api,
83    new_api: oxi_ai::Api,
84}
85
86/// Agent runtime.
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation, tokio-native event streaming, and deferred
92/// model switching (changes are queued while a loop is running and applied
93/// after it completes).
94#[allow(missing_docs)]
95pub struct Agent {
96    inner: RwLock<AgentInner>,
97    tools: Arc<ToolRegistry>,
98    state: SharedState,
99    compaction_manager: CompactionManager,
100    hooks: parking_lot::RwLock<crate::config::AgentHooks>,
101    /// Guard: true while a run is in progress. Prevents concurrent runs.
102    is_running: Arc<AtomicBool>,
103    /// Provider/model resolver. Uses global functions by default,
104    /// or a custom resolver when created via `new_with_resolver()`.
105    resolver: Arc<dyn ProviderResolver>,
106    /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
107    /// propagated to AgentLoop's `external_stop` during each run.
108    cancel_flag: Arc<AtomicBool>,
109    /// Pending model switch — stored when the agent is running,
110    /// applied after the current loop completes.
111    pending_model_switch: RwLock<Option<PendingModelSwitch>>,
112}
113
114impl Agent {
115    /// Create a new agent with the given provider, config, and tool registry.
116    ///
117    /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
118    /// for model switching. For isolated instances, use [`new_with_resolver`].
119    ///
120    /// [`new_with_resolver`]: Agent::new_with_resolver
121    pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
122        let resolver = Arc::new(GlobalProviderResolver);
123        Self::build_inner(provider, config, tools, resolver)
124    }
125
126    /// Create an agent with a custom provider/model resolver.
127    ///
128    /// This is the preferred constructor for SDK usage where provider
129    /// and model registries must be isolated from global state.
130    pub fn new_with_resolver(
131        provider: Arc<dyn Provider>,
132        config: AgentConfig,
133        tools: Arc<ToolRegistry>,
134        resolver: Arc<dyn ProviderResolver>,
135    ) -> Self {
136        Self::build_inner(provider, config, tools, resolver)
137    }
138
139    /// Internal constructor shared by `new()` and `new_with_resolver()`.
140    fn build_inner(
141        provider: Arc<dyn Provider>,
142        config: AgentConfig,
143        tools: Arc<ToolRegistry>,
144        resolver: Arc<dyn ProviderResolver>,
145    ) -> Self {
146        let mut compaction_manager =
147            CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
148
149        // Pre-initialize the LLM compactor if compaction is enabled
150        if config.compaction_strategy != CompactionStrategy::Disabled {
151            let model = resolver.resolve_model(&config.model_id);
152
153            if let Some(model) = model {
154                let llm_compactor =
155                    Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
156                compaction_manager.set_compactor(llm_compactor);
157            }
158        }
159
160        Self {
161            inner: RwLock::new(AgentInner { config, provider }),
162            tools,
163            state: SharedState::new(),
164            compaction_manager,
165            hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
166            is_running: Arc::new(AtomicBool::new(false)),
167            resolver,
168            cancel_flag: Arc::new(AtomicBool::new(false)),
169            pending_model_switch: RwLock::new(None),
170        }
171    }
172
173    /// Create an agent with an empty tool registry.
174    pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
175        Self::new(provider, config, Arc::new(ToolRegistry::new()))
176    }
177
178    /// Get the agent configuration (read guard)
179    fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
180        self.inner.read()
181    }
182
183    /// Get a write guard for the agent inner state
184    fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
185        self.inner.write()
186    }
187
188    /// Get the current model ID
189    pub fn model_id(&self) -> String {
190        self.config().config.model_id.clone()
191    }
192
193    /// Get the agent configuration (full clone)
194    pub fn get_config(&self) -> AgentConfig {
195        self.config().config.clone()
196    }
197
198    /// Get a reference to the provider resolver.
199    pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
200        &self.resolver
201    }
202
203    /// Switch the model used for future LLM calls.
204    ///
205    /// Switch model mid-conversation.
206    ///
207    /// If the agent is currently running, the switch is deferred: the new
208    /// model and provider are stored in `pending_model_switch` and applied
209    /// automatically when the current loop finishes. This ensures the
210    /// running loop completes with a consistent provider/model without
211    /// interruption.
212    ///
213    /// If the agent is idle, the switch takes effect immediately.
214    ///
215    /// If the new model uses a different provider API, the conversation
216    /// history is automatically transformed for cross-provider compatibility
217    /// (e.g. thinking blocks are converted to `<thinking>` tags).
218    ///
219    /// # Arguments
220    /// * `model_id` - New model ID in `provider/model` format
221    /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
222    ///
223    /// # Returns
224    /// `Ok(())` on success, or an error if the model/provider is unknown
225    pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
226        let new_model = self
227            .resolver
228            .resolve_model(model_id)
229            .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
230
231        // Create the new provider via resolver
232        let new_provider = self
233            .resolver
234            .resolve_provider(&new_model.provider)
235            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
236
237        // Detect API change
238        let (old_api, needs_transform) = {
239            let inner = self.config();
240            let old_api = self
241                .resolver
242                .resolve_model(&inner.config.model_id)
243                .map(|m| m.api)
244                .unwrap_or(oxi_ai::Api::AnthropicMessages);
245            (old_api, old_api != new_model.api)
246        };
247
248        // If the agent is currently running, defer the switch.
249        if self.is_running.load(Ordering::SeqCst) {
250            tracing::info!(
251                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
252                model_id
253            );
254            *self.pending_model_switch.write() = Some(PendingModelSwitch {
255                model_id: model_id.to_string(),
256                provider: new_provider,
257                needs_transform,
258                old_api,
259                new_api: new_model.api,
260            });
261            // Update config immediately so model_id() returns the new value,
262            // but leave provider unchanged so the running loop keeps its provider.
263            {
264                let mut inner = self.inner_mut();
265                inner.config.model_id = model_id.to_string();
266                inner.config.api_key = api_key;
267            }
268            return Ok(());
269        }
270
271        // Agent is idle — apply immediately.
272        if needs_transform {
273            let messages = self.state.get_state().messages.clone();
274            let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
275            self.state.update(|s| {
276                s.replace_messages(transformed);
277            });
278        }
279
280        // Update config and provider atomically
281        let mut inner = self.inner_mut();
282        inner.config.model_id = model_id.to_string();
283        inner.config.api_key = api_key;
284        inner.provider = new_provider;
285
286        Ok(())
287    }
288
289    /// Switch the model using a pre-resolved `Model` object.
290    ///
291    /// This is useful when the caller has already looked up the model
292    /// and optionally created the provider.
293    ///
294    /// Like [`switch_model`], if the agent is currently running, the switch
295    /// is deferred until the current loop completes.
296    ///
297    /// [`switch_model`]: Agent::switch_model
298    pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
299        let model_id = format!("{}/{}", model.provider, model.id);
300        let new_provider = self
301            .resolver
302            .resolve_provider(&model.provider)
303            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
304
305        // Detect API change
306        let (old_api, needs_transform) = {
307            let inner = self.config();
308            let old_api = self
309                .resolver
310                .resolve_model(&inner.config.model_id)
311                .map(|m| m.api)
312                .unwrap_or(oxi_ai::Api::AnthropicMessages);
313            (old_api, old_api != model.api)
314        };
315
316        // If the agent is currently running, defer the switch.
317        if self.is_running.load(Ordering::SeqCst) {
318            tracing::info!(
319                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
320                model_id
321            );
322            *self.pending_model_switch.write() = Some(PendingModelSwitch {
323                model_id: model_id.clone(),
324                provider: new_provider,
325                needs_transform,
326                old_api,
327                new_api: model.api,
328            });
329            let mut inner = self.inner_mut();
330            inner.config.model_id = model_id;
331            inner.config.api_key = api_key;
332            return Ok(());
333        }
334
335        // Agent is idle — apply immediately.
336        if needs_transform {
337            let messages = self.state.get_state().messages.clone();
338            let transformed = transform_for_provider(&messages, &old_api, &model.api);
339            self.state.update(|s| {
340                s.replace_messages(transformed);
341            });
342        }
343
344        let mut inner = self.inner_mut();
345        inner.config.model_id = model_id;
346        inner.config.api_key = api_key;
347        inner.provider = new_provider;
348
349        Ok(())
350    }
351
352    /// Refresh only the API key without changing model or provider.
353    /// Useful when the user stores a key after the session was already created.
354    pub fn refresh_api_key(&self, api_key: Option<String>) {
355        let mut inner = self.inner_mut();
356        inner.config.api_key = api_key;
357    }
358
359    /// Get a handle to the tool registry.
360    pub fn tools(&self) -> Arc<ToolRegistry> {
361        Arc::clone(&self.tools)
362    }
363
364    /// Get a snapshot of the current agent state.
365    pub fn state(&self) -> AgentState {
366        self.state.get_state()
367    }
368
369    /// Update agent state in-place. Used by compaction to replace messages.
370    pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
371        self.state.update(f);
372    }
373
374    /// Reset agent state for a new conversation
375    pub fn reset(&self) {
376        self.state.reset();
377    }
378
379    /// Register a tool that the agent can invoke during a run.
380    pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
381        self.tools.register(tool);
382    }
383
384    /// Update the system prompt for future interactions.
385    pub fn set_system_prompt(&self, prompt: String) {
386        self.inner_mut().config.system_prompt = Some(prompt);
387    }
388
389    /// Get the compaction manager
390    pub fn compaction_manager(&self) -> &CompactionManager {
391        &self.compaction_manager
392    }
393    /// Update the compaction strategy for future runs.
394    ///
395    /// The strategy is read fresh from the config at the start of each run
396    /// (see `run_with_channel_inner`), so this takes effect on the next
397    /// agent turn — never mid-run. Pair with `compaction_manager()` for
398    /// manual compaction, which is unaffected by the strategy.
399    pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
400        self.inner.write().config.compaction_strategy = strategy;
401    }
402    /// Get the compaction strategy that will be used on the next run.
403    ///
404    /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
405    /// **not** from the `compaction_manager` field (which retains its
406    /// construction-time strategy). The agent loop reads from config fresh
407    /// each run, so this is the authoritative value.
408    pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
409        self.inner.read().config.compaction_strategy.clone()
410    }
411
412    /// Run the agent with a prompt, collecting all events into a vector.
413    ///
414    /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
415    /// [`AgentEvent`] produced during the run.
416    pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
417        let mut events = Vec::new();
418        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
419        let result = self.run_with_channel(prompt, tx).await;
420        while let Ok(event) = rx.recv() {
421            events.push(event);
422        }
423        result.map(|r| (r, events))
424    }
425
426    /// Run the agent, delivering events through the provided channel.
427    ///
428    /// Delegates to the agent loop which implements the same 2-level agentic
429    /// loop matching pi-mono's architecture:
430    ///
431    /// ```text
432    /// AgentLoop.run_messages()
433    ///   Outer loop (follow-up messages):
434    ///     Inner loop (tool calls + steering):
435    ///       1. Inject pending messages (steering)
436    ///       2. Compaction check
437    ///       3. Stream LLM response (with accumulated partial messages)
438    ///       4. Execute tool calls if any
439    ///       5. Emit turn_end
440    ///       6. Check shouldStopAfterTurn
441    ///       7. Poll steering messages
442    ///     Check follow-up messages
443    ///     Exit
444    /// ```
445    pub async fn run_with_channel(
446        &self,
447        prompt: String,
448        tx: std::sync::mpsc::Sender<AgentEvent>,
449    ) -> Result<Response> {
450        // pi-mono: Agent.prompt() throws if activeRun exists.
451        // Prevent concurrent runs that would corrupt shared state.
452        if self
453            .is_running
454            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
455            .is_err()
456        {
457            return Err(Error::msg("Agent is already running"));
458        }
459
460        // Drop guard ensures is_running is cleared even on panic.
461        struct RunningGuard<'a>(&'a AtomicBool);
462        impl Drop for RunningGuard<'_> {
463            fn drop(&mut self) {
464                self.0.store(false, Ordering::SeqCst);
465            }
466        }
467        let _guard = RunningGuard(&self.is_running);
468        self.reset_cancel();
469
470        self.run_with_channel_inner(prompt, tx).await
471    }
472
473    /// Inner implementation of run_with_channel, called after the running guard is set.
474    async fn run_with_channel_inner(
475        &self,
476        prompt: String,
477        tx: std::sync::mpsc::Sender<AgentEvent>,
478    ) -> Result<Response> {
479        use crate::agent_loop::AgentLoop;
480
481        let (
482            provider,
483            system_prompt,
484            temperature,
485            max_tokens,
486            compaction_strategy,
487            context_window,
488            api_key,
489            workspace_dir,
490        ) = {
491            let inner = self.inner.read();
492            (
493                Arc::clone(&inner.provider) as Arc<dyn Provider>,
494                inner.config.system_prompt.clone(),
495                inner.config.temperature,
496                inner.config.max_tokens,
497                inner.config.compaction_strategy.clone(),
498                inner.config.context_window,
499                inner.config.api_key.clone(),
500                inner.config.workspace_dir.clone(),
501            )
502        }; // release read lock
503
504        // Build AgentLoopConfig from Agent's config
505        let loop_config = crate::agent_loop::config::AgentLoopConfig {
506            model_id: self.model_id(),
507            system_prompt,
508            temperature: temperature.unwrap_or(1.0) as f32,
509            max_tokens: max_tokens.unwrap_or(4096) as u32,
510            tool_execution: crate::config::ToolExecutionMode::Sequential,
511            compaction_strategy,
512            compaction_instruction: None,
513            context_window,
514            session_id: self.config().config.session_id.clone(),
515            transport: None,
516            compact_on_start: false,
517            max_retry_delay_ms: None,
518            auto_retry_enabled: true,
519            auto_retry_max_attempts: 3,
520            auto_retry_base_delay_ms: 1000,
521            api_key,
522            workspace_dir,
523            provider_options: self.config().config.provider_options.clone(),
524            on_compaction: None,
525            ttsr_engine: self.config().config.ttsr_engine.clone(),
526            memory: self.config().config.memory.clone(),
527            todo: self.config().config.todo.clone(),
528            agent_pool: self.config().config.agent_pool.clone(),
529            ..Default::default()
530        };
531
532        // Create AgentLoop. We give it a NEW SharedState and sync back after.
533        // (SharedState is not Clone, so we create a fresh one from current state)
534        let fresh_state = crate::state::SharedState::new();
535        let current = self.state.get_state();
536        fresh_state.update(|s| {
537            *s = current;
538        });
539
540        let mut agent_loop = AgentLoop::new_with_resolver(
541            provider,
542            loop_config,
543            Arc::clone(&self.tools),
544            fresh_state,
545            Arc::clone(&self.resolver),
546        );
547
548        // Add the user prompt to Agent.state() AFTER fresh_state is created.
549        // fresh_state got a copy of the pre-prompt state, so run_loop will
550        // add the prompt to fresh_state independently via initial_prompts.
551        // But persist_session() reads Agent.state() (not fresh_state), so it
552        // needs the user prompt there to write it to the session file.
553        // Sync happens at AgentEnd (after run_loop completes), where
554        // Agent.state is overwritten with fresh_state (which has all messages).
555        self.state.update(|s| {
556            s.messages
557                .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
558                    prompt.clone(),
559                )));
560        });
561
562        // Pre-populate steering/follow-up from hooks
563        {
564            let hooks = self.hooks.read();
565            if let Some(ref get_steering) = hooks.get_steering_messages {
566                for msg_text in get_steering() {
567                    agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
568                }
569            }
570            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
571                for msg_text in get_follow_up() {
572                    agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
573                }
574            }
575
576            // Store hooks on AgentLoop so they can be polled each turn
577            // to pick up new messages injected during the run.
578            if let Some(ref get_steering) = hooks.get_steering_messages {
579                agent_loop.set_steering_hook(Arc::clone(get_steering));
580            }
581            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
582                agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
583            }
584        }
585        let mut al = agent_loop;
586
587        // Wire should_stop_after_turn hook: share AgentLoop's external_stop
588        // Arc with the emit callback. When the hook fires (Ctrl+C detected),
589        // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
590        // AND during streaming (streaming.rs checks external_stop each event).
591        //
592        // Arc<dyn Fn> can be cloned, so we read it without consuming.
593        let maybe_hook = {
594            let hooks_r = self.hooks.read();
595            hooks_r.should_stop_after_turn.clone()
596        };
597        let ext_stop = al.external_stop().clone();
598        let cancel_flag = self.cancel_flag.clone();
599
600        // Share cancel_flag with AgentLoop so the streaming loop can check
601        // it directly in the periodic timer — no emit callback required.
602        // This closes the gap where cancel() was ineffective when the
603        // provider stream produced no events.
604        al.set_cancel_signal(self.cancel_flag.clone());
605
606        // Create emit callback that sends through the channel.
607        // AgentLoop calls this synchronously. UnboundedSender::send() is
608        // non-blocking and never drops events (unlike try_send on bounded).
609        let tx_emit = tx.clone();
610
611        // Run the agent loop
612        tracing::info!("[AGENT] Starting agent run with channel");
613        let result = al
614            .run(prompt.clone(), move |event: AgentEvent| {
615                // Forward event to channel (std::sync::mpsc — send from sync context)
616                tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
617                if let Err(e) = tx_emit.send(event.clone()) {
618                    tracing::error!(
619                        "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
620                        e
621                    );
622                } else {
623                    tracing::info!("[AGENT-EMIT] Successfully sent event");
624                }
625
626                // Propagate cancellation from Agent::cancel() → external_stop.
627                // This runs on every event, ensuring the streaming loop detects
628                // cancellation promptly.
629                if cancel_flag.load(Ordering::SeqCst) {
630                    ext_stop.store(true, Ordering::SeqCst);
631                }
632
633                // Propagate should_stop → external_stop on every event, not
634                // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
635                // so the context contents are irrelevant for non-TurnEnd events.
636                // This ensures streaming.rs detects cancellation immediately
637                // when the user presses Ctrl+C mid-stream.
638                if let Some(ref hook) = maybe_hook {
639                    let ctx = ShouldStopAfterTurnContext {
640                        message: match &event {
641                            AgentEvent::TurnEnd {
642                                assistant_message: oxi_ai::Message::Assistant(a),
643                                ..
644                            } => a.clone(),
645                            _ => oxi_ai::AssistantMessage::new(
646                                oxi_ai::Api::OpenAiCompletions,
647                                "agent",
648                                "agent-model",
649                            ),
650                        },
651                        tool_results: match &event {
652                            AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
653                            _ => Vec::new(),
654                        },
655                        iteration: 0,
656                    };
657                    if hook(&ctx) {
658                        ext_stop.store(true, Ordering::SeqCst);
659                    }
660                }
661            })
662            .await;
663
664        match result {
665            Ok(_events) => {
666                // Sync state back from AgentLoop
667                let loop_state = al.state().get_state();
668                self.state.update(|s| {
669                    *s = loop_state;
670                });
671
672                // Apply any pending model switch that was deferred during the run.
673                // This transforms messages (if cross-provider) and swaps the provider
674                // so the next run uses the new model.
675                self.apply_pending_model_switch();
676
677                // Extract final response text from state
678                let state = self.state.get_state();
679                let final_text = state
680                    .messages
681                    .iter()
682                    .rev()
683                    .find_map(|m| match m {
684                        oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
685                            oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
686                            _ => None,
687                        }),
688                        _ => None,
689                    })
690                    .unwrap_or_default();
691
692                let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
693
694                Ok(Response {
695                    content: final_text,
696                    stop_reason,
697                })
698            }
699            Err(e) => {
700                // Apply pending model switch even on error so the next run
701                // uses the new model.
702                self.apply_pending_model_switch();
703                Err(e)
704            }
705        }
706    }
707
708    // ── Helper methods for the agentic loop ────────────────────────
709
710    /// Set hooks for the agent loop.
711    pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
712        let mut h = self.hooks.write();
713        *h = hooks;
714    }
715
716    /// Request cancellation of the current agent run.
717    ///
718    /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
719    /// `external_stop` on every event AND polled every ~500ms by the
720    /// streaming loop's periodic check. This ensures cancellation is
721    /// detected quickly even when the provider stream is completely hung
722    /// (no events arriving).
723    pub fn cancel(&self) {
724        self.cancel_flag.store(true, Ordering::SeqCst);
725    }
726
727    /// Reset the cancellation flag before starting a new run.
728    pub fn reset_cancel(&self) {
729        self.cancel_flag.store(false, Ordering::SeqCst);
730    }
731
732    /// Apply any pending model switch that was deferred during a running loop.
733    ///
734    /// Called after `run_with_channel_inner` completes (success or error).
735    /// Transforms messages for cross-provider switches and swaps the provider
736    /// so the next run uses the new model.
737    fn apply_pending_model_switch(&self) {
738        let pending = self.pending_model_switch.write().take();
739        if let Some(pending) = pending {
740            tracing::info!(
741                "[AGENT] Applying deferred model switch to '{}' (transform={})",
742                pending.model_id,
743                pending.needs_transform
744            );
745
746            // Transform messages if cross-provider
747            if pending.needs_transform {
748                let messages = self.state.get_state().messages.clone();
749                let transformed =
750                    transform_for_provider(&messages, &pending.old_api, &pending.new_api);
751                self.state.update(|s| {
752                    s.replace_messages(transformed);
753                });
754            }
755
756            // Swap the provider
757            let mut inner = self.inner_mut();
758            inner.provider = pending.provider;
759            // model_id was already updated in switch_model()
760        }
761    }
762
763    /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
764    ///
765    /// Blocking convenience wrapper suitable for callers that prefer a
766    /// callback-based API over a channel.
767    pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
768    where
769        F: FnMut(AgentEvent) + Send,
770    {
771        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
772        let result = self.run_with_channel(prompt, tx).await;
773        while let Ok(event) = rx.recv() {
774            on_event(event);
775        }
776        result
777    }
778
779    // ── Session persistence ────────────────────────────────────────
780
781    /// Export the agent state as a JSON value.
782    ///
783    /// The serialized state includes conversation messages, token counts,
784    /// iteration progress, and stop reason. Use [`import_state`] to restore.
785    ///
786    /// [`import_state`]: Agent::import_state
787    pub fn export_state(&self) -> Result<serde_json::Value> {
788        let state = self.state.get_state();
789        serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
790    }
791
792    /// Import agent state from a JSON value.
793    ///
794    /// Restores conversation history, token counts, and iteration progress.
795    /// Typically used together with [`export_state`] for session persistence.
796    ///
797    /// [`export_state`]: Agent::export_state
798    pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
799        let state: AgentState = serde_json::from_value(value)
800            .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
801        self.state.update(|s| *s = state);
802        Ok(())
803    }
804
805    // ── Session continuation ───────────────────────────────────────
806
807    /// Continue the current session with a new prompt.
808    ///
809    /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
810    /// preserves the existing conversation state and appends the new prompt.
811    /// This enables multi-turn interactions within the same session.
812    pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
813        let mut events = Vec::new();
814        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
815        let result = self.run_with_channel(prompt, tx).await;
816        while let Ok(event) = rx.recv() {
817            events.push(event);
818        }
819        result.map(|r| (r, events))
820    }
821
822    // ── Tokio-native streaming ─────────────────────────────────────
823
824    /// Run the agent with tokio-native event streaming.
825    ///
826    /// Returns a `tokio::sync::mpsc::Receiver` for events and a
827    /// `JoinHandle` for the response. This is the preferred API for
828    /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
829    ///
830    /// # Example
831    ///
832    /// ```ignore
833    /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
834    /// while let Some(event) = rx.recv().await {
835    ///     println!("Event: {:?}", event.type_name());
836    /// }
837    /// let response = handle.await??;
838    /// ```
839    pub async fn run_tokio_stream(
840        &self,
841        prompt: String,
842    ) -> Result<(
843        tokio::sync::mpsc::Receiver<AgentEvent>,
844        tokio::task::JoinHandle<Result<Response>>,
845    )> {
846        let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
847
848        if self
849            .is_running
850            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
851            .is_err()
852        {
853            return Err(Error::msg("Agent is already running"));
854        }
855
856        let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
857
858        let inner = self.inner.read().clone();
859        let tools = Arc::clone(&self.tools);
860        let resolver = Arc::clone(&self.resolver);
861
862        // Build AgentLoopConfig
863        let loop_config = crate::agent_loop::config::AgentLoopConfig {
864            model_id: inner.config.model_id.clone(),
865            system_prompt: inner.config.system_prompt.clone(),
866            temperature: inner.config.temperature.unwrap_or(1.0) as f32,
867            max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
868            tool_execution: crate::config::ToolExecutionMode::Sequential,
869            compaction_strategy: inner.config.compaction_strategy.clone(),
870            compaction_instruction: None,
871            context_window: inner.config.context_window,
872            session_id: inner.config.session_id.clone(),
873            transport: None,
874            compact_on_start: false,
875            max_retry_delay_ms: None,
876            auto_retry_enabled: true,
877            auto_retry_max_attempts: 3,
878            auto_retry_base_delay_ms: 1000,
879            api_key: inner.config.api_key.clone(),
880            workspace_dir: inner.config.workspace_dir.clone(),
881            provider_options: inner.config.provider_options.clone(),
882            on_compaction: None,
883            ttsr_engine: inner.config.ttsr_engine.clone(),
884            ..Default::default()
885        };
886
887        let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
888
889        // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
890        // agent loop so that state mutations inside the spawned task are
891        // visible through self.state() without an explicit sync step.
892        //
893        // Unlike run_with_channel_inner which creates a fresh SharedState
894        // and syncs back on completion, the tokio streaming API cannot
895        // access `self` inside the `'static` spawned task, so we share
896        // the underlying Arc instead.
897        //
898        // Pre-load current state into the shared Arc (in case it was
899        // modified by a previous run that used a different SharedState).
900        let shared_state = self.state.clone();
901
902        let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
903            provider,
904            loop_config,
905            tools,
906            shared_state.clone(),
907            resolver,
908        );
909
910        let maybe_hook = should_stop_hook;
911        let ext_stop = agent_loop.external_stop().clone();
912
913        // Clone the is_running Arc so the spawned task can clear it.
914        let is_running_flag = Arc::clone(&self.is_running);
915
916        let handle = tokio::task::spawn(async move {
917            let result = agent_loop
918                .run(prompt, move |event: AgentEvent| {
919                    // Forward to tokio channel (non-blocking)
920                    let _ = tx.try_send(event.clone());
921
922                    // Propagate should_stop → external_stop on every event,
923                    // not just TurnEnd. See run_with_channel_inner for rationale.
924                    if let Some(ref hook) = maybe_hook {
925                        let ctx = ShouldStopAfterTurnContext {
926                            message: match &event {
927                                AgentEvent::TurnEnd {
928                                    assistant_message: oxi_ai::Message::Assistant(a),
929                                    ..
930                                } => a.clone(),
931                                _ => oxi_ai::AssistantMessage::new(
932                                    oxi_ai::Api::OpenAiCompletions,
933                                    "agent",
934                                    "agent-model",
935                                ),
936                            },
937                            tool_results: match &event {
938                                AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
939                                _ => Vec::new(),
940                            },
941                            iteration: 0,
942                        };
943                        if hook(&ctx) {
944                            ext_stop.store(true, Ordering::SeqCst);
945                        }
946                    }
947                })
948                .await;
949
950            // Clear the Agent's running flag
951            is_running_flag.store(false, Ordering::SeqCst);
952
953            match result {
954                Ok(_events) => {
955                    // State is already shared via the same SharedState Arc,
956                    // so self.state() will reflect all mutations.
957                    Ok(Response {
958                        content: String::new(),
959                        stop_reason: StopReason::Stop,
960                    })
961                }
962                Err(e) => Err(e),
963            }
964        });
965
966        Ok((rx, handle))
967    }
968}