oxi_agent/agent.rs
1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10 CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26 /// Resolve a provider by name, returning an Arc handle.
27 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29 /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30 fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41 oxi_ai::get_provider(name).map(Arc::from)
42 }
43
44 fn resolve_model(&self, model_id: &str) -> Option<Model> {
45 crate::model_id::resolve_model_from_id(model_id)
46 }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50
51/// Mutable agent internals protected by a read-write lock.
52struct AgentInner {
53 config: AgentConfig,
54 provider: Arc<dyn Provider>,
55}
56
57impl Clone for AgentInner {
58 fn clone(&self) -> Self {
59 Self {
60 config: self.config.clone(),
61 provider: Arc::clone(&self.provider),
62 }
63 }
64}
65
66/// Agent runtime.
67///
68/// Manages provider, tool registry, state, and compaction, providing an
69/// agentic loop for prompt execution, model switching, tool calls, and fallback.
70///
71/// Supports session continuation via [`continue_with`] and tokio-native
72/// event streaming via [`run_tokio_stream`].
73///
74/// [`continue_with`]: Agent::continue_with
75/// [`run_tokio_stream`]: Agent::run_tokio_stream
76/// Deferred model switch request, stored when the agent is running.
77struct PendingModelSwitch {
78 model_id: String,
79 provider: Arc<dyn Provider>,
80 /// Whether messages need cross-provider transformation.
81 needs_transform: bool,
82 old_api: oxi_ai::Api,
83 new_api: oxi_ai::Api,
84}
85
86/// Agent runtime.
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation, tokio-native event streaming, and deferred
92/// model switching (changes are queued while a loop is running and applied
93/// after it completes).
94#[allow(missing_docs)]
95pub struct Agent {
96 inner: RwLock<AgentInner>,
97 tools: Arc<ToolRegistry>,
98 state: SharedState,
99 compaction_manager: CompactionManager,
100 hooks: parking_lot::RwLock<crate::config::AgentHooks>,
101 /// Guard: true while a run is in progress. Prevents concurrent runs.
102 is_running: Arc<AtomicBool>,
103 /// Provider/model resolver. Uses global functions by default,
104 /// or a custom resolver when created via `new_with_resolver()`.
105 resolver: Arc<dyn ProviderResolver>,
106 /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
107 /// propagated to AgentLoop's `external_stop` during each run.
108 cancel_flag: Arc<AtomicBool>,
109 /// Pending model switch — stored when the agent is running,
110 /// applied after the current loop completes.
111 pending_model_switch: RwLock<Option<PendingModelSwitch>>,
112}
113
114impl Agent {
115 /// Create a new agent with the given provider, config, and tool registry.
116 ///
117 /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
118 /// for model switching. For isolated instances, use [`new_with_resolver`].
119 ///
120 /// [`new_with_resolver`]: Agent::new_with_resolver
121 pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
122 let resolver = Arc::new(GlobalProviderResolver);
123 Self::build_inner(provider, config, tools, resolver)
124 }
125
126 /// Create an agent with a custom provider/model resolver.
127 ///
128 /// This is the preferred constructor for SDK usage where provider
129 /// and model registries must be isolated from global state.
130 pub fn new_with_resolver(
131 provider: Arc<dyn Provider>,
132 config: AgentConfig,
133 tools: Arc<ToolRegistry>,
134 resolver: Arc<dyn ProviderResolver>,
135 ) -> Self {
136 Self::build_inner(provider, config, tools, resolver)
137 }
138
139 /// Internal constructor shared by `new()` and `new_with_resolver()`.
140 fn build_inner(
141 provider: Arc<dyn Provider>,
142 config: AgentConfig,
143 tools: Arc<ToolRegistry>,
144 resolver: Arc<dyn ProviderResolver>,
145 ) -> Self {
146 let mut compaction_manager =
147 CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
148
149 // Pre-initialize the LLM compactor if compaction is enabled
150 if config.compaction_strategy != CompactionStrategy::Disabled {
151 let model = resolver.resolve_model(&config.model_id);
152
153 if let Some(model) = model {
154 let llm_compactor =
155 Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
156 compaction_manager.set_compactor(llm_compactor);
157 }
158 }
159
160 Self {
161 inner: RwLock::new(AgentInner { config, provider }),
162 tools,
163 state: SharedState::new(),
164 compaction_manager,
165 hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
166 is_running: Arc::new(AtomicBool::new(false)),
167 resolver,
168 cancel_flag: Arc::new(AtomicBool::new(false)),
169 pending_model_switch: RwLock::new(None),
170 }
171 }
172
173 /// Create an agent with an empty tool registry.
174 pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
175 Self::new(provider, config, Arc::new(ToolRegistry::new()))
176 }
177
178 /// Get the agent configuration (read guard)
179 fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
180 self.inner.read()
181 }
182
183 /// Get a write guard for the agent inner state
184 fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
185 self.inner.write()
186 }
187
188 /// Get the current model ID
189 pub fn model_id(&self) -> String {
190 self.config().config.model_id.clone()
191 }
192
193 /// Get the agent configuration (full clone)
194 pub fn get_config(&self) -> AgentConfig {
195 self.config().config.clone()
196 }
197
198 /// Get a reference to the provider resolver.
199 pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
200 &self.resolver
201 }
202
203 /// Switch the model used for future LLM calls.
204 ///
205 /// Switch model mid-conversation.
206 ///
207 /// If the agent is currently running, the switch is deferred: the new
208 /// model and provider are stored in `pending_model_switch` and applied
209 /// automatically when the current loop finishes. This ensures the
210 /// running loop completes with a consistent provider/model without
211 /// interruption.
212 ///
213 /// If the agent is idle, the switch takes effect immediately.
214 ///
215 /// If the new model uses a different provider API, the conversation
216 /// history is automatically transformed for cross-provider compatibility
217 /// (e.g. thinking blocks are converted to `<thinking>` tags).
218 ///
219 /// # Arguments
220 /// * `model_id` - New model ID in `provider/model` format
221 /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
222 ///
223 /// # Returns
224 /// `Ok(())` on success, or an error if the model/provider is unknown
225 pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
226 let new_model = self
227 .resolver
228 .resolve_model(model_id)
229 .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
230
231 // Create the new provider via resolver
232 let new_provider = self
233 .resolver
234 .resolve_provider(&new_model.provider)
235 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
236
237 // Detect API change
238 let (old_api, needs_transform) = {
239 let inner = self.config();
240 let old_api = self
241 .resolver
242 .resolve_model(&inner.config.model_id)
243 .map(|m| m.api)
244 .unwrap_or(oxi_ai::Api::AnthropicMessages);
245 (old_api, old_api != new_model.api)
246 };
247
248 // If the agent is currently running, defer the switch.
249 if self.is_running.load(Ordering::SeqCst) {
250 tracing::info!(
251 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
252 model_id
253 );
254 *self.pending_model_switch.write() = Some(PendingModelSwitch {
255 model_id: model_id.to_string(),
256 provider: new_provider,
257 needs_transform,
258 old_api,
259 new_api: new_model.api,
260 });
261 // Update config immediately so model_id() returns the new value,
262 // but leave provider unchanged so the running loop keeps its provider.
263 {
264 let mut inner = self.inner_mut();
265 inner.config.model_id = model_id.to_string();
266 inner.config.api_key = api_key;
267 }
268 return Ok(());
269 }
270
271 // Agent is idle — apply immediately.
272 if needs_transform {
273 let messages = self.state.get_state().messages.clone();
274 let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
275 self.state.update(|s| {
276 s.replace_messages(transformed);
277 });
278 }
279
280 // Update config and provider atomically
281 let mut inner = self.inner_mut();
282 inner.config.model_id = model_id.to_string();
283 inner.config.api_key = api_key;
284 inner.provider = new_provider;
285
286 Ok(())
287 }
288
289 /// Switch the model using a pre-resolved `Model` object.
290 ///
291 /// This is useful when the caller has already looked up the model
292 /// and optionally created the provider.
293 ///
294 /// Like [`switch_model`], if the agent is currently running, the switch
295 /// is deferred until the current loop completes.
296 ///
297 /// [`switch_model`]: Agent::switch_model
298 pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
299 let model_id = format!("{}/{}", model.provider, model.id);
300 let new_provider = self
301 .resolver
302 .resolve_provider(&model.provider)
303 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
304
305 // Detect API change
306 let (old_api, needs_transform) = {
307 let inner = self.config();
308 let old_api = self
309 .resolver
310 .resolve_model(&inner.config.model_id)
311 .map(|m| m.api)
312 .unwrap_or(oxi_ai::Api::AnthropicMessages);
313 (old_api, old_api != model.api)
314 };
315
316 // If the agent is currently running, defer the switch.
317 if self.is_running.load(Ordering::SeqCst) {
318 tracing::info!(
319 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
320 model_id
321 );
322 *self.pending_model_switch.write() = Some(PendingModelSwitch {
323 model_id: model_id.clone(),
324 provider: new_provider,
325 needs_transform,
326 old_api,
327 new_api: model.api,
328 });
329 let mut inner = self.inner_mut();
330 inner.config.model_id = model_id;
331 inner.config.api_key = api_key;
332 return Ok(());
333 }
334
335 // Agent is idle — apply immediately.
336 if needs_transform {
337 let messages = self.state.get_state().messages.clone();
338 let transformed = transform_for_provider(&messages, &old_api, &model.api);
339 self.state.update(|s| {
340 s.replace_messages(transformed);
341 });
342 }
343
344 let mut inner = self.inner_mut();
345 inner.config.model_id = model_id;
346 inner.config.api_key = api_key;
347 inner.provider = new_provider;
348
349 Ok(())
350 }
351
352 /// Refresh only the API key without changing model or provider.
353 /// Useful when the user stores a key after the session was already created.
354 pub fn refresh_api_key(&self, api_key: Option<String>) {
355 let mut inner = self.inner_mut();
356 inner.config.api_key = api_key;
357 }
358
359 /// Get a handle to the tool registry.
360 pub fn tools(&self) -> Arc<ToolRegistry> {
361 Arc::clone(&self.tools)
362 }
363
364 /// Get a snapshot of the current agent state.
365 pub fn state(&self) -> AgentState {
366 self.state.get_state()
367 }
368
369 /// Update agent state in-place. Used by compaction to replace messages.
370 pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
371 self.state.update(f);
372 }
373
374 /// Reset agent state for a new conversation
375 pub fn reset(&self) {
376 self.state.reset();
377 }
378
379 /// Register a tool that the agent can invoke during a run.
380 pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
381 self.tools.register(tool);
382 }
383
384 /// Update the system prompt for future interactions.
385 pub fn set_system_prompt(&self, prompt: String) {
386 self.inner_mut().config.system_prompt = Some(prompt);
387 }
388
389 /// Get the compaction manager
390 pub fn compaction_manager(&self) -> &CompactionManager {
391 &self.compaction_manager
392 }
393
394 /// Run the agent with a prompt, collecting all events into a vector.
395 ///
396 /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
397 /// [`AgentEvent`] produced during the run.
398 pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
399 let mut events = Vec::new();
400 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
401 let result = self.run_with_channel(prompt, tx).await;
402 while let Ok(event) = rx.recv() {
403 events.push(event);
404 }
405 result.map(|r| (r, events))
406 }
407
408 /// Run the agent, delivering events through the provided channel.
409 ///
410 /// Delegates to the agent loop which implements the same 2-level agentic
411 /// loop matching pi-mono's architecture:
412 ///
413 /// ```text
414 /// AgentLoop.run_messages()
415 /// Outer loop (follow-up messages):
416 /// Inner loop (tool calls + steering):
417 /// 1. Inject pending messages (steering)
418 /// 2. Compaction check
419 /// 3. Stream LLM response (with accumulated partial messages)
420 /// 4. Execute tool calls if any
421 /// 5. Emit turn_end
422 /// 6. Check shouldStopAfterTurn
423 /// 7. Poll steering messages
424 /// Check follow-up messages
425 /// Exit
426 /// ```
427 pub async fn run_with_channel(
428 &self,
429 prompt: String,
430 tx: std::sync::mpsc::Sender<AgentEvent>,
431 ) -> Result<Response> {
432 // pi-mono: Agent.prompt() throws if activeRun exists.
433 // Prevent concurrent runs that would corrupt shared state.
434 if self
435 .is_running
436 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
437 .is_err()
438 {
439 return Err(Error::msg("Agent is already running"));
440 }
441
442 // Drop guard ensures is_running is cleared even on panic.
443 struct RunningGuard<'a>(&'a AtomicBool);
444 impl Drop for RunningGuard<'_> {
445 fn drop(&mut self) {
446 self.0.store(false, Ordering::SeqCst);
447 }
448 }
449 let _guard = RunningGuard(&self.is_running);
450 self.reset_cancel();
451
452 self.run_with_channel_inner(prompt, tx).await
453 }
454
455 /// Inner implementation of run_with_channel, called after the running guard is set.
456 async fn run_with_channel_inner(
457 &self,
458 prompt: String,
459 tx: std::sync::mpsc::Sender<AgentEvent>,
460 ) -> Result<Response> {
461 use crate::agent_loop::AgentLoop;
462
463 let (
464 provider,
465 system_prompt,
466 temperature,
467 max_tokens,
468 compaction_strategy,
469 context_window,
470 api_key,
471 workspace_dir,
472 ) = {
473 let inner = self.inner.read();
474 (
475 Arc::clone(&inner.provider) as Arc<dyn Provider>,
476 inner.config.system_prompt.clone(),
477 inner.config.temperature,
478 inner.config.max_tokens,
479 inner.config.compaction_strategy.clone(),
480 inner.config.context_window,
481 inner.config.api_key.clone(),
482 inner.config.workspace_dir.clone(),
483 )
484 }; // release read lock
485
486 // Build AgentLoopConfig from Agent's config
487 let loop_config = crate::agent_loop::config::AgentLoopConfig {
488 model_id: self.model_id(),
489 system_prompt,
490 temperature: temperature.unwrap_or(1.0) as f32,
491 max_tokens: max_tokens.unwrap_or(4096) as u32,
492 tool_execution: crate::config::ToolExecutionMode::Sequential,
493 compaction_strategy,
494 compaction_instruction: None,
495 context_window,
496 session_id: self.config().config.session_id.clone(),
497 transport: None,
498 compact_on_start: false,
499 max_retry_delay_ms: None,
500 auto_retry_enabled: true,
501 auto_retry_max_attempts: 3,
502 auto_retry_base_delay_ms: 1000,
503 api_key,
504 workspace_dir,
505 provider_options: self.config().config.provider_options.clone(),
506 on_compaction: None,
507 ttsr_engine: self.config().config.ttsr_engine.clone(),
508 memory: self.config().config.memory.clone(),
509 todo: self.config().config.todo.clone(),
510 agent_pool: self.config().config.agent_pool.clone(),
511 ..Default::default()
512 };
513
514 // Create AgentLoop. We give it a NEW SharedState and sync back after.
515 // (SharedState is not Clone, so we create a fresh one from current state)
516 let fresh_state = crate::state::SharedState::new();
517 let current = self.state.get_state();
518 fresh_state.update(|s| {
519 *s = current;
520 });
521
522 let mut agent_loop = AgentLoop::new_with_resolver(
523 provider,
524 loop_config,
525 Arc::clone(&self.tools),
526 fresh_state,
527 Arc::clone(&self.resolver),
528 );
529
530 // Add the user prompt to Agent.state() AFTER fresh_state is created.
531 // fresh_state got a copy of the pre-prompt state, so run_loop will
532 // add the prompt to fresh_state independently via initial_prompts.
533 // But persist_session() reads Agent.state() (not fresh_state), so it
534 // needs the user prompt there to write it to the session file.
535 // Sync happens at AgentEnd (after run_loop completes), where
536 // Agent.state is overwritten with fresh_state (which has all messages).
537 self.state.update(|s| {
538 s.messages
539 .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
540 prompt.clone(),
541 )));
542 });
543
544 // Pre-populate steering/follow-up from hooks
545 {
546 let hooks = self.hooks.read();
547 if let Some(ref get_steering) = hooks.get_steering_messages {
548 for msg_text in get_steering() {
549 agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
550 }
551 }
552 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
553 for msg_text in get_follow_up() {
554 agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
555 }
556 }
557
558 // Store hooks on AgentLoop so they can be polled each turn
559 // to pick up new messages injected during the run.
560 if let Some(ref get_steering) = hooks.get_steering_messages {
561 agent_loop.set_steering_hook(Arc::clone(get_steering));
562 }
563 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
564 agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
565 }
566 }
567 let mut al = agent_loop;
568
569 // Wire should_stop_after_turn hook: share AgentLoop's external_stop
570 // Arc with the emit callback. When the hook fires (Ctrl+C detected),
571 // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
572 // AND during streaming (streaming.rs checks external_stop each event).
573 //
574 // Arc<dyn Fn> can be cloned, so we read it without consuming.
575 let maybe_hook = {
576 let hooks_r = self.hooks.read();
577 hooks_r.should_stop_after_turn.clone()
578 };
579 let ext_stop = al.external_stop().clone();
580 let cancel_flag = self.cancel_flag.clone();
581
582 // Share cancel_flag with AgentLoop so the streaming loop can check
583 // it directly in the periodic timer — no emit callback required.
584 // This closes the gap where cancel() was ineffective when the
585 // provider stream produced no events.
586 al.set_cancel_signal(self.cancel_flag.clone());
587
588 // Create emit callback that sends through the channel.
589 // AgentLoop calls this synchronously. UnboundedSender::send() is
590 // non-blocking and never drops events (unlike try_send on bounded).
591 let tx_emit = tx.clone();
592
593 // Run the agent loop
594 tracing::info!("[AGENT] Starting agent run with channel");
595 let result = al
596 .run(prompt.clone(), move |event: AgentEvent| {
597 // Forward event to channel (std::sync::mpsc — send from sync context)
598 tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
599 if let Err(e) = tx_emit.send(event.clone()) {
600 tracing::error!(
601 "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
602 e
603 );
604 } else {
605 tracing::info!("[AGENT-EMIT] Successfully sent event");
606 }
607
608 // Propagate cancellation from Agent::cancel() → external_stop.
609 // This runs on every event, ensuring the streaming loop detects
610 // cancellation promptly.
611 if cancel_flag.load(Ordering::SeqCst) {
612 ext_stop.store(true, Ordering::SeqCst);
613 }
614
615 // Propagate should_stop → external_stop on every event, not
616 // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
617 // so the context contents are irrelevant for non-TurnEnd events.
618 // This ensures streaming.rs detects cancellation immediately
619 // when the user presses Ctrl+C mid-stream.
620 if let Some(ref hook) = maybe_hook {
621 let ctx = ShouldStopAfterTurnContext {
622 message: match &event {
623 AgentEvent::TurnEnd {
624 assistant_message: oxi_ai::Message::Assistant(a),
625 ..
626 } => a.clone(),
627 _ => oxi_ai::AssistantMessage::new(
628 oxi_ai::Api::OpenAiCompletions,
629 "agent",
630 "agent-model",
631 ),
632 },
633 tool_results: match &event {
634 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
635 _ => Vec::new(),
636 },
637 iteration: 0,
638 };
639 if hook(&ctx) {
640 ext_stop.store(true, Ordering::SeqCst);
641 }
642 }
643 })
644 .await;
645
646 match result {
647 Ok(_events) => {
648 // Sync state back from AgentLoop
649 let loop_state = al.state().get_state();
650 self.state.update(|s| {
651 *s = loop_state;
652 });
653
654 // Apply any pending model switch that was deferred during the run.
655 // This transforms messages (if cross-provider) and swaps the provider
656 // so the next run uses the new model.
657 self.apply_pending_model_switch();
658
659 // Extract final response text from state
660 let state = self.state.get_state();
661 let final_text = state
662 .messages
663 .iter()
664 .rev()
665 .find_map(|m| match m {
666 oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
667 oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
668 _ => None,
669 }),
670 _ => None,
671 })
672 .unwrap_or_default();
673
674 let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
675
676 Ok(Response {
677 content: final_text,
678 stop_reason,
679 })
680 }
681 Err(e) => {
682 // Apply pending model switch even on error so the next run
683 // uses the new model.
684 self.apply_pending_model_switch();
685 Err(e)
686 }
687 }
688 }
689
690 // ── Helper methods for the agentic loop ────────────────────────
691
692 /// Set hooks for the agent loop.
693 pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
694 let mut h = self.hooks.write();
695 *h = hooks;
696 }
697
698 /// Request cancellation of the current agent run.
699 ///
700 /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
701 /// `external_stop` on every event AND polled every ~500ms by the
702 /// streaming loop's periodic check. This ensures cancellation is
703 /// detected quickly even when the provider stream is completely hung
704 /// (no events arriving).
705 pub fn cancel(&self) {
706 self.cancel_flag.store(true, Ordering::SeqCst);
707 }
708
709 /// Reset the cancellation flag before starting a new run.
710 pub fn reset_cancel(&self) {
711 self.cancel_flag.store(false, Ordering::SeqCst);
712 }
713
714 /// Apply any pending model switch that was deferred during a running loop.
715 ///
716 /// Called after `run_with_channel_inner` completes (success or error).
717 /// Transforms messages for cross-provider switches and swaps the provider
718 /// so the next run uses the new model.
719 fn apply_pending_model_switch(&self) {
720 let pending = self.pending_model_switch.write().take();
721 if let Some(pending) = pending {
722 tracing::info!(
723 "[AGENT] Applying deferred model switch to '{}' (transform={})",
724 pending.model_id,
725 pending.needs_transform
726 );
727
728 // Transform messages if cross-provider
729 if pending.needs_transform {
730 let messages = self.state.get_state().messages.clone();
731 let transformed =
732 transform_for_provider(&messages, &pending.old_api, &pending.new_api);
733 self.state.update(|s| {
734 s.replace_messages(transformed);
735 });
736 }
737
738 // Swap the provider
739 let mut inner = self.inner_mut();
740 inner.provider = pending.provider;
741 // model_id was already updated in switch_model()
742 }
743 }
744
745 /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
746 ///
747 /// Blocking convenience wrapper suitable for callers that prefer a
748 /// callback-based API over a channel.
749 pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
750 where
751 F: FnMut(AgentEvent) + Send,
752 {
753 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
754 let result = self.run_with_channel(prompt, tx).await;
755 while let Ok(event) = rx.recv() {
756 on_event(event);
757 }
758 result
759 }
760
761 // ── Session persistence ────────────────────────────────────────
762
763 /// Export the agent state as a JSON value.
764 ///
765 /// The serialized state includes conversation messages, token counts,
766 /// iteration progress, and stop reason. Use [`import_state`] to restore.
767 ///
768 /// [`import_state`]: Agent::import_state
769 pub fn export_state(&self) -> Result<serde_json::Value> {
770 let state = self.state.get_state();
771 serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
772 }
773
774 /// Import agent state from a JSON value.
775 ///
776 /// Restores conversation history, token counts, and iteration progress.
777 /// Typically used together with [`export_state`] for session persistence.
778 ///
779 /// [`export_state`]: Agent::export_state
780 pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
781 let state: AgentState = serde_json::from_value(value)
782 .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
783 self.state.update(|s| *s = state);
784 Ok(())
785 }
786
787 // ── Session continuation ───────────────────────────────────────
788
789 /// Continue the current session with a new prompt.
790 ///
791 /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
792 /// preserves the existing conversation state and appends the new prompt.
793 /// This enables multi-turn interactions within the same session.
794 pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
795 let mut events = Vec::new();
796 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
797 let result = self.run_with_channel(prompt, tx).await;
798 while let Ok(event) = rx.recv() {
799 events.push(event);
800 }
801 result.map(|r| (r, events))
802 }
803
804 // ── Tokio-native streaming ─────────────────────────────────────
805
806 /// Run the agent with tokio-native event streaming.
807 ///
808 /// Returns a `tokio::sync::mpsc::Receiver` for events and a
809 /// `JoinHandle` for the response. This is the preferred API for
810 /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
811 ///
812 /// # Example
813 ///
814 /// ```ignore
815 /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
816 /// while let Some(event) = rx.recv().await {
817 /// println!("Event: {:?}", event.type_name());
818 /// }
819 /// let response = handle.await??;
820 /// ```
821 pub async fn run_tokio_stream(
822 &self,
823 prompt: String,
824 ) -> Result<(
825 tokio::sync::mpsc::Receiver<AgentEvent>,
826 tokio::task::JoinHandle<Result<Response>>,
827 )> {
828 let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
829
830 if self
831 .is_running
832 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
833 .is_err()
834 {
835 return Err(Error::msg("Agent is already running"));
836 }
837
838 let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
839
840 let inner = self.inner.read().clone();
841 let tools = Arc::clone(&self.tools);
842 let resolver = Arc::clone(&self.resolver);
843
844 // Build AgentLoopConfig
845 let loop_config = crate::agent_loop::config::AgentLoopConfig {
846 model_id: inner.config.model_id.clone(),
847 system_prompt: inner.config.system_prompt.clone(),
848 temperature: inner.config.temperature.unwrap_or(1.0) as f32,
849 max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
850 tool_execution: crate::config::ToolExecutionMode::Sequential,
851 compaction_strategy: inner.config.compaction_strategy.clone(),
852 compaction_instruction: None,
853 context_window: inner.config.context_window,
854 session_id: inner.config.session_id.clone(),
855 transport: None,
856 compact_on_start: false,
857 max_retry_delay_ms: None,
858 auto_retry_enabled: true,
859 auto_retry_max_attempts: 3,
860 auto_retry_base_delay_ms: 1000,
861 api_key: inner.config.api_key.clone(),
862 workspace_dir: inner.config.workspace_dir.clone(),
863 provider_options: inner.config.provider_options.clone(),
864 on_compaction: None,
865 ttsr_engine: inner.config.ttsr_engine.clone(),
866 ..Default::default()
867 };
868
869 let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
870
871 // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
872 // agent loop so that state mutations inside the spawned task are
873 // visible through self.state() without an explicit sync step.
874 //
875 // Unlike run_with_channel_inner which creates a fresh SharedState
876 // and syncs back on completion, the tokio streaming API cannot
877 // access `self` inside the `'static` spawned task, so we share
878 // the underlying Arc instead.
879 //
880 // Pre-load current state into the shared Arc (in case it was
881 // modified by a previous run that used a different SharedState).
882 let shared_state = self.state.clone();
883
884 let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
885 provider,
886 loop_config,
887 tools,
888 shared_state.clone(),
889 resolver,
890 );
891
892 let maybe_hook = should_stop_hook;
893 let ext_stop = agent_loop.external_stop().clone();
894
895 // Clone the is_running Arc so the spawned task can clear it.
896 let is_running_flag = Arc::clone(&self.is_running);
897
898 let handle = tokio::task::spawn(async move {
899 let result = agent_loop
900 .run(prompt, move |event: AgentEvent| {
901 // Forward to tokio channel (non-blocking)
902 let _ = tx.try_send(event.clone());
903
904 // Propagate should_stop → external_stop on every event,
905 // not just TurnEnd. See run_with_channel_inner for rationale.
906 if let Some(ref hook) = maybe_hook {
907 let ctx = ShouldStopAfterTurnContext {
908 message: match &event {
909 AgentEvent::TurnEnd {
910 assistant_message: oxi_ai::Message::Assistant(a),
911 ..
912 } => a.clone(),
913 _ => oxi_ai::AssistantMessage::new(
914 oxi_ai::Api::OpenAiCompletions,
915 "agent",
916 "agent-model",
917 ),
918 },
919 tool_results: match &event {
920 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
921 _ => Vec::new(),
922 },
923 iteration: 0,
924 };
925 if hook(&ctx) {
926 ext_stop.store(true, Ordering::SeqCst);
927 }
928 }
929 })
930 .await;
931
932 // Clear the Agent's running flag
933 is_running_flag.store(false, Ordering::SeqCst);
934
935 match result {
936 Ok(_events) => {
937 // State is already shared via the same SharedState Arc,
938 // so self.state() will reflect all mutations.
939 Ok(Response {
940 content: String::new(),
941 stop_reason: StopReason::Stop,
942 })
943 }
944 Err(e) => Err(e),
945 }
946 });
947
948 Ok((rx, handle))
949 }
950}