1use crate::{ProviderDecision, ProviderRequest};
2use async_trait::async_trait;
3use std::collections::VecDeque;
4use std::sync::{Arc, Mutex};
5use thiserror::Error;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum ProviderErrorKind {
9 Timeout,
10 Transport,
11 RateLimited,
12 InvalidResponse,
13 Configuration,
14 Internal,
15}
16
17#[derive(Debug, Error, Clone, PartialEq, Eq)]
18#[error("{kind:?}: {message}")]
19pub struct ProviderError {
20 pub kind: ProviderErrorKind,
21 pub message: String,
22 pub retryable: bool,
23}
24
25impl ProviderError {
26 pub fn new(kind: ProviderErrorKind, message: impl Into<String>, retryable: bool) -> Self {
27 Self {
28 kind,
29 message: message.into(),
30 retryable,
31 }
32 }
33
34 pub fn internal(message: impl Into<String>) -> Self {
35 Self::new(ProviderErrorKind::Internal, message, false)
36 }
37}
38
39#[async_trait]
40pub trait LlmProvider: Send + Sync {
41 async fn generate_action(
42 &self,
43 input: ProviderRequest,
44 ) -> Result<ProviderDecision, ProviderError>;
45}
46
47#[async_trait]
48pub trait EmbeddingProvider: Send + Sync {
49 async fn generate_embeddings(&self, texts: Vec<String>)
50 -> Result<Vec<Vec<f32>>, ProviderError>;
51}
52
53type DynamicResponder =
54 dyn Fn(ProviderRequest) -> Result<ProviderDecision, ProviderError> + Send + Sync;
55
56#[derive(Clone)]
57pub struct MockLlmProvider {
58 responder: Arc<DynamicResponder>,
59 observed_inputs: Arc<Mutex<Vec<ProviderRequest>>>,
60}
61
62impl MockLlmProvider {
63 pub fn scripted(actions: Vec<crate::AgentAction>) -> Self {
64 let queue = Arc::new(Mutex::new(VecDeque::from(actions)));
65 Self::dynamic(move |_input| {
66 let action = queue
67 .lock()
68 .expect("script queue poisoned")
69 .pop_front()
70 .ok_or_else(|| ProviderError::internal("mock script exhausted"))?;
71 Ok(ProviderDecision {
72 action,
73 usage: None,
74 cache: None,
75 })
76 })
77 }
78
79 pub fn dynamic<F>(responder: F) -> Self
80 where
81 F: Fn(ProviderRequest) -> Result<ProviderDecision, ProviderError> + Send + Sync + 'static,
82 {
83 Self {
84 responder: Arc::new(responder),
85 observed_inputs: Arc::new(Mutex::new(Vec::new())),
86 }
87 }
88
89 pub fn observed_inputs(&self) -> Vec<ProviderRequest> {
90 self.observed_inputs
91 .lock()
92 .expect("observed input lock poisoned")
93 .clone()
94 }
95}
96
97#[async_trait]
98impl LlmProvider for MockLlmProvider {
99 async fn generate_action(
100 &self,
101 input: ProviderRequest,
102 ) -> Result<ProviderDecision, ProviderError> {
103 self.observed_inputs
104 .lock()
105 .expect("observed input lock poisoned")
106 .push(input.clone());
107 (self.responder)(input)
108 }
109}