Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18    client: Client,
19    token: String,
20    base_url: String,
21    provider_name: String,
22}
23
24impl CopilotProvider {
25    pub fn new(token: String) -> Result<Self> {
26        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27    }
28
29    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30        let base_url = enterprise_base_url(&enterprise_url);
31        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32    }
33
34    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35        Ok(Self {
36            client: Client::new(),
37            token,
38            base_url: base_url.trim_end_matches('/').to_string(),
39            provider_name: provider_name.to_string(),
40        })
41    }
42
43    fn user_agent() -> String {
44        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45    }
46
47    fn convert_messages(messages: &[Message]) -> Vec<Value> {
48        messages
49            .iter()
50            .map(|msg| {
51                let role = match msg.role {
52                    Role::System => "system",
53                    Role::User => "user",
54                    Role::Assistant => "assistant",
55                    Role::Tool => "tool",
56                };
57
58                match msg.role {
59                    Role::Tool => {
60                        if let Some(ContentPart::ToolResult {
61                            tool_call_id,
62                            content,
63                        }) = msg.content.first()
64                        {
65                            json!({
66                                "role": "tool",
67                                "tool_call_id": tool_call_id,
68                                "content": content
69                            })
70                        } else {
71                            json!({ "role": role, "content": "" })
72                        }
73                    }
74                    Role::Assistant => {
75                        let text: String = msg
76                            .content
77                            .iter()
78                            .filter_map(|p| match p {
79                                ContentPart::Text { text } => Some(text.clone()),
80                                _ => None,
81                            })
82                            .collect::<Vec<_>>()
83                            .join("");
84
85                        let tool_calls: Vec<Value> = msg
86                            .content
87                            .iter()
88                            .filter_map(|p| match p {
89                                ContentPart::ToolCall {
90                                    id,
91                                    name,
92                                    arguments,
93                                    ..
94                                } => Some(json!({
95                                    "id": id,
96                                    "type": "function",
97                                    "function": {
98                                        "name": name,
99                                        "arguments": arguments
100                                    }
101                                })),
102                                _ => None,
103                            })
104                            .collect();
105
106                        if tool_calls.is_empty() {
107                            json!({ "role": "assistant", "content": text })
108                        } else {
109                            json!({
110                                "role": "assistant",
111                                "content": if text.is_empty() { "".to_string() } else { text },
112                                "tool_calls": tool_calls
113                            })
114                        }
115                    }
116                    _ => {
117                        let text: String = msg
118                            .content
119                            .iter()
120                            .filter_map(|p| match p {
121                                ContentPart::Text { text } => Some(text.clone()),
122                                _ => None,
123                            })
124                            .collect::<Vec<_>>()
125                            .join("\n");
126                        json!({ "role": role, "content": text })
127                    }
128                }
129            })
130            .collect()
131    }
132
133    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
134        tools
135            .iter()
136            .map(|t| {
137                json!({
138                    "type": "function",
139                    "function": {
140                        "name": t.name,
141                        "description": t.description,
142                        "parameters": t.parameters
143                    }
144                })
145            })
146            .collect()
147    }
148
149    fn is_agent_initiated(messages: &[Message]) -> bool {
150        messages
151            .iter()
152            .rev()
153            .find(|msg| msg.role != Role::System)
154            .map(|msg| msg.role != Role::User)
155            .unwrap_or(false)
156    }
157
158    fn has_vision_input(messages: &[Message]) -> bool {
159        messages.iter().any(|msg| {
160            msg.content
161                .iter()
162                .any(|part| matches!(part, ContentPart::Image { .. }))
163        })
164    }
165
166    /// Discover models dynamically from the Copilot /models API endpoint.
167    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
168        let response = match self
169            .client
170            .get(format!("{}/models", self.base_url))
171            .header("Authorization", format!("Bearer {}", self.token))
172            .header("User-Agent", Self::user_agent())
173            .send()
174            .await
175        {
176            Ok(r) => r,
177            Err(e) => {
178                tracing::warn!(provider = %self.provider_name, error = %e, "Failed to fetch Copilot models endpoint");
179                return Vec::new();
180            }
181        };
182
183        let status = response.status();
184        if !status.is_success() {
185            let body = response.text().await.unwrap_or_default();
186            tracing::warn!(
187                provider = %self.provider_name,
188                status = %status,
189                body = %body.chars().take(200).collect::<String>(),
190                "Copilot /models endpoint returned non-success"
191            );
192            return Vec::new();
193        }
194
195        let parsed: CopilotModelsResponse = match response.json().await {
196            Ok(p) => p,
197            Err(e) => {
198                tracing::warn!(provider = %self.provider_name, error = %e, "Failed to parse Copilot models response");
199                return Vec::new();
200            }
201        };
202
203        let models: Vec<ModelInfo> = parsed
204            .data
205            .into_iter()
206            .filter(|model| {
207                // Skip models explicitly disabled in the picker
208                if model.model_picker_enabled == Some(false) {
209                    return false;
210                }
211                // Skip models with a disabled policy state
212                if let Some(ref policy) = model.policy {
213                    if policy.state.as_deref() == Some("disabled") {
214                        return false;
215                    }
216                }
217                true
218            })
219            .map(|model| {
220                let caps = model.capabilities.as_ref();
221                let limits = caps.and_then(|c| c.limits.as_ref());
222                let supports = caps.and_then(|c| c.supports.as_ref());
223
224                ModelInfo {
225                    id: model.id.clone(),
226                    name: model.name.unwrap_or_else(|| model.id.clone()),
227                    provider: self.provider_name.clone(),
228                    context_window: limits
229                        .and_then(|l| l.max_context_window_tokens)
230                        .unwrap_or(128_000),
231                    max_output_tokens: limits.and_then(|l| l.max_output_tokens).or(Some(16_384)),
232                    supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
233                    supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
234                    supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
235                    input_cost_per_million: None,
236                    output_cost_per_million: None,
237                }
238            })
239            .collect();
240
241        tracing::info!(
242            provider = %self.provider_name,
243            count = models.len(),
244            "Discovered models from Copilot API"
245        );
246        models
247    }
248
249    /// Enrich models with pricing metadata from known premium request multipliers.
250    ///
251    /// Source: https://docs.github.com/en/copilot/concepts/billing/copilot-requests
252    /// Cost model: Premium requests at $0.04/request overflow rate.
253    /// We convert multiplier to approximate $/M tokens using ~4K tokens/request avg.
254    /// Formula: multiplier * $0.04 / 4K tokens * 1M = multiplier * $10/M tokens.
255    fn enrich_with_pricing(&self, models: &mut [ModelInfo]) {
256        // (display_name, premium_multiplier)
257        let pricing: std::collections::HashMap<&str, (&str, f64)> = [
258            ("claude-opus-4.5", ("Claude Opus 4.5", 3.0)),
259            ("claude-opus-4.6", ("Claude Opus 4.6", 3.0)),
260            ("claude-opus-41", ("Claude Opus 4.1", 10.0)),
261            ("claude-sonnet-4-6", ("Claude Sonnet 4.6", 1.0)),
262            ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 1.0)),
263            ("claude-sonnet-4", ("Claude Sonnet 4", 1.0)),
264            ("claude-haiku-4.5", ("Claude Haiku 4.5", 0.33)),
265            ("gpt-5.3-codex", ("GPT-5.3-Codex", 1.0)),
266            ("gpt-5.2", ("GPT-5.2", 1.0)),
267            ("gpt-5.2-codex", ("GPT-5.2-Codex", 1.0)),
268            ("gpt-5.1", ("GPT-5.1", 1.0)),
269            ("gpt-5.1-codex", ("GPT-5.1-Codex", 1.0)),
270            ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 0.33)),
271            ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 1.0)),
272            ("gpt-5", ("GPT-5", 1.0)),
273            ("gpt-5-mini", ("GPT-5 mini", 0.0)),
274            ("gpt-5-codex", ("GPT-5-Codex", 1.0)),
275            ("gpt-4.1", ("GPT-4.1", 0.0)),
276            ("gpt-4o", ("GPT-4o", 0.0)),
277            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1.0)),
278            ("gemini-3.1-pro-preview", ("Gemini 3.1 Pro Preview", 1.0)),
279            (
280                "gemini-3.1-pro-preview-customtools",
281                ("Gemini 3.1 Pro Preview (Custom Tools)", 1.0),
282            ),
283            ("gemini-3-flash-preview", ("Gemini 3 Flash Preview", 0.33)),
284            ("gemini-3-pro-preview", ("Gemini 3 Pro Preview", 1.0)),
285            (
286                "gemini-3-pro-image-preview",
287                ("Gemini 3 Pro Image Preview", 1.0),
288            ),
289            ("grok-code-fast-1", ("Grok Code Fast 1", 0.25)),
290        ]
291        .into_iter()
292        .collect();
293
294        for model in models.iter_mut() {
295            if let Some((display_name, premium_mult)) = pricing.get(model.id.as_str()) {
296                // Set a friendlier display name when the API only returned the raw id
297                if model.name == model.id {
298                    model.name = display_name.to_string();
299                }
300                let approx_cost = premium_mult * 10.0;
301                model.input_cost_per_million = Some(approx_cost);
302                model.output_cost_per_million = Some(approx_cost);
303            } else {
304                // Unknown Copilot model — assume 1x premium request ($10/M approx)
305                if model.input_cost_per_million.is_none() {
306                    model.input_cost_per_million = Some(10.0);
307                }
308                if model.output_cost_per_million.is_none() {
309                    model.output_cost_per_million = Some(10.0);
310                }
311            }
312        }
313    }
314
315    /// Known models to use as a fallback when the /models API is unreachable.
316    fn known_models(&self) -> Vec<ModelInfo> {
317        let entries: &[(&str, &str, usize, usize, bool)] = &[
318            ("gpt-4o", "GPT-4o", 128_000, 16_384, true),
319            ("gpt-4.1", "GPT-4.1", 128_000, 32_768, false),
320            ("gpt-5", "GPT-5", 400_000, 128_000, false),
321            ("gpt-5-mini", "GPT-5 mini", 264_000, 64_000, false),
322            ("claude-sonnet-4", "Claude Sonnet 4", 200_000, 64_000, false),
323            (
324                "claude-sonnet-4.5",
325                "Claude Sonnet 4.5",
326                200_000,
327                64_000,
328                false,
329            ),
330            (
331                "claude-sonnet-4-6",
332                "Claude Sonnet 4.6",
333                200_000,
334                128_000,
335                false,
336            ),
337            (
338                "claude-haiku-4.5",
339                "Claude Haiku 4.5",
340                200_000,
341                64_000,
342                false,
343            ),
344            ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_000_000, 64_000, false),
345            (
346                "gemini-3.1-pro-preview",
347                "Gemini 3.1 Pro Preview",
348                1_048_576,
349                65_536,
350                false,
351            ),
352            (
353                "gemini-3.1-pro-preview-customtools",
354                "Gemini 3.1 Pro Preview (Custom Tools)",
355                1_048_576,
356                65_536,
357                false,
358            ),
359            (
360                "gemini-3-pro-preview",
361                "Gemini 3 Pro Preview",
362                1_048_576,
363                65_536,
364                false,
365            ),
366            (
367                "gemini-3-flash-preview",
368                "Gemini 3 Flash Preview",
369                1_048_576,
370                65_536,
371                false,
372            ),
373            (
374                "gemini-3-pro-image-preview",
375                "Gemini 3 Pro Image Preview",
376                65_536,
377                32_768,
378                false,
379            ),
380        ];
381
382        entries
383            .iter()
384            .map(|(id, name, ctx, max_out, vision)| ModelInfo {
385                id: id.to_string(),
386                name: name.to_string(),
387                provider: self.provider_name.clone(),
388                context_window: *ctx,
389                max_output_tokens: Some(*max_out),
390                supports_vision: *vision,
391                supports_tools: true,
392                supports_streaming: true,
393                input_cost_per_million: None,
394                output_cost_per_million: None,
395            })
396            .collect()
397    }
398}
399
400#[derive(Debug, Deserialize)]
401struct CopilotResponse {
402    choices: Vec<CopilotChoice>,
403    #[serde(default)]
404    usage: Option<CopilotUsage>,
405}
406
407#[derive(Debug, Deserialize)]
408struct CopilotChoice {
409    message: CopilotMessage,
410    #[serde(default)]
411    finish_reason: Option<String>,
412}
413
414#[derive(Debug, Deserialize)]
415struct CopilotMessage {
416    #[serde(default)]
417    content: Option<String>,
418    #[serde(default)]
419    tool_calls: Option<Vec<CopilotToolCall>>,
420}
421
422#[derive(Debug, Deserialize)]
423struct CopilotToolCall {
424    id: String,
425    #[serde(rename = "type")]
426    #[allow(dead_code)]
427    call_type: String,
428    function: CopilotFunction,
429}
430
431#[derive(Debug, Deserialize)]
432struct CopilotFunction {
433    name: String,
434    arguments: String,
435}
436
437#[derive(Debug, Deserialize)]
438struct CopilotUsage {
439    #[serde(default)]
440    prompt_tokens: usize,
441    #[serde(default)]
442    completion_tokens: usize,
443    #[serde(default)]
444    total_tokens: usize,
445}
446
447#[derive(Debug, Deserialize)]
448struct CopilotErrorResponse {
449    error: Option<CopilotErrorDetail>,
450    message: Option<String>,
451}
452
453#[derive(Debug, Deserialize)]
454struct CopilotErrorDetail {
455    message: Option<String>,
456    code: Option<String>,
457}
458
459#[derive(Debug, Deserialize)]
460struct CopilotModelsResponse {
461    data: Vec<CopilotModelInfo>,
462}
463
464#[derive(Debug, Deserialize)]
465struct CopilotModelInfo {
466    id: String,
467    #[serde(default)]
468    name: Option<String>,
469    #[serde(default)]
470    model_picker_enabled: Option<bool>,
471    #[serde(default)]
472    policy: Option<CopilotModelPolicy>,
473    #[serde(default)]
474    capabilities: Option<CopilotModelCapabilities>,
475}
476
477#[derive(Debug, Deserialize)]
478struct CopilotModelPolicy {
479    #[serde(default)]
480    state: Option<String>,
481}
482
483#[derive(Debug, Deserialize)]
484struct CopilotModelCapabilities {
485    #[serde(default)]
486    limits: Option<CopilotModelLimits>,
487    #[serde(default)]
488    supports: Option<CopilotModelSupports>,
489}
490
491#[derive(Debug, Deserialize)]
492struct CopilotModelLimits {
493    #[serde(default)]
494    max_context_window_tokens: Option<usize>,
495    #[serde(default)]
496    max_output_tokens: Option<usize>,
497}
498
499#[derive(Debug, Deserialize)]
500struct CopilotModelSupports {
501    #[serde(default)]
502    tool_calls: Option<bool>,
503    #[serde(default)]
504    vision: Option<bool>,
505    #[serde(default)]
506    streaming: Option<bool>,
507}
508
509#[async_trait]
510impl Provider for CopilotProvider {
511    fn name(&self) -> &str {
512        &self.provider_name
513    }
514
515    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
516        let mut models = self.discover_models_from_api().await;
517
518        // If API discovery returned nothing, fall back to known models
519        if models.is_empty() {
520            tracing::info!(provider = %self.provider_name, "No models from API, using known model catalog");
521            models = self.known_models();
522        }
523
524        // Enrich with pricing metadata from known premium request multipliers
525        self.enrich_with_pricing(&mut models);
526
527        // Filter out non-chat models (embeddings, etc.) and legacy dated variants
528        models.retain(|m| {
529            !m.id.starts_with("text-embedding")
530                && !m.id.contains("-embedding-")
531                && !is_dated_model_variant(&m.id)
532        });
533
534        // Deduplicate by id (API sometimes returns duplicates)
535        let mut seen = std::collections::HashSet::new();
536        models.retain(|m| seen.insert(m.id.clone()));
537
538        Ok(models)
539    }
540
541    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
542        let messages = Self::convert_messages(&request.messages);
543        let tools = Self::convert_tools(&request.tools);
544        let is_agent = Self::is_agent_initiated(&request.messages);
545        let has_vision = Self::has_vision_input(&request.messages);
546
547        let mut body = json!({
548            "model": request.model,
549            "messages": messages,
550        });
551
552        if !tools.is_empty() {
553            body["tools"] = json!(tools);
554        }
555        if let Some(temp) = request.temperature {
556            body["temperature"] = json!(temp);
557        }
558        if let Some(top_p) = request.top_p {
559            body["top_p"] = json!(top_p);
560        }
561        if let Some(max) = request.max_tokens {
562            body["max_tokens"] = json!(max);
563        }
564        if !request.stop.is_empty() {
565            body["stop"] = json!(request.stop);
566        }
567
568        let mut req = self
569            .client
570            .post(format!("{}/chat/completions", self.base_url))
571            .header("Authorization", format!("Bearer {}", self.token))
572            .header("Content-Type", "application/json")
573            .header("Openai-Intent", "conversation-edits")
574            .header("User-Agent", Self::user_agent())
575            .header("X-Initiator", if is_agent { "agent" } else { "user" });
576
577        if has_vision {
578            req = req.header("Copilot-Vision-Request", "true");
579        }
580
581        let response = req
582            .json(&body)
583            .send()
584            .await
585            .context("Failed to send Copilot request")?;
586
587        let status = response.status();
588        let text = response
589            .text()
590            .await
591            .context("Failed to read Copilot response")?;
592
593        if !status.is_success() {
594            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
595                let message = err
596                    .error
597                    .and_then(|detail| {
598                        detail.message.map(|msg| {
599                            if let Some(code) = detail.code {
600                                format!("{} ({})", msg, code)
601                            } else {
602                                msg
603                            }
604                        })
605                    })
606                    .or(err.message)
607                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
608                anyhow::bail!("Copilot API error: {}", message);
609            }
610            anyhow::bail!("Copilot API error: {} {}", status, text);
611        }
612
613        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
614            "Failed to parse Copilot response: {}",
615            &text[..text.len().min(200)]
616        ))?;
617
618        let choice = response
619            .choices
620            .first()
621            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
622
623        let mut content = Vec::new();
624        let mut has_tool_calls = false;
625
626        if let Some(text) = &choice.message.content {
627            if !text.is_empty() {
628                content.push(ContentPart::Text { text: text.clone() });
629            }
630        }
631
632        if let Some(tool_calls) = &choice.message.tool_calls {
633            has_tool_calls = !tool_calls.is_empty();
634            for tc in tool_calls {
635                content.push(ContentPart::ToolCall {
636                    id: tc.id.clone(),
637                    name: tc.function.name.clone(),
638                    arguments: tc.function.arguments.clone(),
639                    thought_signature: None,
640                });
641            }
642        }
643
644        let finish_reason = if has_tool_calls {
645            FinishReason::ToolCalls
646        } else {
647            match choice.finish_reason.as_deref() {
648                Some("stop") => FinishReason::Stop,
649                Some("length") => FinishReason::Length,
650                Some("tool_calls") => FinishReason::ToolCalls,
651                Some("content_filter") => FinishReason::ContentFilter,
652                _ => FinishReason::Stop,
653            }
654        };
655
656        Ok(CompletionResponse {
657            message: Message {
658                role: Role::Assistant,
659                content,
660            },
661            usage: Usage {
662                prompt_tokens: response
663                    .usage
664                    .as_ref()
665                    .map(|u| u.prompt_tokens)
666                    .unwrap_or(0),
667                completion_tokens: response
668                    .usage
669                    .as_ref()
670                    .map(|u| u.completion_tokens)
671                    .unwrap_or(0),
672                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
673                ..Default::default()
674            },
675            finish_reason,
676        })
677    }
678
679    async fn complete_stream(
680        &self,
681        request: CompletionRequest,
682    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
683        // For now, keep behavior aligned with other non-streaming providers.
684        let response = self.complete(request).await?;
685        let text = response
686            .message
687            .content
688            .iter()
689            .filter_map(|p| match p {
690                ContentPart::Text { text } => Some(text.clone()),
691                _ => None,
692            })
693            .collect::<Vec<_>>()
694            .join("");
695
696        Ok(Box::pin(futures::stream::once(async move {
697            StreamChunk::Text(text)
698        })))
699    }
700}
701
702/// Check if a model ID is a dated variant (e.g. "gpt-4o-2024-05-13") that should
703/// be filtered out in favor of the canonical alias (e.g. "gpt-4o").
704fn is_dated_model_variant(id: &str) -> bool {
705    // Match IDs ending in a YYYY-MM-DD date suffix
706    let bytes = id.as_bytes();
707    if bytes.len() < 11 {
708        return false;
709    }
710    // Check for "-YYYY-MM-DD" at end
711    let tail = &id[id.len() - 11..];
712    tail.starts_with('-')
713        && tail[1..5].bytes().all(|b| b.is_ascii_digit())
714        && tail.as_bytes()[5] == b'-'
715        && tail[6..8].bytes().all(|b| b.is_ascii_digit())
716        && tail.as_bytes()[8] == b'-'
717        && tail[9..11].bytes().all(|b| b.is_ascii_digit())
718}
719
720pub fn normalize_enterprise_domain(input: &str) -> String {
721    input
722        .trim()
723        .trim_start_matches("https://")
724        .trim_start_matches("http://")
725        .trim_end_matches('/')
726        .to_string()
727}
728
729pub fn enterprise_base_url(enterprise_url: &str) -> String {
730    format!(
731        "https://copilot-api.{}",
732        normalize_enterprise_domain(enterprise_url)
733    )
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739
740    #[test]
741    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
742        assert_eq!(
743            normalize_enterprise_domain("https://company.ghe.com/"),
744            "company.ghe.com"
745        );
746        assert_eq!(
747            normalize_enterprise_domain("http://company.ghe.com"),
748            "company.ghe.com"
749        );
750        assert_eq!(
751            normalize_enterprise_domain("company.ghe.com"),
752            "company.ghe.com"
753        );
754    }
755
756    #[test]
757    fn enterprise_base_url_uses_copilot_api_subdomain() {
758        assert_eq!(
759            enterprise_base_url("https://company.ghe.com/"),
760            "https://copilot-api.company.ghe.com"
761        );
762    }
763
764    #[test]
765    fn is_dated_model_variant_detects_date_suffix() {
766        assert!(is_dated_model_variant("gpt-4o-2024-05-13"));
767        assert!(is_dated_model_variant("gpt-4o-2024-08-06"));
768        assert!(is_dated_model_variant("gpt-4.1-2025-04-14"));
769        assert!(is_dated_model_variant("gpt-4o-mini-2024-07-18"));
770        assert!(!is_dated_model_variant("gpt-4o"));
771        assert!(!is_dated_model_variant("gpt-5"));
772        assert!(!is_dated_model_variant("claude-sonnet-4"));
773        assert!(!is_dated_model_variant("gemini-2.5-pro"));
774    }
775
776    #[test]
777    fn known_models_fallback_is_non_empty() {
778        let provider = CopilotProvider::new("test-token".to_string()).unwrap();
779        let models = provider.known_models();
780        assert!(!models.is_empty());
781        // All fallback models should support tools
782        assert!(models.iter().all(|m| m.supports_tools));
783    }
784
785    #[test]
786    fn enrich_with_pricing_sets_costs() {
787        let provider = CopilotProvider::new("test-token".to_string()).unwrap();
788        let mut models = vec![ModelInfo {
789            id: "gpt-4o".to_string(),
790            name: "gpt-4o".to_string(),
791            provider: "github-copilot".to_string(),
792            context_window: 128_000,
793            max_output_tokens: Some(16_384),
794            supports_vision: true,
795            supports_tools: true,
796            supports_streaming: true,
797            input_cost_per_million: None,
798            output_cost_per_million: None,
799        }];
800        provider.enrich_with_pricing(&mut models);
801        // GPT-4o is 0x premium (free), so cost = 0.0
802        assert_eq!(models[0].input_cost_per_million, Some(0.0));
803        assert_eq!(models[0].output_cost_per_million, Some(0.0));
804        // Name should be enriched
805        assert_eq!(models[0].name, "GPT-4o");
806    }
807}