Skip to main content

fierros_core/
llm.rs

1use crate::{FierrosError, FierrosResult};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum MessageRole {
7    System,
8    User,
9    Assistant,
10    Tool,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Message {
15    pub role: MessageRole,
16    pub content: String,
17}
18
19impl Message {
20    pub fn system(content: impl Into<String>) -> Self {
21        Self {
22            role: MessageRole::System,
23            content: content.into(),
24        }
25    }
26
27    pub fn user(content: impl Into<String>) -> Self {
28        Self {
29            role: MessageRole::User,
30            content: content.into(),
31        }
32    }
33
34    pub fn assistant(content: impl Into<String>) -> Self {
35        Self {
36            role: MessageRole::Assistant,
37            content: content.into(),
38        }
39    }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct CompletionRequest {
44    pub messages: Vec<Message>,
45    pub temperature: f32,
46    pub max_tokens: Option<u32>,
47}
48
49impl CompletionRequest {
50    pub fn from_user(content: impl Into<String>) -> Self {
51        Self {
52            messages: vec![Message::user(content)],
53            temperature: 0.0,
54            max_tokens: None,
55        }
56    }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct TokenUsage {
61    pub input_tokens: u32,
62    pub output_tokens: u32,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct CompletionResponse {
67    pub content: String,
68    pub usage: Option<TokenUsage>,
69}
70
71#[async_trait]
72pub trait Llm: Send + Sync {
73    async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse>;
74}
75
76#[derive(Debug, Clone)]
77pub struct MockLlm {
78    behavior: MockLlmBehavior,
79}
80
81#[derive(Debug, Clone)]
82enum MockLlmBehavior {
83    Success {
84        response: String,
85        usage: Option<TokenUsage>,
86    },
87    Failure(FierrosError),
88}
89
90impl MockLlm {
91    pub fn new(response: impl Into<String>) -> Self {
92        Self {
93            behavior: MockLlmBehavior::Success {
94                response: response.into(),
95                usage: Some(TokenUsage {
96                    input_tokens: 0,
97                    output_tokens: 0,
98                }),
99            },
100        }
101    }
102
103    pub fn failing(error: FierrosError) -> Self {
104        Self {
105            behavior: MockLlmBehavior::Failure(error),
106        }
107    }
108
109    pub fn with_usage(mut self, usage: TokenUsage) -> Self {
110        if let MockLlmBehavior::Success {
111            usage: current_usage,
112            ..
113        } = &mut self.behavior
114        {
115            *current_usage = Some(usage);
116        }
117        self
118    }
119}
120
121#[async_trait]
122impl Llm for MockLlm {
123    async fn complete(&self, _request: CompletionRequest) -> FierrosResult<CompletionResponse> {
124        match &self.behavior {
125            MockLlmBehavior::Success { response, usage } => Ok(CompletionResponse {
126                content: response.clone(),
127                usage: usage.clone(),
128            }),
129            MockLlmBehavior::Failure(error) => Err(error.clone()),
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::{CompletionRequest, Llm, Message, MessageRole, MockLlm, TokenUsage};
137    use crate::FierrosError;
138
139    #[test]
140    fn message_constructors_set_roles() {
141        assert_eq!(Message::user("x").content, "x");
142        assert_eq!(Message::system("s").content, "s");
143        assert_eq!(Message::assistant("a").role, MessageRole::Assistant);
144    }
145
146    #[tokio::test]
147    async fn mock_llm_returns_configured_response() {
148        let llm = MockLlm::new("answer");
149        let response = llm
150            .complete(CompletionRequest::from_user("question"))
151            .await
152            .unwrap();
153        assert_eq!(response.content, "answer");
154        assert_eq!(response.usage.unwrap().input_tokens, 0);
155    }
156
157    #[tokio::test]
158    async fn mock_llm_can_override_usage() {
159        let llm = MockLlm::new("answer").with_usage(TokenUsage {
160            input_tokens: 12,
161            output_tokens: 7,
162        });
163        let response = llm
164            .complete(CompletionRequest::from_user("question"))
165            .await
166            .unwrap();
167
168        assert_eq!(
169            response.usage,
170            Some(TokenUsage {
171                input_tokens: 12,
172                output_tokens: 7,
173            })
174        );
175    }
176
177    #[tokio::test]
178    async fn mock_llm_can_return_configured_error() {
179        let llm = MockLlm::failing(FierrosError::Provider("downstream unavailable".into()));
180        let error = llm
181            .complete(CompletionRequest::from_user("question"))
182            .await
183            .unwrap_err();
184        assert_eq!(
185            error,
186            FierrosError::Provider("downstream unavailable".into())
187        );
188    }
189
190    #[test]
191    fn completion_request_from_user_has_defaults() {
192        let request = CompletionRequest::from_user("question");
193        assert_eq!(request.messages.len(), 1);
194        assert_eq!(request.messages[0].role, MessageRole::User);
195        assert_eq!(request.temperature, 0.0);
196        assert_eq!(request.max_tokens, None);
197    }
198}