1use crate::config::AgentConfig;
3use crate::config::ShouldStopAfterTurnContext;
4use crate::events::AgentEvent;
5use crate::state::{AgentState, SharedState};
6use crate::tools::{AgentTool, ToolRegistry};
7use crate::types::{Response, StopReason};
8use anyhow::{Error, Result};
9use oxi_ai::{
10 transform_for_provider, CompactionManager, CompactionStrategy, LlmCompactor, Model, Provider,
11};
12use parking_lot::RwLock;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15
16pub trait ProviderResolver: Send + Sync + 'static {
26 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>>;
28
29 fn resolve_model(&self, model_id: &str) -> Option<Model>;
31}
32
33pub(crate) struct GlobalProviderResolver;
38
39impl ProviderResolver for GlobalProviderResolver {
40 fn resolve_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
41 oxi_ai::get_provider(name).map(Arc::from)
42 }
43
44 fn resolve_model(&self, model_id: &str) -> Option<Model> {
45 crate::model_id::resolve_model_from_id(model_id)
46 }
47}
48
49struct AgentInner {
53 config: AgentConfig,
54 provider: Arc<dyn Provider>,
55}
56
57impl Clone for AgentInner {
58 fn clone(&self) -> Self {
59 Self {
60 config: self.config.clone(),
61 provider: Arc::clone(&self.provider),
62 }
63 }
64}
65
66pub struct Agent {
77 inner: RwLock<AgentInner>,
78 tools: Arc<ToolRegistry>,
79 state: SharedState,
80 compaction_manager: CompactionManager,
81 hooks: parking_lot::RwLock<crate::config::AgentHooks>,
82 is_running: AtomicBool,
84 resolver: Arc<dyn ProviderResolver>,
87}
88
89impl Agent {
90 pub fn new(provider: Arc<dyn Provider>, config: AgentConfig, tools: Arc<ToolRegistry>) -> Self {
97 let resolver = Arc::new(GlobalProviderResolver);
98 Self::build_inner(provider, config, tools, resolver)
99 }
100
101 pub fn new_with_resolver(
106 provider: Arc<dyn Provider>,
107 config: AgentConfig,
108 tools: Arc<ToolRegistry>,
109 resolver: Arc<dyn ProviderResolver>,
110 ) -> Self {
111 Self::build_inner(provider, config, tools, resolver)
112 }
113
114 fn build_inner(
116 provider: Arc<dyn Provider>,
117 config: AgentConfig,
118 tools: Arc<ToolRegistry>,
119 resolver: Arc<dyn ProviderResolver>,
120 ) -> Self {
121 let mut compaction_manager =
122 CompactionManager::new(config.compaction_strategy.clone(), config.context_window);
123
124 if config.compaction_strategy != CompactionStrategy::Disabled {
126 let model = resolver.resolve_model(&config.model_id);
127
128 if let Some(model) = model {
129 let llm_compactor =
130 Arc::new(LlmCompactor::new(model.clone(), Arc::clone(&provider)));
131 compaction_manager.set_compactor(llm_compactor);
132 }
133 }
134
135 Self {
136 inner: RwLock::new(AgentInner { config, provider }),
137 tools,
138 state: SharedState::new(),
139 compaction_manager,
140 hooks: parking_lot::RwLock::new(crate::config::AgentHooks::default()),
141 is_running: AtomicBool::new(false),
142 resolver,
143 }
144 }
145
146 pub fn new_empty(provider: Arc<dyn Provider>, config: AgentConfig) -> Self {
148 Self::new(provider, config, Arc::new(ToolRegistry::new()))
149 }
150
151 fn config(&self) -> parking_lot::RwLockReadGuard<'_, AgentInner> {
153 self.inner.read()
154 }
155
156 fn inner_mut(&self) -> parking_lot::RwLockWriteGuard<'_, AgentInner> {
158 self.inner.write()
159 }
160
161 pub fn model_id(&self) -> String {
163 self.config().config.model_id.clone()
164 }
165
166 pub fn switch_model(&self, model_id: &str) -> Result<()> {
178 let new_model = self
179 .resolver
180 .resolve_model(model_id)
181 .ok_or_else(|| Error::msg(format!("Model '{}' not found", model_id)))?;
182
183 let new_provider = self
185 .resolver
186 .resolve_provider(&new_model.provider)
187 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", new_model.provider)))?;
188
189 {
191 let inner = self.config();
192 let old_model_id = &inner.config.model_id;
193 let old_api = self
194 .resolver
195 .resolve_model(old_model_id)
196 .map(|m| m.api)
197 .unwrap_or(oxi_ai::Api::AnthropicMessages);
198
199 if old_api != new_model.api {
200 let messages = self.state.get_state().messages.clone();
202 let transformed = transform_for_provider(&messages, &old_api, &new_model.api);
203 self.state.update(|s| {
204 s.replace_messages(transformed);
205 });
206 }
207 }
208
209 let mut inner = self.inner_mut();
211 inner.config.model_id = model_id.to_string();
212 inner.provider = new_provider;
213
214 Ok(())
215 }
216
217 pub fn switch_to_model(&self, model: &oxi_ai::Model) -> Result<()> {
222 let model_id = format!("{}/{}", model.provider, model.id);
223 let new_provider = self
224 .resolver
225 .resolve_provider(&model.provider)
226 .ok_or_else(|| Error::msg(format!("Provider '{}' not found", model.provider)))?;
227
228 {
230 let inner = self.config();
231 let old_api = self
232 .resolver
233 .resolve_model(&inner.config.model_id)
234 .map(|m| m.api)
235 .unwrap_or(oxi_ai::Api::AnthropicMessages);
236
237 if old_api != model.api {
238 let messages = self.state.get_state().messages.clone();
239 let transformed = transform_for_provider(&messages, &old_api, &model.api);
240 self.state.update(|s| {
241 s.replace_messages(transformed);
242 });
243 }
244 }
245
246 let mut inner = self.inner_mut();
247 inner.config.model_id = model_id;
248 inner.provider = new_provider;
249
250 Ok(())
251 }
252
253 pub fn tools(&self) -> Arc<ToolRegistry> {
255 Arc::clone(&self.tools)
256 }
257
258 pub fn state(&self) -> AgentState {
260 self.state.get_state()
261 }
262
263 pub fn reset(&self) {
265 self.state.reset();
266 }
267
268 pub fn add_tool<T: AgentTool + 'static>(&self, tool: T) {
270 self.tools.register(tool);
271 }
272
273 pub fn set_system_prompt(&self, prompt: String) {
275 self.inner_mut().config.system_prompt = Some(prompt);
276 }
277
278 pub fn compaction_manager(&self) -> &CompactionManager {
280 &self.compaction_manager
281 }
282
283 pub async fn run(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
288 let mut events = Vec::new();
289 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
290 let result = self.run_with_channel(prompt, tx).await;
291 while let Ok(event) = rx.recv() {
292 events.push(event);
293 }
294 result.map(|r| (r, events))
295 }
296
297 pub async fn run_with_channel(
317 &self,
318 prompt: String,
319 tx: std::sync::mpsc::Sender<AgentEvent>,
320 ) -> Result<Response> {
321 if self
324 .is_running
325 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
326 .is_err()
327 {
328 return Err(Error::msg("Agent is already running"));
329 }
330
331 let result = self.run_with_channel_inner(prompt, tx).await;
332
333 self.is_running.store(false, Ordering::SeqCst);
335 result
336 }
337
338 async fn run_with_channel_inner(
340 &self,
341 prompt: String,
342 tx: std::sync::mpsc::Sender<AgentEvent>,
343 ) -> Result<Response> {
344 use crate::agent_loop::AgentLoop;
345
346 let (
347 provider,
348 max_iterations,
349 system_prompt,
350 temperature,
351 max_tokens,
352 compaction_strategy,
353 context_window,
354 api_key,
355 workspace_dir,
356 ) = {
357 let inner = self.inner.read();
358 (
359 Arc::clone(&inner.provider) as Arc<dyn Provider>,
360 inner.config.max_iterations,
361 inner.config.system_prompt.clone(),
362 inner.config.temperature,
363 inner.config.max_tokens,
364 inner.config.compaction_strategy.clone(),
365 inner.config.context_window,
366 inner.config.api_key.clone(),
367 inner.config.workspace_dir.clone(),
368 )
369 }; let loop_config = crate::agent_loop::config::AgentLoopConfig {
373 model_id: self.model_id(),
374 system_prompt,
375 max_iterations,
376 temperature: temperature.unwrap_or(1.0) as f32,
377 max_tokens: max_tokens.unwrap_or(4096) as u32,
378 tool_execution: crate::config::ToolExecutionMode::Sequential,
379 compaction_strategy,
380 compaction_instruction: None,
381 context_window,
382 session_id: None,
383 transport: None,
384 compact_on_start: false,
385 max_retry_delay_ms: None,
386 auto_retry_enabled: true,
387 auto_retry_max_attempts: 3,
388 auto_retry_base_delay_ms: 1000,
389 api_key,
390 workspace_dir,
391 };
392
393 let fresh_state = crate::state::SharedState::new();
396 let current = self.state.get_state();
397 fresh_state.update(|s| {
398 *s = current;
399 });
400
401 let agent_loop = AgentLoop::new_with_resolver(
402 provider,
403 loop_config,
404 Arc::clone(&self.tools),
405 fresh_state,
406 Arc::clone(&self.resolver),
407 );
408
409 {
411 let hooks = self.hooks.read();
412 if let Some(ref get_steering) = hooks.get_steering_messages {
413 for msg_text in get_steering() {
414 agent_loop.steer(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
415 }
416 }
417 if let Some(ref get_follow_up) = hooks.get_follow_up_messages {
418 for msg_text in get_follow_up() {
419 agent_loop.follow_up(oxi_ai::Message::User(oxi_ai::UserMessage::new(msg_text)));
420 }
421 }
422 }
423 let al = agent_loop;
424
425 let maybe_hook = {
431 let hooks_r = self.hooks.read();
432 hooks_r.should_stop_after_turn.clone()
433 };
434 let ext_stop = al.external_stop().clone();
435
436 let tx_emit = tx.clone();
440
441 tracing::info!("[AGENT] Starting agent run with channel");
443 let result = al
444 .run(prompt.clone(), move |event: AgentEvent| {
445 tracing::info!("[AGENT-EMIT] Event: {:?}", std::mem::discriminant(&event));
447 if let Err(e) = tx_emit.send(event.clone()) {
448 tracing::error!(
449 "[AGENT-EMIT] Failed to send agent event to channel: {:?}",
450 e
451 );
452 } else {
453 tracing::info!("[AGENT-EMIT] Successfully sent event");
454 }
455
456 if let Some(ref hook) = maybe_hook {
461 if let AgentEvent::TurnEnd {
462 ref assistant_message,
463 ref tool_results,
464 ..
465 } = event
466 {
467 let asst = match assistant_message {
469 oxi_ai::Message::Assistant(a) => a.clone(),
470 _ => {
471 let ctx = ShouldStopAfterTurnContext {
473 message: oxi_ai::AssistantMessage::new(
474 oxi_ai::Api::OpenAiCompletions,
475 "agent",
476 "agent-model",
477 ),
478 tool_results: Vec::new(),
479 iteration: 0,
480 };
481 if hook(&ctx) {
482 ext_stop.store(true, Ordering::SeqCst);
483 }
484 return;
485 }
486 };
487 let ctx = ShouldStopAfterTurnContext {
488 message: asst,
489 tool_results: tool_results.clone(),
490 iteration: 0,
491 };
492 if hook(&ctx) {
493 ext_stop.store(true, Ordering::SeqCst);
494 }
495 }
496 }
497 })
498 .await;
499
500 match result {
501 Ok(_events) => {
502 let loop_state = al.state().get_state();
504 self.state.update(|s| {
505 *s = loop_state;
506 });
507
508 let state = self.state.get_state();
510 let final_text = state
511 .messages
512 .iter()
513 .rev()
514 .find_map(|m| match m {
515 oxi_ai::Message::Assistant(a) => a.content.iter().find_map(|b| match b {
516 oxi_ai::ContentBlock::Text(t) => Some(t.text.clone()),
517 _ => None,
518 }),
519 _ => None,
520 })
521 .unwrap_or_default();
522
523 let stop_reason = state.stop_reason.unwrap_or(StopReason::Stop);
524
525 Ok(Response {
526 content: final_text,
527 stop_reason,
528 })
529 }
530 Err(e) => Err(e),
531 }
532 }
533
534 pub fn set_hooks(&self, hooks: crate::config::AgentHooks) {
538 let mut h = self.hooks.write();
539 *h = hooks;
540 }
541
542 pub async fn run_streaming<F>(&self, prompt: String, mut on_event: F) -> Result<Response>
547 where
548 F: FnMut(AgentEvent) + Send,
549 {
550 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
551 let result = self.run_with_channel(prompt, tx).await;
552 while let Ok(event) = rx.recv() {
553 on_event(event);
554 }
555 result
556 }
557
558 pub fn export_state(&self) -> Result<serde_json::Value> {
567 let state = self.state.get_state();
568 serde_json::to_value(&state).map_err(|e| Error::msg(format!("State export failed: {}", e)))
569 }
570
571 pub fn import_state(&self, value: serde_json::Value) -> Result<()> {
578 let state: AgentState = serde_json::from_value(value)
579 .map_err(|e| Error::msg(format!("State import failed: {}", e)))?;
580 self.state.update(|s| *s = state);
581 Ok(())
582 }
583
584 pub async fn continue_with(&self, prompt: String) -> Result<(Response, Vec<AgentEvent>)> {
592 let mut events = Vec::new();
593 let (tx, rx) = std::sync::mpsc::channel::<AgentEvent>();
594 let result = self.run_with_channel(prompt, tx).await;
595 while let Ok(event) = rx.recv() {
596 events.push(event);
597 }
598 result.map(|r| (r, events))
599 }
600
601 pub async fn run_tokio_stream(
619 &self,
620 prompt: String,
621 ) -> Result<(
622 tokio::sync::mpsc::Receiver<AgentEvent>,
623 tokio::task::JoinHandle<Result<Response>>,
624 )> {
625 let (tx, rx) = tokio::sync::mpsc::channel::<AgentEvent>(256);
626
627 if self
628 .is_running
629 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
630 .is_err()
631 {
632 return Err(Error::msg("Agent is already running"));
633 }
634
635 let should_stop_hook = self.hooks.read().should_stop_after_turn.clone();
636
637 let state = self.state.clone();
638 let inner = self.inner.read().clone();
639 let tools = Arc::clone(&self.tools);
640 let resolver = Arc::clone(&self.resolver);
641
642 let loop_config = crate::agent_loop::config::AgentLoopConfig {
644 model_id: inner.config.model_id.clone(),
645 system_prompt: inner.config.system_prompt.clone(),
646 max_iterations: inner.config.max_iterations,
647 temperature: inner.config.temperature.unwrap_or(1.0) as f32,
648 max_tokens: inner.config.max_tokens.unwrap_or(4096) as u32,
649 tool_execution: crate::config::ToolExecutionMode::Sequential,
650 compaction_strategy: inner.config.compaction_strategy.clone(),
651 compaction_instruction: None,
652 context_window: inner.config.context_window,
653 session_id: None,
654 transport: None,
655 compact_on_start: false,
656 max_retry_delay_ms: None,
657 auto_retry_enabled: true,
658 auto_retry_max_attempts: 3,
659 auto_retry_base_delay_ms: 1000,
660 api_key: inner.config.api_key.clone(),
661 workspace_dir: inner.config.workspace_dir.clone(),
662 };
663
664 let provider: Arc<dyn Provider> = Arc::clone(&inner.provider);
665
666 let fresh_state = SharedState::new();
668 let current = state.get_state();
669 fresh_state.update(|s| *s = current);
670
671 let agent_loop = crate::agent_loop::AgentLoop::new_with_resolver(
672 provider,
673 loop_config,
674 tools,
675 fresh_state,
676 resolver,
677 );
678
679 let maybe_hook = should_stop_hook;
680 let ext_stop = agent_loop.external_stop().clone();
681
682 let is_running = Arc::new(AtomicBool::new(true));
683 let is_running_clone = Arc::clone(&is_running);
684
685 let handle = tokio::task::spawn(async move {
686 let result = agent_loop
687 .run(prompt, move |event: AgentEvent| {
688 let _ = tx.try_send(event.clone());
690
691 if let Some(ref hook) = maybe_hook {
692 if let AgentEvent::TurnEnd {
693 ref assistant_message,
694 ref tool_results,
695 ..
696 } = event
697 {
698 let asst = match assistant_message {
699 oxi_ai::Message::Assistant(a) => a.clone(),
700 _ => return,
701 };
702 let ctx = ShouldStopAfterTurnContext {
703 message: asst,
704 tool_results: tool_results.clone(),
705 iteration: 0,
706 };
707 if hook(&ctx) {
708 ext_stop.store(true, Ordering::SeqCst);
709 }
710 }
711 }
712 })
713 .await;
714
715 is_running_clone.store(false, Ordering::SeqCst);
717
718 match result {
719 Ok(_events) => {
720 Ok(Response {
723 content: String::new(),
724 stop_reason: StopReason::Stop,
725 })
726 }
727 Err(e) => Err(e),
728 }
729 });
730
731 Ok((rx, handle))
732 }
733}