autoagents_llm/completion/
mod.rs1use async_trait::async_trait;
2
3use crate::{
4 chat::{ChatResponse, StructuredOutputFormat},
5 error::LLMError,
6 ToolCall,
7};
8
9#[derive(Debug, Clone)]
11pub struct CompletionRequest {
12 pub prompt: String,
14 pub max_tokens: Option<u32>,
16 pub temperature: Option<f32>,
18}
19
20#[derive(Debug, Clone)]
22pub struct CompletionResponse {
23 pub text: String,
25}
26
27impl ChatResponse for CompletionResponse {
28 fn text(&self) -> Option<String> {
29 Some(self.text.clone())
30 }
31
32 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
33 None
34 }
35}
36
37impl CompletionRequest {
38 pub fn new(prompt: impl Into<String>) -> Self {
44 Self {
45 prompt: prompt.into(),
46 max_tokens: None,
47 temperature: None,
48 }
49 }
50
51 pub fn builder(prompt: impl Into<String>) -> CompletionRequestBuilder {
57 CompletionRequestBuilder {
58 prompt: prompt.into(),
59 max_tokens: None,
60 temperature: None,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct CompletionRequestBuilder {
68 pub prompt: String,
70 pub max_tokens: Option<u32>,
72 pub temperature: Option<f32>,
74}
75
76impl CompletionRequestBuilder {
77 pub fn max_tokens(mut self, val: u32) -> Self {
79 self.max_tokens = Some(val);
80 self
81 }
82
83 pub fn temperature(mut self, val: f32) -> Self {
85 self.temperature = Some(val);
86 self
87 }
88
89 pub fn build(self) -> CompletionRequest {
91 CompletionRequest {
92 prompt: self.prompt,
93 max_tokens: self.max_tokens,
94 temperature: self.temperature,
95 }
96 }
97}
98
99#[async_trait]
101pub trait CompletionProvider {
102 async fn complete(
112 &self,
113 req: &CompletionRequest,
114 json_schema: Option<StructuredOutputFormat>,
115 ) -> Result<CompletionResponse, LLMError>;
116}
117
118impl std::fmt::Display for CompletionResponse {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "{}", self.text)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::error::LLMError;
128
129 #[test]
130 fn test_completion_request_new() {
131 let request = CompletionRequest::new("Hello, world!");
132 assert_eq!(request.prompt, "Hello, world!");
133 assert!(request.max_tokens.is_none());
134 assert!(request.temperature.is_none());
135 }
136
137 #[test]
138 fn test_completion_request_builder() {
139 let request = CompletionRequest::builder("Test prompt")
140 .max_tokens(500)
141 .temperature(0.8)
142 .build();
143
144 assert_eq!(request.prompt, "Test prompt");
145 assert_eq!(request.max_tokens, Some(500));
146 assert_eq!(request.temperature, Some(0.8));
147 }
148
149 #[test]
150 fn test_completion_request_builder_partial() {
151 let request = CompletionRequest::builder("Partial test")
152 .max_tokens(100)
153 .build();
154
155 assert_eq!(request.prompt, "Partial test");
156 assert_eq!(request.max_tokens, Some(100));
157 assert!(request.temperature.is_none());
158 }
159
160 #[test]
161 fn test_completion_request_builder_chaining() {
162 let builder = CompletionRequest::builder("Chain test")
163 .max_tokens(200)
164 .temperature(0.5);
165
166 let request = builder.build();
167 assert_eq!(request.prompt, "Chain test");
168 assert_eq!(request.max_tokens, Some(200));
169 assert_eq!(request.temperature, Some(0.5));
170 }
171
172 #[test]
173 fn test_completion_request_clone() {
174 let request = CompletionRequest::new("Cloneable prompt");
175 let cloned = request.clone();
176
177 assert_eq!(request.prompt, cloned.prompt);
178 assert_eq!(request.max_tokens, cloned.max_tokens);
179 assert_eq!(request.temperature, cloned.temperature);
180 }
181
182 #[test]
183 fn test_completion_request_debug() {
184 let request = CompletionRequest::new("Debug test");
185 let debug_str = format!("{request:?}");
186 assert!(debug_str.contains("CompletionRequest"));
187 assert!(debug_str.contains("Debug test"));
188 }
189
190 #[test]
191 fn test_completion_response_new() {
192 let response = CompletionResponse {
193 text: "Generated text".to_string(),
194 };
195 assert_eq!(response.text, "Generated text");
196 }
197
198 #[test]
199 fn test_completion_response_clone() {
200 let response = CompletionResponse {
201 text: "Cloneable response".to_string(),
202 };
203 let cloned = response.clone();
204 assert_eq!(response.text, cloned.text);
205 }
206
207 #[test]
208 fn test_completion_response_debug() {
209 let response = CompletionResponse {
210 text: "Debug response".to_string(),
211 };
212 let debug_str = format!("{response:?}");
213 assert!(debug_str.contains("CompletionResponse"));
214 assert!(debug_str.contains("Debug response"));
215 }
216
217 #[test]
218 fn test_completion_response_display() {
219 let response = CompletionResponse {
220 text: "Display test".to_string(),
221 };
222 assert_eq!(response.to_string(), "Display test");
223 }
224
225 #[test]
226 fn test_completion_response_chat_response_trait() {
227 let response = CompletionResponse {
228 text: "Chat response test".to_string(),
229 };
230
231 assert_eq!(response.text(), Some("Chat response test".to_string()));
233 assert!(response.tool_calls().is_none());
234 }
235
236 #[test]
237 fn test_completion_request_builder_debug() {
238 let builder = CompletionRequest::builder("Builder debug")
239 .max_tokens(300)
240 .temperature(0.9);
241
242 let debug_str = format!("{builder:?}");
243 assert!(debug_str.contains("CompletionRequestBuilder"));
244 assert!(debug_str.contains("Builder debug"));
245 }
246
247 #[test]
248 fn test_completion_request_builder_clone() {
249 let builder = CompletionRequest::builder("Clone test")
250 .max_tokens(400)
251 .temperature(0.3);
252
253 let cloned = builder.clone();
254 let request1 = builder.build();
255 let request2 = cloned.build();
256
257 assert_eq!(request1.prompt, request2.prompt);
258 assert_eq!(request1.max_tokens, request2.max_tokens);
259 assert_eq!(request1.temperature, request2.temperature);
260 }
261
262 #[test]
263 fn test_completion_request_with_string_types() {
264 let request = CompletionRequest::new(String::from("String prompt"));
265 assert_eq!(request.prompt, "String prompt");
266
267 let request2 = CompletionRequest::builder(String::from("Builder string")).build();
268 assert_eq!(request2.prompt, "Builder string");
269 }
270
271 #[test]
272 fn test_completion_request_zero_max_tokens() {
273 let request = CompletionRequest::builder("Zero tokens")
274 .max_tokens(0)
275 .build();
276 assert_eq!(request.max_tokens, Some(0));
277 }
278
279 #[test]
280 fn test_completion_request_extreme_temperature() {
281 let request = CompletionRequest::builder("Extreme temp")
282 .temperature(0.0)
283 .build();
284 assert_eq!(request.temperature, Some(0.0));
285
286 let request2 = CompletionRequest::builder("Extreme temp 2")
287 .temperature(1.0)
288 .build();
289 assert_eq!(request2.temperature, Some(1.0));
290 }
291
292 #[test]
293 fn test_completion_response_empty_text() {
294 let response = CompletionResponse {
295 text: String::new(),
296 };
297 assert_eq!(response.text(), Some(String::new()));
298 assert_eq!(response.to_string(), "");
299 }
300
301 #[test]
302 fn test_completion_response_multiline_text() {
303 let multiline_text = "Line 1\nLine 2\nLine 3";
304 let response = CompletionResponse {
305 text: multiline_text.to_string(),
306 };
307 assert_eq!(response.text(), Some(multiline_text.to_string()));
308 assert_eq!(response.to_string(), multiline_text);
309 }
310
311 #[test]
312 fn test_completion_response_unicode_text() {
313 let unicode_text = "Hello δΈη! π";
314 let response = CompletionResponse {
315 text: unicode_text.to_string(),
316 };
317 assert_eq!(response.text(), Some(unicode_text.to_string()));
318 assert_eq!(response.to_string(), unicode_text);
319 }
320
321 struct MockCompletionProvider {
323 should_fail: bool,
324 }
325
326 impl MockCompletionProvider {
327 fn new() -> Self {
328 Self { should_fail: false }
329 }
330
331 fn new_failing() -> Self {
332 Self { should_fail: true }
333 }
334 }
335
336 #[async_trait::async_trait]
337 impl CompletionProvider for MockCompletionProvider {
338 async fn complete(
339 &self,
340 req: &CompletionRequest,
341 _json_schema: Option<StructuredOutputFormat>,
342 ) -> Result<CompletionResponse, LLMError> {
343 if self.should_fail {
344 Err(LLMError::ProviderError("Mock provider error".to_string()))
345 } else {
346 Ok(CompletionResponse {
347 text: format!("Completed: {}", req.prompt),
348 })
349 }
350 }
351 }
352
353 #[tokio::test]
354 async fn test_completion_provider_trait_success() {
355 let provider = MockCompletionProvider::new();
356 let request = CompletionRequest::new("Test prompt");
357
358 let result = provider.complete(&request, None).await;
359 assert!(result.is_ok());
360
361 let response = result.unwrap();
362 assert_eq!(response.text, "Completed: Test prompt");
363 }
364
365 #[tokio::test]
366 async fn test_completion_provider_trait_failure() {
367 let provider = MockCompletionProvider::new_failing();
368 let request = CompletionRequest::new("Test prompt");
369
370 let result = provider.complete(&request, None).await;
371 assert!(result.is_err());
372
373 let error = result.unwrap_err();
374 assert!(error.to_string().contains("Mock provider error"));
375 }
376
377 #[tokio::test]
378 async fn test_completion_provider_with_parameters() {
379 let provider = MockCompletionProvider::new();
380 let request = CompletionRequest::builder("Parameterized prompt")
381 .max_tokens(100)
382 .temperature(0.7)
383 .build();
384
385 let result = provider.complete(&request, None).await;
386 assert!(result.is_ok());
387
388 let response = result.unwrap();
389 assert_eq!(response.text, "Completed: Parameterized prompt");
390 }
391}