oxi_agent/agent.rs
1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10 CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26 /// Resolve a provider by name, returning an Arc handle.
27 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29 /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30 fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41 oxi_ai::get_provider(name).map(Arc::from)
42 }
43
44 fn resolve_model(&self, model_id: &str) -> Option<Model> {
45 crate::model_id::resolve_model_from_id(model_id)
46 }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50
51/// Mutable agent internals protected by a read-write lock.
52struct AgentInner {
53 config: AgentConfig,
54 provider: Arc<dyn Provider>,
55}
56
57impl Clone for AgentInner {
58 fn clone(&self) -> Self {
59 Self {
60 config: self.config.clone(),
61 provider: Arc::clone(&self.provider),
62 }
63 }
64}
65
66/// Agent runtime.
67///
68/// Manages provider, tool registry, state, and compaction, providing an
69/// agentic loop for prompt execution, model switching, tool calls, and fallback.
70///
71/// Supports session continuation via [`continue_with`] and tokio-native
72/// event streaming via [`run_tokio_stream`].
73///
74/// [`continue_with`]: Agent::continue_with
75/// [`run_tokio_stream`]: Agent::run_tokio_stream
76/// Deferred model switch request, stored when the agent is running.
77struct PendingModelSwitch {
78 model_id: String,
79 provider: Arc<dyn Provider>,
80 /// Whether messages need cross-provider transformation.
81 needs_transform: bool,
82 old_api: oxi_ai::Api,
83 new_api: oxi_ai::Api,
84}
85
86/// Agent runtime.
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation, tokio-native event streaming, and deferred
92/// model switching (changes are queued while a loop is running and applied
93/// after it completes).
94#[allow(missing_docs)]
95pub struct Agent {
96 inner: RwLock<AgentInner>,
97 tools: Arc<ToolRegistry>,
98 state: SharedState,
99 compaction_manager: CompactionManager,
100 hooks: parking_lot::RwLock<crate::config::AgentHooks>,
101 /// Guard: true while a run is in progress. Prevents concurrent runs.
102 is_running: Arc<AtomicBool>,
103 /// Provider/model resolver. Uses global functions by default,
104 /// or a custom resolver when created via `new_with_resolver()`.
105 resolver: Arc<dyn ProviderResolver>,
106 /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
107 /// propagated to AgentLoop's `external_stop` during each run.
108 cancel_flag: Arc<AtomicBool>,
109 /// Pending model switch — stored when the agent is running,
110 /// applied after the current loop completes.
111 pending_model_switch: RwLock<Option<PendingModelSwitch>>,
112}
113
114impl Agent {
115 /// Create a new agent with the given provider, config, and tool registry.
116 ///
117 /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
118 /// for model switching. For isolated instances, use [`new_with_resolver`].
119 ///
120 /// [`new_with_resolver`]: Agent::new_with_resolver
121 pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
122 let resolver = Arc::new(GlobalProviderResolver);
123 Self::build_inner(provider, config, tools, resolver)
124 }
125
126 /// Create an agent with a custom provider/model resolver.
127 ///
128 /// This is the preferred constructor for SDK usage where provider
129 /// and model registries must be isolated from global state.
130 pub fn new_with_resolver(
131 provider: Arc<dyn Provider>,
132 config: AgentConfig,
133 tools: Arc<ToolRegistry>,
134 resolver: Arc<dyn ProviderResolver>,
135 ) -> Self {
136 Self::build_inner(provider, config, tools, resolver)
137 }
138
139 /// Internal constructor shared by `new()` and `new_with_resolver()`.
140 fn build_inner(
141 provider: Arc<dyn Provider>,
142 config: AgentConfig,
143 tools: Arc<ToolRegistry>,
144 resolver: Arc<dyn ProviderResolver>,
145 ) -> Self {
146 let mut compaction_manager =
147 CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
148
149 // Pre-initialize the LLM compactor if compaction is enabled
150 if config.compaction_strategy != CompactionStrategy::Disabled {
151 let model = resolver.resolve_model(&config.model_id);
152
153 if let Some(model) = model {
154 let llm_compactor =
155 Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
156 compaction_manager.set_compactor(llm_compactor);
157 }
158 }
159
160 Self {
161 inner: RwLock::new(AgentInner { config, provider }),
162 tools,
163 state: SharedState::new(),
164 compaction_manager,
165 hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
166 is_running: Arc::new(AtomicBool::new(false)),
167 resolver,
168 cancel_flag: Arc::new(AtomicBool::new(false)),
169 pending_model_switch: RwLock::new(None),
170 }
171 }
172
173 /// Create an agent with an empty tool registry.
174 pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
175 Self::new(provider, config, Arc::new(ToolRegistry::new()))
176 }
177
178 /// Get the agent configuration (read guard)
179 fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
180 self.inner.read()
181 }
182
183 /// Get a write guard for the agent inner state
184 fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
185 self.inner.write()
186 }
187
188 /// Get the current model ID
189 pub fn model_id(&self) -> String {
190 self.config().config.model_id.clone()
191 }
192
193 /// Get the agent configuration (full clone)
194 pub fn get_config(&self) -> AgentConfig {
195 self.config().config.clone()
196 }
197
198 /// Get a reference to the provider resolver.
199 pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
200 &self.resolver
201 }
202
203 /// Switch the model used for future LLM calls.
204 ///
205 /// Switch model mid-conversation.
206 ///
207 /// If the agent is currently running, the switch is deferred: the new
208 /// model and provider are stored in `pending_model_switch` and applied
209 /// automatically when the current loop finishes. This ensures the
210 /// running loop completes with a consistent provider/model without
211 /// interruption.
212 ///
213 /// If the agent is idle, the switch takes effect immediately.
214 ///
215 /// If the new model uses a different provider API, the conversation
216 /// history is automatically transformed for cross-provider compatibility
217 /// (e.g. thinking blocks are converted to `<thinking>` tags).
218 ///
219 /// # Arguments
220 /// * `model_id` - New model ID in `provider/model` format
221 /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
222 ///
223 /// # Returns
224 /// `Ok(())` on success, or an error if the model/provider is unknown
225 pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
226 let new_model = self
227 .resolver
228 .resolve_model(model_id)
229 .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
230
231 // Create the new provider via resolver
232 let new_provider = self
233 .resolver
234 .resolve_provider(&new_model.provider)
235 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
236
237 // Detect API change
238 let (old_api, needs_transform) = {
239 let inner = self.config();
240 let old_api = self
241 .resolver
242 .resolve_model(&inner.config.model_id)
243 .map(|m| m.api)
244 .unwrap_or(oxi_ai::Api::AnthropicMessages);
245 (old_api, old_api != new_model.api)
246 };
247
248 // If the agent is currently running, defer the switch.
249 if self.is_running.load(Ordering::SeqCst) {
250 tracing::info!(
251 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
252 model_id
253 );
254 *self.pending_model_switch.write() = Some(PendingModelSwitch {
255 model_id: model_id.to_string(),
256 provider: new_provider,
257 needs_transform,
258 old_api,
259 new_api: new_model.api,
260 });
261 // Update config immediately so model_id() returns the new value,
262 // but leave provider unchanged so the running loop keeps its provider.
263 {
264 let mut inner = self.inner_mut();
265 inner.config.model_id = model_id.to_string();
266 inner.config.api_key = api_key;
267 }
268 return Ok(());
269 }
270
271 // Agent is idle — apply immediately.
272 if needs_transform {
273 let messages = self.state.get_state().messages.clone();
274 let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
275 self.state.update(|s| {
276 s.replace_messages(transformed);
277 });
278 }
279
280 // Update config and provider atomically
281 let mut inner = self.inner_mut();
282 inner.config.model_id = model_id.to_string();
283 inner.config.api_key = api_key;
284 inner.provider = new_provider;
285
286 Ok(())
287 }
288
289 /// Switch the model using a pre-resolved `Model` object.
290 ///
291 /// This is useful when the caller has already looked up the model
292 /// and optionally created the provider.
293 ///
294 /// Like [`switch_model`], if the agent is currently running, the switch
295 /// is deferred until the current loop completes.
296 ///
297 /// [`switch_model`]: Agent::switch_model
298 pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
299 let model_id = format!("{}/{}", model.provider, model.id);
300 let new_provider = self
301 .resolver
302 .resolve_provider(&model.provider)
303 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
304
305 // Detect API change
306 let (old_api, needs_transform) = {
307 let inner = self.config();
308 let old_api = self
309 .resolver
310 .resolve_model(&inner.config.model_id)
311 .map(|m| m.api)
312 .unwrap_or(oxi_ai::Api::AnthropicMessages);
313 (old_api, old_api != model.api)
314 };
315
316 // If the agent is currently running, defer the switch.
317 if self.is_running.load(Ordering::SeqCst) {
318 tracing::info!(
319 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
320 model_id
321 );
322 *self.pending_model_switch.write() = Some(PendingModelSwitch {
323 model_id: model_id.clone(),
324 provider: new_provider,
325 needs_transform,
326 old_api,
327 new_api: model.api,
328 });
329 let mut inner = self.inner_mut();
330 inner.config.model_id = model_id;
331 inner.config.api_key = api_key;
332 return Ok(());
333 }
334
335 // Agent is idle — apply immediately.
336 if needs_transform {
337 let messages = self.state.get_state().messages.clone();
338 let transformed = transform_for_provider(&messages, &old_api, &model.api);
339 self.state.update(|s| {
340 s.replace_messages(transformed);
341 });
342 }
343
344 let mut inner = self.inner_mut();
345 inner.config.model_id = model_id;
346 inner.config.api_key = api_key;
347 inner.provider = new_provider;
348
349 Ok(())
350 }
351
352 /// Refresh only the API key without changing model or provider.
353 /// Useful when the user stores a key after the session was already created.
354 pub fn refresh_api_key(&self, api_key: Option<String>) {
355 let mut inner = self.inner_mut();
356 inner.config.api_key = api_key;
357 }
358
359 /// Get a handle to the tool registry.
360 pub fn tools(&self) -> Arc<ToolRegistry> {
361 Arc::clone(&self.tools)
362 }
363
364 /// Get a snapshot of the current agent state.
365 pub fn state(&self) -> AgentState {
366 self.state.get_state()
367 }
368
369 /// Update agent state in-place. Used by compaction to replace messages.
370 pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
371 self.state.update(f);
372 }
373
374 /// Reset agent state for a new conversation
375 pub fn reset(&self) {
376 self.state.reset();
377 }
378
379 /// Register a tool that the agent can invoke during a run.
380 pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
381 self.tools.register(tool);
382 }
383
384 /// Update the system prompt for future interactions.
385 pub fn set_system_prompt(&self, prompt: String) {
386 self.inner_mut().config.system_prompt = Some(prompt);
387 }
388
389 /// Get the compaction manager
390 pub fn compaction_manager(&self) -> &CompactionManager {
391 &self.compaction_manager
392 }
393 /// Update the compaction strategy for future runs.
394 ///
395 /// The strategy is read fresh from the config at the start of each run
396 /// (see `run_with_channel_inner`), so this takes effect on the next
397 /// agent turn — never mid-run. Pair with `compaction_manager()` for
398 /// manual compaction, which is unaffected by the strategy.
399 pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
400 self.inner.write().config.compaction_strategy = strategy;
401 }
402 /// Get the compaction strategy that will be used on the next run.
403 ///
404 /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
405 /// **not** from the `compaction_manager` field (which retains its
406 /// construction-time strategy). The agent loop reads from config fresh
407 /// each run, so this is the authoritative value.
408 pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
409 self.inner.read().config.compaction_strategy.clone()
410 }
411
412 /// Run the agent with a prompt, collecting all events into a vector.
413 ///
414 /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
415 /// [`AgentEvent`] produced during the run.
416 pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
417 let mut events = Vec::new();
418 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
419 let result = self.run_with_channel(prompt, tx).await;
420 while let Ok(event) = rx.recv() {
421 events.push(event);
422 }
423 result.map(|r| (r, events))
424 }
425
426 /// Run the agent, delivering events through the provided channel.
427 ///
428 /// Delegates to the agent loop which implements the same 2-level agentic
429 /// loop matching pi-mono's architecture:
430 ///
431 /// ```text
432 /// AgentLoop.run_messages()
433 /// Outer loop (follow-up messages):
434 /// Inner loop (tool calls + steering):
435 /// 1. Inject pending messages (steering)
436 /// 2. Compaction check
437 /// 3. Stream LLM response (with accumulated partial messages)
438 /// 4. Execute tool calls if any
439 /// 5. Emit turn_end
440 /// 6. Check shouldStopAfterTurn
441 /// 7. Poll steering messages
442 /// Check follow-up messages
443 /// Exit
444 /// ```
445 pub async fn run_with_channel(
446 &self,
447 prompt: String,
448 tx: std::sync::mpsc::Sender<AgentEvent>,
449 ) -> Result<Response> {
450 // pi-mono: Agent.prompt() throws if activeRun exists.
451 // Prevent concurrent runs that would corrupt shared state.
452 if self
453 .is_running
454 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
455 .is_err()
456 {
457 return Err(Error::msg("Agent is already running"));
458 }
459
460 // Drop guard ensures is_running is cleared even on panic.
461 struct RunningGuard<'a>(&'a AtomicBool);
462 impl Drop for RunningGuard<'_> {
463 fn drop(&mut self) {
464 self.0.store(false, Ordering::SeqCst);
465 }
466 }
467 let _guard = RunningGuard(&self.is_running);
468 self.reset_cancel();
469
470 self.run_with_channel_inner(prompt, tx).await
471 }
472
473 /// Inner implementation of run_with_channel, called after the running guard is set.
474 async fn run_with_channel_inner(
475 &self,
476 prompt: String,
477 tx: std::sync::mpsc::Sender<AgentEvent>,
478 ) -> Result<Response> {
479 use crate::agent_loop::AgentLoop;
480
481 let (
482 provider,
483 system_prompt,
484 temperature,
485 max_tokens,
486 compaction_strategy,
487 context_window,
488 api_key,
489 workspace_dir,
490 ) = {
491 let inner = self.inner.read();
492 (
493 Arc::clone(&inner.provider) as Arc<dyn Provider>,
494 inner.config.system_prompt.clone(),
495 inner.config.temperature,
496 inner.config.max_tokens,
497 inner.config.compaction_strategy.clone(),
498 inner.config.context_window,
499 inner.config.api_key.clone(),
500 inner.config.workspace_dir.clone(),
501 )
502 }; // release read lock
503
504 // Build AgentLoopConfig from Agent's config
505 let loop_config = crate::agent_loop::config::AgentLoopConfig {
506 model_id: self.model_id(),
507 system_prompt,
508 temperature: temperature.unwrap_or(1.0) as f32,
509 max_tokens: max_tokens.unwrap_or(4096) as u32,
510 tool_execution: crate::config::ToolExecutionMode::Sequential,
511 compaction_strategy,
512 compaction_instruction: None,
513 context_window,
514 session_id: self.config().config.session_id.clone(),
515 transport: None,
516 compact_on_start: false,
517 max_retry_delay_ms: None,
518 auto_retry_enabled: true,
519 auto_retry_max_attempts: 3,
520 auto_retry_base_delay_ms: 1000,
521 api_key,
522 workspace_dir,
523 provider_options: self.config().config.provider_options.clone(),
524 on_compaction: None,
525 ttsr_engine: self.config().config.ttsr_engine.clone(),
526 memory: self.config().config.memory.clone(),
527 todo: self.config().config.todo.clone(),
528 agent_pool: self.config().config.agent_pool.clone(),
529 ..Default::default()
530 };
531
532 // Create AgentLoop. We give it a NEW SharedState and sync back after.
533 // (SharedState is not Clone, so we create a fresh one from current state)
534 let fresh_state = crate::state::SharedState::new();
535 let current = self.state.get_state();
536 fresh_state.update(|s| {
537 *s = current;
538 });
539
540 let mut agent_loop = AgentLoop::new_with_resolver(
541 provider,
542 loop_config,
543 Arc::clone(&self.tools),
544 fresh_state,
545 Arc::clone(&self.resolver),
546 );
547
548 // Add the user prompt to Agent.state() AFTER fresh_state is created.
549 // fresh_state got a copy of the pre-prompt state, so run_loop will
550 // add the prompt to fresh_state independently via initial_prompts.
551 // But persist_session() reads Agent.state() (not fresh_state), so it
552 // needs the user prompt there to write it to the session file.
553 // Sync happens at AgentEnd (after run_loop completes), where
554 // Agent.state is overwritten with fresh_state (which has all messages).
555 self.state.update(|s| {
556 s.messages
557 .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
558 prompt.clone(),
559 )));
560 });
561
562 // Pre-populate steering/follow-up from hooks
563 {
564 let hooks = self.hooks.read();
565 if let Some(ref get_steering) = hooks.get_steering_messages {
566 for msg_text in get_steering() {
567 agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
568 }
569 }
570 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
571 for msg_text in get_follow_up() {
572 agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
573 }
574 }
575
576 // Store hooks on AgentLoop so they can be polled each turn
577 // to pick up new messages injected during the run.
578 if let Some(ref get_steering) = hooks.get_steering_messages {
579 agent_loop.set_steering_hook(Arc::clone(get_steering));
580 }
581 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
582 agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
583 }
584 }
585 let mut al = agent_loop;
586
587 // Wire should_stop_after_turn hook: share AgentLoop's external_stop
588 // Arc with the emit callback. When the hook fires (Ctrl+C detected),
589 // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
590 // AND during streaming (streaming.rs checks external_stop each event).
591 //
592 // Arc<dyn Fn> can be cloned, so we read it without consuming.
593 let maybe_hook = {
594 let hooks_r = self.hooks.read();
595 hooks_r.should_stop_after_turn.clone()
596 };
597 let ext_stop = al.external_stop().clone();
598 let cancel_flag = self.cancel_flag.clone();
599
600 // Share cancel_flag with AgentLoop so the streaming loop can check
601 // it directly in the periodic timer — no emit callback required.
602 // This closes the gap where cancel() was ineffective when the
603 // provider stream produced no events.
604 al.set_cancel_signal(self.cancel_flag.clone());
605
606 // Create emit callback that sends through the channel.
607 // AgentLoop calls this synchronously. UnboundedSender::send() is
608 // non-blocking and never drops events (unlike try_send on bounded).
609 let tx_emit = tx.clone();
610
611 // Run the agent loop
612 tracing::info!("[AGENT] Starting agent run with channel");
613 let result = al
614 .run(prompt.clone(), move |event: AgentEvent| {
615 // Forward event to channel (std::sync::mpsc — send from sync context)
616 tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
617 if let Err(e) = tx_emit.send(event.clone()) {
618 tracing::error!(
619 "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
620 e
621 );
622 } else {
623 tracing::info!("[AGENT-EMIT] Successfully sent event");
624 }
625
626 // Propagate cancellation from Agent::cancel() → external_stop.
627 // This runs on every event, ensuring the streaming loop detects
628 // cancellation promptly.
629 if cancel_flag.load(Ordering::SeqCst) {
630 ext_stop.store(true, Ordering::SeqCst);
631 }
632
633 // Propagate should_stop → external_stop on every event, not
634 // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
635 // so the context contents are irrelevant for non-TurnEnd events.
636 // This ensures streaming.rs detects cancellation immediately
637 // when the user presses Ctrl+C mid-stream.
638 if let Some(ref hook) = maybe_hook {
639 let ctx = ShouldStopAfterTurnContext {
640 message: match &event {
641 AgentEvent::TurnEnd {
642 assistant_message: oxi_ai::Message::Assistant(a),
643 ..
644 } => a.clone(),
645 _ => oxi_ai::AssistantMessage::new(
646 oxi_ai::Api::OpenAiCompletions,
647 "agent",
648 "agent-model",
649 ),
650 },
651 tool_results: match &event {
652 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
653 _ => Vec::new(),
654 },
655 iteration: 0,
656 };
657 if hook(&ctx) {
658 ext_stop.store(true, Ordering::SeqCst);
659 }
660 }
661 })
662 .await;
663
664 match result {
665 Ok(_events) => {
666 // Sync state back from AgentLoop
667 let loop_state = al.state().get_state();
668 self.state.update(|s| {
669 *s = loop_state;
670 });
671
672 // Apply any pending model switch that was deferred during the run.
673 // This transforms messages (if cross-provider) and swaps the provider
674 // so the next run uses the new model.
675 self.apply_pending_model_switch();
676
677 // Extract final response text from state
678 let state = self.state.get_state();
679 let final_text = state
680 .messages
681 .iter()
682 .rev()
683 .find_map(|m| match m {
684 oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
685 oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
686 _ => None,
687 }),
688 _ => None,
689 })
690 .unwrap_or_default();
691
692 let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
693
694 Ok(Response {
695 content: final_text,
696 stop_reason,
697 })
698 }
699 Err(e) => {
700 // Apply pending model switch even on error so the next run
701 // uses the new model.
702 self.apply_pending_model_switch();
703 Err(e)
704 }
705 }
706 }
707
708 // ── Helper methods for the agentic loop ────────────────────────
709
710 /// Set hooks for the agent loop.
711 pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
712 let mut h = self.hooks.write();
713 *h = hooks;
714 }
715
716 /// Request cancellation of the current agent run.
717 ///
718 /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
719 /// `external_stop` on every event AND polled every ~500ms by the
720 /// streaming loop's periodic check. This ensures cancellation is
721 /// detected quickly even when the provider stream is completely hung
722 /// (no events arriving).
723 pub fn cancel(&self) {
724 self.cancel_flag.store(true, Ordering::SeqCst);
725 }
726
727 /// Reset the cancellation flag before starting a new run.
728 pub fn reset_cancel(&self) {
729 self.cancel_flag.store(false, Ordering::SeqCst);
730 }
731
732 /// Apply any pending model switch that was deferred during a running loop.
733 ///
734 /// Called after `run_with_channel_inner` completes (success or error).
735 /// Transforms messages for cross-provider switches and swaps the provider
736 /// so the next run uses the new model.
737 fn apply_pending_model_switch(&self) {
738 let pending = self.pending_model_switch.write().take();
739 if let Some(pending) = pending {
740 tracing::info!(
741 "[AGENT] Applying deferred model switch to '{}' (transform={})",
742 pending.model_id,
743 pending.needs_transform
744 );
745
746 // Transform messages if cross-provider
747 if pending.needs_transform {
748 let messages = self.state.get_state().messages.clone();
749 let transformed =
750 transform_for_provider(&messages, &pending.old_api, &pending.new_api);
751 self.state.update(|s| {
752 s.replace_messages(transformed);
753 });
754 }
755
756 // Swap the provider
757 let mut inner = self.inner_mut();
758 inner.provider = pending.provider;
759 // model_id was already updated in switch_model()
760 }
761 }
762
763 /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
764 ///
765 /// Blocking convenience wrapper suitable for callers that prefer a
766 /// callback-based API over a channel.
767 pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
768 where
769 F: FnMut(AgentEvent) + Send,
770 {
771 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
772 let result = self.run_with_channel(prompt, tx).await;
773 while let Ok(event) = rx.recv() {
774 on_event(event);
775 }
776 result
777 }
778
779 // ── Session persistence ────────────────────────────────────────
780
781 /// Export the agent state as a JSON value.
782 ///
783 /// The serialized state includes conversation messages, token counts,
784 /// iteration progress, and stop reason. Use [`import_state`] to restore.
785 ///
786 /// [`import_state`]: Agent::import_state
787 pub fn export_state(&self) -> Result<serde_json::Value> {
788 let state = self.state.get_state();
789 serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
790 }
791
792 /// Import agent state from a JSON value.
793 ///
794 /// Restores conversation history, token counts, and iteration progress.
795 /// Typically used together with [`export_state`] for session persistence.
796 ///
797 /// [`export_state`]: Agent::export_state
798 pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
799 let state: AgentState = serde_json::from_value(value)
800 .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
801 self.state.update(|s| *s = state);
802 Ok(())
803 }
804
805 // ── Session continuation ───────────────────────────────────────
806
807 /// Continue the current session with a new prompt.
808 ///
809 /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
810 /// preserves the existing conversation state and appends the new prompt.
811 /// This enables multi-turn interactions within the same session.
812 pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
813 let mut events = Vec::new();
814 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
815 let result = self.run_with_channel(prompt, tx).await;
816 while let Ok(event) = rx.recv() {
817 events.push(event);
818 }
819 result.map(|r| (r, events))
820 }
821
822 // ── Tokio-native streaming ─────────────────────────────────────
823
824 /// Run the agent with tokio-native event streaming.
825 ///
826 /// Returns a `tokio::sync::mpsc::Receiver` for events and a
827 /// `JoinHandle` for the response. This is the preferred API for
828 /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
829 ///
830 /// # Example
831 ///
832 /// ```ignore
833 /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
834 /// while let Some(event) = rx.recv().await {
835 /// println!("Event: {:?}", event.type_name());
836 /// }
837 /// let response = handle.await??;
838 /// ```
839 pub async fn run_tokio_stream(
840 &self,
841 prompt: String,
842 ) -> Result<(
843 tokio::sync::mpsc::Receiver<AgentEvent>,
844 tokio::task::JoinHandle<Result<Response>>,
845 )> {
846 let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
847
848 if self
849 .is_running
850 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
851 .is_err()
852 {
853 return Err(Error::msg("Agent is already running"));
854 }
855
856 let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
857
858 let inner = self.inner.read().clone();
859 let tools = Arc::clone(&self.tools);
860 let resolver = Arc::clone(&self.resolver);
861
862 // Build AgentLoopConfig
863 let loop_config = crate::agent_loop::config::AgentLoopConfig {
864 model_id: inner.config.model_id.clone(),
865 system_prompt: inner.config.system_prompt.clone(),
866 temperature: inner.config.temperature.unwrap_or(1.0) as f32,
867 max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
868 tool_execution: crate::config::ToolExecutionMode::Sequential,
869 compaction_strategy: inner.config.compaction_strategy.clone(),
870 compaction_instruction: None,
871 context_window: inner.config.context_window,
872 session_id: inner.config.session_id.clone(),
873 transport: None,
874 compact_on_start: false,
875 max_retry_delay_ms: None,
876 auto_retry_enabled: true,
877 auto_retry_max_attempts: 3,
878 auto_retry_base_delay_ms: 1000,
879 api_key: inner.config.api_key.clone(),
880 workspace_dir: inner.config.workspace_dir.clone(),
881 provider_options: inner.config.provider_options.clone(),
882 on_compaction: None,
883 ttsr_engine: inner.config.ttsr_engine.clone(),
884 ..Default::default()
885 };
886
887 let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
888
889 // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
890 // agent loop so that state mutations inside the spawned task are
891 // visible through self.state() without an explicit sync step.
892 //
893 // Unlike run_with_channel_inner which creates a fresh SharedState
894 // and syncs back on completion, the tokio streaming API cannot
895 // access `self` inside the `'static` spawned task, so we share
896 // the underlying Arc instead.
897 //
898 // Pre-load current state into the shared Arc (in case it was
899 // modified by a previous run that used a different SharedState).
900 let shared_state = self.state.clone();
901
902 let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
903 provider,
904 loop_config,
905 tools,
906 shared_state.clone(),
907 resolver,
908 );
909
910 let maybe_hook = should_stop_hook;
911 let ext_stop = agent_loop.external_stop().clone();
912
913 // Clone the is_running Arc so the spawned task can clear it.
914 let is_running_flag = Arc::clone(&self.is_running);
915
916 let handle = tokio::task::spawn(async move {
917 let result = agent_loop
918 .run(prompt, move |event: AgentEvent| {
919 // Forward to tokio channel (non-blocking)
920 let _ = tx.try_send(event.clone());
921
922 // Propagate should_stop → external_stop on every event,
923 // not just TurnEnd. See run_with_channel_inner for rationale.
924 if let Some(ref hook) = maybe_hook {
925 let ctx = ShouldStopAfterTurnContext {
926 message: match &event {
927 AgentEvent::TurnEnd {
928 assistant_message: oxi_ai::Message::Assistant(a),
929 ..
930 } => a.clone(),
931 _ => oxi_ai::AssistantMessage::new(
932 oxi_ai::Api::OpenAiCompletions,
933 "agent",
934 "agent-model",
935 ),
936 },
937 tool_results: match &event {
938 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
939 _ => Vec::new(),
940 },
941 iteration: 0,
942 };
943 if hook(&ctx) {
944 ext_stop.store(true, Ordering::SeqCst);
945 }
946 }
947 })
948 .await;
949
950 // Clear the Agent's running flag
951 is_running_flag.store(false, Ordering::SeqCst);
952
953 match result {
954 Ok(_events) => {
955 // State is already shared via the same SharedState Arc,
956 // so self.state() will reflect all mutations.
957 Ok(Response {
958 content: String::new(),
959 stop_reason: StopReason::Stop,
960 })
961 }
962 Err(e) => Err(e),
963 }
964 });
965
966 Ok((rx, handle))
967 }
968}