Skip to main content

autoagents_llm/completion/
mod.rs

1use async_trait::async_trait;
2
3use crate::{
4    ToolCall,
5    chat::{ChatResponse, StructuredOutputFormat},
6    error::LLMError,
7};
8
9/// A request for text completion from an LLM provider.
10#[derive(Debug, Clone)]
11pub struct CompletionRequest {
12    /// The input prompt text to complete
13    pub prompt: String,
14    /// Optional maximum number of tokens to generate
15    pub max_tokens: Option<u32>,
16    /// Optional temperature parameter to control randomness (0.0-1.0)
17    pub temperature: Option<f32>,
18}
19
20/// A response containing generated text from a completion request.
21#[derive(Debug, Clone)]
22pub struct CompletionResponse {
23    /// The generated completion text
24    pub text: String,
25}
26
27impl ChatResponse for CompletionResponse {
28    fn text(&self) -> Option<String> {
29        Some(self.text.clone())
30    }
31
32    fn tool_calls(&self) -> Option<Vec<ToolCall>> {
33        None
34    }
35}
36
37impl CompletionRequest {
38    /// Creates a new completion request with just a prompt.
39    ///
40    /// # Arguments
41    ///
42    /// * `prompt` - The input text to complete
43    pub fn new(prompt: impl Into<String>) -> Self {
44        Self {
45            prompt: prompt.into(),
46            max_tokens: None,
47            temperature: None,
48        }
49    }
50
51    /// Creates a builder for constructing a completion request.
52    ///
53    /// # Arguments
54    ///
55    /// * `prompt` - The input text to complete
56    pub fn builder(prompt: impl Into<String>) -> CompletionRequestBuilder {
57        CompletionRequestBuilder {
58            prompt: prompt.into(),
59            max_tokens: None,
60            temperature: None,
61        }
62    }
63}
64
65/// Builder for constructing completion requests with optional parameters.
66#[derive(Debug, Clone)]
67pub struct CompletionRequestBuilder {
68    /// The input prompt text to complete
69    pub prompt: String,
70    /// Optional maximum number of tokens to generate
71    pub max_tokens: Option<u32>,
72    /// Optional temperature parameter to control randomness (0.0-1.0)
73    pub temperature: Option<f32>,
74}
75
76impl CompletionRequestBuilder {
77    /// Sets the maximum number of tokens to generate.
78    pub fn max_tokens(mut self, val: u32) -> Self {
79        self.max_tokens = Some(val);
80        self
81    }
82
83    /// Sets the temperature parameter for controlling randomness.
84    pub fn temperature(mut self, val: f32) -> Self {
85        self.temperature = Some(val);
86        self
87    }
88
89    /// Builds the completion request with the configured parameters.
90    pub fn build(self) -> CompletionRequest {
91        CompletionRequest {
92            prompt: self.prompt,
93            max_tokens: self.max_tokens,
94            temperature: self.temperature,
95        }
96    }
97}
98
99/// Trait for providers that support text completion requests.
100#[async_trait]
101pub trait CompletionProvider {
102    /// Sends a completion request to generate text.
103    ///
104    /// # Arguments
105    ///
106    /// * `req` - The completion request parameters
107    ///
108    /// # Returns
109    ///
110    /// The generated completion text or an error
111    async fn complete(
112        &self,
113        req: &CompletionRequest,
114        json_schema: Option<StructuredOutputFormat>,
115    ) -> Result<CompletionResponse, LLMError>;
116}
117
118impl std::fmt::Display for CompletionResponse {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "{}", self.text)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::error::LLMError;
128
129    #[test]
130    fn test_completion_request_new() {
131        let request = CompletionRequest::new("Hello, world!");
132        assert_eq!(request.prompt, "Hello, world!");
133        assert!(request.max_tokens.is_none());
134        assert!(request.temperature.is_none());
135    }
136
137    #[test]
138    fn test_completion_request_builder() {
139        let request = CompletionRequest::builder("Test prompt")
140            .max_tokens(500)
141            .temperature(0.8)
142            .build();
143
144        assert_eq!(request.prompt, "Test prompt");
145        assert_eq!(request.max_tokens, Some(500));
146        assert_eq!(request.temperature, Some(0.8));
147    }
148
149    #[test]
150    fn test_completion_request_builder_partial() {
151        let request = CompletionRequest::builder("Partial test")
152            .max_tokens(100)
153            .build();
154
155        assert_eq!(request.prompt, "Partial test");
156        assert_eq!(request.max_tokens, Some(100));
157        assert!(request.temperature.is_none());
158    }
159
160    #[test]
161    fn test_completion_request_builder_chaining() {
162        let builder = CompletionRequest::builder("Chain test")
163            .max_tokens(200)
164            .temperature(0.5);
165
166        let request = builder.build();
167        assert_eq!(request.prompt, "Chain test");
168        assert_eq!(request.max_tokens, Some(200));
169        assert_eq!(request.temperature, Some(0.5));
170    }
171
172    #[test]
173    fn test_completion_request_clone() {
174        let request = CompletionRequest::new("Cloneable prompt");
175        let cloned = request.clone();
176
177        assert_eq!(request.prompt, cloned.prompt);
178        assert_eq!(request.max_tokens, cloned.max_tokens);
179        assert_eq!(request.temperature, cloned.temperature);
180    }
181
182    #[test]
183    fn test_completion_request_debug() {
184        let request = CompletionRequest::new("Debug test");
185        let debug_str = format!("{request:?}");
186        assert!(debug_str.contains("CompletionRequest"));
187        assert!(debug_str.contains("Debug test"));
188    }
189
190    #[test]
191    fn test_completion_response_new() {
192        let response = CompletionResponse {
193            text: "Generated text".to_string(),
194        };
195        assert_eq!(response.text, "Generated text");
196    }
197
198    #[test]
199    fn test_completion_response_clone() {
200        let response = CompletionResponse {
201            text: "Cloneable response".to_string(),
202        };
203        let cloned = response.clone();
204        assert_eq!(response.text, cloned.text);
205    }
206
207    #[test]
208    fn test_completion_response_debug() {
209        let response = CompletionResponse {
210            text: "Debug response".to_string(),
211        };
212        let debug_str = format!("{response:?}");
213        assert!(debug_str.contains("CompletionResponse"));
214        assert!(debug_str.contains("Debug response"));
215    }
216
217    #[test]
218    fn test_completion_response_display() {
219        let response = CompletionResponse {
220            text: "Display test".to_string(),
221        };
222        assert_eq!(response.to_string(), "Display test");
223    }
224
225    #[test]
226    fn test_completion_response_chat_response_trait() {
227        let response = CompletionResponse {
228            text: "Chat response test".to_string(),
229        };
230
231        // Test ChatResponse trait implementation
232        assert_eq!(response.text(), Some("Chat response test".to_string()));
233        assert!(response.tool_calls().is_none());
234    }
235
236    #[test]
237    fn test_completion_request_builder_debug() {
238        let builder = CompletionRequest::builder("Builder debug")
239            .max_tokens(300)
240            .temperature(0.9);
241
242        let debug_str = format!("{builder:?}");
243        assert!(debug_str.contains("CompletionRequestBuilder"));
244        assert!(debug_str.contains("Builder debug"));
245    }
246
247    #[test]
248    fn test_completion_request_builder_clone() {
249        let builder = CompletionRequest::builder("Clone test")
250            .max_tokens(400)
251            .temperature(0.3);
252
253        let cloned = builder.clone();
254        let request1 = builder.build();
255        let request2 = cloned.build();
256
257        assert_eq!(request1.prompt, request2.prompt);
258        assert_eq!(request1.max_tokens, request2.max_tokens);
259        assert_eq!(request1.temperature, request2.temperature);
260    }
261
262    #[test]
263    fn test_completion_request_with_string_types() {
264        let request = CompletionRequest::new(String::from("String prompt"));
265        assert_eq!(request.prompt, "String prompt");
266
267        let request2 = CompletionRequest::builder(String::from("Builder string")).build();
268        assert_eq!(request2.prompt, "Builder string");
269    }
270
271    #[test]
272    fn test_completion_request_zero_max_tokens() {
273        let request = CompletionRequest::builder("Zero tokens")
274            .max_tokens(0)
275            .build();
276        assert_eq!(request.max_tokens, Some(0));
277    }
278
279    #[test]
280    fn test_completion_request_extreme_temperature() {
281        let request = CompletionRequest::builder("Extreme temp")
282            .temperature(0.0)
283            .build();
284        assert_eq!(request.temperature, Some(0.0));
285
286        let request2 = CompletionRequest::builder("Extreme temp 2")
287            .temperature(1.0)
288            .build();
289        assert_eq!(request2.temperature, Some(1.0));
290    }
291
292    #[test]
293    fn test_completion_response_empty_text() {
294        let response = CompletionResponse {
295            text: String::default(),
296        };
297        assert_eq!(response.text(), Some(String::default()));
298        assert_eq!(response.to_string(), "");
299    }
300
301    #[test]
302    fn test_completion_response_multiline_text() {
303        let multiline_text = "Line 1\nLine 2\nLine 3";
304        let response = CompletionResponse {
305            text: multiline_text.to_string(),
306        };
307        assert_eq!(response.text(), Some(multiline_text.to_string()));
308        assert_eq!(response.to_string(), multiline_text);
309    }
310
311    #[test]
312    fn test_completion_response_unicode_text() {
313        let unicode_text = "Hello δΈ–η•Œ! 🌍";
314        let response = CompletionResponse {
315            text: unicode_text.to_string(),
316        };
317        assert_eq!(response.text(), Some(unicode_text.to_string()));
318        assert_eq!(response.to_string(), unicode_text);
319    }
320
321    // Mock provider for testing the trait
322    struct MockCompletionProvider {
323        should_fail: bool,
324    }
325
326    impl MockCompletionProvider {
327        fn new() -> Self {
328            Self { should_fail: false }
329        }
330
331        fn new_failing() -> Self {
332            Self { should_fail: true }
333        }
334    }
335
336    #[async_trait::async_trait]
337    impl CompletionProvider for MockCompletionProvider {
338        async fn complete(
339            &self,
340            req: &CompletionRequest,
341            _json_schema: Option<StructuredOutputFormat>,
342        ) -> Result<CompletionResponse, LLMError> {
343            if self.should_fail {
344                Err(LLMError::ProviderError("Mock provider error".to_string()))
345            } else {
346                Ok(CompletionResponse {
347                    text: format!("Completed: {}", req.prompt),
348                })
349            }
350        }
351    }
352
353    #[tokio::test]
354    async fn test_completion_provider_trait_success() {
355        let provider = MockCompletionProvider::new();
356        let request = CompletionRequest::new("Test prompt");
357
358        let result = provider.complete(&request, None).await;
359        assert!(result.is_ok());
360
361        let response = result.unwrap();
362        assert_eq!(response.text, "Completed: Test prompt");
363    }
364
365    #[tokio::test]
366    async fn test_completion_provider_trait_failure() {
367        let provider = MockCompletionProvider::new_failing();
368        let request = CompletionRequest::new("Test prompt");
369
370        let result = provider.complete(&request, None).await;
371        assert!(result.is_err());
372
373        let error = result.unwrap_err();
374        assert!(error.to_string().contains("Mock provider error"));
375    }
376
377    #[tokio::test]
378    async fn test_completion_provider_with_parameters() {
379        let provider = MockCompletionProvider::new();
380        let request = CompletionRequest::builder("Parameterized prompt")
381            .max_tokens(100)
382            .temperature(0.7)
383            .build();
384
385        let result = provider.complete(&request, None).await;
386        assert!(result.is_ok());
387
388        let response = result.unwrap();
389        assert_eq!(response.text, "Completed: Parameterized prompt");
390    }
391}