Skip to main content

oxi_agent/
agent.rs

1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10    CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26    /// Resolve a provider by name, returning an Arc handle.
27    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29    /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30    fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40    fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41        oxi_ai::get_provider(name).map(Arc::from)
42    }
43
44    fn resolve_model(&self, model_id: &str) -> Option<Model> {
45        crate::model_id::resolve_model_from_id(model_id)
46    }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50
51/// Mutable agent internals protected by a read-write lock.
52struct AgentInner {
53    config: AgentConfig,
54    provider: Arc<dyn Provider>,
55}
56
57impl Clone for AgentInner {
58    fn clone(&self) -> Self {
59        Self {
60            config: self.config.clone(),
61            provider: Arc::clone(&self.provider),
62        }
63    }
64}
65
66/// Agent runtime.
67///
68/// Manages provider, tool registry, state, and compaction, providing an
69/// agentic loop for prompt execution, model switching, tool calls, and fallback.
70///
71/// Supports session continuation via [`continue_with`] and tokio-native
72/// event streaming via [`run_tokio_stream`].
73///
74/// [`continue_with`]: Agent::continue_with
75/// [`run_tokio_stream`]: Agent::run_tokio_stream
76/// Deferred model switch request, stored when the agent is running.
77struct PendingModelSwitch {
78    model_id: String,
79    provider: Arc<dyn Provider>,
80    /// Whether messages need cross-provider transformation.
81    needs_transform: bool,
82    old_api: oxi_ai::Api,
83    new_api: oxi_ai::Api,
84}
85
86/// Agent runtime.
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation, tokio-native event streaming, and deferred
92/// model switching (changes are queued while a loop is running and applied
93/// after it completes).
94#[allow(missing_docs)]
95pub struct Agent {
96    inner: RwLock<AgentInner>,
97    tools: Arc<ToolRegistry>,
98    state: SharedState,
99    compaction_manager: CompactionManager,
100    hooks: parking_lot::RwLock<crate::config::AgentHooks>,
101    /// Guard: true while a run is in progress. Prevents concurrent runs.
102    is_running: Arc<AtomicBool>,
103    /// Provider/model resolver. Uses global functions by default,
104    /// or a custom resolver when created via `new_with_resolver()`.
105    resolver: Arc<dyn ProviderResolver>,
106    /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
107    /// propagated to AgentLoop's `external_stop` during each run.
108    cancel_flag: Arc<AtomicBool>,
109    /// Pending model switch — stored when the agent is running,
110    /// applied after the current loop completes.
111    pending_model_switch: RwLock<Option<PendingModelSwitch>>,
112}
113
114impl Agent {
115    /// Create a new agent with the given provider, config, and tool registry.
116    ///
117    /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
118    /// for model switching. For isolated instances, use [`new_with_resolver`].
119    ///
120    /// [`new_with_resolver`]: Agent::new_with_resolver
121    pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
122        let resolver = Arc::new(GlobalProviderResolver);
123        Self::build_inner(provider, config, tools, resolver)
124    }
125
126    /// Create an agent with a custom provider/model resolver.
127    ///
128    /// This is the preferred constructor for SDK usage where provider
129    /// and model registries must be isolated from global state.
130    pub fn new_with_resolver(
131        provider: Arc<dyn Provider>,
132        config: AgentConfig,
133        tools: Arc<ToolRegistry>,
134        resolver: Arc<dyn ProviderResolver>,
135    ) -> Self {
136        Self::build_inner(provider, config, tools, resolver)
137    }
138
139    /// Internal constructor shared by `new()` and `new_with_resolver()`.
140    fn build_inner(
141        provider: Arc<dyn Provider>,
142        config: AgentConfig,
143        tools: Arc<ToolRegistry>,
144        resolver: Arc<dyn ProviderResolver>,
145    ) -> Self {
146        let mut compaction_manager =
147            CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
148
149        // Pre-initialize the LLM compactor if compaction is enabled
150        if config.compaction_strategy != CompactionStrategy::Disabled {
151            let model = resolver.resolve_model(&config.model_id);
152
153            if let Some(model) = model {
154                let llm_compactor =
155                    Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
156                compaction_manager.set_compactor(llm_compactor);
157            }
158        }
159
160        Self {
161            inner: RwLock::new(AgentInner { config, provider }),
162            tools,
163            state: SharedState::new(),
164            compaction_manager,
165            hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
166            is_running: Arc::new(AtomicBool::new(false)),
167            resolver,
168            cancel_flag: Arc::new(AtomicBool::new(false)),
169            pending_model_switch: RwLock::new(None),
170        }
171    }
172
173    /// Create an agent with an empty tool registry.
174    pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
175        Self::new(provider, config, Arc::new(ToolRegistry::new()))
176    }
177
178    /// Get the agent configuration (read guard)
179    fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
180        self.inner.read()
181    }
182
183    /// Get a write guard for the agent inner state
184    fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
185        self.inner.write()
186    }
187
188    /// Get the current model ID
189    pub fn model_id(&self) -> String {
190        self.config().config.model_id.clone()
191    }
192
193    /// Get the agent configuration (full clone)
194    pub fn get_config(&self) -> AgentConfig {
195        self.config().config.clone()
196    }
197
198    /// Get a reference to the provider resolver.
199    pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
200        &self.resolver
201    }
202
203    /// Switch the model used for future LLM calls.
204    ///
205    /// Switch model mid-conversation.
206    ///
207    /// If the agent is currently running, the switch is deferred: the new
208    /// model and provider are stored in `pending_model_switch` and applied
209    /// automatically when the current loop finishes. This ensures the
210    /// running loop completes with a consistent provider/model without
211    /// interruption.
212    ///
213    /// If the agent is idle, the switch takes effect immediately.
214    ///
215    /// If the new model uses a different provider API, the conversation
216    /// history is automatically transformed for cross-provider compatibility
217    /// (e.g. thinking blocks are converted to `<thinking>` tags).
218    ///
219    /// # Arguments
220    /// * `model_id` - New model ID in `provider/model` format
221    /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
222    ///
223    /// # Returns
224    /// `Ok(())` on success, or an error if the model/provider is unknown
225    pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
226        let new_model = self
227            .resolver
228            .resolve_model(model_id)
229            .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
230
231        // Create the new provider via resolver
232        let new_provider = self
233            .resolver
234            .resolve_provider(&new_model.provider)
235            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
236
237        // Detect API change
238        let (old_api, needs_transform) = {
239            let inner = self.config();
240            let old_api = self
241                .resolver
242                .resolve_model(&inner.config.model_id)
243                .map(|m| m.api)
244                .unwrap_or(oxi_ai::Api::AnthropicMessages);
245            (old_api, old_api != new_model.api)
246        };
247
248        // If the agent is currently running, defer the switch.
249        if self.is_running.load(Ordering::SeqCst) {
250            tracing::info!(
251                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
252                model_id
253            );
254            *self.pending_model_switch.write() = Some(PendingModelSwitch {
255                model_id: model_id.to_string(),
256                provider: new_provider,
257                needs_transform,
258                old_api,
259                new_api: new_model.api,
260            });
261            // Update config immediately so model_id() returns the new value,
262            // but leave provider unchanged so the running loop keeps its provider.
263            {
264                let mut inner = self.inner_mut();
265                inner.config.model_id = model_id.to_string();
266                inner.config.api_key = api_key;
267            }
268            return Ok(());
269        }
270
271        // Agent is idle — apply immediately.
272        if needs_transform {
273            let messages = self.state.get_state().messages.clone();
274            let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
275            self.state.update(|s| {
276                s.replace_messages(transformed);
277            });
278        }
279
280        // Update config and provider atomically
281        let mut inner = self.inner_mut();
282        inner.config.model_id = model_id.to_string();
283        inner.config.api_key = api_key;
284        inner.provider = new_provider;
285
286        Ok(())
287    }
288
289    /// Switch the model using a pre-resolved `Model` object.
290    ///
291    /// This is useful when the caller has already looked up the model
292    /// and optionally created the provider.
293    ///
294    /// Like [`switch_model`], if the agent is currently running, the switch
295    /// is deferred until the current loop completes.
296    ///
297    /// [`switch_model`]: Agent::switch_model
298    pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
299        let model_id = format!("{}/{}", model.provider, model.id);
300        let new_provider = self
301            .resolver
302            .resolve_provider(&model.provider)
303            .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
304
305        // Detect API change
306        let (old_api, needs_transform) = {
307            let inner = self.config();
308            let old_api = self
309                .resolver
310                .resolve_model(&inner.config.model_id)
311                .map(|m| m.api)
312                .unwrap_or(oxi_ai::Api::AnthropicMessages);
313            (old_api, old_api != model.api)
314        };
315
316        // If the agent is currently running, defer the switch.
317        if self.is_running.load(Ordering::SeqCst) {
318            tracing::info!(
319                "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
320                model_id
321            );
322            *self.pending_model_switch.write() = Some(PendingModelSwitch {
323                model_id: model_id.clone(),
324                provider: new_provider,
325                needs_transform,
326                old_api,
327                new_api: model.api,
328            });
329            let mut inner = self.inner_mut();
330            inner.config.model_id = model_id;
331            inner.config.api_key = api_key;
332            return Ok(());
333        }
334
335        // Agent is idle — apply immediately.
336        if needs_transform {
337            let messages = self.state.get_state().messages.clone();
338            let transformed = transform_for_provider(&messages, &old_api, &model.api);
339            self.state.update(|s| {
340                s.replace_messages(transformed);
341            });
342        }
343
344        let mut inner = self.inner_mut();
345        inner.config.model_id = model_id;
346        inner.config.api_key = api_key;
347        inner.provider = new_provider;
348
349        Ok(())
350    }
351
352    /// Refresh only the API key without changing model or provider.
353    /// Useful when the user stores a key after the session was already created.
354    pub fn refresh_api_key(&self, api_key: Option<String>) {
355        let mut inner = self.inner_mut();
356        inner.config.api_key = api_key;
357    }
358
359    /// Get a handle to the tool registry.
360    pub fn tools(&self) -> Arc<ToolRegistry> {
361        Arc::clone(&self.tools)
362    }
363
364    /// Get a snapshot of the current agent state.
365    pub fn state(&self) -> AgentState {
366        self.state.get_state()
367    }
368
369    /// Update agent state in-place. Used by compaction to replace messages.
370    pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
371        self.state.update(f);
372    }
373
374    /// Reset agent state for a new conversation
375    pub fn reset(&self) {
376        self.state.reset();
377    }
378
379    /// Register a tool that the agent can invoke during a run.
380    pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
381        self.tools.register(tool);
382    }
383
384    /// Update the system prompt for future interactions.
385    pub fn set_system_prompt(&self, prompt: String) {
386        self.inner_mut().config.system_prompt = Some(prompt);
387    }
388
389    /// Get the compaction manager
390    pub fn compaction_manager(&self) -> &CompactionManager {
391        &self.compaction_manager
392    }
393
394    /// Run the agent with a prompt, collecting all events into a vector.
395    ///
396    /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
397    /// [`AgentEvent`] produced during the run.
398    pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
399        let mut events = Vec::new();
400        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
401        let result = self.run_with_channel(prompt, tx).await;
402        while let Ok(event) = rx.recv() {
403            events.push(event);
404        }
405        result.map(|r| (r, events))
406    }
407
408    /// Run the agent, delivering events through the provided channel.
409    ///
410    /// Delegates to the agent loop which implements the same 2-level agentic
411    /// loop matching pi-mono's architecture:
412    ///
413    /// ```text
414    /// AgentLoop.run_messages()
415    ///   Outer loop (follow-up messages):
416    ///     Inner loop (tool calls + steering):
417    ///       1. Inject pending messages (steering)
418    ///       2. Compaction check
419    ///       3. Stream LLM response (with accumulated partial messages)
420    ///       4. Execute tool calls if any
421    ///       5. Emit turn_end
422    ///       6. Check shouldStopAfterTurn
423    ///       7. Poll steering messages
424    ///     Check follow-up messages
425    ///     Exit
426    /// ```
427    pub async fn run_with_channel(
428        &self,
429        prompt: String,
430        tx: std::sync::mpsc::Sender<AgentEvent>,
431    ) -> Result<Response> {
432        // pi-mono: Agent.prompt() throws if activeRun exists.
433        // Prevent concurrent runs that would corrupt shared state.
434        if self
435            .is_running
436            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
437            .is_err()
438        {
439            return Err(Error::msg("Agent is already running"));
440        }
441
442        // Drop guard ensures is_running is cleared even on panic.
443        struct RunningGuard<'a>(&'a AtomicBool);
444        impl Drop for RunningGuard<'_> {
445            fn drop(&mut self) {
446                self.0.store(false, Ordering::SeqCst);
447            }
448        }
449        let _guard = RunningGuard(&self.is_running);
450        self.reset_cancel();
451
452        self.run_with_channel_inner(prompt, tx).await
453    }
454
455    /// Inner implementation of run_with_channel, called after the running guard is set.
456    async fn run_with_channel_inner(
457        &self,
458        prompt: String,
459        tx: std::sync::mpsc::Sender<AgentEvent>,
460    ) -> Result<Response> {
461        use crate::agent_loop::AgentLoop;
462
463        let (
464            provider,
465            system_prompt,
466            temperature,
467            max_tokens,
468            compaction_strategy,
469            context_window,
470            api_key,
471            workspace_dir,
472        ) = {
473            let inner = self.inner.read();
474            (
475                Arc::clone(&inner.provider) as Arc<dyn Provider>,
476                inner.config.system_prompt.clone(),
477                inner.config.temperature,
478                inner.config.max_tokens,
479                inner.config.compaction_strategy.clone(),
480                inner.config.context_window,
481                inner.config.api_key.clone(),
482                inner.config.workspace_dir.clone(),
483            )
484        }; // release read lock
485
486        // Build AgentLoopConfig from Agent's config
487        let loop_config = crate::agent_loop::config::AgentLoopConfig {
488            model_id: self.model_id(),
489            system_prompt,
490            temperature: temperature.unwrap_or(1.0) as f32,
491            max_tokens: max_tokens.unwrap_or(4096) as u32,
492            tool_execution: crate::config::ToolExecutionMode::Sequential,
493            compaction_strategy,
494            compaction_instruction: None,
495            context_window,
496            session_id: self.config().config.session_id.clone(),
497            transport: None,
498            compact_on_start: false,
499            max_retry_delay_ms: None,
500            auto_retry_enabled: true,
501            auto_retry_max_attempts: 3,
502            auto_retry_base_delay_ms: 1000,
503            api_key,
504            workspace_dir,
505            provider_options: self.config().config.provider_options.clone(),
506            on_compaction: None,
507            ttsr_engine: self.config().config.ttsr_engine.clone(),
508            memory: self.config().config.memory.clone(),
509            todo: self.config().config.todo.clone(),
510            agent_pool: self.config().config.agent_pool.clone(),
511            ..Default::default()
512        };
513
514        // Create AgentLoop. We give it a NEW SharedState and sync back after.
515        // (SharedState is not Clone, so we create a fresh one from current state)
516        let fresh_state = crate::state::SharedState::new();
517        let current = self.state.get_state();
518        fresh_state.update(|s| {
519            *s = current;
520        });
521
522        let mut agent_loop = AgentLoop::new_with_resolver(
523            provider,
524            loop_config,
525            Arc::clone(&self.tools),
526            fresh_state,
527            Arc::clone(&self.resolver),
528        );
529
530        // Add the user prompt to Agent.state() AFTER fresh_state is created.
531        // fresh_state got a copy of the pre-prompt state, so run_loop will
532        // add the prompt to fresh_state independently via initial_prompts.
533        // But persist_session() reads Agent.state() (not fresh_state), so it
534        // needs the user prompt there to write it to the session file.
535        // Sync happens at AgentEnd (after run_loop completes), where
536        // Agent.state is overwritten with fresh_state (which has all messages).
537        self.state.update(|s| {
538            s.messages
539                .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
540                    prompt.clone(),
541                )));
542        });
543
544        // Pre-populate steering/follow-up from hooks
545        {
546            let hooks = self.hooks.read();
547            if let Some(ref get_steering) = hooks.get_steering_messages {
548                for msg_text in get_steering() {
549                    agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
550                }
551            }
552            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
553                for msg_text in get_follow_up() {
554                    agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
555                }
556            }
557
558            // Store hooks on AgentLoop so they can be polled each turn
559            // to pick up new messages injected during the run.
560            if let Some(ref get_steering) = hooks.get_steering_messages {
561                agent_loop.set_steering_hook(Arc::clone(get_steering));
562            }
563            if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
564                agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
565            }
566        }
567        let mut al = agent_loop;
568
569        // Wire should_stop_after_turn hook: share AgentLoop's external_stop
570        // Arc with the emit callback. When the hook fires (Ctrl+C detected),
571        // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
572        // AND during streaming (streaming.rs checks external_stop each event).
573        //
574        // Arc<dyn Fn> can be cloned, so we read it without consuming.
575        let maybe_hook = {
576            let hooks_r = self.hooks.read();
577            hooks_r.should_stop_after_turn.clone()
578        };
579        let ext_stop = al.external_stop().clone();
580        let cancel_flag = self.cancel_flag.clone();
581
582        // Share cancel_flag with AgentLoop so the streaming loop can check
583        // it directly in the periodic timer — no emit callback required.
584        // This closes the gap where cancel() was ineffective when the
585        // provider stream produced no events.
586        al.set_cancel_signal(self.cancel_flag.clone());
587
588        // Create emit callback that sends through the channel.
589        // AgentLoop calls this synchronously. UnboundedSender::send() is
590        // non-blocking and never drops events (unlike try_send on bounded).
591        let tx_emit = tx.clone();
592
593        // Run the agent loop
594        tracing::info!("[AGENT] Starting agent run with channel");
595        let result = al
596            .run(prompt.clone(), move |event: AgentEvent| {
597                // Forward event to channel (std::sync::mpsc — send from sync context)
598                tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
599                if let Err(e) = tx_emit.send(event.clone()) {
600                    tracing::error!(
601                        "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
602                        e
603                    );
604                } else {
605                    tracing::info!("[AGENT-EMIT] Successfully sent event");
606                }
607
608                // Propagate cancellation from Agent::cancel() → external_stop.
609                // This runs on every event, ensuring the streaming loop detects
610                // cancellation promptly.
611                if cancel_flag.load(Ordering::SeqCst) {
612                    ext_stop.store(true, Ordering::SeqCst);
613                }
614
615                // Propagate should_stop → external_stop on every event, not
616                // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
617                // so the context contents are irrelevant for non-TurnEnd events.
618                // This ensures streaming.rs detects cancellation immediately
619                // when the user presses Ctrl+C mid-stream.
620                if let Some(ref hook) = maybe_hook {
621                    let ctx = ShouldStopAfterTurnContext {
622                        message: match &event {
623                            AgentEvent::TurnEnd {
624                                assistant_message: oxi_ai::Message::Assistant(a),
625                                ..
626                            } => a.clone(),
627                            _ => oxi_ai::AssistantMessage::new(
628                                oxi_ai::Api::OpenAiCompletions,
629                                "agent",
630                                "agent-model",
631                            ),
632                        },
633                        tool_results: match &event {
634                            AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
635                            _ => Vec::new(),
636                        },
637                        iteration: 0,
638                    };
639                    if hook(&ctx) {
640                        ext_stop.store(true, Ordering::SeqCst);
641                    }
642                }
643            })
644            .await;
645
646        match result {
647            Ok(_events) => {
648                // Sync state back from AgentLoop
649                let loop_state = al.state().get_state();
650                self.state.update(|s| {
651                    *s = loop_state;
652                });
653
654                // Apply any pending model switch that was deferred during the run.
655                // This transforms messages (if cross-provider) and swaps the provider
656                // so the next run uses the new model.
657                self.apply_pending_model_switch();
658
659                // Extract final response text from state
660                let state = self.state.get_state();
661                let final_text = state
662                    .messages
663                    .iter()
664                    .rev()
665                    .find_map(|m| match m {
666                        oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
667                            oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
668                            _ => None,
669                        }),
670                        _ => None,
671                    })
672                    .unwrap_or_default();
673
674                let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
675
676                Ok(Response {
677                    content: final_text,
678                    stop_reason,
679                })
680            }
681            Err(e) => {
682                // Apply pending model switch even on error so the next run
683                // uses the new model.
684                self.apply_pending_model_switch();
685                Err(e)
686            }
687        }
688    }
689
690    // ── Helper methods for the agentic loop ────────────────────────
691
692    /// Set hooks for the agent loop.
693    pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
694        let mut h = self.hooks.write();
695        *h = hooks;
696    }
697
698    /// Request cancellation of the current agent run.
699    ///
700    /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
701    /// `external_stop` on every event AND polled every ~500ms by the
702    /// streaming loop's periodic check. This ensures cancellation is
703    /// detected quickly even when the provider stream is completely hung
704    /// (no events arriving).
705    pub fn cancel(&self) {
706        self.cancel_flag.store(true, Ordering::SeqCst);
707    }
708
709    /// Reset the cancellation flag before starting a new run.
710    pub fn reset_cancel(&self) {
711        self.cancel_flag.store(false, Ordering::SeqCst);
712    }
713
714    /// Apply any pending model switch that was deferred during a running loop.
715    ///
716    /// Called after `run_with_channel_inner` completes (success or error).
717    /// Transforms messages for cross-provider switches and swaps the provider
718    /// so the next run uses the new model.
719    fn apply_pending_model_switch(&self) {
720        let pending = self.pending_model_switch.write().take();
721        if let Some(pending) = pending {
722            tracing::info!(
723                "[AGENT] Applying deferred model switch to '{}' (transform={})",
724                pending.model_id,
725                pending.needs_transform
726            );
727
728            // Transform messages if cross-provider
729            if pending.needs_transform {
730                let messages = self.state.get_state().messages.clone();
731                let transformed =
732                    transform_for_provider(&messages, &pending.old_api, &pending.new_api);
733                self.state.update(|s| {
734                    s.replace_messages(transformed);
735                });
736            }
737
738            // Swap the provider
739            let mut inner = self.inner_mut();
740            inner.provider = pending.provider;
741            // model_id was already updated in switch_model()
742        }
743    }
744
745    /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
746    ///
747    /// Blocking convenience wrapper suitable for callers that prefer a
748    /// callback-based API over a channel.
749    pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
750    where
751        F: FnMut(AgentEvent) + Send,
752    {
753        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
754        let result = self.run_with_channel(prompt, tx).await;
755        while let Ok(event) = rx.recv() {
756            on_event(event);
757        }
758        result
759    }
760
761    // ── Session persistence ────────────────────────────────────────
762
763    /// Export the agent state as a JSON value.
764    ///
765    /// The serialized state includes conversation messages, token counts,
766    /// iteration progress, and stop reason. Use [`import_state`] to restore.
767    ///
768    /// [`import_state`]: Agent::import_state
769    pub fn export_state(&self) -> Result<serde_json::Value> {
770        let state = self.state.get_state();
771        serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
772    }
773
774    /// Import agent state from a JSON value.
775    ///
776    /// Restores conversation history, token counts, and iteration progress.
777    /// Typically used together with [`export_state`] for session persistence.
778    ///
779    /// [`export_state`]: Agent::export_state
780    pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
781        let state: AgentState = serde_json::from_value(value)
782            .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
783        self.state.update(|s| *s = state);
784        Ok(())
785    }
786
787    // ── Session continuation ───────────────────────────────────────
788
789    /// Continue the current session with a new prompt.
790    ///
791    /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
792    /// preserves the existing conversation state and appends the new prompt.
793    /// This enables multi-turn interactions within the same session.
794    pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
795        let mut events = Vec::new();
796        let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
797        let result = self.run_with_channel(prompt, tx).await;
798        while let Ok(event) = rx.recv() {
799            events.push(event);
800        }
801        result.map(|r| (r, events))
802    }
803
804    // ── Tokio-native streaming ─────────────────────────────────────
805
806    /// Run the agent with tokio-native event streaming.
807    ///
808    /// Returns a `tokio::sync::mpsc::Receiver` for events and a
809    /// `JoinHandle` for the response. This is the preferred API for
810    /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
811    ///
812    /// # Example
813    ///
814    /// ```ignore
815    /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
816    /// while let Some(event) = rx.recv().await {
817    ///     println!("Event: {:?}", event.type_name());
818    /// }
819    /// let response = handle.await??;
820    /// ```
821    pub async fn run_tokio_stream(
822        &self,
823        prompt: String,
824    ) -> Result<(
825        tokio::sync::mpsc::Receiver<AgentEvent>,
826        tokio::task::JoinHandle<Result<Response>>,
827    )> {
828        let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
829
830        if self
831            .is_running
832            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
833            .is_err()
834        {
835            return Err(Error::msg("Agent is already running"));
836        }
837
838        let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
839
840        let inner = self.inner.read().clone();
841        let tools = Arc::clone(&self.tools);
842        let resolver = Arc::clone(&self.resolver);
843
844        // Build AgentLoopConfig
845        let loop_config = crate::agent_loop::config::AgentLoopConfig {
846            model_id: inner.config.model_id.clone(),
847            system_prompt: inner.config.system_prompt.clone(),
848            temperature: inner.config.temperature.unwrap_or(1.0) as f32,
849            max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
850            tool_execution: crate::config::ToolExecutionMode::Sequential,
851            compaction_strategy: inner.config.compaction_strategy.clone(),
852            compaction_instruction: None,
853            context_window: inner.config.context_window,
854            session_id: inner.config.session_id.clone(),
855            transport: None,
856            compact_on_start: false,
857            max_retry_delay_ms: None,
858            auto_retry_enabled: true,
859            auto_retry_max_attempts: 3,
860            auto_retry_base_delay_ms: 1000,
861            api_key: inner.config.api_key.clone(),
862            workspace_dir: inner.config.workspace_dir.clone(),
863            provider_options: inner.config.provider_options.clone(),
864            on_compaction: None,
865            ttsr_engine: inner.config.ttsr_engine.clone(),
866            ..Default::default()
867        };
868
869        let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
870
871        // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
872        // agent loop so that state mutations inside the spawned task are
873        // visible through self.state() without an explicit sync step.
874        //
875        // Unlike run_with_channel_inner which creates a fresh SharedState
876        // and syncs back on completion, the tokio streaming API cannot
877        // access `self` inside the `'static` spawned task, so we share
878        // the underlying Arc instead.
879        //
880        // Pre-load current state into the shared Arc (in case it was
881        // modified by a previous run that used a different SharedState).
882        let shared_state = self.state.clone();
883
884        let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
885            provider,
886            loop_config,
887            tools,
888            shared_state.clone(),
889            resolver,
890        );
891
892        let maybe_hook = should_stop_hook;
893        let ext_stop = agent_loop.external_stop().clone();
894
895        // Clone the is_running Arc so the spawned task can clear it.
896        let is_running_flag = Arc::clone(&self.is_running);
897
898        let handle = tokio::task::spawn(async move {
899            let result = agent_loop
900                .run(prompt, move |event: AgentEvent| {
901                    // Forward to tokio channel (non-blocking)
902                    let _ = tx.try_send(event.clone());
903
904                    // Propagate should_stop → external_stop on every event,
905                    // not just TurnEnd. See run_with_channel_inner for rationale.
906                    if let Some(ref hook) = maybe_hook {
907                        let ctx = ShouldStopAfterTurnContext {
908                            message: match &event {
909                                AgentEvent::TurnEnd {
910                                    assistant_message: oxi_ai::Message::Assistant(a),
911                                    ..
912                                } => a.clone(),
913                                _ => oxi_ai::AssistantMessage::new(
914                                    oxi_ai::Api::OpenAiCompletions,
915                                    "agent",
916                                    "agent-model",
917                                ),
918                            },
919                            tool_results: match &event {
920                                AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
921                                _ => Vec::new(),
922                            },
923                            iteration: 0,
924                        };
925                        if hook(&ctx) {
926                            ext_stop.store(true, Ordering::SeqCst);
927                        }
928                    }
929                })
930                .await;
931
932            // Clear the Agent's running flag
933            is_running_flag.store(false, Ordering::SeqCst);
934
935            match result {
936                Ok(_events) => {
937                    // State is already shared via the same SharedState Arc,
938                    // so self.state() will reflect all mutations.
939                    Ok(Response {
940                        content: String::new(),
941                        stop_reason: StopReason::Stop,
942                    })
943                }
944                Err(e) => Err(e),
945            }
946        });
947
948        Ok((rx, handle))
949    }
950}