Skip to main content

rain_engine_core/
llm.rs

1use crate::{ProviderDecision, ProviderRequest};
2use async_trait::async_trait;
3use std::collections::VecDeque;
4use std::sync::{Arc, Mutex};
5use thiserror::Error;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum ProviderErrorKind {
9    Timeout,
10    Transport,
11    RateLimited,
12    InvalidResponse,
13    Configuration,
14    Internal,
15}
16
17#[derive(Debug, Error, Clone, PartialEq, Eq)]
18#[error("{kind:?}: {message}")]
19pub struct ProviderError {
20    pub kind: ProviderErrorKind,
21    pub message: String,
22    pub retryable: bool,
23}
24
25impl ProviderError {
26    pub fn new(kind: ProviderErrorKind, message: impl Into<String>, retryable: bool) -> Self {
27        Self {
28            kind,
29            message: message.into(),
30            retryable,
31        }
32    }
33
34    pub fn internal(message: impl Into<String>) -> Self {
35        Self::new(ProviderErrorKind::Internal, message, false)
36    }
37}
38
39#[async_trait]
40pub trait LlmProvider: Send + Sync {
41    async fn generate_action(
42        &self,
43        input: ProviderRequest,
44    ) -> Result<ProviderDecision, ProviderError>;
45}
46
47#[async_trait]
48pub trait EmbeddingProvider: Send + Sync {
49    async fn generate_embeddings(&self, texts: Vec<String>)
50    -> Result<Vec<Vec<f32>>, ProviderError>;
51}
52
53type DynamicResponder =
54    dyn Fn(ProviderRequest) -> Result<ProviderDecision, ProviderError> + Send + Sync;
55
56#[derive(Clone)]
57pub struct MockLlmProvider {
58    responder: Arc<DynamicResponder>,
59    observed_inputs: Arc<Mutex<Vec<ProviderRequest>>>,
60}
61
62impl MockLlmProvider {
63    pub fn scripted(actions: Vec<crate::AgentAction>) -> Self {
64        let queue = Arc::new(Mutex::new(VecDeque::from(actions)));
65        Self::dynamic(move |_input| {
66            let action = queue
67                .lock()
68                .expect("script queue poisoned")
69                .pop_front()
70                .ok_or_else(|| ProviderError::internal("mock script exhausted"))?;
71            Ok(ProviderDecision {
72                action,
73                usage: None,
74                cache: None,
75            })
76        })
77    }
78
79    pub fn dynamic<F>(responder: F) -> Self
80    where
81        F: Fn(ProviderRequest) -> Result<ProviderDecision, ProviderError> + Send + Sync + 'static,
82    {
83        Self {
84            responder: Arc::new(responder),
85            observed_inputs: Arc::new(Mutex::new(Vec::new())),
86        }
87    }
88
89    pub fn observed_inputs(&self) -> Vec<ProviderRequest> {
90        self.observed_inputs
91            .lock()
92            .expect("observed input lock poisoned")
93            .clone()
94    }
95}
96
97#[async_trait]
98impl LlmProvider for MockLlmProvider {
99    async fn generate_action(
100        &self,
101        input: ProviderRequest,
102    ) -> Result<ProviderDecision, ProviderError> {
103        self.observed_inputs
104            .lock()
105            .expect("observed input lock poisoned")
106            .push(input.clone());
107        (self.responder)(input)
108    }
109}