1use crate::{FierrosError, FierrosResult};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
6pub enum MessageRole {
7 System,
8 User,
9 Assistant,
10 Tool,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Message {
15 pub role: MessageRole,
16 pub content: String,
17}
18
19impl Message {
20 pub fn system(content: impl Into<String>) -> Self {
21 Self {
22 role: MessageRole::System,
23 content: content.into(),
24 }
25 }
26
27 pub fn user(content: impl Into<String>) -> Self {
28 Self {
29 role: MessageRole::User,
30 content: content.into(),
31 }
32 }
33
34 pub fn assistant(content: impl Into<String>) -> Self {
35 Self {
36 role: MessageRole::Assistant,
37 content: content.into(),
38 }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
43pub struct CompletionRequest {
44 pub messages: Vec<Message>,
45 pub temperature: f32,
46 pub max_tokens: Option<u32>,
47}
48
49impl CompletionRequest {
50 pub fn from_user(content: impl Into<String>) -> Self {
51 Self {
52 messages: vec![Message::user(content)],
53 temperature: 0.0,
54 max_tokens: None,
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct TokenUsage {
61 pub input_tokens: u32,
62 pub output_tokens: u32,
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub struct CompletionResponse {
67 pub content: String,
68 pub usage: Option<TokenUsage>,
69}
70
71#[async_trait]
72pub trait Llm: Send + Sync {
73 async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse>;
74}
75
76#[derive(Debug, Clone)]
77pub struct MockLlm {
78 behavior: MockLlmBehavior,
79}
80
81#[derive(Debug, Clone)]
82enum MockLlmBehavior {
83 Success {
84 response: String,
85 usage: Option<TokenUsage>,
86 },
87 Failure(FierrosError),
88}
89
90impl MockLlm {
91 pub fn new(response: impl Into<String>) -> Self {
92 Self {
93 behavior: MockLlmBehavior::Success {
94 response: response.into(),
95 usage: Some(TokenUsage {
96 input_tokens: 0,
97 output_tokens: 0,
98 }),
99 },
100 }
101 }
102
103 pub fn failing(error: FierrosError) -> Self {
104 Self {
105 behavior: MockLlmBehavior::Failure(error),
106 }
107 }
108
109 pub fn with_usage(mut self, usage: TokenUsage) -> Self {
110 if let MockLlmBehavior::Success {
111 usage: current_usage,
112 ..
113 } = &mut self.behavior
114 {
115 *current_usage = Some(usage);
116 }
117 self
118 }
119}
120
121#[async_trait]
122impl Llm for MockLlm {
123 async fn complete(&self, _request: CompletionRequest) -> FierrosResult<CompletionResponse> {
124 match &self.behavior {
125 MockLlmBehavior::Success { response, usage } => Ok(CompletionResponse {
126 content: response.clone(),
127 usage: usage.clone(),
128 }),
129 MockLlmBehavior::Failure(error) => Err(error.clone()),
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::{CompletionRequest, Llm, Message, MessageRole, MockLlm, TokenUsage};
137 use crate::FierrosError;
138
139 #[test]
140 fn message_constructors_set_roles() {
141 assert_eq!(Message::user("x").content, "x");
142 assert_eq!(Message::system("s").content, "s");
143 assert_eq!(Message::assistant("a").role, MessageRole::Assistant);
144 }
145
146 #[tokio::test]
147 async fn mock_llm_returns_configured_response() {
148 let llm = MockLlm::new("answer");
149 let response = llm
150 .complete(CompletionRequest::from_user("question"))
151 .await
152 .unwrap();
153 assert_eq!(response.content, "answer");
154 assert_eq!(response.usage.unwrap().input_tokens, 0);
155 }
156
157 #[tokio::test]
158 async fn mock_llm_can_override_usage() {
159 let llm = MockLlm::new("answer").with_usage(TokenUsage {
160 input_tokens: 12,
161 output_tokens: 7,
162 });
163 let response = llm
164 .complete(CompletionRequest::from_user("question"))
165 .await
166 .unwrap();
167
168 assert_eq!(
169 response.usage,
170 Some(TokenUsage {
171 input_tokens: 12,
172 output_tokens: 7,
173 })
174 );
175 }
176
177 #[tokio::test]
178 async fn mock_llm_can_return_configured_error() {
179 let llm = MockLlm::failing(FierrosError::Provider("downstream unavailable".into()));
180 let error = llm
181 .complete(CompletionRequest::from_user("question"))
182 .await
183 .unwrap_err();
184 assert_eq!(
185 error,
186 FierrosError::Provider("downstream unavailable".into())
187 );
188 }
189
190 #[test]
191 fn completion_request_from_user_has_defaults() {
192 let request = CompletionRequest::from_user("question");
193 assert_eq!(request.messages.len(), 1);
194 assert_eq!(request.messages[0].role, MessageRole::User);
195 assert_eq!(request.temperature, 0.0);
196 assert_eq!(request.max_tokens, None);
197 }
198}