Skip to main content

codetether_agent/provider/
openai.rs

1//! OpenAI provider implementation
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, EmbeddingRequest, EmbeddingResponse,
5    FinishReason, Message, ModelInfo, Provider, Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::Result;
8use async_openai::{
9    Client,
10    config::OpenAIConfig,
11    types::chat::{
12        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
13        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
14        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
15        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
16        CreateChatCompletionRequestArgs, FinishReason as OpenAIFinishReason, FunctionCall,
17        FunctionObjectArgs,
18    },
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22use reqwest::Client as HttpClient;
23use serde_json::Value;
24
25pub struct OpenAIProvider {
26    client: Client<OpenAIConfig>,
27    provider_name: String,
28    api_key: Option<String>,
29    api_base: String,
30    http: HttpClient,
31}
32
33impl std::fmt::Debug for OpenAIProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("OpenAIProvider")
36            .field("provider_name", &self.provider_name)
37            .field("api_base", &self.api_base)
38            .field("client", &"<async_openai::Client>")
39            .finish()
40    }
41}
42
43impl OpenAIProvider {
44    pub fn new(api_key: String) -> Result<Self> {
45        tracing::debug!(
46            provider = "openai",
47            api_key_len = api_key.len(),
48            "Creating OpenAI provider"
49        );
50        let config = OpenAIConfig::new().with_api_key(api_key.clone());
51        let api_base = "https://api.openai.com/v1".to_string();
52        Ok(Self {
53            client: Client::with_config(config),
54            provider_name: "openai".to_string(),
55            api_key: Some(api_key),
56            api_base,
57            http: HttpClient::builder()
58                .timeout(std::time::Duration::from_secs(45))
59                .build()?,
60        })
61    }
62
63    /// Create with custom base URL (for OpenAI-compatible providers like Moonshot)
64    pub fn with_base_url(api_key: String, base_url: String, provider_name: &str) -> Result<Self> {
65        Self::with_base_url_optional_key(Some(api_key), base_url, provider_name)
66    }
67
68    /// Create with custom base URL and optional API key.
69    ///
70    /// Useful for private in-cluster OpenAI-compatible endpoints that rely on
71    /// network policy instead of bearer authentication.
72    pub fn with_base_url_optional_key(
73        api_key: Option<String>,
74        base_url: String,
75        provider_name: &str,
76    ) -> Result<Self> {
77        let api_key = api_key.filter(|key| !key.trim().is_empty());
78        tracing::debug!(
79            provider = provider_name,
80            base_url = %base_url,
81            api_key_len = api_key.as_ref().map(|key| key.len()).unwrap_or(0),
82            "Creating OpenAI-compatible provider"
83        );
84        let config = OpenAIConfig::new()
85            .with_api_key(api_key.clone().unwrap_or_default())
86            .with_api_base(base_url.clone());
87        let api_base = base_url.trim_end_matches('/').to_string();
88        Ok(Self {
89            client: Client::with_config(config),
90            provider_name: provider_name.to_string(),
91            api_key,
92            api_base,
93            http: HttpClient::builder()
94                .timeout(std::time::Duration::from_secs(45))
95                .build()?,
96        })
97    }
98
99    /// Return known models for specific OpenAI-compatible providers
100    fn provider_default_models(&self) -> Vec<ModelInfo> {
101        let models: Vec<(&str, &str)> = match self.provider_name.as_str() {
102            "cerebras" => vec![
103                ("llama3.1-8b", "Llama 3.1 8B"),
104                ("llama-3.3-70b", "Llama 3.3 70B"),
105                ("qwen-3.5-32b", "Qwen 3.5 32B"),
106                ("gpt-oss-120b", "GPT-OSS 120B"),
107            ],
108
109            "minimax" => vec![
110                ("MiniMax-M2.5", "MiniMax M2.5"),
111                ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed"),
112                ("MiniMax-M2.1", "MiniMax M2.1"),
113                ("MiniMax-M2.1-highspeed", "MiniMax M2.1 Highspeed"),
114                ("MiniMax-M2", "MiniMax M2"),
115            ],
116            "zhipuai" => vec![],
117            "novita" => vec![
118                ("Qwen/Qwen3.5-35B-A3B", "Qwen 3.5 35B A3B"),
119                ("deepseek/deepseek-v3-0324", "DeepSeek V3"),
120                ("meta-llama/llama-3.1-70b-instruct", "Llama 3.1 70B"),
121                ("meta-llama/llama-3.1-8b-instruct", "Llama 3.1 8B"),
122            ],
123            _ => vec![],
124        };
125
126        models
127            .into_iter()
128            .map(|(id, name)| ModelInfo {
129                id: id.to_string(),
130                name: name.to_string(),
131                provider: self.provider_name.clone(),
132                context_window: 128_000,
133                max_output_tokens: Some(16_384),
134                supports_vision: false,
135                supports_tools: true,
136                supports_streaming: true,
137                input_cost_per_million: None,
138                output_cost_per_million: None,
139            })
140            .collect()
141    }
142
143    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
144        let url = format!("{}/models", self.api_base);
145        let mut request = self.http.get(&url);
146        if let Some(api_key) = &self.api_key {
147            request = request.bearer_auth(api_key);
148        }
149
150        let response = match request.send().await {
151            Ok(response) => response,
152            Err(error) => {
153                tracing::debug!(
154                    provider = %self.provider_name,
155                    url = %url,
156                    error = %error,
157                    "Failed to fetch OpenAI-compatible /models endpoint"
158                );
159                return Vec::new();
160            }
161        };
162
163        let status = response.status();
164        if !status.is_success() {
165            tracing::debug!(
166                provider = %self.provider_name,
167                url = %url,
168                status = %status,
169                "OpenAI-compatible /models endpoint returned non-success"
170            );
171            return Vec::new();
172        }
173
174        let payload: Value = match response.json().await {
175            Ok(payload) => payload,
176            Err(error) => {
177                tracing::debug!(
178                    provider = %self.provider_name,
179                    url = %url,
180                    error = %error,
181                    "Failed to parse OpenAI-compatible /models response"
182                );
183                return Vec::new();
184            }
185        };
186
187        let models = Self::parse_models_payload(&payload, &self.provider_name);
188        if models.is_empty() {
189            tracing::debug!(
190                provider = %self.provider_name,
191                url = %url,
192                "OpenAI-compatible /models payload did not contain any model ids"
193            );
194        }
195        models
196    }
197
198    fn parse_models_payload(payload: &Value, provider_name: &str) -> Vec<ModelInfo> {
199        payload
200            .get("data")
201            .and_then(Value::as_array)
202            .into_iter()
203            .flatten()
204            .filter_map(|entry| Self::model_info_from_api_entry(entry, provider_name))
205            .collect()
206    }
207
208    fn model_info_from_api_entry(entry: &Value, provider_name: &str) -> Option<ModelInfo> {
209        let id = match entry {
210            Value::String(id) => id.trim(),
211            Value::Object(_) => entry.get("id").and_then(Value::as_str)?.trim(),
212            _ => return None,
213        };
214        if id.is_empty() {
215            return None;
216        }
217
218        let name = entry
219            .get("name")
220            .and_then(Value::as_str)
221            .map(str::trim)
222            .filter(|name| !name.is_empty())
223            .unwrap_or(id);
224
225        let supports_vision = entry
226            .get("supports_vision")
227            .and_then(Value::as_bool)
228            .or_else(|| {
229                entry
230                    .get("input_modalities")
231                    .and_then(Value::as_array)
232                    .map(|modalities| {
233                        modalities.iter().any(|modality| {
234                            modality
235                                .as_str()
236                                .is_some_and(|modality| modality.eq_ignore_ascii_case("image"))
237                        })
238                    })
239            })
240            .unwrap_or(false);
241
242        Some(ModelInfo {
243            id: id.to_string(),
244            name: name.to_string(),
245            provider: provider_name.to_string(),
246            context_window: value_to_usize(
247                entry
248                    .pointer("/limits/max_context_window_tokens")
249                    .or_else(|| entry.get("context_window")),
250            )
251            .unwrap_or(128_000),
252            max_output_tokens: value_to_usize(
253                entry
254                    .pointer("/limits/max_output_tokens")
255                    .or_else(|| entry.get("max_output_tokens")),
256            ),
257            supports_vision,
258            supports_tools: entry
259                .get("supports_tools")
260                .and_then(Value::as_bool)
261                .unwrap_or(true),
262            supports_streaming: entry
263                .get("supports_streaming")
264                .and_then(Value::as_bool)
265                .unwrap_or(true),
266            input_cost_per_million: entry
267                .pointer("/pricing/input_cost_per_million")
268                .and_then(Value::as_f64),
269            output_cost_per_million: entry
270                .pointer("/pricing/output_cost_per_million")
271                .and_then(Value::as_f64),
272        })
273    }
274
275    fn convert_messages(messages: &[Message]) -> Result<Vec<ChatCompletionRequestMessage>> {
276        let mut result = Vec::new();
277
278        for msg in messages {
279            let content = msg
280                .content
281                .iter()
282                .filter_map(|p| match p {
283                    ContentPart::Text { text } => Some(text.clone()),
284                    _ => None,
285                })
286                .collect::<Vec<_>>()
287                .join("\n");
288
289            match msg.role {
290                Role::System => {
291                    result.push(
292                        ChatCompletionRequestSystemMessageArgs::default()
293                            .content(content)
294                            .build()?
295                            .into(),
296                    );
297                }
298                Role::User => {
299                    result.push(
300                        ChatCompletionRequestUserMessageArgs::default()
301                            .content(content)
302                            .build()?
303                            .into(),
304                    );
305                }
306                Role::Assistant => {
307                    let tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
308                        .content
309                        .iter()
310                        .filter_map(|p| match p {
311                            ContentPart::ToolCall {
312                                id,
313                                name,
314                                arguments,
315                                ..
316                            } => Some(ChatCompletionMessageToolCalls::Function(
317                                ChatCompletionMessageToolCall {
318                                    id: id.clone(),
319                                    function: FunctionCall {
320                                        name: name.clone(),
321                                        arguments: arguments.clone(),
322                                    },
323                                },
324                            )),
325                            _ => None,
326                        })
327                        .collect();
328
329                    let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
330                    if !content.is_empty() {
331                        builder.content(content);
332                    }
333                    if !tool_calls.is_empty() {
334                        builder.tool_calls(tool_calls);
335                    }
336                    result.push(builder.build()?.into());
337                }
338                Role::Tool => {
339                    for part in &msg.content {
340                        if let ContentPart::ToolResult {
341                            tool_call_id,
342                            content,
343                        } = part
344                        {
345                            result.push(
346                                ChatCompletionRequestToolMessageArgs::default()
347                                    .tool_call_id(tool_call_id.clone())
348                                    .content(content.clone())
349                                    .build()?
350                                    .into(),
351                            );
352                        }
353                    }
354                }
355            }
356        }
357
358        Ok(result)
359    }
360
361    fn convert_tools(tools: &[ToolDefinition]) -> Result<Vec<ChatCompletionTools>> {
362        let mut result = Vec::new();
363        for tool in tools {
364            result.push(ChatCompletionTools::Function(ChatCompletionTool {
365                function: FunctionObjectArgs::default()
366                    .name(&tool.name)
367                    .description(&tool.description)
368                    .parameters(tool.parameters.clone())
369                    .build()?,
370            }));
371        }
372        Ok(result)
373    }
374
375    fn is_minimax_chat_setting_error(error: &str) -> bool {
376        let normalized = error.to_ascii_lowercase();
377        normalized.contains("invalid chat setting")
378            || normalized.contains("(2013)")
379            || normalized.contains("code: 2013")
380            || normalized.contains("\"2013\"")
381    }
382}
383
384#[async_trait]
385impl Provider for OpenAIProvider {
386    fn name(&self) -> &str {
387        &self.provider_name
388    }
389
390    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
391        // For non-OpenAI providers, return provider-specific model defaults.
392        // Note: async-openai 0.32 does not expose a stable models list API across
393        // all OpenAI-compatible endpoints.
394        if self.provider_name != "openai" {
395            let discovered = self.discover_models_from_api().await;
396            if !discovered.is_empty() {
397                return Ok(discovered);
398            }
399            return Ok(self.provider_default_models());
400        }
401
402        // OpenAI default models
403        Ok(vec![
404            ModelInfo {
405                id: "gpt-4o".to_string(),
406                name: "GPT-4o".to_string(),
407                provider: "openai".to_string(),
408                context_window: 128_000,
409                max_output_tokens: Some(16_384),
410                supports_vision: true,
411                supports_tools: true,
412                supports_streaming: true,
413                input_cost_per_million: Some(2.5),
414                output_cost_per_million: Some(10.0),
415            },
416            ModelInfo {
417                id: "gpt-4o-mini".to_string(),
418                name: "GPT-4o Mini".to_string(),
419                provider: "openai".to_string(),
420                context_window: 128_000,
421                max_output_tokens: Some(16_384),
422                supports_vision: true,
423                supports_tools: true,
424                supports_streaming: true,
425                input_cost_per_million: Some(0.15),
426                output_cost_per_million: Some(0.6),
427            },
428            ModelInfo {
429                id: "o1".to_string(),
430                name: "o1".to_string(),
431                provider: "openai".to_string(),
432                context_window: 200_000,
433                max_output_tokens: Some(100_000),
434                supports_vision: true,
435                supports_tools: true,
436                supports_streaming: true,
437                input_cost_per_million: Some(15.0),
438                output_cost_per_million: Some(60.0),
439            },
440        ])
441    }
442
443    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
444        let messages = Self::convert_messages(&request.messages)?;
445        let tools = Self::convert_tools(&request.tools)?;
446
447        let mut req_builder = CreateChatCompletionRequestArgs::default();
448        req_builder.model(&request.model).messages(messages.clone());
449
450        // Pass tools to the API if provided
451        if !tools.is_empty() {
452            req_builder.tools(tools);
453        }
454        if let Some(temp) = request.temperature {
455            req_builder.temperature(temp);
456        }
457        if let Some(top_p) = request.top_p {
458            req_builder.top_p(top_p);
459        }
460        if let Some(max) = request.max_tokens {
461            if self.provider_name == "openai" {
462                req_builder.max_completion_tokens(max as u32);
463            } else {
464                req_builder.max_tokens(max as u32);
465            }
466        }
467
468        let primary_request = req_builder.build()?;
469        let response = match self.client.chat().create(primary_request).await {
470            Ok(response) => response,
471            Err(err)
472                if self.provider_name == "minimax"
473                    && Self::is_minimax_chat_setting_error(&err.to_string()) =>
474            {
475                tracing::warn!(
476                    provider = "minimax",
477                    error = %err,
478                    "MiniMax rejected chat settings; retrying with conservative defaults"
479                );
480
481                let mut fallback_builder = CreateChatCompletionRequestArgs::default();
482                fallback_builder.model(&request.model).messages(messages);
483                self.client.chat().create(fallback_builder.build()?).await?
484            }
485            Err(err) => return Err(err.into()),
486        };
487
488        let choice = response
489            .choices
490            .first()
491            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
492
493        let mut content = Vec::new();
494        let mut has_tool_calls = false;
495
496        if let Some(text) = &choice.message.content {
497            content.push(ContentPart::Text { text: text.clone() });
498        }
499        if let Some(tool_calls) = &choice.message.tool_calls {
500            has_tool_calls = !tool_calls.is_empty();
501            for tc in tool_calls {
502                if let ChatCompletionMessageToolCalls::Function(func_call) = tc {
503                    content.push(ContentPart::ToolCall {
504                        id: func_call.id.clone(),
505                        name: func_call.function.name.clone(),
506                        arguments: func_call.function.arguments.clone(),
507                        thought_signature: None,
508                    });
509                }
510            }
511        }
512
513        // Determine finish reason based on response
514        let finish_reason = if has_tool_calls {
515            FinishReason::ToolCalls
516        } else {
517            match choice.finish_reason {
518                Some(OpenAIFinishReason::Stop) => FinishReason::Stop,
519                Some(OpenAIFinishReason::Length) => FinishReason::Length,
520                Some(OpenAIFinishReason::ToolCalls) => FinishReason::ToolCalls,
521                Some(OpenAIFinishReason::ContentFilter) => FinishReason::ContentFilter,
522                _ => FinishReason::Stop,
523            }
524        };
525
526        Ok(CompletionResponse {
527            message: Message {
528                role: Role::Assistant,
529                content,
530            },
531            usage: Usage {
532                prompt_tokens: response
533                    .usage
534                    .as_ref()
535                    .map(|u| u.prompt_tokens as usize)
536                    .unwrap_or(0),
537                completion_tokens: response
538                    .usage
539                    .as_ref()
540                    .map(|u| u.completion_tokens as usize)
541                    .unwrap_or(0),
542                total_tokens: response
543                    .usage
544                    .as_ref()
545                    .map(|u| u.total_tokens as usize)
546                    .unwrap_or(0),
547                ..Default::default()
548            },
549            finish_reason,
550        })
551    }
552
553    async fn complete_stream(
554        &self,
555        request: CompletionRequest,
556    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
557        tracing::debug!(
558            provider = %self.provider_name,
559            model = %request.model,
560            message_count = request.messages.len(),
561            "Starting streaming completion request"
562        );
563
564        let messages = Self::convert_messages(&request.messages)?;
565        let tools = Self::convert_tools(&request.tools)?;
566
567        let mut req_builder = CreateChatCompletionRequestArgs::default();
568        req_builder
569            .model(&request.model)
570            .messages(messages)
571            .stream(true);
572
573        if !tools.is_empty() {
574            req_builder.tools(tools);
575        }
576        if let Some(temp) = request.temperature {
577            req_builder.temperature(temp);
578        }
579        if let Some(max) = request.max_tokens {
580            if self.provider_name == "openai" {
581                req_builder.max_completion_tokens(max as u32);
582            } else {
583                req_builder.max_tokens(max as u32);
584            }
585        }
586
587        let stream = self
588            .client
589            .chat()
590            .create_stream(req_builder.build()?)
591            .await?;
592
593        Ok(stream
594            .flat_map(|result| {
595                let chunks: Vec<StreamChunk> = match result {
596                    Ok(response) => {
597                        let mut out = Vec::new();
598                        if let Some(choice) = response.choices.first() {
599                            // Text content delta
600                            if let Some(content) = &choice.delta.content {
601                                if !content.is_empty() {
602                                    out.push(StreamChunk::Text(content.clone()));
603                                }
604                            }
605                            // Tool call deltas
606                            if let Some(tool_calls) = &choice.delta.tool_calls {
607                                for tc in tool_calls {
608                                    if let Some(func) = &tc.function {
609                                        // First chunk for a tool call has id and name
610                                        if let Some(id) = &tc.id {
611                                            out.push(StreamChunk::ToolCallStart {
612                                                id: id.clone(),
613                                                name: func.name.clone().unwrap_or_default(),
614                                            });
615                                        }
616                                        // Argument deltas
617                                        if let Some(args) = &func.arguments {
618                                            if !args.is_empty() {
619                                                // Derive the id from tc.id or use index as fallback
620                                                let id = tc.id.clone().unwrap_or_else(|| {
621                                                    format!("tool_{}", tc.index)
622                                                });
623                                                out.push(StreamChunk::ToolCallDelta {
624                                                    id,
625                                                    arguments_delta: args.clone(),
626                                                });
627                                            }
628                                        }
629                                    }
630                                }
631                            }
632                        }
633                        out
634                    }
635                    Err(e) => vec![StreamChunk::Error(e.to_string())],
636                };
637                futures::stream::iter(chunks)
638            })
639            .boxed())
640    }
641
642    async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
643        if request.inputs.is_empty() {
644            return Ok(EmbeddingResponse {
645                embeddings: Vec::new(),
646                usage: Usage::default(),
647            });
648        }
649
650        let url = format!("{}/embeddings", self.api_base.trim_end_matches('/'));
651        let body = OpenAIEmbeddingRequest {
652            model: request.model,
653            input: request.inputs,
654        };
655
656        let mut request_builder = self.http.post(url);
657        if let Some(api_key) = self.api_key.as_deref().filter(|key| !key.is_empty()) {
658            request_builder = request_builder.bearer_auth(api_key);
659        }
660        let response = request_builder.json(&body).send().await?;
661
662        let status = response.status();
663        let text = response.text().await?;
664        if !status.is_success() {
665            anyhow::bail!(
666                "embedding request failed ({status}): {}",
667                safe_char_prefix(&text, 500)
668            );
669        }
670
671        let mut payload: OpenAIEmbeddingResponse = serde_json::from_str(&text)?;
672        payload.data.sort_by_key(|item| item.index);
673        let embeddings: Vec<Vec<f32>> = payload
674            .data
675            .into_iter()
676            .map(|item| item.embedding)
677            .collect();
678
679        if embeddings.len() != body.input.len() {
680            anyhow::bail!(
681                "embedding response length mismatch: expected {}, got {}",
682                body.input.len(),
683                embeddings.len()
684            );
685        }
686
687        let prompt_tokens = payload.usage.prompt_tokens.unwrap_or(0) as usize;
688        let total_tokens = payload
689            .usage
690            .total_tokens
691            .unwrap_or(payload.usage.prompt_tokens.unwrap_or(0))
692            as usize;
693
694        Ok(EmbeddingResponse {
695            embeddings,
696            usage: Usage {
697                prompt_tokens,
698                completion_tokens: 0,
699                total_tokens,
700                ..Default::default()
701            },
702        })
703    }
704}
705
706fn value_to_usize(value: Option<&Value>) -> Option<usize> {
707    value
708        .and_then(Value::as_u64)
709        .and_then(|value| usize::try_from(value).ok())
710}
711
712fn safe_char_prefix(input: &str, max_chars: usize) -> String {
713    input.chars().take(max_chars).collect()
714}
715
716#[derive(Debug, serde::Serialize)]
717struct OpenAIEmbeddingRequest {
718    model: String,
719    input: Vec<String>,
720}
721
722#[derive(Debug, serde::Deserialize)]
723struct OpenAIEmbeddingResponse {
724    data: Vec<OpenAIEmbeddingData>,
725    #[serde(default)]
726    usage: OpenAIEmbeddingUsage,
727}
728
729#[derive(Debug, serde::Deserialize)]
730struct OpenAIEmbeddingData {
731    index: usize,
732    embedding: Vec<f32>,
733}
734
735#[derive(Debug, Default, serde::Deserialize)]
736struct OpenAIEmbeddingUsage {
737    #[serde(default)]
738    prompt_tokens: Option<u32>,
739    #[serde(default)]
740    total_tokens: Option<u32>,
741}
742
743#[cfg(test)]
744mod tests {
745    use super::{OpenAIProvider, Provider};
746    use serde_json::json;
747
748    #[test]
749    fn detects_minimax_chat_setting_error_variants() {
750        assert!(OpenAIProvider::is_minimax_chat_setting_error(
751            "bad_request_error: invalid params, invalid chat setting (2013)"
752        ));
753        assert!(OpenAIProvider::is_minimax_chat_setting_error(
754            "code: 2013 invalid params"
755        ));
756        assert!(!OpenAIProvider::is_minimax_chat_setting_error(
757            "rate limit exceeded"
758        ));
759    }
760
761    #[test]
762    fn supports_openai_compatible_provider_without_api_key() {
763        let provider = OpenAIProvider::with_base_url_optional_key(
764            None,
765            "http://localhost:8080/v1".to_string(),
766            "huggingface",
767        )
768        .expect("provider should initialize without API key");
769
770        assert_eq!(provider.name(), "huggingface");
771    }
772
773    #[test]
774    fn parses_openai_compatible_models_payload() {
775        let payload = json!({
776            "object": "list",
777            "data": [
778                {
779                    "id": "GLM-5-Turbo",
780                    "name": "GLM-5-Turbo",
781                    "limits": {
782                        "max_context_window_tokens": 200000,
783                        "max_output_tokens": 16384
784                    },
785                    "input_modalities": ["text"]
786                }
787            ]
788        });
789
790        let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
791
792        assert_eq!(models.len(), 1);
793        assert_eq!(models[0].id, "GLM-5-Turbo");
794        assert_eq!(models[0].name, "GLM-5-Turbo");
795        assert_eq!(models[0].provider, "custom-openapi");
796        assert_eq!(models[0].context_window, 200_000);
797        assert_eq!(models[0].max_output_tokens, Some(16_384));
798    }
799
800    #[test]
801    fn parses_string_only_models_payload() {
802        let payload = json!({
803            "data": ["glm-5", "glm-5-turbo"]
804        });
805
806        let models = OpenAIProvider::parse_models_payload(&payload, "custom-openapi");
807
808        assert_eq!(models.len(), 2);
809        assert_eq!(models[0].id, "glm-5");
810        assert_eq!(models[1].id, "glm-5-turbo");
811    }
812}