1#![doc = include_str!("../README.md")]
2
3pub mod error;
6pub mod openai;
7pub mod provider;
8pub mod request;
9pub mod response;
10
11pub use error::LlmError;
12pub use openai::OpenAiProvider;
13pub use provider::LlmProvider;
14pub use request::{CompletionRequest, ToolDefinition};
15pub use response::{CompletionResponse, StreamChunk, Usage};
16
17#[cfg(test)]
18mod tests {
19 use super::*;
20 use std::time::Duration;
21
22 #[test]
25 fn llm_error_rate_limited_displays_message() {
26 let err = LlmError::RateLimited { retry_after: None };
27 assert_eq!(err.to_string(), "Rate limited");
28 }
29
30 #[test]
31 fn llm_error_api_displays_status_and_message() {
32 let err = LlmError::Api {
33 status: 500,
34 message: "internal server error".into(),
35 };
36 assert_eq!(err.to_string(), "API error (500): internal server error");
37 }
38
39 #[test]
40 fn llm_error_invalid_response_displays_detail() {
41 let err = LlmError::InvalidResponse("missing choices field".into());
42 assert_eq!(err.to_string(), "Invalid response: missing choices field");
43 }
44
45 #[test]
46 fn llm_error_timeout_displays_message() {
47 let err = LlmError::Timeout;
48 assert_eq!(err.to_string(), "Request timeout");
49 }
50
51 #[test]
52 fn llm_error_auth_displays_message() {
53 let err = LlmError::Auth;
54 assert_eq!(err.to_string(), "Authentication failed");
55 }
56
57 #[test]
58 fn llm_error_network_displays_detail() {
59 let err = LlmError::Network("connection refused".into());
60 assert_eq!(err.to_string(), "Network error: connection refused");
61 }
62
63 #[test]
66 fn llm_error_rate_limited_is_retryable() {
67 let err = LlmError::RateLimited {
68 retry_after: Some(Duration::from_secs(5)),
69 };
70 assert!(err.is_retryable());
71 }
72
73 #[test]
74 fn llm_error_timeout_is_retryable() {
75 let err = LlmError::Timeout;
76 assert!(err.is_retryable());
77 }
78
79 #[test]
80 fn llm_error_network_is_retryable() {
81 let err = LlmError::Network("connection reset".into());
82 assert!(err.is_retryable());
83 }
84
85 #[test]
86 fn llm_error_api_500_is_retryable() {
87 let err = LlmError::Api {
88 status: 500,
89 message: "server error".into(),
90 };
91 assert!(err.is_retryable());
92 }
93
94 #[test]
95 fn llm_error_auth_is_not_retryable() {
96 let err = LlmError::Auth;
97 assert!(!err.is_retryable());
98 }
99
100 #[test]
101 fn llm_error_invalid_response_is_not_retryable() {
102 let err = LlmError::InvalidResponse("bad json".into());
103 assert!(!err.is_retryable());
104 }
105
106 #[test]
107 fn llm_error_api_400_is_not_retryable() {
108 let err = LlmError::Api {
109 status: 400,
110 message: "bad request".into(),
111 };
112 assert!(!err.is_retryable());
113 }
114
115 fn valid_openai_response() -> serde_json::Value {
118 serde_json::json!({
119 "choices": [{
120 "message": {
121 "content": "Hello from GPT!",
122 "tool_calls": []
123 }
124 }],
125 "model": "gpt-4",
126 "usage": {
127 "prompt_tokens": 10,
128 "completion_tokens": 5,
129 "total_tokens": 15
130 }
131 })
132 }
133
134 fn tool_call_openai_response() -> serde_json::Value {
135 serde_json::json!({
136 "choices": [{
137 "message": {
138 "content": null,
139 "tool_calls": [{
140 "id": "call_abc",
141 "type": "function",
142 "function": {
143 "name": "get_weather",
144 "arguments": "{\"city\":\"London\"}"
145 }
146 }]
147 }
148 }],
149 "model": "gpt-4",
150 "usage": null
151 })
152 }
153
154 #[tokio::test(flavor = "multi_thread")]
155 async fn openai_provider_returns_name() {
156 let provider = OpenAiProvider::new("http://localhost", "key");
157 assert_eq!(provider.name(), "openai");
158 }
159
160 fn no_proxy_client() -> reqwest::Client {
162 reqwest::Client::builder().no_proxy().build().unwrap()
163 }
164
165 #[tokio::test(flavor = "multi_thread")]
166 async fn openai_provider_sends_correct_request() {
167 use wiremock::matchers::{header, method, path};
168 use wiremock::{Mock, MockServer, ResponseTemplate};
169
170 let mock_server = MockServer::start().await;
171
172 Mock::given(method("POST"))
173 .and(path("/chat/completions"))
174 .and(header("authorization", "Bearer test-key"))
175 .respond_with(ResponseTemplate::new(200).set_body_json(valid_openai_response()))
176 .expect(1)
177 .mount(&mock_server)
178 .await;
179
180 let provider = OpenAiProvider::new(mock_server.uri(), "test-key")
181 .with_client(no_proxy_client())
182 .with_retry(erio_core::RetryConfig::no_retry());
183 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
184
185 let response = provider.complete(request).await.unwrap();
186 assert_eq!(response.content, Some("Hello from GPT!".into()));
187 assert_eq!(response.model, "gpt-4");
188 }
189
190 #[tokio::test(flavor = "multi_thread")]
191 async fn openai_provider_parses_tool_calls() {
192 use wiremock::matchers::{method, path};
193 use wiremock::{Mock, MockServer, ResponseTemplate};
194
195 let mock_server = MockServer::start().await;
196
197 Mock::given(method("POST"))
198 .and(path("/chat/completions"))
199 .respond_with(ResponseTemplate::new(200).set_body_json(tool_call_openai_response()))
200 .mount(&mock_server)
201 .await;
202
203 let provider = OpenAiProvider::new(mock_server.uri(), "key")
204 .with_client(no_proxy_client())
205 .with_retry(erio_core::RetryConfig::no_retry());
206 let request = CompletionRequest::new("gpt-4")
207 .message(erio_core::Message::user("What's the weather?"));
208
209 let response = provider.complete(request).await.unwrap();
210 assert!(response.content.is_none());
211 assert_eq!(response.tool_calls.len(), 1);
212 assert_eq!(response.tool_calls[0].name, "get_weather");
213 assert_eq!(response.tool_calls[0].arguments["city"], "London");
214 }
215
216 #[tokio::test(flavor = "multi_thread")]
217 async fn openai_provider_returns_auth_error_on_401() {
218 use wiremock::matchers::{method, path};
219 use wiremock::{Mock, MockServer, ResponseTemplate};
220
221 let mock_server = MockServer::start().await;
222
223 Mock::given(method("POST"))
224 .and(path("/chat/completions"))
225 .respond_with(ResponseTemplate::new(401))
226 .mount(&mock_server)
227 .await;
228
229 let provider = OpenAiProvider::new(mock_server.uri(), "bad-key")
230 .with_client(no_proxy_client())
231 .with_retry(erio_core::RetryConfig::no_retry());
232 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
233
234 let result = provider.complete(request).await;
235 assert!(matches!(result, Err(LlmError::Auth)));
236 }
237
238 #[tokio::test(flavor = "multi_thread")]
239 async fn openai_provider_returns_rate_limited_on_429() {
240 use wiremock::matchers::{method, path};
241 use wiremock::{Mock, MockServer, ResponseTemplate};
242
243 let mock_server = MockServer::start().await;
244
245 Mock::given(method("POST"))
246 .and(path("/chat/completions"))
247 .respond_with(ResponseTemplate::new(429))
248 .mount(&mock_server)
249 .await;
250
251 let provider = OpenAiProvider::new(mock_server.uri(), "key")
252 .with_client(no_proxy_client())
253 .with_retry(erio_core::RetryConfig::no_retry());
254 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
255
256 let result = provider.complete(request).await;
257 assert!(matches!(result, Err(LlmError::RateLimited { .. })));
258 }
259
260 #[tokio::test(flavor = "multi_thread")]
261 async fn openai_provider_returns_api_error_on_500() {
262 use wiremock::matchers::{method, path};
263 use wiremock::{Mock, MockServer, ResponseTemplate};
264
265 let mock_server = MockServer::start().await;
266
267 Mock::given(method("POST"))
268 .and(path("/chat/completions"))
269 .respond_with(ResponseTemplate::new(500))
270 .mount(&mock_server)
271 .await;
272
273 let provider = OpenAiProvider::new(mock_server.uri(), "key")
274 .with_client(no_proxy_client())
275 .with_retry(erio_core::RetryConfig::no_retry());
276 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
277
278 let result = provider.complete(request).await;
279 assert!(matches!(result, Err(LlmError::Api { status: 500, .. })));
280 }
281
282 #[tokio::test(flavor = "multi_thread")]
283 async fn openai_provider_retries_on_429_then_succeeds() {
284 use wiremock::matchers::{method, path};
285 use wiremock::{Mock, MockServer, ResponseTemplate};
286
287 let mock_server = MockServer::start().await;
288
289 Mock::given(method("POST"))
291 .and(path("/chat/completions"))
292 .respond_with(ResponseTemplate::new(429))
293 .up_to_n_times(2)
294 .expect(2)
295 .mount(&mock_server)
296 .await;
297
298 Mock::given(method("POST"))
299 .and(path("/chat/completions"))
300 .respond_with(ResponseTemplate::new(200).set_body_json(valid_openai_response()))
301 .expect(1)
302 .mount(&mock_server)
303 .await;
304
305 let provider = OpenAiProvider::new(mock_server.uri(), "key")
306 .with_client(no_proxy_client())
307 .with_retry(
308 erio_core::RetryConfig::builder()
309 .max_attempts(3)
310 .initial_delay(Duration::from_millis(1))
311 .build(),
312 );
313 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
314
315 let response = provider.complete(request).await.unwrap();
316 assert_eq!(response.content, Some("Hello from GPT!".into()));
317 }
318
319 #[tokio::test(flavor = "multi_thread")]
320 async fn openai_provider_does_not_retry_on_401() {
321 use wiremock::matchers::{method, path};
322 use wiremock::{Mock, MockServer, ResponseTemplate};
323
324 let mock_server = MockServer::start().await;
325
326 Mock::given(method("POST"))
327 .and(path("/chat/completions"))
328 .respond_with(ResponseTemplate::new(401))
329 .expect(1) .mount(&mock_server)
331 .await;
332
333 let provider = OpenAiProvider::new(mock_server.uri(), "bad-key")
334 .with_client(no_proxy_client())
335 .with_retry(
336 erio_core::RetryConfig::builder()
337 .max_attempts(3)
338 .initial_delay(Duration::from_millis(1))
339 .build(),
340 );
341 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
342
343 let result = provider.complete(request).await;
344 assert!(matches!(result, Err(LlmError::Auth)));
345 }
346
347 #[tokio::test(flavor = "multi_thread")]
348 async fn openai_provider_exhausts_retries_on_persistent_429() {
349 use wiremock::matchers::{method, path};
350 use wiremock::{Mock, MockServer, ResponseTemplate};
351
352 let mock_server = MockServer::start().await;
353
354 Mock::given(method("POST"))
355 .and(path("/chat/completions"))
356 .respond_with(ResponseTemplate::new(429))
357 .expect(3) .mount(&mock_server)
359 .await;
360
361 let provider = OpenAiProvider::new(mock_server.uri(), "key")
362 .with_client(no_proxy_client())
363 .with_retry(
364 erio_core::RetryConfig::builder()
365 .max_attempts(3)
366 .initial_delay(Duration::from_millis(1))
367 .build(),
368 );
369 let request = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
370
371 let result = provider.complete(request).await;
372 assert!(matches!(result, Err(LlmError::RateLimited { .. })));
373 }
374
375 #[test]
378 fn response_parses_openai_text_content() {
379 let json = serde_json::json!({
380 "choices": [{
381 "message": {
382 "content": "Hello, world!",
383 "tool_calls": []
384 }
385 }],
386 "model": "gpt-4",
387 "usage": {
388 "prompt_tokens": 10,
389 "completion_tokens": 5,
390 "total_tokens": 15
391 }
392 });
393
394 let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
395 let resp = raw.into_completion_response().unwrap();
396
397 assert_eq!(resp.content, Some("Hello, world!".into()));
398 assert!(resp.tool_calls.is_empty());
399 assert_eq!(resp.model, "gpt-4");
400 assert_eq!(resp.usage.as_ref().unwrap().prompt_tokens, 10);
401 assert_eq!(resp.usage.as_ref().unwrap().completion_tokens, 5);
402 assert_eq!(resp.usage.as_ref().unwrap().total_tokens, 15);
403 }
404
405 #[test]
406 fn response_parses_openai_tool_calls() {
407 let json = serde_json::json!({
408 "choices": [{
409 "message": {
410 "content": null,
411 "tool_calls": [{
412 "id": "call_123",
413 "type": "function",
414 "function": {
415 "name": "get_weather",
416 "arguments": "{\"city\":\"London\"}"
417 }
418 }]
419 }
420 }],
421 "model": "gpt-4",
422 "usage": null
423 });
424
425 let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
426 let resp = raw.into_completion_response().unwrap();
427
428 assert!(resp.content.is_none());
429 assert_eq!(resp.tool_calls.len(), 1);
430 assert_eq!(resp.tool_calls[0].id, "call_123");
431 assert_eq!(resp.tool_calls[0].name, "get_weather");
432 assert_eq!(resp.tool_calls[0].arguments["city"], "London");
433 }
434
435 #[test]
436 fn response_parses_openai_multiple_tool_calls() {
437 let json = serde_json::json!({
438 "choices": [{
439 "message": {
440 "content": "Let me check both.",
441 "tool_calls": [
442 {
443 "id": "call_1",
444 "type": "function",
445 "function": {
446 "name": "get_weather",
447 "arguments": "{\"city\":\"London\"}"
448 }
449 },
450 {
451 "id": "call_2",
452 "type": "function",
453 "function": {
454 "name": "get_time",
455 "arguments": "{\"timezone\":\"UTC\"}"
456 }
457 }
458 ]
459 }
460 }],
461 "model": "gpt-4",
462 "usage": null
463 });
464
465 let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
466 let resp = raw.into_completion_response().unwrap();
467
468 assert_eq!(resp.content, Some("Let me check both.".into()));
469 assert_eq!(resp.tool_calls.len(), 2);
470 assert_eq!(resp.tool_calls[0].name, "get_weather");
471 assert_eq!(resp.tool_calls[1].name, "get_time");
472 }
473
474 #[test]
475 fn response_returns_error_for_empty_choices() {
476 let json = serde_json::json!({
477 "choices": [],
478 "model": "gpt-4",
479 "usage": null
480 });
481
482 let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
483 let result = raw.into_completion_response();
484
485 assert!(result.is_err());
486 assert!(matches!(result.unwrap_err(), LlmError::InvalidResponse(_)));
487 }
488
489 #[test]
490 fn response_handles_no_usage() {
491 let json = serde_json::json!({
492 "choices": [{
493 "message": {
494 "content": "OK",
495 "tool_calls": []
496 }
497 }],
498 "model": "gpt-4"
499 });
500
501 let raw: response::OpenAiResponse = serde_json::from_value(json).unwrap();
502 let resp = raw.into_completion_response().unwrap();
503
504 assert!(resp.usage.is_none());
505 }
506
507 #[test]
510 fn stream_chunk_delta_holds_content() {
511 let chunk = StreamChunk::Delta {
512 content: "Hello".into(),
513 };
514 assert_eq!(
515 chunk,
516 StreamChunk::Delta {
517 content: "Hello".into()
518 }
519 );
520 }
521
522 #[test]
523 fn stream_chunk_done_variant() {
524 let chunk = StreamChunk::Done;
525 assert_eq!(chunk, StreamChunk::Done);
526 }
527
528 #[test]
531 fn request_new_sets_model() {
532 let req = CompletionRequest::new("gpt-4");
533 assert_eq!(req.model, "gpt-4");
534 assert!(req.messages.is_empty());
535 assert!(req.tools.is_none());
536 assert!(req.max_tokens.is_none());
537 assert!(req.temperature.is_none());
538 assert!(!req.stream);
539 }
540
541 #[test]
542 fn request_builder_adds_message() {
543 let req = CompletionRequest::new("gpt-4").message(erio_core::Message::user("Hello"));
544 assert_eq!(req.messages.len(), 1);
545 assert_eq!(req.messages[0].text(), Some("Hello"));
546 }
547
548 #[test]
549 fn request_builder_chains_messages() {
550 let req = CompletionRequest::new("gpt-4")
551 .message(erio_core::Message::system("You are helpful"))
552 .message(erio_core::Message::user("Hi"));
553 assert_eq!(req.messages.len(), 2);
554 }
555
556 #[test]
557 fn request_builder_sets_temperature() {
558 let req = CompletionRequest::new("gpt-4").temperature(0.7);
559 assert_eq!(req.temperature, Some(0.7));
560 }
561
562 #[test]
563 fn request_builder_sets_max_tokens() {
564 let req = CompletionRequest::new("gpt-4").max_tokens(1024);
565 assert_eq!(req.max_tokens, Some(1024));
566 }
567
568 #[test]
569 fn request_builder_sets_tools() {
570 let tools = vec![ToolDefinition {
571 name: "shell".into(),
572 description: "Run a shell command".into(),
573 parameters: serde_json::json!({
574 "type": "object",
575 "properties": {
576 "command": {"type": "string"}
577 },
578 "required": ["command"]
579 }),
580 }];
581 let req = CompletionRequest::new("gpt-4").tools(tools);
582 assert_eq!(req.tools.as_ref().unwrap().len(), 1);
583 assert_eq!(req.tools.as_ref().unwrap()[0].name, "shell");
584 }
585
586 #[test]
587 fn request_builder_sets_stream() {
588 let req = CompletionRequest::new("gpt-4").stream(true);
589 assert!(req.stream);
590 }
591
592 #[test]
595 fn llm_error_converts_to_core_error() {
596 let llm_err = LlmError::Api {
597 status: 500,
598 message: "server error".into(),
599 };
600 let core_err: erio_core::CoreError = llm_err.into();
601 assert!(matches!(core_err, erio_core::CoreError::Llm { .. }));
602 }
603
604 #[test]
605 fn llm_error_rate_limited_converts_with_429_status() {
606 let llm_err = LlmError::RateLimited { retry_after: None };
607 let core_err: erio_core::CoreError = llm_err.into();
608 match core_err {
609 erio_core::CoreError::Llm { status, .. } => {
610 assert_eq!(status, Some(429));
611 }
612 _ => panic!("Expected CoreError::Llm"),
613 }
614 }
615
616 #[test]
617 fn llm_error_auth_converts_with_401_status() {
618 let llm_err = LlmError::Auth;
619 let core_err: erio_core::CoreError = llm_err.into();
620 match core_err {
621 erio_core::CoreError::Llm { status, .. } => {
622 assert_eq!(status, Some(401));
623 }
624 _ => panic!("Expected CoreError::Llm"),
625 }
626 }
627}