1use std::future::Future;
4use std::pin::Pin;
5use std::time::Duration;
6
7use anyhow::Result;
8use reqwest::Client;
9use serde::{Deserialize, Serialize};
10use tracing::{debug, info};
11
12use super::{AiClient, AiClientCapabilities, AiClientMetadata, RequestOptions};
13use crate::claude::{error::ClaudeError, model_config::get_model_registry};
14
15const PROBE_TIMEOUT: Duration = Duration::from_secs(2);
19
20#[derive(Clone, Copy, Debug, PartialEq, Eq)]
24pub enum ProbeSource {
25 LmStudio,
27 Ollama,
29}
30
31impl ProbeSource {
32 #[must_use]
34 pub fn as_str(self) -> &'static str {
35 match self {
36 Self::LmStudio => "lmstudio",
37 Self::Ollama => "ollama",
38 }
39 }
40}
41
42#[derive(Serialize, Debug)]
44struct Message {
45 role: String,
46 content: String,
47}
48
49#[derive(Serialize, Debug)]
58struct ResponseFormatField {
59 #[serde(rename = "type")]
60 kind: &'static str,
61 json_schema: JsonSchemaSpec,
62}
63
64#[derive(Serialize, Debug)]
71struct JsonSchemaSpec {
72 name: &'static str,
73 strict: bool,
74 schema: serde_json::Value,
75}
76
77#[derive(Serialize, Debug)]
79struct OpenAiRequest {
80 model: String,
81 messages: Vec<Message>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 max_tokens: Option<i32>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 max_completion_tokens: Option<i32>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 temperature: Option<f32>,
88 stream: bool,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 response_format: Option<ResponseFormatField>,
91}
92
93#[derive(Deserialize, Debug)]
95struct Choice {
96 message: ResponseMessage,
97 #[allow(dead_code)] finish_reason: Option<String>,
99}
100
101#[derive(Deserialize, Debug)]
103struct ResponseMessage {
104 #[allow(dead_code)] role: String,
106 content: String,
107}
108
109#[derive(Deserialize, Debug)]
111struct OpenAiResponse {
112 choices: Vec<Choice>,
113 model: Option<String>,
114 usage: Option<Usage>,
115}
116
117#[derive(Deserialize, Debug)]
119#[allow(dead_code)] struct Usage {
121 prompt_tokens: Option<i32>,
122 completion_tokens: Option<i32>,
123 total_tokens: Option<i32>,
124}
125
126pub struct OpenAiAiClient {
128 client: Client,
130 api_key: Option<String>,
132 model: String,
134 base_url: String,
136 max_tokens: Option<i32>,
138 temperature: Option<f32>,
140 active_beta: Option<(String, String)>,
142 loaded_context_length: Option<usize>,
147}
148
149#[derive(Deserialize, Debug)]
151struct LmStudioModelsResponse {
152 data: Vec<LmStudioModel>,
153}
154
155#[derive(Deserialize, Debug)]
157struct LmStudioModel {
158 id: String,
159 state: Option<String>,
160 loaded_context_length: Option<usize>,
161}
162
163impl OpenAiAiClient {
164 pub fn new(
166 model: String,
167 api_key: Option<String>,
168 base_url: String,
169 max_tokens: Option<i32>,
170 temperature: Option<f32>,
171 active_beta: Option<(String, String)>,
172 ) -> Result<Self> {
173 let client = super::build_http_client()?;
174
175 Ok(Self {
176 client,
177 api_key,
178 model,
179 base_url,
180 max_tokens,
181 temperature,
182 active_beta,
183 loaded_context_length: None,
184 })
185 }
186
187 pub fn new_ollama(
189 model: String,
190 base_url: Option<String>,
191 active_beta: Option<(String, String)>,
192 ) -> Result<Self> {
193 Self::new(
194 model,
195 None, base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
197 Some(4096), Some(0.1), active_beta,
200 )
201 }
202
203 pub fn new_openai(
205 model: String,
206 api_key: String,
207 active_beta: Option<(String, String)>,
208 ) -> Result<Self> {
209 Self::new(
210 model,
211 Some(api_key),
212 "https://api.openai.com".to_string(),
213 None, Some(0.1), active_beta,
216 )
217 }
218
219 fn get_max_tokens(&self) -> i32 {
221 if let Some(configured_max) = self.max_tokens {
222 return configured_max;
223 }
224 super::registry_max_output_tokens(&self.model, &self.active_beta)
225 }
226
227 fn get_api_url(&self) -> String {
234 let base = self.base_url.trim_end_matches('/');
235 let url = format!("{base}/v1/chat/completions");
236 debug!(base_url = %self.base_url, full_url = %url, "Constructed OpenAI-compatible API URL");
237 url
238 }
239
240 fn is_ollama(&self) -> bool {
242 self.base_url.contains("localhost")
243 || self.base_url.contains("127.0.0.1")
244 || self.api_key.is_none()
245 }
246
247 fn is_gpt5_series(&self) -> bool {
249 self.model.starts_with("gpt-5") || self.model.starts_with("o1")
250 }
251
252 #[must_use]
254 pub fn loaded_context_length(&self) -> Option<usize> {
255 self.loaded_context_length
256 }
257
258 pub fn set_loaded_context_length(&mut self, value: usize) {
262 self.loaded_context_length = Some(value);
263 }
264
265 fn build_request(
269 &self,
270 system_prompt: &str,
271 user_prompt: &str,
272 response_format: Option<ResponseFormatField>,
273 ) -> OpenAiRequest {
274 let mut messages = Vec::new();
275
276 if !system_prompt.is_empty() {
277 messages.push(Message {
278 role: "system".to_string(),
279 content: system_prompt.to_string(),
280 });
281 }
282
283 messages.push(Message {
284 role: "user".to_string(),
285 content: user_prompt.to_string(),
286 });
287
288 let max_tokens = self.get_max_tokens();
289 if self.is_gpt5_series() {
290 OpenAiRequest {
291 model: self.model.clone(),
292 messages,
293 max_tokens: None,
294 max_completion_tokens: Some(max_tokens),
295 temperature: None,
297 stream: false,
298 response_format,
299 }
300 } else {
301 OpenAiRequest {
302 model: self.model.clone(),
303 messages,
304 max_tokens: Some(max_tokens),
305 max_completion_tokens: None,
306 temperature: self.temperature,
307 stream: false,
308 response_format,
309 }
310 }
311 }
312
313 async fn send_inner(&self, request: OpenAiRequest) -> Result<String> {
321 debug!(
322 max_tokens = ?request.max_tokens,
323 max_completion_tokens = ?request.max_completion_tokens,
324 configured_temperature = ?self.temperature,
325 effective_temperature = ?request.temperature,
326 message_count = request.messages.len(),
327 is_gpt5_series = self.is_gpt5_series(),
328 response_format_set = request.response_format.is_some(),
329 "Built OpenAI-compatible request payload"
330 );
331
332 let api_url = self.get_api_url();
333 info!(url = %api_url, model = %self.model, "Sending request to OpenAI-compatible API");
334
335 let mut req_builder = self
336 .client
337 .post(&api_url)
338 .header("Content-Type", "application/json")
339 .json(&request);
340
341 if let Some(ref api_key) = self.api_key {
342 req_builder = req_builder.header("Authorization", format!("Bearer {api_key}"));
343 }
344
345 let response = req_builder
346 .send()
347 .await
348 .map_err(|e| ClaudeError::NetworkError(e.to_string()))?;
349
350 let response = super::check_error_response(response).await?;
351
352 let openai_response: OpenAiResponse = response
353 .json()
354 .await
355 .map_err(|e| ClaudeError::InvalidResponseFormat(e.to_string()))?;
356
357 debug!(
358 choice_count = openai_response.choices.len(),
359 model = ?openai_response.model,
360 usage = ?openai_response.usage,
361 "Received OpenAI-compatible API response"
362 );
363
364 let result = openai_response
365 .choices
366 .first()
367 .map(|choice| choice.message.content.clone())
368 .ok_or_else(|| {
369 ClaudeError::InvalidResponseFormat("No choices in response".to_string()).into()
370 });
371
372 super::log_response_success("OpenAI-compatible", &result);
373
374 result
375 }
376
377 pub async fn probe_loaded_context_length(&mut self) -> Option<ProbeSource> {
391 let host = host_root(&self.base_url);
392
393 if let Some(value) = probe_lm_studio(&self.client, &host, &self.model).await {
394 self.loaded_context_length = Some(value);
395 return Some(ProbeSource::LmStudio);
396 }
397
398 if let Some(value) = probe_ollama_native(&self.client, &host, &self.model).await {
399 self.loaded_context_length = Some(value);
400 return Some(ProbeSource::Ollama);
401 }
402
403 None
404 }
405}
406
407fn host_root(base_url: &str) -> String {
413 let trimmed = base_url.trim_end_matches('/');
414 trimmed
415 .strip_suffix("/v1")
416 .unwrap_or(trimmed)
417 .trim_end_matches('/')
418 .to_string()
419}
420
421async fn probe_lm_studio(client: &Client, host: &str, model: &str) -> Option<usize> {
426 let url = format!("{host}/api/v0/models");
427 debug!(url = %url, model = %model, "Probing LM Studio for loaded context length");
428
429 let response = client.get(&url).timeout(PROBE_TIMEOUT).send().await.ok()?;
430 if !response.status().is_success() {
431 debug!(status = %response.status(), "LM Studio probe returned non-success");
432 return None;
433 }
434 let body: LmStudioModelsResponse = response.json().await.ok()?;
435 body.data
436 .into_iter()
437 .find(|entry| entry.id == model && entry.state.as_deref() == Some("loaded"))
438 .and_then(|entry| entry.loaded_context_length)
439}
440
441async fn probe_ollama_native(client: &Client, host: &str, model: &str) -> Option<usize> {
446 let url = format!("{host}/api/show");
447 debug!(url = %url, model = %model, "Probing Ollama for loaded context length");
448
449 let response = client
450 .post(&url)
451 .timeout(PROBE_TIMEOUT)
452 .json(&serde_json::json!({ "name": model }))
453 .send()
454 .await
455 .ok()?;
456 if !response.status().is_success() {
457 debug!(status = %response.status(), "Ollama probe returned non-success");
458 return None;
459 }
460 let body: serde_json::Value = response.json().await.ok()?;
461 let model_info = body.get("model_info")?.as_object()?;
462 for (key, value) in model_info {
463 if key.ends_with(".context_length") {
464 if let Some(n) = value.as_u64() {
465 return usize::try_from(n).ok();
466 }
467 }
468 }
469 None
470}
471
472impl AiClient for OpenAiAiClient {
473 fn send_request<'a>(
474 &'a self,
475 system_prompt: &'a str,
476 user_prompt: &'a str,
477 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
478 Box::pin(async move {
479 debug!(
480 system_prompt_len = system_prompt.len(),
481 user_prompt_len = user_prompt.len(),
482 model = %self.model,
483 base_url = %self.base_url,
484 is_ollama = self.is_ollama(),
485 "Preparing OpenAI-compatible API request"
486 );
487
488 let request = self.build_request(system_prompt, user_prompt, None);
489 self.send_inner(request).await
490 })
491 }
492
493 fn capabilities(&self) -> AiClientCapabilities {
494 AiClientCapabilities {
495 supports_response_schema: true,
496 }
497 }
498
499 fn send_request_with_options<'a>(
500 &'a self,
501 system_prompt: &'a str,
502 user_prompt: &'a str,
503 options: RequestOptions,
504 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
505 Box::pin(async move {
506 debug!(
507 system_prompt_len = system_prompt.len(),
508 user_prompt_len = user_prompt.len(),
509 has_schema = options.response_schema.is_some(),
510 model = %self.model,
511 base_url = %self.base_url,
512 is_ollama = self.is_ollama(),
513 "Preparing OpenAI-compatible API request (with options)"
514 );
515
516 let response_format = options.response_schema.map(|schema| ResponseFormatField {
517 kind: "json_schema",
518 json_schema: JsonSchemaSpec {
519 name: "response",
520 strict: true,
521 schema,
522 },
523 });
524
525 let request = self.build_request(system_prompt, user_prompt, response_format);
526 self.send_inner(request).await
527 })
528 }
529
530 fn get_metadata(&self) -> AiClientMetadata {
531 let registry = get_model_registry();
532
533 let max_context_length = if let Some(probed) = self.loaded_context_length {
537 probed
538 } else if registry.get_input_context(&self.model) > 0 {
539 registry.get_input_context(&self.model)
540 } else {
541 32768 };
543
544 let max_response_length = if registry.get_max_output_tokens(&self.model) > 0 {
545 registry.get_max_output_tokens(&self.model)
546 } else {
547 4096 };
549
550 let provider = if self.is_ollama() {
551 "Ollama".to_string()
552 } else {
553 "OpenAI".to_string()
554 };
555
556 AiClientMetadata {
557 provider,
558 model: self.model.clone(),
559 max_context_length,
560 max_response_length,
561 active_beta: self.active_beta.clone(),
562 }
563 }
564}
565
566#[cfg(test)]
567#[allow(clippy::unwrap_used, clippy::expect_used)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn new_ollama() {
573 let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
574 assert_eq!(client.model, "llama2");
575 assert_eq!(client.base_url, "http://localhost:11434");
576 assert!(client.api_key.is_none());
577 assert!(client.is_ollama());
578 }
579
580 #[test]
581 fn new_ollama_custom_url() {
582 let client = OpenAiAiClient::new_ollama(
583 "codellama".to_string(),
584 Some("http://192.168.1.100:11434".to_string()),
585 None,
586 )
587 .unwrap();
588 assert_eq!(client.base_url, "http://192.168.1.100:11434");
589 assert!(client.is_ollama());
590 }
591
592 #[test]
593 fn new_openai() {
594 let client =
595 OpenAiAiClient::new_openai("gpt-4".to_string(), "sk-test123".to_string(), None)
596 .unwrap();
597 assert_eq!(client.model, "gpt-4");
598 assert_eq!(client.base_url, "https://api.openai.com");
599 assert_eq!(client.api_key, Some("sk-test123".to_string()));
600 assert!(!client.is_ollama());
601 }
602
603 #[test]
604 fn get_api_url() {
605 let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
606 let url = client.get_api_url();
607 assert_eq!(url, "http://localhost:11434/v1/chat/completions");
608 }
609
610 #[test]
611 fn get_api_url_trailing_slash() {
612 let client = OpenAiAiClient::new(
613 "test-model".to_string(),
614 None,
615 "http://localhost:11434/".to_string(),
616 None,
617 None,
618 None,
619 )
620 .unwrap();
621 let url = client.get_api_url();
622 assert_eq!(url, "http://localhost:11434/v1/chat/completions");
623 }
624
625 #[test]
626 fn is_ollama_detection() {
627 let ollama_client = OpenAiAiClient::new(
629 "llama2".to_string(),
630 None,
631 "http://localhost:11434".to_string(),
632 None,
633 None,
634 None,
635 )
636 .unwrap();
637 assert!(ollama_client.is_ollama());
638
639 let local_client = OpenAiAiClient::new(
641 "llama2".to_string(),
642 Some("fake-key".to_string()),
643 "http://127.0.0.1:11434".to_string(),
644 None,
645 None,
646 None,
647 )
648 .unwrap();
649 assert!(local_client.is_ollama());
650
651 let no_key_client = OpenAiAiClient::new(
653 "llama2".to_string(),
654 None,
655 "http://remote-server.com".to_string(),
656 None,
657 None,
658 None,
659 )
660 .unwrap();
661 assert!(no_key_client.is_ollama());
662
663 let openai_client = OpenAiAiClient::new(
665 "gpt-4".to_string(),
666 Some("sk-real-key".to_string()),
667 "https://api.openai.com".to_string(),
668 None,
669 None,
670 None,
671 )
672 .unwrap();
673 assert!(!openai_client.is_ollama());
674 }
675
676 #[test]
679 fn gpt5_series_gpt5_models() {
680 let client = OpenAiAiClient::new(
681 "gpt-5-preview".to_string(),
682 Some("key".to_string()),
683 "https://api.openai.com".to_string(),
684 None,
685 None,
686 None,
687 )
688 .unwrap();
689 assert!(client.is_gpt5_series());
690
691 let client2 = OpenAiAiClient::new(
692 "gpt-5".to_string(),
693 Some("key".to_string()),
694 "https://api.openai.com".to_string(),
695 None,
696 None,
697 None,
698 )
699 .unwrap();
700 assert!(client2.is_gpt5_series());
701 }
702
703 #[test]
704 fn gpt5_series_o1_models() {
705 let client = OpenAiAiClient::new(
706 "o1-mini".to_string(),
707 Some("key".to_string()),
708 "https://api.openai.com".to_string(),
709 None,
710 None,
711 None,
712 )
713 .unwrap();
714 assert!(client.is_gpt5_series());
715
716 let client2 = OpenAiAiClient::new(
717 "o1-preview".to_string(),
718 Some("key".to_string()),
719 "https://api.openai.com".to_string(),
720 None,
721 None,
722 None,
723 )
724 .unwrap();
725 assert!(client2.is_gpt5_series());
726 }
727
728 #[test]
729 fn gpt5_series_regular_models_not_matched() {
730 let client = OpenAiAiClient::new(
731 "gpt-4".to_string(),
732 Some("key".to_string()),
733 "https://api.openai.com".to_string(),
734 None,
735 None,
736 None,
737 )
738 .unwrap();
739 assert!(!client.is_gpt5_series());
740
741 let client2 = OpenAiAiClient::new(
742 "gpt-4o-mini".to_string(),
743 Some("key".to_string()),
744 "https://api.openai.com".to_string(),
745 None,
746 None,
747 None,
748 )
749 .unwrap();
750 assert!(!client2.is_gpt5_series());
751 }
752
753 #[test]
756 fn get_max_tokens_configured_value_wins() {
757 let client = OpenAiAiClient::new(
758 "gpt-4".to_string(),
759 Some("key".to_string()),
760 "https://api.openai.com".to_string(),
761 Some(8192),
762 None,
763 None,
764 )
765 .unwrap();
766 assert_eq!(client.get_max_tokens(), 8192);
767 }
768
769 #[test]
770 fn get_max_tokens_from_registry() {
771 let client =
773 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
774 let tokens = client.get_max_tokens();
775 assert!(tokens > 0, "expected positive token limit, got {tokens}");
777 }
778
779 #[test]
782 fn get_metadata_openai() {
783 let client =
784 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
785 let metadata = client.get_metadata();
786 assert_eq!(metadata.provider, "OpenAI");
787 assert_eq!(metadata.model, "gpt-4o");
788 assert!(metadata.active_beta.is_none());
789 }
790
791 #[test]
792 fn get_metadata_ollama() {
793 let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
794 let metadata = client.get_metadata();
795 assert_eq!(metadata.provider, "Ollama");
796 assert_eq!(metadata.model, "llama2");
797 }
798
799 #[test]
800 fn get_metadata_with_beta() {
801 let beta = Some(("anthropic-beta".to_string(), "output-128k".to_string()));
802 let client =
803 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), beta).unwrap();
804 let metadata = client.get_metadata();
805 assert!(metadata.active_beta.is_some());
806 let (key, value) = metadata.active_beta.unwrap();
807 assert_eq!(key, "anthropic-beta");
808 assert_eq!(value, "output-128k");
809 }
810
811 #[test]
814 fn request_gpt5_uses_max_completion_tokens() {
815 let request = OpenAiRequest {
816 model: "gpt-5".to_string(),
817 messages: vec![Message {
818 role: "user".to_string(),
819 content: "hello".to_string(),
820 }],
821 max_tokens: None,
822 max_completion_tokens: Some(4096),
823 temperature: None,
824 stream: false,
825 response_format: None,
826 };
827
828 let json = serde_json::to_string(&request).unwrap();
829 assert!(json.contains("max_completion_tokens"));
830 assert!(!json.contains("\"max_tokens\""));
832 }
833
834 #[test]
835 fn request_regular_model_uses_max_tokens() {
836 let request = OpenAiRequest {
837 model: "gpt-4".to_string(),
838 messages: vec![Message {
839 role: "user".to_string(),
840 content: "hello".to_string(),
841 }],
842 max_tokens: Some(4096),
843 max_completion_tokens: None,
844 temperature: Some(0.1),
845 stream: false,
846 response_format: None,
847 };
848
849 let json = serde_json::to_string(&request).unwrap();
850 assert!(json.contains("\"max_tokens\""));
851 assert!(!json.contains("max_completion_tokens"));
852 assert!(json.contains("\"temperature\""));
853 }
854
855 #[test]
860 fn capabilities_advertise_response_schema_support_openai() {
861 let client =
862 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
863 assert!(client.capabilities().supports_response_schema);
864 }
865
866 #[test]
867 fn capabilities_advertise_response_schema_support_ollama() {
868 let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
869 assert!(client.capabilities().supports_response_schema);
870 }
871
872 #[test]
880 fn build_request_omits_response_format_without_schema() {
881 let client =
882 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
883 let request = client.build_request("sys", "user", None);
884 let body = serde_json::to_value(&request).unwrap();
885 assert!(
886 body.get("response_format").is_none(),
887 "expected response_format to be omitted, got: {body}"
888 );
889 }
890
891 #[test]
894 fn build_request_embeds_response_format_with_schema_regular_model() {
895 let client =
896 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
897 let schema = serde_json::json!({
898 "type": "object",
899 "properties": { "answer": { "type": "string" } },
900 "required": ["answer"],
901 "additionalProperties": false,
902 });
903 let response_format = Some(ResponseFormatField {
904 kind: "json_schema",
905 json_schema: JsonSchemaSpec {
906 name: "response",
907 strict: true,
908 schema: schema.clone(),
909 },
910 });
911 let request = client.build_request("sys", "user", response_format);
912 let body = serde_json::to_value(&request).unwrap();
913 assert_eq!(body["response_format"]["type"], "json_schema");
914 assert_eq!(body["response_format"]["json_schema"]["name"], "response");
915 assert_eq!(body["response_format"]["json_schema"]["strict"], true);
916 assert_eq!(body["response_format"]["json_schema"]["schema"], schema);
917 assert!(body.get("max_tokens").is_some());
919 assert!(body.get("max_completion_tokens").is_none());
920 }
921
922 #[test]
925 fn build_request_embeds_response_format_with_schema_gpt5() {
926 let client =
927 OpenAiAiClient::new_openai("gpt-5".to_string(), "key".to_string(), None).unwrap();
928 let schema = serde_json::json!({ "type": "object", "additionalProperties": false });
929 let response_format = Some(ResponseFormatField {
930 kind: "json_schema",
931 json_schema: JsonSchemaSpec {
932 name: "response",
933 strict: true,
934 schema: schema.clone(),
935 },
936 });
937 let request = client.build_request("sys", "user", response_format);
938 let body = serde_json::to_value(&request).unwrap();
939 assert_eq!(body["response_format"]["type"], "json_schema");
940 assert_eq!(body["response_format"]["json_schema"]["schema"], schema);
941 assert!(body.get("max_completion_tokens").is_some());
943 assert!(body.get("max_tokens").is_none());
944 assert!(body.get("temperature").is_none());
945 }
946
947 #[test]
950 fn host_root_strips_trailing_slash() {
951 assert_eq!(host_root("http://localhost:1234/"), "http://localhost:1234");
952 }
953
954 #[test]
955 fn host_root_strips_v1_suffix() {
956 assert_eq!(
957 host_root("http://localhost:1234/v1"),
958 "http://localhost:1234"
959 );
960 }
961
962 #[test]
963 fn host_root_strips_v1_with_trailing_slash() {
964 assert_eq!(
965 host_root("http://localhost:1234/v1/"),
966 "http://localhost:1234"
967 );
968 }
969
970 #[test]
971 fn host_root_passthrough_when_no_v1() {
972 assert_eq!(
973 host_root("http://localhost:11434"),
974 "http://localhost:11434"
975 );
976 }
977
978 #[test]
981 fn probe_source_as_str_stable() {
982 assert_eq!(ProbeSource::LmStudio.as_str(), "lmstudio");
983 assert_eq!(ProbeSource::Ollama.as_str(), "ollama");
984 }
985
986 #[test]
989 fn metadata_uses_probed_value_when_set() {
990 let mut client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
991 client.set_loaded_context_length(8192);
992 let metadata = client.get_metadata();
993 assert_eq!(metadata.max_context_length, 8192);
994 }
995
996 #[test]
997 fn metadata_falls_back_to_registry_when_probe_value_absent() {
998 let client =
1002 OpenAiAiClient::new_ollama("totally-unknown-model".to_string(), None, None).unwrap();
1003 let metadata = client.get_metadata();
1004 let expected = get_model_registry().get_input_context("totally-unknown-model");
1005 assert_eq!(metadata.max_context_length, expected);
1006 }
1007
1008 #[test]
1009 fn loaded_context_length_starts_unset() {
1010 let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
1011 assert!(client.loaded_context_length().is_none());
1012 }
1013
1014 #[test]
1015 fn loaded_context_length_round_trips() {
1016 let mut client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
1017 client.set_loaded_context_length(4096);
1018 assert_eq!(client.loaded_context_length(), Some(4096));
1019 }
1020
1021 use wiremock::matchers::{body_json, method, path};
1024 use wiremock::{Mock, MockServer, ResponseTemplate};
1025
1026 fn ollama_client_pointing_at(server_uri: &str, model: &str) -> OpenAiAiClient {
1027 OpenAiAiClient::new_ollama(model.to_string(), Some(server_uri.to_string()), None).unwrap()
1028 }
1029
1030 #[tokio::test]
1031 async fn probe_returns_lm_studio_value_when_model_loaded() {
1032 let server = MockServer::start().await;
1033 Mock::given(method("GET"))
1034 .and(path("/api/v0/models"))
1035 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1036 "data": [
1037 {
1038 "id": "llama-3.2-3b-instruct",
1039 "state": "loaded",
1040 "loaded_context_length": 4096_u64,
1041 "max_context_length": 131_072_u64,
1042 }
1043 ]
1044 })))
1045 .mount(&server)
1046 .await;
1047
1048 let mut client = ollama_client_pointing_at(&server.uri(), "llama-3.2-3b-instruct");
1049 let source = client.probe_loaded_context_length().await;
1050 assert_eq!(source, Some(ProbeSource::LmStudio));
1051 assert_eq!(client.loaded_context_length(), Some(4096));
1052 assert_eq!(client.get_metadata().max_context_length, 4096);
1054 }
1055
1056 #[tokio::test]
1057 async fn probe_skips_lm_studio_entry_when_model_not_loaded() {
1058 let server = MockServer::start().await;
1059 Mock::given(method("GET"))
1063 .and(path("/api/v0/models"))
1064 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1065 "data": [
1066 { "id": "model-a", "state": "not-loaded", "loaded_context_length": 4096_u64 }
1067 ]
1068 })))
1069 .mount(&server)
1070 .await;
1071 Mock::given(method("POST"))
1072 .and(path("/api/show"))
1073 .and(body_json(serde_json::json!({ "name": "model-a" })))
1074 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1075 "model_info": { "llama.context_length": 8192_u64 }
1076 })))
1077 .mount(&server)
1078 .await;
1079
1080 let mut client = ollama_client_pointing_at(&server.uri(), "model-a");
1081 let source = client.probe_loaded_context_length().await;
1082 assert_eq!(source, Some(ProbeSource::Ollama));
1083 assert_eq!(client.loaded_context_length(), Some(8192));
1084 }
1085
1086 #[tokio::test]
1087 async fn probe_skips_lm_studio_when_model_id_does_not_match() {
1088 let server = MockServer::start().await;
1089 Mock::given(method("GET"))
1090 .and(path("/api/v0/models"))
1091 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1092 "data": [
1093 { "id": "other-model", "state": "loaded", "loaded_context_length": 4096_u64 }
1094 ]
1095 })))
1096 .mount(&server)
1097 .await;
1098 Mock::given(method("POST"))
1099 .and(path("/api/show"))
1100 .respond_with(ResponseTemplate::new(404))
1101 .mount(&server)
1102 .await;
1103
1104 let mut client = ollama_client_pointing_at(&server.uri(), "wanted-model");
1105 let source = client.probe_loaded_context_length().await;
1106 assert!(source.is_none());
1107 assert!(client.loaded_context_length().is_none());
1108 }
1109
1110 #[tokio::test]
1111 async fn probe_falls_back_to_ollama_native() {
1112 let server = MockServer::start().await;
1113 Mock::given(method("GET"))
1114 .and(path("/api/v0/models"))
1115 .respond_with(ResponseTemplate::new(404))
1116 .mount(&server)
1117 .await;
1118 Mock::given(method("POST"))
1119 .and(path("/api/show"))
1120 .and(body_json(serde_json::json!({ "name": "qwen2.5-coder" })))
1121 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1122 "model_info": {
1123 "general.architecture": "qwen2",
1124 "qwen2.context_length": 32768_u64,
1125 "qwen2.embedding_length": 3584_u64
1126 }
1127 })))
1128 .mount(&server)
1129 .await;
1130
1131 let mut client = ollama_client_pointing_at(&server.uri(), "qwen2.5-coder");
1132 let source = client.probe_loaded_context_length().await;
1133 assert_eq!(source, Some(ProbeSource::Ollama));
1134 assert_eq!(client.loaded_context_length(), Some(32768));
1135 }
1136
1137 #[tokio::test]
1138 async fn probe_returns_none_when_neither_endpoint_responds() {
1139 let server = MockServer::start().await;
1140 Mock::given(method("GET"))
1141 .and(path("/api/v0/models"))
1142 .respond_with(ResponseTemplate::new(500))
1143 .mount(&server)
1144 .await;
1145 Mock::given(method("POST"))
1146 .and(path("/api/show"))
1147 .respond_with(ResponseTemplate::new(500))
1148 .mount(&server)
1149 .await;
1150
1151 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1152 let source = client.probe_loaded_context_length().await;
1153 assert!(source.is_none());
1154 assert!(client.loaded_context_length().is_none());
1155 let registry_value = get_model_registry().get_input_context("anything");
1157 assert_eq!(client.get_metadata().max_context_length, registry_value);
1158 }
1159
1160 #[tokio::test]
1161 async fn probe_returns_none_when_ollama_payload_lacks_context_length_key() {
1162 let server = MockServer::start().await;
1163 Mock::given(method("GET"))
1164 .and(path("/api/v0/models"))
1165 .respond_with(ResponseTemplate::new(404))
1166 .mount(&server)
1167 .await;
1168 Mock::given(method("POST"))
1169 .and(path("/api/show"))
1170 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1171 "model_info": {
1172 "general.architecture": "phantom",
1173 "phantom.embedding_length": 1024_u64
1174 }
1175 })))
1176 .mount(&server)
1177 .await;
1178
1179 let mut client = ollama_client_pointing_at(&server.uri(), "ghost");
1180 let source = client.probe_loaded_context_length().await;
1181 assert!(source.is_none());
1182 }
1183
1184 #[tokio::test]
1185 async fn probe_handles_v1_suffix_in_base_url() {
1186 let server = MockServer::start().await;
1190 Mock::given(method("GET"))
1191 .and(path("/api/v0/models"))
1192 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1193 "data": [
1194 { "id": "lm", "state": "loaded", "loaded_context_length": 2048_u64 }
1195 ]
1196 })))
1197 .mount(&server)
1198 .await;
1199
1200 let base_with_v1 = format!("{}/v1", server.uri());
1201 let mut client = ollama_client_pointing_at(&base_with_v1, "lm");
1202 let source = client.probe_loaded_context_length().await;
1203 assert_eq!(source, Some(ProbeSource::LmStudio));
1204 assert_eq!(client.loaded_context_length(), Some(2048));
1205 }
1206
1207 #[tokio::test]
1208 async fn probe_ignores_lm_studio_entry_with_no_loaded_context_length() {
1209 let server = MockServer::start().await;
1210 Mock::given(method("GET"))
1213 .and(path("/api/v0/models"))
1214 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1215 "data": [ { "id": "x", "state": "loaded", "loaded_context_length": null } ]
1216 })))
1217 .mount(&server)
1218 .await;
1219 Mock::given(method("POST"))
1220 .and(path("/api/show"))
1221 .respond_with(ResponseTemplate::new(404))
1222 .mount(&server)
1223 .await;
1224
1225 let mut client = ollama_client_pointing_at(&server.uri(), "x");
1226 let source = client.probe_loaded_context_length().await;
1227 assert!(source.is_none());
1228 }
1229
1230 #[tokio::test]
1231 async fn probe_returns_none_when_lm_studio_returns_invalid_json() {
1232 let server = MockServer::start().await;
1235 Mock::given(method("GET"))
1236 .and(path("/api/v0/models"))
1237 .respond_with(ResponseTemplate::new(200).set_body_string("<html>not json</html>"))
1238 .mount(&server)
1239 .await;
1240 Mock::given(method("POST"))
1241 .and(path("/api/show"))
1242 .respond_with(ResponseTemplate::new(404))
1243 .mount(&server)
1244 .await;
1245
1246 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1247 let source = client.probe_loaded_context_length().await;
1248 assert!(source.is_none());
1249 }
1250
1251 #[tokio::test]
1252 async fn probe_returns_none_when_ollama_returns_invalid_json() {
1253 let server = MockServer::start().await;
1256 Mock::given(method("GET"))
1257 .and(path("/api/v0/models"))
1258 .respond_with(ResponseTemplate::new(404))
1259 .mount(&server)
1260 .await;
1261 Mock::given(method("POST"))
1262 .and(path("/api/show"))
1263 .respond_with(ResponseTemplate::new(200).set_body_string("{not json"))
1264 .mount(&server)
1265 .await;
1266
1267 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1268 let source = client.probe_loaded_context_length().await;
1269 assert!(source.is_none());
1270 }
1271
1272 #[tokio::test]
1273 async fn probe_returns_none_when_ollama_response_lacks_model_info() {
1274 let server = MockServer::start().await;
1277 Mock::given(method("GET"))
1278 .and(path("/api/v0/models"))
1279 .respond_with(ResponseTemplate::new(404))
1280 .mount(&server)
1281 .await;
1282 Mock::given(method("POST"))
1283 .and(path("/api/show"))
1284 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1285 "details": { "family": "llama" }
1286 })))
1287 .mount(&server)
1288 .await;
1289
1290 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1291 let source = client.probe_loaded_context_length().await;
1292 assert!(source.is_none());
1293 }
1294
1295 #[tokio::test]
1296 async fn probe_returns_none_when_ollama_model_info_is_not_object() {
1297 let server = MockServer::start().await;
1301 Mock::given(method("GET"))
1302 .and(path("/api/v0/models"))
1303 .respond_with(ResponseTemplate::new(404))
1304 .mount(&server)
1305 .await;
1306 Mock::given(method("POST"))
1307 .and(path("/api/show"))
1308 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1309 "model_info": "not an object"
1310 })))
1311 .mount(&server)
1312 .await;
1313
1314 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1315 let source = client.probe_loaded_context_length().await;
1316 assert!(source.is_none());
1317 }
1318
1319 #[tokio::test]
1320 async fn probe_returns_none_when_ollama_context_length_is_not_u64() {
1321 let server = MockServer::start().await;
1325 Mock::given(method("GET"))
1326 .and(path("/api/v0/models"))
1327 .respond_with(ResponseTemplate::new(404))
1328 .mount(&server)
1329 .await;
1330 Mock::given(method("POST"))
1331 .and(path("/api/show"))
1332 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1333 "model_info": { "llama.context_length": "8192" }
1334 })))
1335 .mount(&server)
1336 .await;
1337
1338 let mut client = ollama_client_pointing_at(&server.uri(), "anything");
1339 let source = client.probe_loaded_context_length().await;
1340 assert!(source.is_none());
1341 }
1342
1343 #[tokio::test]
1344 async fn probe_returns_none_when_server_unreachable() {
1345 let mut client = ollama_client_pointing_at("http://127.0.0.1:1", "anything");
1350 let source = client.probe_loaded_context_length().await;
1351 assert!(source.is_none());
1352 }
1353
1354 async fn mock_chat_completion_ok(server: &MockServer) {
1359 Mock::given(method("POST"))
1360 .and(path("/v1/chat/completions"))
1361 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1362 "choices": [
1363 {
1364 "message": { "role": "assistant", "content": "ok" },
1365 "finish_reason": "stop"
1366 }
1367 ],
1368 "model": "test-model"
1369 })))
1370 .mount(server)
1371 .await;
1372 }
1373
1374 fn openai_client_pointing_at(server_uri: &str, model: &str) -> OpenAiAiClient {
1375 OpenAiAiClient::new(
1376 model.to_string(),
1377 Some("test-key".to_string()),
1378 server_uri.to_string(),
1379 Some(1024),
1380 Some(0.1),
1381 None,
1382 )
1383 .unwrap()
1384 }
1385
1386 #[tokio::test]
1394 async fn send_request_with_options_serializes_response_format_on_the_wire() {
1395 let _ = tracing_subscriber::fmt()
1396 .with_max_level(tracing::Level::DEBUG)
1397 .with_test_writer()
1398 .try_init();
1399
1400 let server = MockServer::start().await;
1401 mock_chat_completion_ok(&server).await;
1402
1403 let client = openai_client_pointing_at(&server.uri(), "gpt-4o");
1404 let schema = serde_json::json!({
1405 "type": "object",
1406 "properties": { "answer": { "type": "string" } },
1407 "required": ["answer"],
1408 "additionalProperties": false,
1409 });
1410 let options = RequestOptions::default().with_response_schema(schema.clone());
1411
1412 let result = client
1413 .send_request_with_options("system", "user", options)
1414 .await
1415 .unwrap();
1416 assert_eq!(result, "ok");
1417
1418 let received = server.received_requests().await.unwrap();
1419 assert_eq!(
1420 received.len(),
1421 1,
1422 "expected exactly one chat-completions request"
1423 );
1424 let body: serde_json::Value = serde_json::from_slice(&received[0].body).unwrap();
1425 assert_eq!(body["response_format"]["type"], "json_schema");
1426 assert_eq!(body["response_format"]["json_schema"]["name"], "response");
1427 assert_eq!(body["response_format"]["json_schema"]["strict"], true);
1428 assert_eq!(body["response_format"]["json_schema"]["schema"], schema);
1429 }
1430
1431 #[tokio::test]
1435 async fn send_request_omits_response_format_on_the_wire() {
1436 let server = MockServer::start().await;
1437 mock_chat_completion_ok(&server).await;
1438
1439 let client = openai_client_pointing_at(&server.uri(), "gpt-4o");
1440 let _ = client.send_request("system", "user").await.unwrap();
1441
1442 let received = server.received_requests().await.unwrap();
1443 assert_eq!(received.len(), 1);
1444 let body: serde_json::Value = serde_json::from_slice(&received[0].body).unwrap();
1445 assert!(
1446 body.get("response_format").is_none(),
1447 "expected response_format to be absent from wire body, got: {body}"
1448 );
1449 }
1450
1451 #[test]
1455 fn build_request_skips_empty_system_prompt() {
1456 let client =
1457 OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
1458 let request = client.build_request("", "user prompt", None);
1459 assert_eq!(request.messages.len(), 1);
1460 assert_eq!(request.messages[0].role, "user");
1461 assert_eq!(request.messages[0].content, "user prompt");
1462 }
1463
1464 #[tokio::test]
1470 async fn send_request_propagates_network_error_on_unreachable_server() {
1471 let client = openai_client_pointing_at("http://127.0.0.1:1", "gpt-4o");
1472 let err = client
1473 .send_request("system", "user")
1474 .await
1475 .expect_err("expected network error against closed port");
1476 let chain = format!("{err:#}");
1477 assert!(
1478 chain.to_lowercase().contains("network"),
1479 "expected network-error wording in chain, got: {chain}"
1480 );
1481 }
1482
1483 #[tokio::test]
1488 async fn send_request_propagates_http_error_response() {
1489 let server = MockServer::start().await;
1490 Mock::given(method("POST"))
1491 .and(path("/v1/chat/completions"))
1492 .respond_with(ResponseTemplate::new(500).set_body_string("upstream boom"))
1493 .mount(&server)
1494 .await;
1495
1496 let client = openai_client_pointing_at(&server.uri(), "gpt-4o");
1497 let err = client
1498 .send_request("system", "user")
1499 .await
1500 .expect_err("expected error from 500 response");
1501 let chain = format!("{err:#}");
1502 assert!(
1503 chain.contains("HTTP 500"),
1504 "expected 'HTTP 500' in error chain, got: {chain}"
1505 );
1506 }
1507
1508 #[tokio::test]
1511 async fn send_request_propagates_json_parse_error() {
1512 let server = MockServer::start().await;
1513 Mock::given(method("POST"))
1514 .and(path("/v1/chat/completions"))
1515 .respond_with(ResponseTemplate::new(200).set_body_string("{not valid json"))
1516 .mount(&server)
1517 .await;
1518
1519 let client = openai_client_pointing_at(&server.uri(), "gpt-4o");
1520 let err = client
1521 .send_request("system", "user")
1522 .await
1523 .expect_err("expected error from malformed JSON body");
1524 let chain = format!("{err:#}");
1528 assert!(
1529 chain.contains("Invalid response format"),
1530 "expected 'Invalid response format' in error chain, got: {chain}"
1531 );
1532 }
1533
1534 #[tokio::test]
1539 async fn send_request_errors_when_response_has_no_choices() {
1540 let server = MockServer::start().await;
1541 Mock::given(method("POST"))
1542 .and(path("/v1/chat/completions"))
1543 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1544 "choices": [],
1545 "model": "test-model"
1546 })))
1547 .mount(&server)
1548 .await;
1549
1550 let client = openai_client_pointing_at(&server.uri(), "gpt-4o");
1551 let err = client
1552 .send_request("system", "user")
1553 .await
1554 .expect_err("expected error when choices array is empty");
1555 let chain = format!("{err:#}");
1556 assert!(
1557 chain.contains("No choices in response"),
1558 "expected 'No choices in response' in error chain, got: {chain}"
1559 );
1560 }
1561}