Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, EmbeddingRequest, EmbeddingResponse,
5    FinishReason, Message, ModelInfo, Provider, Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22use reqwest::Client as HttpClient;
23use serde_json::Value;
24
25pub struct OpenAIProvider {
26    client: Client<OpenAIConfig>,
27    provider_name: String,
28    api_key: Option<String>,
29    api_base: String,
30    http: HttpClient,
31}
32
33impl std::fmt::Debug for OpenAIProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("OpenAIProvider")
36            .field("provider_name", &self.provider_name)
37            .field("api_base", &self.api_base)
38            .field("client", &"<async_openai::Client>")
39            .finish()
40    }
41}
42
43impl OpenAIProvider {
44    pub fn new(api_key: String) -> Result<Self> {
45        tracing::debug!(
46            provider = "openai",
47            api_key_len = api_key.len(),
48            "Creating OpenAI provider"
49        );
50        let config = OpenAIConfig::new().with_api_key(api_key.clone());
51        let api_base = "https://api.openai.com/v1".to_string();
52        Ok(Self {
53            client: Client::with_config(config),
54            provider_name: "openai".to_string(),
55            api_key: Some(api_key),
56            api_base,
57            http: HttpClient::builder()
58                .timeout(std::time::Duration::from_secs(45))
59                .build()?,
60        })
61    }
62
63    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
64    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
65        Self::with_base_url_optional_key(Some(api_key), base_url, provider_name)
66    }
67
68    /// Create with custom base URL and optional API key.
69    ///
70    /// Useful for private in-cluster OpenAI-compatible endpoints that rely on
71    /// network policy instead of bearer authentication.
72    pub fn with_base_url_optional_key(
73        api_key: Option<String>,
74        base_url: String,
75        provider_name: &str,
76    ) -> Result<Self> {
77        let api_key = api_key.filter(|key| !key.trim().is_empty());
78        tracing::debug!(
79            provider = provider_name,
80            base_url = %base_url,
81            api_key_len = api_key.as_ref().map(|key| key.len()).unwrap_or(0),
82            "Creating OpenAI-compatible provider"
83        );
84        let config = OpenAIConfig::new()
85            .with_api_key(api_key.clone().unwrap_or_default())
86            .with_api_base(base_url.clone());
87        let api_base = base_url.trim_end_matches('/').to_string();
88        Ok(Self {
89            client: Client::with_config(config),
90            provider_name: provider_name.to_string(),
91            api_key,
92            api_base,
93            http: HttpClient::builder()
94                .timeout(std::time::Duration::from_secs(45))
95                .build()?,
96        })
97    }
98
99    /// Return known models for specific OpenAI-compatible providers
100    fn provider_default_models(&self) -> Vec<ModelInfo> {
101        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
102            "cerebras" => vec![
103                ("llama3.1-8b", "Llama 3.1 8B"),
104                ("llama-3.3-70b", "Llama 3.3 70B"),
105                ("qwen-3.5-32b", "Qwen 3.5 32B"),
106                ("gpt-oss-120b", "GPT-OSS 120B"),
107            ],
108
109            "minimax" => vec![
110                ("MiniMax-M2.5", "MiniMax M2.5"),
111                ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed"),
112                ("MiniMax-M2.1", "MiniMax M2.1"),
113                ("MiniMax-M2.1-highspeed", "MiniMax M2.1 Highspeed"),
114                ("MiniMax-M2", "MiniMax M2"),
115            ],
116            "zhipuai" => vec![],
117            "novita" => vec![
118                ("Qwen/Qwen3.5-35B-A3B", "Qwen 3.5 35B A3B"),
119                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
120                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
121                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
122            ],
123            _ => vec![],
124        };
125
126        models
127            .into_iter()
128            .map(|(id, name)| ModelInfo {
129                id: id.to_string(),
130                name: name.to_string(),
131                provider: self.provider_name.clone(),
132                context_window: 128_000,
133                max_output_tokens: Some(16_384),
134                supports_vision: false,
135                supports_tools: true,
136                supports_streaming: true,
137                input_cost_per_million: None,
138                output_cost_per_million: None,
139            })
140            .collect()
141    }
142
143    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
144        let url = format!("{}/models", self.api_base);
145        let mut request = self.http.get(&url);
146        if let Some(api_key) = &self.api_key {
147            request = request.bearer_auth(api_key);
148        }
149
150        let response = match request.send().await {
151            Ok(response) => response,
152            Err(error) => {
153                tracing::debug!(
154                    provider = %self.provider_name,
155                    url = %url,
156                    error = %error,
157                    "Failed to fetch OpenAI-compatible /models endpoint"
158                );
159                return Vec::new();
160            }
161        };
162
163        let status = response.status();
164        if !status.is_success() {
165            tracing::debug!(
166                provider = %self.provider_name,
167                url = %url,
168                status = %status,
169                "OpenAI-compatible /models endpoint returned non-success"
170            );
171            return Vec::new();
172        }
173
174        let payload: Value = match crate::provider::body_cap::json_capped(
175            response,
176            crate::provider::body_cap::PROVIDER_METADATA_BODY_CAP,
177        )
178        .await
179        {
180            Ok(payload) => payload,
181            Err(error) => {
182                tracing::debug!(
183                    provider = %self.provider_name,
184                    url = %url,
185                    error = %error,
186                    "Failed to parse OpenAI-compatible /models response (or exceeded body cap)"
187                );
188                return Vec::new();
189            }
190        };
191
192        let models = Self::parse_models_payload(&payload, &self.provider_name);
193        if models.is_empty() {
194            tracing::debug!(
195                provider = %self.provider_name,
196                url = %url,
197                "OpenAI-compatible /models payload did not contain any model ids"
198            );
199        }
200        models
201    }
202
203    fn parse_models_payload(payload: &Value, provider_name: &str) -> Vec<ModelInfo> {
204        payload
205            .get("data")
206            .and_then(Value::as_array)
207            .into_iter()
208            .flatten()
209            .filter_map(|entry| Self::model_info_from_api_entry(entry, provider_name))
210            .collect()
211    }
212
213    fn model_info_from_api_entry(entry: &Value, provider_name: &str) -> Option<ModelInfo> {
214        let id = match entry {
215            Value::String(id) => id.trim(),
216            Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim(),
217            _ => return None,
218        };
219        if id.is_empty() {
220            return None;
221        }
222
223        let name = entry
224            .get("name")
225            .and_then(Value::as_str)
226            .map(str::trim)
227            .filter(|name| !name.is_empty())
228            .unwrap_or(id);
229
230        let supports_vision = entry
231            .get("supports_vision")
232            .and_then(Value::as_bool)
233            .or_else(|| {
234                entry
235                    .get("input_modalities")
236                    .and_then(Value::as_array)
237                    .map(|modalities| {
238                        modalities.iter().any(|modality| {
239                            modality
240                                .as_str()
241                                .is_some_and(|modality| modality.eq_ignore_ascii_case("image"))
242                        })
243                    })
244            })
245            .unwrap_or(false);
246
247        Some(ModelInfo {
248            id: id.to_string(),
249            name: name.to_string(),
250            provider: provider_name.to_string(),
251            context_window: value_to_usize(
252                entry
253                    .pointer("/limits/max_context_window_tokens")
254                    .or_else(|| entry.get("context_window")),
255            )
256            .unwrap_or(128_000),
257            max_output_tokens: value_to_usize(
258                entry
259                    .pointer("/limits/max_output_tokens")
260                    .or_else(|| entry.get("max_output_tokens")),
261            ),
262            supports_vision,
263            supports_tools: entry
264                .get("supports_tools")
265                .and_then(Value::as_bool)
266                .unwrap_or(true),
267            supports_streaming: entry
268                .get("supports_streaming")
269                .and_then(Value::as_bool)
270                .unwrap_or(true),
271            input_cost_per_million: entry
272                .pointer("/pricing/input_cost_per_million")
273                .and_then(Value::as_f64),
274            output_cost_per_million: entry
275                .pointer("/pricing/output_cost_per_million")
276                .and_then(Value::as_f64),
277        })
278    }
279
280    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
281        let mut result = Vec::new();
282
283        for msg in messages {
284            let content = msg
285                .content
286                .iter()
287                .filter_map(|p| match p {
288                    ContentPart::Text { text } => Some(text.clone()),
289                    _ => None,
290                })
291                .collect::<Vec<_>>()
292                .join("\n");
293
294            match msg.role {
295                Role::System => {
296                    result.push(
297                        ChatCompletionRequestSystemMessageArgs::default()
298                            .content(content)
299                            .build()?
300                            .into(),
301                    );
302                }
303                Role::User => {
304                    result.push(
305                        ChatCompletionRequestUserMessageArgs::default()
306                            .content(content)
307                            .build()?
308                            .into(),
309                    );
310                }
311                Role::Assistant => {
312                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
313                        .content
314                        .iter()
315                        .filter_map(|p| match p {
316                            ContentPart::ToolCall {
317                                id,
318                                name,
319                                arguments,
320                                ..
321                            } => Some(ChatCompletionMessageToolCalls::Function(
322                                ChatCompletionMessageToolCall {
323                                    id: id.clone(),
324                                    function: FunctionCall {
325                                        name: name.clone(),
326                                        arguments: arguments.clone(),
327                                    },
328                                },
329                            )),
330                            _ => None,
331                        })
332                        .collect();
333
334                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
335                    if !content.is_empty() {
336                        builder.content(content);
337                    }
338                    if !tool_calls.is_empty() {
339                        builder.tool_calls(tool_calls);
340                    }
341                    result.push(builder.build()?.into());
342                }
343                Role::Tool => {
344                    for part in &msg.content {
345                        if let ContentPart::ToolResult {
346                            tool_call_id,
347                            content,
348                        } = part
349                        {
350                            result.push(
351                                ChatCompletionRequestToolMessageArgs::default()
352                                    .tool_call_id(tool_call_id.clone())
353                                    .content(content.clone())
354                                    .build()?
355                                    .into(),
356                            );
357                        }
358                    }
359                }
360            }
361        }
362
363        Ok(result)
364    }
365
366    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
367        let mut result = Vec::new();
368        for tool in tools {
369            result.push(ChatCompletionTools::Function(ChatCompletionTool {
370                function: FunctionObjectArgs::default()
371                    .name(&tool.name)
372                    .description(&tool.description)
373                    .parameters(tool.parameters.clone())
374                    .build()?,
375            }));
376        }
377        Ok(result)
378    }
379
380    fn is_minimax_chat_setting_error(error: &str) -> bool {
381        let normalized = error.to_ascii_lowercase();
382        normalized.contains("invalid chat setting")
383            || normalized.contains("(2013)")
384            || normalized.contains("code: 2013")
385            || normalized.contains("\"2013\"")
386    }
387}
388
389#[async_trait]
390impl Provider for OpenAIProvider {
391    fn name(&self) -> &str {
392        &self.provider_name
393    }
394
395    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
396        // For non-OpenAI providers, return provider-specific model defaults.
397        // Note: async-openai 0.32 does not expose a stable models list API across
398        // all OpenAI-compatible endpoints.
399        if self.provider_name != "openai" {
400            let discovered = self.discover_models_from_api().await;
401            if !discovered.is_empty() {
402                return Ok(discovered);
403            }
404            return Ok(self.provider_default_models());
405        }
406
407        // OpenAI default models
408        Ok(vec![
409            ModelInfo {
410                id: "gpt-4o".to_string(),
411                name: "GPT-4o".to_string(),
412                provider: "openai".to_string(),
413                context_window: 128_000,
414                max_output_tokens: Some(16_384),
415                supports_vision: true,
416                supports_tools: true,
417                supports_streaming: true,
418                input_cost_per_million: Some(2.5),
419                output_cost_per_million: Some(10.0),
420            },
421            ModelInfo {
422                id: "gpt-4o-mini".to_string(),
423                name: "GPT-4o Mini".to_string(),
424                provider: "openai".to_string(),
425                context_window: 128_000,
426                max_output_tokens: Some(16_384),
427                supports_vision: true,
428                supports_tools: true,
429                supports_streaming: true,
430                input_cost_per_million: Some(0.15),
431                output_cost_per_million: Some(0.6),
432            },
433            ModelInfo {
434                id: "o1".to_string(),
435                name: "o1".to_string(),
436                provider: "openai".to_string(),
437                context_window: 200_000,
438                max_output_tokens: Some(100_000),
439                supports_vision: true,
440                supports_tools: true,
441                supports_streaming: true,
442                input_cost_per_million: Some(15.0),
443                output_cost_per_million: Some(60.0),
444            },
445        ])
446    }
447
448    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
449        let messages = Self::convert_messages(&request.messages)?;
450        let tools = Self::convert_tools(&request.tools)?;
451
452        let mut req_builder = CreateChatCompletionRequestArgs::default();
453        req_builder.model(&request.model).messages(messages.clone());
454
455        // Pass tools to the API if provided
456        if !tools.is_empty() {
457            req_builder.tools(tools);
458        }
459        if let Some(temp) = request.temperature {
460            req_builder.temperature(temp);
461        }
462        if let Some(top_p) = request.top_p {
463            req_builder.top_p(top_p);
464        }
465        if let Some(max) = request.max_tokens {
466            if self.provider_name == "openai" {
467                req_builder.max_completion_tokens(max as u32);
468            } else {
469                req_builder.max_tokens(max as u32);
470            }
471        }
472
473        let primary_request = req_builder.build()?;
474        let response = match self.client.chat().create(primary_request).await {
475            Ok(response) => response,
476            Err(err)
477                if self.provider_name == "minimax"
478                    && Self::is_minimax_chat_setting_error(&err.to_string()) =>
479            {
480                tracing::warn!(
481                    provider = "minimax",
482                    error = %err,
483                    "MiniMax rejected chat settings; retrying with conservative defaults"
484                );
485
486                let mut fallback_builder = CreateChatCompletionRequestArgs::default();
487                fallback_builder.model(&request.model).messages(messages);
488                self.client.chat().create(fallback_builder.build()?).await?
489            }
490            Err(err) => return Err(err.into()),
491        };
492
493        let choice = response
494            .choices
495            .first()
496            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
497
498        let mut content = Vec::new();
499        let mut has_tool_calls = false;
500
501        if let Some(text) = &choice.message.content {
502            content.push(ContentPart::Text { text: text.clone() });
503        }
504        if let Some(tool_calls) = &choice.message.tool_calls {
505            has_tool_calls = !tool_calls.is_empty();
506            for tc in tool_calls {
507                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
508                    content.push(ContentPart::ToolCall {
509                        id: func_call.id.clone(),
510                        name: func_call.function.name.clone(),
511                        arguments: func_call.function.arguments.clone(),
512                        thought_signature: None,
513                    });
514                }
515            }
516        }
517
518        // Determine finish reason based on response
519        let finish_reason = if has_tool_calls {
520            FinishReason::ToolCalls
521        } else {
522            match choice.finish_reason {
523                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
524                Some(OpenAIFinishReason::Length) => FinishReason::Length,
525                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
526                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
527                _ => FinishReason::Stop,
528            }
529        };
530
531        Ok(CompletionResponse {
532            message: Message {
533                role: Role::Assistant,
534                content,
535            },
536            usage: Usage {
537                prompt_tokens: response
538                    .usage
539                    .as_ref()
540                    .map(|u| u.prompt_tokens as usize)
541                    .unwrap_or(0),
542                completion_tokens: response
543                    .usage
544                    .as_ref()
545                    .map(|u| u.completion_tokens as usize)
546                    .unwrap_or(0),
547                total_tokens: response
548                    .usage
549                    .as_ref()
550                    .map(|u| u.total_tokens as usize)
551                    .unwrap_or(0),
552                ..Default::default()
553            },
554            finish_reason,
555        })
556    }
557
558    async fn complete_stream(
559        &self,
560        request: CompletionRequest,
561    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
562        tracing::debug!(
563            provider = %self.provider_name,
564            model = %request.model,
565            message_count = request.messages.len(),
566            "Starting streaming completion request"
567        );
568
569        let messages = Self::convert_messages(&request.messages)?;
570        let tools = Self::convert_tools(&request.tools)?;
571
572        let mut req_builder = CreateChatCompletionRequestArgs::default();
573        req_builder
574            .model(&request.model)
575            .messages(messages)
576            .stream(true);
577
578        if !tools.is_empty() {
579            req_builder.tools(tools);
580        }
581        if let Some(temp) = request.temperature {
582            req_builder.temperature(temp);
583        }
584        if let Some(max) = request.max_tokens {
585            if self.provider_name == "openai" {
586                req_builder.max_completion_tokens(max as u32);
587            } else {
588                req_builder.max_tokens(max as u32);
589            }
590        }
591
592        let stream = self
593            .client
594            .chat()
595            .create_stream(req_builder.build()?)
596            .await?;
597
598        Ok(stream
599            .flat_map(|result| {
600                let chunks: Vec<StreamChunk> = match result {
601                    Ok(response) => {
602                        let mut out = Vec::new();
603                        if let Some(choice) = response.choices.first() {
604                            // Text content delta
605                            if let Some(content) = &choice.delta.content {
606                                if !content.is_empty() {
607                                    out.push(StreamChunk::Text(content.clone()));
608                                }
609                            }
610                            // Tool call deltas
611                            if let Some(tool_calls) = &choice.delta.tool_calls {
612                                for tc in tool_calls {
613                                    if let Some(func) = &tc.function {
614                                        // First chunk for a tool call has id and name
615                                        if let Some(id) = &tc.id {
616                                            out.push(StreamChunk::ToolCallStart {
617                                                id: id.clone(),
618                                                name: func.name.clone().unwrap_or_default(),
619                                            });
620                                        }
621                                        // Argument deltas
622                                        if let Some(args) = &func.arguments {
623                                            if !args.is_empty() {
624                                                // Derive the id from tc.id or use index as fallback
625                                                let id = tc.id.clone().unwrap_or_else(|| {
626                                                    format!("tool_{}", tc.index)
627                                                });
628                                                out.push(StreamChunk::ToolCallDelta {
629                                                    id,
630                                                    arguments_delta: args.clone(),
631                                                });
632                                            }
633                                        }
634                                    }
635                                }
636                            }
637                        }
638                        out
639                    }
640                    Err(e) => vec![StreamChunk::Error(e.to_string())],
641                };
642                futures::stream::iter(chunks)
643            })
644            .boxed())
645    }
646
647    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
648        if request.inputs.is_empty() {
649            return Ok(EmbeddingResponse {
650                embeddings: Vec::new(),
651                usage: Usage::default(),
652            });
653        }
654
655        let url = format!("{}/embeddings", self.api_base.trim_end_matches('/'));
656        let body = OpenAIEmbeddingRequest {
657            model: request.model,
658            input: request.inputs,
659        };
660
661        let mut request_builder = self.http.post(url);
662        if let Some(api_key) = self.api_key.as_deref().filter(|key| !key.is_empty()) {
663            request_builder = request_builder.bearer_auth(api_key);
664        }
665        let response = request_builder.json(&body).send().await?;
666
667        let status = response.status();
668        let text = response.text().await?;
669        if !status.is_success() {
670            anyhow::bail!(
671                "embedding request failed ({status}): {}",
672                safe_char_prefix(&text, 500)
673            );
674        }
675
676        let mut payload: OpenAIEmbeddingResponse = serde_json::from_str(&text)?;
677        payload.data.sort_by_key(|item| item.index);
678        let embeddings: Vec<Vec<f32>> = payload
679            .data
680            .into_iter()
681            .map(|item| item.embedding)
682            .collect();
683
684        if embeddings.len() != body.input.len() {
685            anyhow::bail!(
686                "embedding response length mismatch: expected {}, got {}",
687                body.input.len(),
688                embeddings.len()
689            );
690        }
691
692        let prompt_tokens = payload.usage.prompt_tokens.unwrap_or(0) as usize;
693        let total_tokens = payload
694            .usage
695            .total_tokens
696            .unwrap_or(payload.usage.prompt_tokens.unwrap_or(0))
697            as usize;
698
699        Ok(EmbeddingResponse {
700            embeddings,
701            usage: Usage {
702                prompt_tokens,
703                completion_tokens: 0,
704                total_tokens,
705                ..Default::default()
706            },
707        })
708    }
709}
710
711fn value_to_usize(value: Option<&Value>) -> Option<usize> {
712    value
713        .and_then(Value::as_u64)
714        .and_then(|value| usize::try_from(value).ok())
715}
716
717fn safe_char_prefix(input: &str, max_chars: usize) -> String {
718    input.chars().take(max_chars).collect()
719}
720
721#[derive(Debug, serde::Serialize)]
722struct OpenAIEmbeddingRequest {
723    model: String,
724    input: Vec<String>,
725}
726
727#[derive(Debug, serde::Deserialize)]
728struct OpenAIEmbeddingResponse {
729    data: Vec<OpenAIEmbeddingData>,
730    #[serde(default)]
731    usage: OpenAIEmbeddingUsage,
732}
733
734#[derive(Debug, serde::Deserialize)]
735struct OpenAIEmbeddingData {
736    index: usize,
737    embedding: Vec<f32>,
738}
739
740#[derive(Debug, Default, serde::Deserialize)]
741struct OpenAIEmbeddingUsage {
742    #[serde(default)]
743    prompt_tokens: Option<u32>,
744    #[serde(default)]
745    total_tokens: Option<u32>,
746}
747
748#[cfg(test)]
749mod tests {
750    use super::{OpenAIProvider, Provider};
751    use serde_json::json;
752
753    #[test]
754    fn detects_minimax_chat_setting_error_variants() {
755        assert!(OpenAIProvider::is_minimax_chat_setting_error(
756            "bad_request_error: invalid params, invalid chat setting (2013)"
757        ));
758        assert!(OpenAIProvider::is_minimax_chat_setting_error(
759            "code: 2013 invalid params"
760        ));
761        assert!(!OpenAIProvider::is_minimax_chat_setting_error(
762            "rate limit exceeded"
763        ));
764    }
765
766    #[test]
767    fn supports_openai_compatible_provider_without_api_key() {
768        let provider = OpenAIProvider::with_base_url_optional_key(
769            None,
770            "http://localhost:8080/v1".to_string(),
771            "huggingface",
772        )
773        .expect("provider should initialize without API key");
774
775        assert_eq!(provider.name(), "huggingface");
776    }
777
778    #[test]
779    fn parses_openai_compatible_models_payload() {
780        let payload = json!({
781            "object": "list",
782            "data": [
783                {
784                    "id": "GLM-5-Turbo",
785                    "name": "GLM-5-Turbo",
786                    "limits": {
787                        "max_context_window_tokens": 200000,
788                        "max_output_tokens": 16384
789                    },
790                    "input_modalities": ["text"]
791                }
792            ]
793        });
794
795        let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
796
797        assert_eq!(models.len(), 1);
798        assert_eq!(models[0].id, "GLM-5-Turbo");
799        assert_eq!(models[0].name, "GLM-5-Turbo");
800        assert_eq!(models[0].provider, "custom-openapi");
801        assert_eq!(models[0].context_window, 200_000);
802        assert_eq!(models[0].max_output_tokens, Some(16_384));
803    }
804
805    #[test]
806    fn parses_string_only_models_payload() {
807        let payload = json!({
808            "data": ["glm-5", "glm-5-turbo"]
809        });
810
811        let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
812
813        assert_eq!(models.len(), 2);
814        assert_eq!(models[0].id, "glm-5");
815        assert_eq!(models[1].id, "glm-5-turbo");
816    }
817}