oxi_agent/agent.rs
1/// Core agent implementation
2use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10 CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider, transform_for_provider,
11};
12use parking_lot::RwLock;
13use std::sync::Arc;
14use std::sync::atomic::{AtomicBool, Ordering};
15
16// ── ProviderResolver trait ────────────────────────────────────────
17
18/// Trait for resolving providers and models within an Agent.
19///
20/// This abstracts away global static registries, allowing SDK users
21/// to provide isolated provider/model lookups.
22///
23/// When using the SDK (`oxi-sdk`), the `Oxi` engine implements this trait.
24/// When using `Agent::new()` directly, a global fallback is used.
25pub trait ProviderResolver: Send + Sync + 'static {
26 /// Resolve a provider by name, returning an Arc handle.
27 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29 /// Resolve a model ID ("provider/model" or bare "model") to a Model.
30 fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33/// Global provider resolver — uses `oxi_ai` global functions.
34///
35/// This is the default resolver when using `Agent::new()`, preserving
36/// backward compatibility with existing CLI usage.
37pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41 oxi_ai::get_provider(name).map(Arc::from)
42 }
43
44 fn resolve_model(&self, model_id: &str) -> Option<Model> {
45 crate::model_id::resolve_model_from_id(model_id)
46 }
47}
48
49// ── AgentInner ────────────────────────────────────────────────────
50/// Mutable agent internals protected by a read-write lock.
51struct AgentInner {
52 config: AgentConfig,
53 provider: Arc<dyn Provider>,
54 /// Side-dispatch closures invoked for every `AgentEvent` emitted by
55 /// the agent run methods. Used by `oxi-sdk` to bridge observability
56 /// types (Tracer, CostTracker, ...) into the agent loop without
57 /// leaking SDK types into `oxi-agent`.
58 ///
59 /// Lock-mutex rather than `RwLock`: dispatch lists mutate rarely
60 /// (only on `add_observability_dispatch`), but reads happen on every
61 /// event (high frequency), so a `Mutex` with cheap poison-free
62 /// acquisition is the right shape.
63 observability_dispatch: parking_lot::Mutex<Vec<EventDispatchFn>>,
64}
65
66/// Type alias for an observability dispatch handler. Each entry is a
67/// closure registered via [`Agent::add_observability_dispatch`] and
68/// invoked on every emitted `AgentEvent`. Named to keep the
69/// [`AgentInner`] field readable without an inline `dyn` route.
70type EventDispatchFn = Arc<dyn Fn(AgentEvent) + Send + Sync>;
71
72impl Clone for AgentInner {
73 fn clone(&self) -> Self {
74 Self {
75 config: self.config.clone(),
76 provider: Arc::clone(&self.provider),
77 // The dispatch list is *not* cloned: each `Agent` instance has
78 // its own observers. Cloning the AgentInner (rare; happens in
79 // `run_with_channel_inner` when sharing config across loops)
80 // gives the new loop an empty observer set, which is correct:
81 // the *Agent* retains the original dispatch list, and the
82 // temporary inner clone is discarded after the run.
83 observability_dispatch: parking_lot::Mutex::new(Vec::new()),
84 }
85 }
86}
87///
88/// Manages provider, tool registry, state, and compaction, providing an
89/// agentic loop for prompt execution, model switching, tool calls, and fallback.
90///
91/// Supports session continuation via [`continue_with`] and tokio-native
92/// event streaming via [`run_tokio_stream`].
93///
94/// [`continue_with`]: Agent::continue_with
95/// [`run_tokio_stream`]: Agent::run_tokio_stream
96/// Deferred model switch request, stored when the agent is running.
97struct PendingModelSwitch {
98 model_id: String,
99 provider: Arc<dyn Provider>,
100 /// Whether messages need cross-provider transformation.
101 needs_transform: bool,
102 old_api: oxi_ai::Api,
103 new_api: oxi_ai::Api,
104}
105
106/// Agent runtime.
107///
108/// Manages provider, tool registry, state, and compaction, providing an
109/// agentic loop for prompt execution, model switching, tool calls, and fallback.
110///
111/// Supports session continuation, tokio-native event streaming, and deferred
112/// model switching (changes are queued while a loop is running and applied
113/// after it completes).
114#[allow(missing_docs)]
115pub struct Agent {
116 inner: RwLock<AgentInner>,
117 tools: Arc<ToolRegistry>,
118 state: SharedState,
119 compaction_manager: CompactionManager,
120 hooks: parking_lot::RwLock<crate::config::AgentHooks>,
121 /// Guard: true while a run is in progress. Prevents concurrent runs.
122 is_running: Arc<AtomicBool>,
123 /// Provider/model resolver. Uses global functions by default,
124 /// or a custom resolver when created via `new_with_resolver()`.
125 resolver: Arc<dyn ProviderResolver>,
126 /// Shared cancellation flag. Set by `cancel()` (e.g. on Ctrl+C),
127 /// propagated to AgentLoop's `external_stop` during each run.
128 cancel_flag: Arc<AtomicBool>,
129 /// Pending model switch — stored when the agent is running,
130 /// applied after the current loop completes.
131 pending_model_switch: RwLock<Option<PendingModelSwitch>>,
132}
133
134impl Agent {
135 /// Create a new agent with the given provider, config, and tool registry.
136 ///
137 /// Uses the global `oxi_ai::get_provider()` / `resolve_model_from_id()`
138 /// for model switching. For isolated instances, use [`new_with_resolver`].
139 ///
140 /// [`new_with_resolver`]: Agent::new_with_resolver
141 pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
142 let resolver = Arc::new(GlobalProviderResolver);
143 Self::build_inner(provider, config, tools, resolver)
144 }
145
146 /// Create an agent with a custom provider/model resolver.
147 ///
148 /// This is the preferred constructor for SDK usage where provider
149 /// and model registries must be isolated from global state.
150 pub fn new_with_resolver(
151 provider: Arc<dyn Provider>,
152 config: AgentConfig,
153 tools: Arc<ToolRegistry>,
154 resolver: Arc<dyn ProviderResolver>,
155 ) -> Self {
156 Self::build_inner(provider, config, tools, resolver)
157 }
158
159 /// Create an agent with an empty tool registry.
160 pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
161 Self::new(provider, config, Arc::new(ToolRegistry::new()))
162 }
163
164 /// Get the agent configuration (read guard)
165 fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
166 self.inner.read()
167 }
168
169 /// Get a write guard for the agent inner state
170 fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
171 self.inner.write()
172 }
173
174 /// Get the current model ID
175 pub fn model_id(&self) -> String {
176 self.config().config.model_id.clone()
177 }
178
179 /// Get the agent configuration (full clone)
180 pub fn get_config(&self) -> AgentConfig {
181 self.config().config.clone()
182 }
183
184 /// Internal constructor shared by `new()` and `new_with_resolver()`.
185 fn build_inner(
186 provider: Arc<dyn Provider>,
187 config: AgentConfig,
188 tools: Arc<ToolRegistry>,
189 resolver: Arc<dyn ProviderResolver>,
190 ) -> Self {
191 let mut compaction_manager =
192 CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
193
194 // Pre-initialize the LLM compactor if compaction is enabled
195 if config.compaction_strategy != CompactionStrategy::Disabled {
196 let model = resolver.resolve_model(&config.model_id);
197
198 if let Some(model) = model {
199 let llm_compactor =
200 Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
201 compaction_manager.set_compactor(llm_compactor);
202 }
203 }
204
205 Self {
206 inner: RwLock::new(AgentInner {
207 config,
208 provider,
209 observability_dispatch: parking_lot::Mutex::new(Vec::new()),
210 }),
211 tools,
212 state: SharedState::new(),
213 compaction_manager,
214 hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
215 is_running: Arc::new(AtomicBool::new(false)),
216 resolver,
217 cancel_flag: Arc::new(AtomicBool::new(false)),
218 pending_model_switch: RwLock::new(None),
219 }
220 }
221
222 /// Get a reference to the provider resolver.
223 pub fn resolver(&self) -> &Arc<dyn ProviderResolver> {
224 &self.resolver
225 }
226
227 /// Switch the model used for future LLM calls.
228 ///
229 /// Switch model mid-conversation.
230 ///
231 /// If the agent is currently running, the switch is deferred: the new
232 /// model and provider are stored in `pending_model_switch` and applied
233 /// automatically when the current loop finishes. This ensures the
234 /// running loop completes with a consistent provider/model without
235 /// interruption.
236 ///
237 /// If the agent is idle, the switch takes effect immediately.
238 ///
239 /// If the new model uses a different provider API, the conversation
240 /// history is automatically transformed for cross-provider compatibility
241 /// (e.g. thinking blocks are converted to `<thinking>` tags).
242 ///
243 /// # Arguments
244 /// * `model_id` - New model ID in `provider/model` format
245 /// * `api_key` - Optional API key for the new provider (will be passed to StreamOptions)
246 ///
247 /// # Returns
248 /// `Ok(())` on success, or an error if the model/provider is unknown
249 pub fn switch_model(&self, model_id: &str, api_key: Option<String>) -> Result<()> {
250 let new_model = self
251 .resolver
252 .resolve_model(model_id)
253 .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
254
255 // Create the new provider via resolver
256 let new_provider = self
257 .resolver
258 .resolve_provider(&new_model.provider)
259 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
260
261 // Detect API change
262 let (old_api, needs_transform) = {
263 let inner = self.config();
264 let old_api = self
265 .resolver
266 .resolve_model(&inner.config.model_id)
267 .map(|m| m.api)
268 .unwrap_or(oxi_ai::Api::AnthropicMessages);
269 (old_api, old_api != new_model.api)
270 };
271
272 // If the agent is currently running, defer the switch.
273 if self.is_running.load(Ordering::SeqCst) {
274 tracing::info!(
275 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
276 model_id
277 );
278 *self.pending_model_switch.write() = Some(PendingModelSwitch {
279 model_id: model_id.to_string(),
280 provider: new_provider,
281 needs_transform,
282 old_api,
283 new_api: new_model.api,
284 });
285 // Update config immediately so model_id() returns the new value,
286 // but leave provider unchanged so the running loop keeps its provider.
287 {
288 let mut inner = self.inner_mut();
289 inner.config.model_id = model_id.to_string();
290 inner.config.api_key = api_key;
291 }
292 return Ok(());
293 }
294
295 // Agent is idle — apply immediately.
296 if needs_transform {
297 let messages = self.state.get_state().messages.clone();
298 let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
299 self.state.update(|s| {
300 s.replace_messages(transformed);
301 });
302 }
303
304 // Update config and provider atomically
305 let mut inner = self.inner_mut();
306 inner.config.model_id = model_id.to_string();
307 inner.config.api_key = api_key;
308 inner.provider = new_provider;
309
310 Ok(())
311 }
312
313 /// Switch the model using a pre-resolved `Model` object.
314 ///
315 /// This is useful when the caller has already looked up the model
316 /// and optionally created the provider.
317 ///
318 /// Like [`switch_model`], if the agent is currently running, the switch
319 /// is deferred until the current loop completes.
320 ///
321 /// [`switch_model`]: Agent::switch_model
322 pub fn switch_to_model(&self, model: &oxi_ai::Model, api_key: Option<String>) -> Result<()> {
323 let model_id = format!("{}/{}", model.provider, model.id);
324 let new_provider = self
325 .resolver
326 .resolve_provider(&model.provider)
327 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
328
329 // Detect API change
330 let (old_api, needs_transform) = {
331 let inner = self.config();
332 let old_api = self
333 .resolver
334 .resolve_model(&inner.config.model_id)
335 .map(|m| m.api)
336 .unwrap_or(oxi_ai::Api::AnthropicMessages);
337 (old_api, old_api != model.api)
338 };
339
340 // If the agent is currently running, defer the switch.
341 if self.is_running.load(Ordering::SeqCst) {
342 tracing::info!(
343 "[AGENT] Agent running, deferring model switch to '{}' until loop completes",
344 model_id
345 );
346 *self.pending_model_switch.write() = Some(PendingModelSwitch {
347 model_id: model_id.clone(),
348 provider: new_provider,
349 needs_transform,
350 old_api,
351 new_api: model.api,
352 });
353 let mut inner = self.inner_mut();
354 inner.config.model_id = model_id;
355 inner.config.api_key = api_key;
356 return Ok(());
357 }
358
359 // Agent is idle — apply immediately.
360 if needs_transform {
361 let messages = self.state.get_state().messages.clone();
362 let transformed = transform_for_provider(&messages, &old_api, &model.api);
363 self.state.update(|s| {
364 s.replace_messages(transformed);
365 });
366 }
367
368 let mut inner = self.inner_mut();
369 inner.config.model_id = model_id;
370 inner.config.api_key = api_key;
371 inner.provider = new_provider;
372
373 Ok(())
374 }
375
376 /// Refresh only the API key without changing model or provider.
377 /// Useful when the user stores a key after the session was already created.
378 pub fn refresh_api_key(&self, api_key: Option<String>) {
379 let mut inner = self.inner_mut();
380 inner.config.api_key = api_key;
381 }
382
383 /// Get a handle to the tool registry.
384 pub fn tools(&self) -> Arc<ToolRegistry> {
385 Arc::clone(&self.tools)
386 }
387
388 /// Get a snapshot of the current agent state.
389 pub fn state(&self) -> AgentState {
390 self.state.get_state()
391 }
392
393 /// Update agent state in-place. Used by compaction to replace messages.
394 pub fn update_state(&self, f: impl FnOnce(&mut AgentState)) {
395 self.state.update(f);
396 }
397
398 /// Reset agent state for a new conversation
399 pub fn reset(&self) {
400 self.state.reset();
401 }
402
403 /// Register a tool that the agent can invoke during a run.
404 pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
405 self.tools.register(tool);
406 }
407
408 /// Update the system prompt for future interactions.
409 pub fn set_system_prompt(&self, prompt: String) {
410 self.inner_mut().config.system_prompt = Some(prompt);
411 }
412
413 /// Get the compaction manager
414 pub fn compaction_manager(&self) -> &CompactionManager {
415 &self.compaction_manager
416 }
417 /// Update the compaction strategy for future runs.
418 ///
419 /// The strategy is read fresh from the config at the start of each run
420 /// (see `run_with_channel_inner`), so this takes effect on the next
421 /// agent turn — never mid-run. Pair with `compaction_manager()` for
422 /// manual compaction, which is unaffected by the strategy.
423 pub fn set_compaction_strategy(&self, strategy: oxi_ai::CompactionStrategy) {
424 self.inner.write().config.compaction_strategy = strategy;
425 }
426 /// Get the compaction strategy that will be used on the next run.
427 ///
428 /// This reads from `inner.config` (mutable via `set_compaction_strategy`),
429 /// **not** from the `compaction_manager` field (which retains its
430 /// construction-time strategy). The agent loop reads from config fresh
431 /// each run, so this is the authoritative value.
432 pub fn compaction_strategy(&self) -> oxi_ai::CompactionStrategy {
433 self.inner.read().config.compaction_strategy.clone()
434 }
435
436 /// Run the agent with a prompt, collecting all events into a vector.
437 ///
438 /// Convenience wrapper around [`run_with_channel`](Self::run_with_channel) that gathers every
439 /// [`AgentEvent`] produced during the run.
440 pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
441 let mut events = Vec::new();
442 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
443 let result = self.run_with_channel(prompt, tx).await;
444 while let Ok(event) = rx.recv() {
445 events.push(event);
446 }
447 result.map(|r| (r, events))
448 }
449
450 /// Run the agent, delivering events through the provided channel.
451 ///
452 /// Delegates to the agent loop which implements the same 2-level agentic
453 /// loop matching pi-mono's architecture:
454 ///
455 /// ```text
456 /// AgentLoop.run_messages()
457 /// Outer loop (follow-up messages):
458 /// Inner loop (tool calls + steering):
459 /// 1. Inject pending messages (steering)
460 /// 2. Compaction check
461 /// 3. Stream LLM response (with accumulated partial messages)
462 /// 4. Execute tool calls if any
463 /// 5. Emit turn_end
464 /// 6. Check shouldStopAfterTurn
465 /// 7. Poll steering messages
466 /// Check follow-up messages
467 /// Exit
468 /// ```
469 pub async fn run_with_channel(
470 &self,
471 prompt: String,
472 tx: std::sync::mpsc::Sender<AgentEvent>,
473 ) -> Result<Response> {
474 // pi-mono: Agent.prompt() throws if activeRun exists.
475 // Prevent concurrent runs that would corrupt shared state.
476 if self
477 .is_running
478 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
479 .is_err()
480 {
481 return Err(Error::msg("Agent is already running"));
482 }
483
484 // Drop guard ensures is_running is cleared even on panic.
485 struct RunningGuard<'a>(&'a AtomicBool);
486 impl Drop for RunningGuard<'_> {
487 fn drop(&mut self) {
488 self.0.store(false, Ordering::SeqCst);
489 }
490 }
491 let _guard = RunningGuard(&self.is_running);
492 self.reset_cancel();
493
494 self.run_with_channel_inner(prompt, tx).await
495 }
496
497 /// Inner implementation of run_with_channel, called after the running guard is set.
498 async fn run_with_channel_inner(
499 &self,
500 prompt: String,
501 tx: std::sync::mpsc::Sender<AgentEvent>,
502 ) -> Result<Response> {
503 use crate::agent_loop::AgentLoop;
504
505 let (
506 provider,
507 system_prompt,
508 temperature,
509 max_tokens,
510 compaction_strategy,
511 context_window,
512 api_key,
513 workspace_dir,
514 ) = {
515 let inner = self.inner.read();
516 (
517 Arc::clone(&inner.provider) as Arc<dyn Provider>,
518 inner.config.system_prompt.clone(),
519 inner.config.temperature,
520 inner.config.max_tokens,
521 inner.config.compaction_strategy.clone(),
522 inner.config.context_window,
523 inner.config.api_key.clone(),
524 inner.config.workspace_dir.clone(),
525 )
526 }; // release read lock
527
528 // Build AgentLoopConfig from Agent's config
529 let loop_config = crate::agent_loop::config::AgentLoopConfig {
530 model_id: self.model_id(),
531 system_prompt,
532 temperature: temperature.unwrap_or(1.0) as f32,
533 max_tokens: max_tokens.unwrap_or(4096) as u32,
534 tool_execution: crate::config::ToolExecutionMode::Sequential,
535 compaction_strategy,
536 compaction_instruction: None,
537 context_window,
538 session_id: self.config().config.session_id.clone(),
539 transport: None,
540 compact_on_start: false,
541 max_retry_delay_ms: None,
542 auto_retry_enabled: true,
543 auto_retry_max_attempts: 3,
544 auto_retry_base_delay_ms: 1000,
545 api_key,
546 workspace_dir,
547 provider_options: self.config().config.provider_options.clone(),
548 on_compaction: None,
549 ttsr_engine: self.config().config.ttsr_engine.clone(),
550 memory: self.config().config.memory.clone(),
551 todo: self.config().config.todo.clone(),
552 agent_pool: self.config().config.agent_pool.clone(),
553 ..Default::default()
554 };
555
556 // Create AgentLoop. We give it a NEW SharedState and sync back after.
557 // (SharedState is not Clone, so we create a fresh one from current state)
558 let fresh_state = crate::state::SharedState::new();
559 let current = self.state.get_state();
560 fresh_state.update(|s| {
561 *s = current;
562 });
563
564 let mut agent_loop = AgentLoop::new_with_resolver(
565 provider,
566 loop_config,
567 Arc::clone(&self.tools),
568 fresh_state,
569 Arc::clone(&self.resolver),
570 );
571
572 // Add the user prompt to Agent.state() AFTER fresh_state is created.
573 // fresh_state got a copy of the pre-prompt state, so run_loop will
574 // add the prompt to fresh_state independently via initial_prompts.
575 // But persist_session() reads Agent.state() (not fresh_state), so it
576 // needs the user prompt there to write it to the session file.
577 // Sync happens at AgentEnd (after run_loop completes), where
578 // Agent.state is overwritten with fresh_state (which has all messages).
579 self.state.update(|s| {
580 s.messages
581 .push(oxi_ai::Message::User(oxi_ai::UserMessage::new(
582 prompt.clone(),
583 )));
584 });
585
586 // Pre-populate steering/follow-up from hooks
587 {
588 let hooks = self.hooks.read();
589 if let Some(ref get_steering) = hooks.get_steering_messages {
590 for msg_text in get_steering() {
591 agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
592 }
593 }
594 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
595 for msg_text in get_follow_up() {
596 agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
597 }
598 }
599
600 // Store hooks on AgentLoop so they can be polled each turn
601 // to pick up new messages injected during the run.
602 if let Some(ref get_steering) = hooks.get_steering_messages {
603 agent_loop.set_steering_hook(Arc::clone(get_steering));
604 }
605 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
606 agent_loop.set_follow_up_hook(Arc::clone(get_follow_up));
607 }
608 }
609 let mut al = agent_loop;
610
611 // Wire should_stop_after_turn hook: share AgentLoop's external_stop
612 // Arc with the emit callback. When the hook fires (Ctrl+C detected),
613 // it sets ext_stop. AgentLoop checks this in should_stop_after_turn()
614 // AND during streaming (streaming.rs checks external_stop each event).
615 //
616 // Arc<dyn Fn> can be cloned, so we read it without consuming.
617 let maybe_hook = {
618 let hooks_r = self.hooks.read();
619 hooks_r.should_stop_after_turn.clone()
620 };
621 let ext_stop = al.external_stop().clone();
622 let cancel_flag = self.cancel_flag.clone();
623
624 // Share cancel_flag with AgentLoop so the streaming loop can check
625 // it directly in the periodic timer — no emit callback required.
626 // This closes the gap where cancel() was ineffective when the
627 // provider stream produced no events.
628 al.set_cancel_signal(self.cancel_flag.clone());
629
630 // Create emit callback that sends through the channel.
631 // AgentLoop calls this synchronously. UnboundedSender::send() is
632 // non-blocking and never drops events (unlike try_send on bounded).
633 let tx_emit = tx.clone();
634
635 // Snapshot the observability_dispatch list once per run. This avoids
636 // holding an Agent lock on the emit-fn hot path while still letting
637 // SDK consumers register new dispatchers at any time (registers after
638 // this snapshot will fire on the next run).
639 let dispatch_handlers: Vec<EventDispatchFn> =
640 { self.inner.read().observability_dispatch.lock().clone() };
641 tracing::info!("[AGENT] Starting agent run with channel");
642 let result = al
643 .run(prompt.clone(), move |event: AgentEvent| {
644 // Forward event to channel (std::sync::mpsc — send from sync context)
645 tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
646 if let Err(e) = tx_emit.send(event.clone()) {
647 tracing::error!(
648 "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
649 e
650 );
651 } else {
652 tracing::info!("[AGENT-EMIT] Successfully sent event");
653 }
654
655 // Propagate cancellation from Agent::cancel() → external_stop.
656 // This runs on every event, ensuring the streaming loop detects
657 // cancellation promptly.
658 if cancel_flag.load(Ordering::SeqCst) {
659 ext_stop.store(true, Ordering::SeqCst);
660 }
661
662 // Fan out to SDK-side observability handlers (Tracer,
663 // CostTracker, ...). The dispatch list is snapshotted at
664 // run-start so we hold Arc clones, not a lock. This means
665 // handlers added mid-run do not fire until the next run.
666 for handler in dispatch_handlers.iter() {
667 handler(event.clone());
668 }
669 // Propagate should_stop → external_stop on every event, not
670 // just TurnEnd. The TUI hook only checks should_stop_flag.load(),
671 // so the context contents are irrelevant for non-TurnEnd events.
672 // This ensures streaming.rs detects cancellation immediately
673 // when the user presses Ctrl+C mid-stream.
674 if let Some(ref hook) = maybe_hook {
675 let ctx = ShouldStopAfterTurnContext {
676 message: match &event {
677 AgentEvent::TurnEnd {
678 assistant_message: oxi_ai::Message::Assistant(a),
679 ..
680 } => a.clone(),
681 _ => oxi_ai::AssistantMessage::new(
682 oxi_ai::Api::OpenAiCompletions,
683 "agent",
684 "agent-model",
685 ),
686 },
687 tool_results: match &event {
688 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
689 _ => Vec::new(),
690 },
691 iteration: 0,
692 };
693 if hook(&ctx) {
694 ext_stop.store(true, Ordering::SeqCst);
695 }
696 }
697 })
698 .await;
699
700 match result {
701 Ok(_events) => {
702 // Sync state back from AgentLoop
703 let loop_state = al.state().get_state();
704 self.state.update(|s| {
705 *s = loop_state;
706 });
707
708 // Apply any pending model switch that was deferred during the run.
709 // This transforms messages (if cross-provider) and swaps the provider
710 // so the next run uses the new model.
711 self.apply_pending_model_switch();
712
713 // Extract final response text from state
714 let state = self.state.get_state();
715 let final_text = state
716 .messages
717 .iter()
718 .rev()
719 .find_map(|m| match m {
720 oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
721 oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
722 _ => None,
723 }),
724 _ => None,
725 })
726 .unwrap_or_default();
727
728 let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
729
730 Ok(Response {
731 content: final_text,
732 stop_reason,
733 })
734 }
735 Err(e) => {
736 // Apply pending model switch even on error so the next run
737 // uses the new model.
738 self.apply_pending_model_switch();
739 Err(e)
740 }
741 }
742 }
743
744 // ── Helper methods for the agentic loop ────────────────────────
745
746 /// Set hooks for the agent loop.
747 pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
748 let mut h = self.hooks.write();
749 *h = hooks;
750 }
751
752 /// Register a side-dispatch closure called for every `AgentEvent`
753 /// emitted by `run`, `run_with_channel`, `run_streaming`,
754 /// `run_tokio_stream`, and `continue_with`.
755 ///
756 /// Multiple calls stack: every registered closure is invoked on
757 /// every event. Closures run synchronously on the agent-loop emit
758 /// thread, so they must be cheap and non-blocking. Long work
759 /// should be spawned off (e.g. `tokio::spawn`) by the closure
760 /// itself.
761 ///
762 /// Used by `oxi-sdk` to bridge observability types
763 /// (`Tracer`, `CostTracker`, `AuditLog`, `Authorizer` /
764 /// `AccessGate`) into the runtime without leaking those types
765 /// into `oxi-agent`.
766 ///
767 /// # Example
768 ///
769 /// ```ignore
770 /// agent.add_observability_dispatch(|event| match event {
771 /// AgentEvent::TurnStart { turn_number } => {
772 /// // open a span
773 /// }
774 /// AgentEvent::Usage { input_tokens, output_tokens } => {
775 /// // record cost
776 /// }
777 /// _ => {}
778 /// });
779 /// ```
780 pub fn add_observability_dispatch(&self, f: impl Fn(AgentEvent) + Send + Sync + 'static) {
781 let guard = self.inner.write();
782 let mut slot = guard.observability_dispatch.lock();
783 slot.push(Arc::new(f));
784 }
785
786 /// Request cancellation of the current agent run.
787 ///
788 /// Sets a shared `cancel_flag` that is propagated to the `AgentLoop`'s
789 /// `external_stop` on every event AND polled every ~500ms by the
790 /// streaming loop's periodic check. This ensures cancellation is
791 /// detected quickly even when the provider stream is completely hung
792 /// (no events arriving).
793 pub fn cancel(&self) {
794 self.cancel_flag.store(true, Ordering::SeqCst);
795 }
796
797 /// Reset the cancellation flag before starting a new run.
798 pub fn reset_cancel(&self) {
799 self.cancel_flag.store(false, Ordering::SeqCst);
800 }
801
802 /// Apply any pending model switch that was deferred during a running loop.
803 ///
804 /// Called after `run_with_channel_inner` completes (success or error).
805 /// Transforms messages for cross-provider switches and swaps the provider
806 /// so the next run uses the new model.
807 fn apply_pending_model_switch(&self) {
808 let pending = self.pending_model_switch.write().take();
809 if let Some(pending) = pending {
810 tracing::info!(
811 "[AGENT] Applying deferred model switch to '{}' (transform={})",
812 pending.model_id,
813 pending.needs_transform
814 );
815
816 // Transform messages if cross-provider
817 if pending.needs_transform {
818 let messages = self.state.get_state().messages.clone();
819 let transformed =
820 transform_for_provider(&messages, &pending.old_api, &pending.new_api);
821 self.state.update(|s| {
822 s.replace_messages(transformed);
823 });
824 }
825
826 // Swap the provider
827 let mut inner = self.inner_mut();
828 inner.provider = pending.provider;
829 // model_id was already updated in switch_model()
830 }
831 }
832
833 /// Run the agent, invoking `on_event` for each [`AgentEvent`] produced.
834 ///
835 /// Blocking convenience wrapper suitable for callers that prefer a
836 /// callback-based API over a channel.
837 pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
838 where
839 F: FnMut(AgentEvent) + Send,
840 {
841 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
842 let result = self.run_with_channel(prompt, tx).await;
843 while let Ok(event) = rx.recv() {
844 on_event(event);
845 }
846 result
847 }
848
849 // ── Session persistence ────────────────────────────────────────
850
851 /// Export the agent state as a JSON value.
852 ///
853 /// The serialized state includes conversation messages, token counts,
854 /// iteration progress, and stop reason. Use [`import_state`] to restore.
855 ///
856 /// [`import_state`]: Agent::import_state
857 pub fn export_state(&self) -> Result<serde_json::Value> {
858 let state = self.state.get_state();
859 serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
860 }
861
862 /// Import agent state from a JSON value.
863 ///
864 /// Restores conversation history, token counts, and iteration progress.
865 /// Typically used together with [`export_state`] for session persistence.
866 ///
867 /// [`export_state`]: Agent::export_state
868 pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
869 let state: AgentState = serde_json::from_value(value)
870 .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
871 self.state.update(|s| *s = state);
872 Ok(())
873 }
874
875 // ── Session continuation ───────────────────────────────────────
876
877 /// Continue the current session with a new prompt.
878 ///
879 /// Unlike `run()`, which can be used on a fresh agent, `continue_with`
880 /// preserves the existing conversation state and appends the new prompt.
881 /// This enables multi-turn interactions within the same session.
882 pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
883 let mut events = Vec::new();
884 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
885 let result = self.run_with_channel(prompt, tx).await;
886 while let Ok(event) = rx.recv() {
887 events.push(event);
888 }
889 result.map(|r| (r, events))
890 }
891
892 // ── Tokio-native streaming ─────────────────────────────────────
893
894 /// Run the agent with tokio-native event streaming.
895 ///
896 /// Returns a `tokio::sync::mpsc::Receiver` for events and a
897 /// `JoinHandle` for the response. This is the preferred API for
898 /// async runtimes (WebSocket/SSE gateways, tokio-based servers).
899 ///
900 /// # Example
901 ///
902 /// ```ignore
903 /// let (rx, handle) = agent.run_tokio_stream("Explain Rust".into()).await?;
904 /// while let Some(event) = rx.recv().await {
905 /// println!("Event: {:?}", event.type_name());
906 /// }
907 /// let response = handle.await??;
908 /// ```
909 pub async fn run_tokio_stream(
910 &self,
911 prompt: String,
912 ) -> Result<(
913 tokio::sync::mpsc::Receiver<AgentEvent>,
914 tokio::task::JoinHandle<Result<Response>>,
915 )> {
916 let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
917
918 if self
919 .is_running
920 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
921 .is_err()
922 {
923 return Err(Error::msg("Agent is already running"));
924 }
925
926 let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
927
928 let inner = self.inner.read().clone();
929 let tools = Arc::clone(&self.tools);
930 let resolver = Arc::clone(&self.resolver);
931
932 // Build AgentLoopConfig
933 let loop_config = crate::agent_loop::config::AgentLoopConfig {
934 model_id: inner.config.model_id.clone(),
935 system_prompt: inner.config.system_prompt.clone(),
936 temperature: inner.config.temperature.unwrap_or(1.0) as f32,
937 max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
938 tool_execution: crate::config::ToolExecutionMode::Sequential,
939 compaction_strategy: inner.config.compaction_strategy.clone(),
940 compaction_instruction: None,
941 context_window: inner.config.context_window,
942 session_id: inner.config.session_id.clone(),
943 transport: None,
944 compact_on_start: false,
945 max_retry_delay_ms: None,
946 auto_retry_enabled: true,
947 auto_retry_max_attempts: 3,
948 auto_retry_base_delay_ms: 1000,
949 api_key: inner.config.api_key.clone(),
950 workspace_dir: inner.config.workspace_dir.clone(),
951 provider_options: inner.config.provider_options.clone(),
952 on_compaction: None,
953 ttsr_engine: inner.config.ttsr_engine.clone(),
954 ..Default::default()
955 };
956
957 let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
958
959 // Share the SAME SharedState (Arc<RwLock<AgentState>>) with the
960 // agent loop so that state mutations inside the spawned task are
961 // visible through self.state() without an explicit sync step.
962 //
963 // Unlike run_with_channel_inner which creates a fresh SharedState
964 // and syncs back on completion, the tokio streaming API cannot
965 // access `self` inside the `'static` spawned task, so we share
966 // the underlying Arc instead.
967 //
968 // Pre-load current state into the shared Arc (in case it was
969 // modified by a previous run that used a different SharedState).
970 let shared_state = self.state.clone();
971
972 let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
973 provider,
974 loop_config,
975 tools,
976 shared_state.clone(),
977 resolver,
978 );
979
980 let maybe_hook = should_stop_hook;
981 let ext_stop = agent_loop.external_stop().clone();
982
983 // Clone the is_running Arc so the spawned task can clear it.
984 let is_running_flag = Arc::clone(&self.is_running);
985
986 // Snapshot the observability_dispatch list before the spawned
987 // task. The future is `'static` and cannot borrow `&self`,
988 // so we take the snapshot at run-start on the regular borrow
989 // stack and move the resulting Arc-clones into the task.
990 let dispatch_handlers: Vec<EventDispatchFn> = {
991 let guard = self.inner.read();
992 guard.observability_dispatch.lock().clone()
993 };
994
995 let handle = tokio::task::spawn(async move {
996 let result = agent_loop
997 .run(prompt, move |event: AgentEvent| {
998 // Forward to tokio channel (non-blocking)
999 let _ = tx.try_send(event.clone());
1000
1001 // Fan out to SDK-side observability handlers
1002 // (Tracer, CostTracker, ...).
1003 for handler in dispatch_handlers.iter() {
1004 handler(event.clone());
1005 }
1006 // Propagate should_stop → external_stop on every event,
1007 // not just TurnEnd. See run_with_channel_inner for rationale.
1008 if let Some(ref hook) = maybe_hook {
1009 let ctx = ShouldStopAfterTurnContext {
1010 message: match &event {
1011 AgentEvent::TurnEnd {
1012 assistant_message: oxi_ai::Message::Assistant(a),
1013 ..
1014 } => a.clone(),
1015 _ => oxi_ai::AssistantMessage::new(
1016 oxi_ai::Api::OpenAiCompletions,
1017 "agent",
1018 "agent-model",
1019 ),
1020 },
1021 tool_results: match &event {
1022 AgentEvent::TurnEnd { tool_results, .. } => tool_results.clone(),
1023 _ => Vec::new(),
1024 },
1025 iteration: 0,
1026 };
1027 if hook(&ctx) {
1028 ext_stop.store(true, Ordering::SeqCst);
1029 }
1030 }
1031 })
1032 .await;
1033
1034 // Clear the Agent's running flag
1035 is_running_flag.store(false, Ordering::SeqCst);
1036
1037 match result {
1038 Ok(_events) => {
1039 // State is already shared via the same SharedState Arc,
1040 // so self.state() will reflect all mutations.
1041 Ok(Response {
1042 content: String::new(),
1043 stop_reason: StopReason::Stop,
1044 })
1045 }
1046 Err(e) => Err(e),
1047 }
1048 });
1049
1050 Ok((rx, handle))
1051 }
1052}