Skip to main content

erio_llm_client/
lib.rs

1#![doc = include_str!("../README.md")]
2
3//! Erio LLM Client - LLM provider abstraction and adapters for the agent runtime.
4
5pub mod error;
6pub mod openai;
7pub mod provider;
8pub mod request;
9pub mod response;
10
11pub use error::LlmError;
12pub use openai::OpenAiProvider;
13pub use provider::LlmProvider;
14pub use request::{CompletionRequest, ToolDefinition};
15pub use response::{CompletionResponse, StreamChunk, Usage};
16
17#[cfg(test)]
18mod tests {
19    use super::*;
20    use std::time::Duration;
21
22    // === LlmError Display Tests ===
23
24    #[test]
25    fn llm_error_rate_limited_displays_message() {
26        let err = LlmError::RateLimited { retry_after: None };
27        assert_eq!(err.to_string(), "Rate limited");
28    }
29
30    #[test]
31    fn llm_error_api_displays_status_and_message() {
32        let err = LlmError::Api {
33            status: 500,
34            message: "internal server error".into(),
35        };
36        assert_eq!(err.to_string(), "API error (500): internal server error");
37    }
38
39    #[test]
40    fn llm_error_invalid_response_displays_detail() {
41        let err = LlmError::InvalidResponse("missing choices field".into());
42        assert_eq!(err.to_string(), "Invalid response: missing choices field");
43    }
44
45    #[test]
46    fn llm_error_timeout_displays_message() {
47        let err = LlmError::Timeout;
48        assert_eq!(err.to_string(), "Request timeout");
49    }
50
51    #[test]
52    fn llm_error_auth_displays_message() {
53        let err = LlmError::Auth;
54        assert_eq!(err.to_string(), "Authentication failed");
55    }
56
57    #[test]
58    fn llm_error_network_displays_detail() {
59        let err = LlmError::Network("connection refused".into());
60        assert_eq!(err.to_string(), "Network error: connection refused");
61    }
62
63    // === Retryable Tests ===
64
65    #[test]
66    fn llm_error_rate_limited_is_retryable() {
67        let err = LlmError::RateLimited {
68            retry_after: Some(Duration::from_secs(5)),
69        };
70        assert!(err.is_retryable());
71    }
72
73    #[test]
74    fn llm_error_timeout_is_retryable() {
75        let err = LlmError::Timeout;
76        assert!(err.is_retryable());
77    }
78
79    #[test]
80    fn llm_error_network_is_retryable() {
81        let err = LlmError::Network("connection reset".into());
82        assert!(err.is_retryable());
83    }
84
85    #[test]
86    fn llm_error_api_500_is_retryable() {
87        let err = LlmError::Api {
88            status: 500,
89            message: "server error".into(),
90        };
91        assert!(err.is_retryable());
92    }
93
94    #[test]
95    fn llm_error_auth_is_not_retryable() {
96        let err = LlmError::Auth;
97        assert!(!err.is_retryable());
98    }
99
100    #[test]
101    fn llm_error_invalid_response_is_not_retryable() {
102        let err = LlmError::InvalidResponse("bad json".into());
103        assert!(!err.is_retryable());
104    }
105
106    #[test]
107    fn llm_error_api_400_is_not_retryable() {
108        let err = LlmError::Api {
109            status: 400,
110            message: "bad request".into(),
111        };
112        assert!(!err.is_retryable());
113    }
114
115    // === OpenAiProvider Tests (wiremock) ===
116
117    fn valid_openai_response() -> serde_json::Value {
118        serde_json::json!({
119            "choices": [{
120                "message": {
121                    "content": "Hello from GPT!",
122                    "tool_calls": []
123                }
124            }],
125            "model": "gpt-4",
126            "usage": {
127                "prompt_tokens": 10,
128                "completion_tokens": 5,
129                "total_tokens": 15
130            }
131        })
132    }
133
134    fn tool_call_openai_response() -> serde_json::Value {
135        serde_json::json!({
136            "choices": [{
137                "message": {
138                    "content": null,
139                    "tool_calls": [{
140                        "id": "call_abc",
141                        "type": "function",
142                        "function": {
143                            "name": "get_weather",
144                            "arguments": "{\"city\":\"London\"}"
145                        }
146                    }]
147                }
148            }],
149            "model": "gpt-4",
150            "usage": null
151        })
152    }
153
154    #[tokio::test(flavor = "multi_thread")]
155    async fn openai_provider_returns_name() {
156        let provider = OpenAiProvider::new("http://localhost", "key");
157        assert_eq!(provider.name(), "openai");
158    }
159
160    /// Creates a reqwest client that bypasses proxy for test use.
161    fn no_proxy_client() -> reqwest::Client {
162        reqwest::Client::builder().no_proxy().build().unwrap()
163    }
164
165    #[tokio::test(flavor = "multi_thread")]
166    async fn openai_provider_sends_correct_request() {
167        use wiremock::matchers::{header, method, path};
168        use wiremock::{Mock, MockServer, ResponseTemplate};
169
170        let mock_server = MockServer::start().await;
171
172        Mock::given(method("POST"))
173            .and(path("/chat/completions"))
174            .and(header("authorization", "Bearer test-key"))
175            .respond_with(ResponseTemplate::new(200).set_body_json(valid_openai_response()))
176            .expect(1)
177            .mount(&mock_server)
178            .await;
179
180        let provider = OpenAiProvider::new(mock_server.uri(), "test-key")
181            .with_client(no_proxy_client())
182            .with_retry(erio_core::RetryConfig::no_retry());
183        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
184
185        let response = provider.complete(request).await.unwrap();
186        assert_eq!(response.content, Some("Hello from GPT!".into()));
187        assert_eq!(response.model, "gpt-4");
188    }
189
190    #[tokio::test(flavor = "multi_thread")]
191    async fn openai_provider_parses_tool_calls() {
192        use wiremock::matchers::{method, path};
193        use wiremock::{Mock, MockServer, ResponseTemplate};
194
195        let mock_server = MockServer::start().await;
196
197        Mock::given(method("POST"))
198            .and(path("/chat/completions"))
199            .respond_with(ResponseTemplate::new(200).set_body_json(tool_call_openai_response()))
200            .mount(&mock_server)
201            .await;
202
203        let provider = OpenAiProvider::new(mock_server.uri(), "key")
204            .with_client(no_proxy_client())
205            .with_retry(erio_core::RetryConfig::no_retry());
206        let request = CompletionRequest::new("gpt-4")
207            .message(erio_core::Message::user("What's the weather?"));
208
209        let response = provider.complete(request).await.unwrap();
210        assert!(response.content.is_none());
211        assert_eq!(response.tool_calls.len(), 1);
212        assert_eq!(response.tool_calls[0].name, "get_weather");
213        assert_eq!(response.tool_calls[0].arguments["city"], "London");
214    }
215
216    #[tokio::test(flavor = "multi_thread")]
217    async fn openai_provider_returns_auth_error_on_401() {
218        use wiremock::matchers::{method, path};
219        use wiremock::{Mock, MockServer, ResponseTemplate};
220
221        let mock_server = MockServer::start().await;
222
223        Mock::given(method("POST"))
224            .and(path("/chat/completions"))
225            .respond_with(ResponseTemplate::new(401))
226            .mount(&mock_server)
227            .await;
228
229        let provider = OpenAiProvider::new(mock_server.uri(), "bad-key")
230            .with_client(no_proxy_client())
231            .with_retry(erio_core::RetryConfig::no_retry());
232        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
233
234        let result = provider.complete(request).await;
235        assert!(matches!(result, Err(LlmError::Auth)));
236    }
237
238    #[tokio::test(flavor = "multi_thread")]
239    async fn openai_provider_returns_rate_limited_on_429() {
240        use wiremock::matchers::{method, path};
241        use wiremock::{Mock, MockServer, ResponseTemplate};
242
243        let mock_server = MockServer::start().await;
244
245        Mock::given(method("POST"))
246            .and(path("/chat/completions"))
247            .respond_with(ResponseTemplate::new(429))
248            .mount(&mock_server)
249            .await;
250
251        let provider = OpenAiProvider::new(mock_server.uri(), "key")
252            .with_client(no_proxy_client())
253            .with_retry(erio_core::RetryConfig::no_retry());
254        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
255
256        let result = provider.complete(request).await;
257        assert!(matches!(result, Err(LlmError::RateLimited { .. })));
258    }
259
260    #[tokio::test(flavor = "multi_thread")]
261    async fn openai_provider_returns_api_error_on_500() {
262        use wiremock::matchers::{method, path};
263        use wiremock::{Mock, MockServer, ResponseTemplate};
264
265        let mock_server = MockServer::start().await;
266
267        Mock::given(method("POST"))
268            .and(path("/chat/completions"))
269            .respond_with(ResponseTemplate::new(500))
270            .mount(&mock_server)
271            .await;
272
273        let provider = OpenAiProvider::new(mock_server.uri(), "key")
274            .with_client(no_proxy_client())
275            .with_retry(erio_core::RetryConfig::no_retry());
276        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
277
278        let result = provider.complete(request).await;
279        assert!(matches!(result, Err(LlmError::Api { status: 500, .. })));
280    }
281
282    #[tokio::test(flavor = "multi_thread")]
283    async fn openai_provider_retries_on_429_then_succeeds() {
284        use wiremock::matchers::{method, path};
285        use wiremock::{Mock, MockServer, ResponseTemplate};
286
287        let mock_server = MockServer::start().await;
288
289        // First two calls return 429, third succeeds
290        Mock::given(method("POST"))
291            .and(path("/chat/completions"))
292            .respond_with(ResponseTemplate::new(429))
293            .up_to_n_times(2)
294            .expect(2)
295            .mount(&mock_server)
296            .await;
297
298        Mock::given(method("POST"))
299            .and(path("/chat/completions"))
300            .respond_with(ResponseTemplate::new(200).set_body_json(valid_openai_response()))
301            .expect(1)
302            .mount(&mock_server)
303            .await;
304
305        let provider = OpenAiProvider::new(mock_server.uri(), "key")
306            .with_client(no_proxy_client())
307            .with_retry(
308                erio_core::RetryConfig::builder()
309                    .max_attempts(3)
310                    .initial_delay(Duration::from_millis(1))
311                    .build(),
312            );
313        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
314
315        let response = provider.complete(request).await.unwrap();
316        assert_eq!(response.content, Some("Hello from GPT!".into()));
317    }
318
319    #[tokio::test(flavor = "multi_thread")]
320    async fn openai_provider_does_not_retry_on_401() {
321        use wiremock::matchers::{method, path};
322        use wiremock::{Mock, MockServer, ResponseTemplate};
323
324        let mock_server = MockServer::start().await;
325
326        Mock::given(method("POST"))
327            .and(path("/chat/completions"))
328            .respond_with(ResponseTemplate::new(401))
329            .expect(1) // Should only be called once - no retry
330            .mount(&mock_server)
331            .await;
332
333        let provider = OpenAiProvider::new(mock_server.uri(), "bad-key")
334            .with_client(no_proxy_client())
335            .with_retry(
336                erio_core::RetryConfig::builder()
337                    .max_attempts(3)
338                    .initial_delay(Duration::from_millis(1))
339                    .build(),
340            );
341        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
342
343        let result = provider.complete(request).await;
344        assert!(matches!(result, Err(LlmError::Auth)));
345    }
346
347    #[tokio::test(flavor = "multi_thread")]
348    async fn openai_provider_exhausts_retries_on_persistent_429() {
349        use wiremock::matchers::{method, path};
350        use wiremock::{Mock, MockServer, ResponseTemplate};
351
352        let mock_server = MockServer::start().await;
353
354        Mock::given(method("POST"))
355            .and(path("/chat/completions"))
356            .respond_with(ResponseTemplate::new(429))
357            .expect(3) // 1 initial + 2 retries
358            .mount(&mock_server)
359            .await;
360
361        let provider = OpenAiProvider::new(mock_server.uri(), "key")
362            .with_client(no_proxy_client())
363            .with_retry(
364                erio_core::RetryConfig::builder()
365                    .max_attempts(3)
366                    .initial_delay(Duration::from_millis(1))
367                    .build(),
368            );
369        let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
370
371        let result = provider.complete(request).await;
372        assert!(matches!(result, Err(LlmError::RateLimited { .. })));
373    }
374
375    // === CompletionResponse Tests ===
376
377    #[test]
378    fn response_parses_openai_text_content() {
379        let json = serde_json::json!({
380            "choices": [{
381                "message": {
382                    "content": "Hello, world!",
383                    "tool_calls": []
384                }
385            }],
386            "model": "gpt-4",
387            "usage": {
388                "prompt_tokens": 10,
389                "completion_tokens": 5,
390                "total_tokens": 15
391            }
392        });
393
394        let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
395        let resp = raw.into_completion_response().unwrap();
396
397        assert_eq!(resp.content, Some("Hello, world!".into()));
398        assert!(resp.tool_calls.is_empty());
399        assert_eq!(resp.model, "gpt-4");
400        assert_eq!(resp.usage.as_ref().unwrap().prompt_tokens, 10);
401        assert_eq!(resp.usage.as_ref().unwrap().completion_tokens, 5);
402        assert_eq!(resp.usage.as_ref().unwrap().total_tokens, 15);
403    }
404
405    #[test]
406    fn response_parses_openai_tool_calls() {
407        let json = serde_json::json!({
408            "choices": [{
409                "message": {
410                    "content": null,
411                    "tool_calls": [{
412                        "id": "call_123",
413                        "type": "function",
414                        "function": {
415                            "name": "get_weather",
416                            "arguments": "{\"city\":\"London\"}"
417                        }
418                    }]
419                }
420            }],
421            "model": "gpt-4",
422            "usage": null
423        });
424
425        let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
426        let resp = raw.into_completion_response().unwrap();
427
428        assert!(resp.content.is_none());
429        assert_eq!(resp.tool_calls.len(), 1);
430        assert_eq!(resp.tool_calls[0].id, "call_123");
431        assert_eq!(resp.tool_calls[0].name, "get_weather");
432        assert_eq!(resp.tool_calls[0].arguments["city"], "London");
433    }
434
435    #[test]
436    fn response_parses_openai_multiple_tool_calls() {
437        let json = serde_json::json!({
438            "choices": [{
439                "message": {
440                    "content": "Let me check both.",
441                    "tool_calls": [
442                        {
443                            "id": "call_1",
444                            "type": "function",
445                            "function": {
446                                "name": "get_weather",
447                                "arguments": "{\"city\":\"London\"}"
448                            }
449                        },
450                        {
451                            "id": "call_2",
452                            "type": "function",
453                            "function": {
454                                "name": "get_time",
455                                "arguments": "{\"timezone\":\"UTC\"}"
456                            }
457                        }
458                    ]
459                }
460            }],
461            "model": "gpt-4",
462            "usage": null
463        });
464
465        let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
466        let resp = raw.into_completion_response().unwrap();
467
468        assert_eq!(resp.content, Some("Let me check both.".into()));
469        assert_eq!(resp.tool_calls.len(), 2);
470        assert_eq!(resp.tool_calls[0].name, "get_weather");
471        assert_eq!(resp.tool_calls[1].name, "get_time");
472    }
473
474    #[test]
475    fn response_returns_error_for_empty_choices() {
476        let json = serde_json::json!({
477            "choices": [],
478            "model": "gpt-4",
479            "usage": null
480        });
481
482        let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
483        let result = raw.into_completion_response();
484
485        assert!(result.is_err());
486        assert!(matches!(result.unwrap_err(), LlmError::InvalidResponse(_)));
487    }
488
489    #[test]
490    fn response_handles_no_usage() {
491        let json = serde_json::json!({
492            "choices": [{
493                "message": {
494                    "content": "OK",
495                    "tool_calls": []
496                }
497            }],
498            "model": "gpt-4"
499        });
500
501        let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
502        let resp = raw.into_completion_response().unwrap();
503
504        assert!(resp.usage.is_none());
505    }
506
507    // === StreamChunk Tests ===
508
509    #[test]
510    fn stream_chunk_delta_holds_content() {
511        let chunk = StreamChunk::Delta {
512            content: "Hello".into(),
513        };
514        assert_eq!(
515            chunk,
516            StreamChunk::Delta {
517                content: "Hello".into()
518            }
519        );
520    }
521
522    #[test]
523    fn stream_chunk_done_variant() {
524        let chunk = StreamChunk::Done;
525        assert_eq!(chunk, StreamChunk::Done);
526    }
527
528    // === CompletionRequest Tests ===
529
530    #[test]
531    fn request_new_sets_model() {
532        let req = CompletionRequest::new("gpt-4");
533        assert_eq!(req.model, "gpt-4");
534        assert!(req.messages.is_empty());
535        assert!(req.tools.is_none());
536        assert!(req.max_tokens.is_none());
537        assert!(req.temperature.is_none());
538        assert!(!req.stream);
539    }
540
541    #[test]
542    fn request_builder_adds_message() {
543        let req = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
544        assert_eq!(req.messages.len(), 1);
545        assert_eq!(req.messages[0].text(), Some("Hello"));
546    }
547
548    #[test]
549    fn request_builder_chains_messages() {
550        let req = CompletionRequest::new("gpt-4")
551            .message(erio_core::Message::system("You are helpful"))
552            .message(erio_core::Message::user("Hi"));
553        assert_eq!(req.messages.len(), 2);
554    }
555
556    #[test]
557    fn request_builder_sets_temperature() {
558        let req = CompletionRequest::new("gpt-4").temperature(0.7);
559        assert_eq!(req.temperature, Some(0.7));
560    }
561
562    #[test]
563    fn request_builder_sets_max_tokens() {
564        let req = CompletionRequest::new("gpt-4").max_tokens(1024);
565        assert_eq!(req.max_tokens, Some(1024));
566    }
567
568    #[test]
569    fn request_builder_sets_tools() {
570        let tools = vec![ToolDefinition {
571            name: "shell".into(),
572            description: "Run a shell command".into(),
573            parameters: serde_json::json!({
574                "type": "object",
575                "properties": {
576                    "command": {"type": "string"}
577                },
578                "required": ["command"]
579            }),
580        }];
581        let req = CompletionRequest::new("gpt-4").tools(tools);
582        assert_eq!(req.tools.as_ref().unwrap().len(), 1);
583        assert_eq!(req.tools.as_ref().unwrap()[0].name, "shell");
584    }
585
586    #[test]
587    fn request_builder_sets_stream() {
588        let req = CompletionRequest::new("gpt-4").stream(true);
589        assert!(req.stream);
590    }
591
592    // === CoreError Conversion ===
593
594    #[test]
595    fn llm_error_converts_to_core_error() {
596        let llm_err = LlmError::Api {
597            status: 500,
598            message: "server error".into(),
599        };
600        let core_err: erio_core::CoreError = llm_err.into();
601        assert!(matches!(core_err, erio_core::CoreError::Llm { .. }));
602    }
603
604    #[test]
605    fn llm_error_rate_limited_converts_with_429_status() {
606        let llm_err = LlmError::RateLimited { retry_after: None };
607        let core_err: erio_core::CoreError = llm_err.into();
608        match core_err {
609            erio_core::CoreError::Llm { status, .. } => {
610                assert_eq!(status, Some(429));
611            }
612            _ => panic!("Expected CoreError::Llm"),
613        }
614    }
615
616    #[test]
617    fn llm_error_auth_converts_with_401_status() {
618        let llm_err = LlmError::Auth;
619        let core_err: erio_core::CoreError = llm_err.into();
620        match core_err {
621            erio_core::CoreError::Llm { status, .. } => {
622                assert_eq!(status, Some(401));
623            }
624            _ => panic!("Expected CoreError::Llm"),
625        }
626    }
627}