Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::util;
4use super::{
5    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
6    Role, StreamChunk, ToolDefinition, Usage,
7};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::{Value, json};
13
14const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
15const COPILOT_PROVIDER: &str = "github-copilot";
16const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
17
18pub struct CopilotProvider {
19    client: Client,
20    token: String,
21    base_url: String,
22    provider_name: String,
23}
24
25impl CopilotProvider {
26    pub fn new(token: String) -> Result<Self> {
27        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
28    }
29
30    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
31        let base_url = enterprise_base_url(&enterprise_url);
32        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
33    }
34
35    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
36        Ok(Self {
37            client: crate::provider::shared_http::shared_client().clone(),
38            token,
39            base_url: base_url.trim_end_matches('/').to_string(),
40            provider_name: provider_name.to_string(),
41        })
42    }
43
44    fn user_agent() -> String {
45        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
46    }
47
48    fn convert_messages(messages: &[Message]) -> Vec<Value> {
49        messages
50            .iter()
51            .map(|msg| {
52                let role = match msg.role {
53                    Role::System => "system",
54                    Role::User => "user",
55                    Role::Assistant => "assistant",
56                    Role::Tool => "tool",
57                };
58
59                match msg.role {
60                    Role::Tool => {
61                        if let Some(ContentPart::ToolResult {
62                            tool_call_id,
63                            content,
64                        }) = msg.content.first()
65                        {
66                            json!({
67                                "role": "tool",
68                                "tool_call_id": tool_call_id,
69                                "content": content
70                            })
71                        } else {
72                            json!({ "role": role, "content": "" })
73                        }
74                    }
75                    Role::Assistant => {
76                        let text: String = msg
77                            .content
78                            .iter()
79                            .filter_map(|p| match p {
80                                ContentPart::Text { text } => Some(text.clone()),
81                                _ => None,
82                            })
83                            .collect::<Vec<_>>()
84                            .join("");
85
86                        let tool_calls: Vec<Value> = msg
87                            .content
88                            .iter()
89                            .filter_map(|p| match p {
90                                ContentPart::ToolCall {
91                                    id,
92                                    name,
93                                    arguments,
94                                    ..
95                                } => Some(json!({
96                                    "id": id,
97                                    "type": "function",
98                                    "function": {
99                                        "name": name,
100                                        "arguments": arguments
101                                    }
102                                })),
103                                _ => None,
104                            })
105                            .collect();
106
107                        if tool_calls.is_empty() {
108                            json!({ "role": "assistant", "content": text })
109                        } else {
110                            json!({
111                                "role": "assistant",
112                                "content": if text.is_empty() { "".to_string() } else { text },
113                                "tool_calls": tool_calls
114                            })
115                        }
116                    }
117                    _ => {
118                        let text: String = msg
119                            .content
120                            .iter()
121                            .filter_map(|p| match p {
122                                ContentPart::Text { text } => Some(text.clone()),
123                                _ => None,
124                            })
125                            .collect::<Vec<_>>()
126                            .join("\n");
127                        json!({ "role": role, "content": text })
128                    }
129                }
130            })
131            .collect()
132    }
133
134    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
135        tools
136            .iter()
137            .map(|t| {
138                json!({
139                    "type": "function",
140                    "function": {
141                        "name": t.name,
142                        "description": t.description,
143                        "parameters": t.parameters
144                    }
145                })
146            })
147            .collect()
148    }
149
150    fn is_agent_initiated(messages: &[Message]) -> bool {
151        messages
152            .iter()
153            .rev()
154            .find(|msg| msg.role != Role::System)
155            .map(|msg| msg.role != Role::User)
156            .unwrap_or(false)
157    }
158
159    fn has_vision_input(messages: &[Message]) -> bool {
160        messages.iter().any(|msg| {
161            msg.content
162                .iter()
163                .any(|part| matches!(part, ContentPart::Image { .. }))
164        })
165    }
166
167    /// Discover models dynamically from the Copilot /models API endpoint.
168    async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
169        let response = match self
170            .client
171            .get(format!("{}/models", self.base_url))
172            .header("Authorization", format!("Bearer {}", self.token))
173            .header("User-Agent", Self::user_agent())
174            .send()
175            .await
176        {
177            Ok(r) => r,
178            Err(e) => {
179                tracing::warn!(provider = %self.provider_name, error = %e, "Failed to fetch Copilot models endpoint");
180                return Vec::new();
181            }
182        };
183
184        let status = response.status();
185        if !status.is_success() {
186            let body = response.text().await.unwrap_or_default();
187            tracing::warn!(
188                provider = %self.provider_name,
189                status = %status,
190                body = %body.chars().take(200).collect::<String>(),
191                "Copilot /models endpoint returned non-success"
192            );
193            return Vec::new();
194        }
195
196        let parsed: CopilotModelsResponse = match crate::provider::body_cap::json_capped(
197            response,
198            crate::provider::body_cap::PROVIDER_METADATA_BODY_CAP,
199        )
200        .await
201        {
202            Ok(p) => p,
203            Err(e) => {
204                tracing::warn!(provider = %self.provider_name, error = %e, "Failed to parse Copilot models response (or exceeded body cap)");
205                return Vec::new();
206            }
207        };
208
209        let models: Vec<ModelInfo> = parsed
210            .data
211            .into_iter()
212            .filter(|model| {
213                // Skip models explicitly disabled in the picker
214                if model.model_picker_enabled == Some(false) {
215                    return false;
216                }
217                // Skip models with a disabled policy state
218                if let Some(ref policy) = model.policy
219                    && policy.state.as_deref() == Some("disabled")
220                {
221                    return false;
222                }
223                true
224            })
225            .map(|model| {
226                let caps = model.capabilities.as_ref();
227                let limits = caps.and_then(|c| c.limits.as_ref());
228                let supports = caps.and_then(|c| c.supports.as_ref());
229
230                ModelInfo {
231                    id: model.id.clone(),
232                    name: model.name.unwrap_or_else(|| model.id.clone()),
233                    provider: self.provider_name.clone(),
234                    context_window: limits
235                        .and_then(|l| l.max_context_window_tokens)
236                        .unwrap_or(128_000),
237                    max_output_tokens: limits.and_then(|l| l.max_output_tokens).or(Some(16_384)),
238                    supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
239                    supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
240                    supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
241                    input_cost_per_million: None,
242                    output_cost_per_million: None,
243                }
244            })
245            .collect();
246
247        tracing::info!(
248            provider = %self.provider_name,
249            count = models.len(),
250            "Discovered models from Copilot API"
251        );
252        models
253    }
254
255    /// Enrich models with pricing metadata from known premium request multipliers.
256    ///
257    /// Source: https://docs.github.com/en/copilot/concepts/billing/copilot-requests
258    /// Cost model: Premium requests at $0.04/request overflow rate.
259    /// We convert multiplier to approximate $/M tokens using ~4K tokens/request avg.
260    /// Formula: multiplier * $0.04 / 4K tokens * 1M = multiplier * $10/M tokens.
261    fn enrich_with_pricing(&self, models: &mut [ModelInfo]) {
262        // (display_name, premium_multiplier)
263        let pricing: std::collections::HashMap<&str, (&str, f64)> = [
264            ("claude-opus-4.5", ("Claude Opus 4.5", 3.0)),
265            ("claude-opus-4.6", ("Claude Opus 4.6", 3.0)),
266            ("claude-opus-41", ("Claude Opus 4.1", 10.0)),
267            ("claude-sonnet-4-6", ("Claude Sonnet 4.6", 1.0)),
268            ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 1.0)),
269            ("claude-sonnet-4", ("Claude Sonnet 4", 1.0)),
270            ("claude-haiku-4.5", ("Claude Haiku 4.5", 0.33)),
271            ("gpt-5.3-codex", ("GPT-5.3-Codex", 1.0)),
272            ("gpt-5.2", ("GPT-5.2", 1.0)),
273            ("gpt-5.2-codex", ("GPT-5.2-Codex", 1.0)),
274            ("gpt-5.1", ("GPT-5.1", 1.0)),
275            ("gpt-5.1-codex", ("GPT-5.1-Codex", 1.0)),
276            ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 0.33)),
277            ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 1.0)),
278            ("gpt-5", ("GPT-5", 1.0)),
279            ("gpt-5-mini", ("GPT-5 mini", 0.0)),
280            ("gpt-5-codex", ("GPT-5-Codex", 1.0)),
281            ("gpt-4.1", ("GPT-4.1", 0.0)),
282            ("gpt-4o", ("GPT-4o", 0.0)),
283            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1.0)),
284            ("gemini-3.1-pro-preview", ("Gemini 3.1 Pro Preview", 1.0)),
285            (
286                "gemini-3.1-pro-preview-customtools",
287                ("Gemini 3.1 Pro Preview (Custom Tools)", 1.0),
288            ),
289            ("gemini-3-flash-preview", ("Gemini 3 Flash Preview", 0.33)),
290            ("gemini-3-pro-preview", ("Gemini 3 Pro Preview", 1.0)),
291            (
292                "gemini-3-pro-image-preview",
293                ("Gemini 3 Pro Image Preview", 1.0),
294            ),
295            ("grok-code-fast-1", ("Grok Code Fast 1", 0.25)),
296        ]
297        .into_iter()
298        .collect();
299
300        for model in models.iter_mut() {
301            if let Some((display_name, premium_mult)) = pricing.get(model.id.as_str()) {
302                // Set a friendlier display name when the API only returned the raw id
303                if model.name == model.id {
304                    model.name = display_name.to_string();
305                }
306                let approx_cost = premium_mult * 10.0;
307                model.input_cost_per_million = Some(approx_cost);
308                model.output_cost_per_million = Some(approx_cost);
309            } else {
310                // Unknown Copilot model — assume 1x premium request ($10/M approx)
311                if model.input_cost_per_million.is_none() {
312                    model.input_cost_per_million = Some(10.0);
313                }
314                if model.output_cost_per_million.is_none() {
315                    model.output_cost_per_million = Some(10.0);
316                }
317            }
318        }
319    }
320
321    /// Known models to use as a fallback when the /models API is unreachable.
322    fn known_models(&self) -> Vec<ModelInfo> {
323        let entries: &[(&str, &str, usize, usize, bool)] = &[
324            ("gpt-4o", "GPT-4o", 128_000, 16_384, true),
325            ("gpt-4.1", "GPT-4.1", 128_000, 32_768, false),
326            ("gpt-5", "GPT-5", 400_000, 128_000, false),
327            ("gpt-5-mini", "GPT-5 mini", 264_000, 64_000, false),
328            ("claude-sonnet-4", "Claude Sonnet 4", 200_000, 64_000, false),
329            (
330                "claude-sonnet-4.5",
331                "Claude Sonnet 4.5",
332                200_000,
333                64_000,
334                false,
335            ),
336            (
337                "claude-sonnet-4-6",
338                "Claude Sonnet 4.6",
339                200_000,
340                128_000,
341                false,
342            ),
343            (
344                "claude-haiku-4.5",
345                "Claude Haiku 4.5",
346                200_000,
347                64_000,
348                false,
349            ),
350            ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_000_000, 64_000, false),
351            (
352                "gemini-3.1-pro-preview",
353                "Gemini 3.1 Pro Preview",
354                1_048_576,
355                65_536,
356                false,
357            ),
358            (
359                "gemini-3.1-pro-preview-customtools",
360                "Gemini 3.1 Pro Preview (Custom Tools)",
361                1_048_576,
362                65_536,
363                false,
364            ),
365            (
366                "gemini-3-pro-preview",
367                "Gemini 3 Pro Preview",
368                1_048_576,
369                65_536,
370                false,
371            ),
372            (
373                "gemini-3-flash-preview",
374                "Gemini 3 Flash Preview",
375                1_048_576,
376                65_536,
377                false,
378            ),
379            (
380                "gemini-3-pro-image-preview",
381                "Gemini 3 Pro Image Preview",
382                65_536,
383                32_768,
384                false,
385            ),
386        ];
387
388        entries
389            .iter()
390            .map(|(id, name, ctx, max_out, vision)| ModelInfo {
391                id: id.to_string(),
392                name: name.to_string(),
393                provider: self.provider_name.clone(),
394                context_window: *ctx,
395                max_output_tokens: Some(*max_out),
396                supports_vision: *vision,
397                supports_tools: true,
398                supports_streaming: true,
399                input_cost_per_million: None,
400                output_cost_per_million: None,
401            })
402            .collect()
403    }
404}
405
406#[derive(Debug, Deserialize)]
407struct CopilotResponse {
408    choices: Vec<CopilotChoice>,
409    #[serde(default)]
410    usage: Option<CopilotUsage>,
411}
412
413#[derive(Debug, Deserialize)]
414struct CopilotChoice {
415    message: CopilotMessage,
416    #[serde(default)]
417    finish_reason: Option<String>,
418}
419
420#[derive(Debug, Deserialize)]
421struct CopilotMessage {
422    #[serde(default)]
423    content: Option<String>,
424    #[serde(default)]
425    tool_calls: Option<Vec<CopilotToolCall>>,
426}
427
428#[derive(Debug, Deserialize)]
429struct CopilotToolCall {
430    id: String,
431    #[serde(rename = "type")]
432    #[allow(dead_code)]
433    call_type: String,
434    function: CopilotFunction,
435}
436
437#[derive(Debug, Deserialize)]
438struct CopilotFunction {
439    name: String,
440    arguments: String,
441}
442
443#[derive(Debug, Deserialize)]
444struct CopilotUsage {
445    #[serde(default)]
446    prompt_tokens: usize,
447    #[serde(default)]
448    completion_tokens: usize,
449    #[serde(default)]
450    total_tokens: usize,
451}
452
453#[derive(Debug, Deserialize)]
454struct CopilotErrorResponse {
455    error: Option<CopilotErrorDetail>,
456    message: Option<String>,
457}
458
459#[derive(Debug, Deserialize)]
460struct CopilotErrorDetail {
461    message: Option<String>,
462    code: Option<String>,
463}
464
465#[derive(Debug, Deserialize)]
466struct CopilotModelsResponse {
467    data: Vec<CopilotModelInfo>,
468}
469
470#[derive(Debug, Deserialize)]
471struct CopilotModelInfo {
472    id: String,
473    #[serde(default)]
474    name: Option<String>,
475    #[serde(default)]
476    model_picker_enabled: Option<bool>,
477    #[serde(default)]
478    policy: Option<CopilotModelPolicy>,
479    #[serde(default)]
480    capabilities: Option<CopilotModelCapabilities>,
481}
482
483#[derive(Debug, Deserialize)]
484struct CopilotModelPolicy {
485    #[serde(default)]
486    state: Option<String>,
487}
488
489#[derive(Debug, Deserialize)]
490struct CopilotModelCapabilities {
491    #[serde(default)]
492    limits: Option<CopilotModelLimits>,
493    #[serde(default)]
494    supports: Option<CopilotModelSupports>,
495}
496
497#[derive(Debug, Deserialize)]
498struct CopilotModelLimits {
499    #[serde(default)]
500    max_context_window_tokens: Option<usize>,
501    #[serde(default)]
502    max_output_tokens: Option<usize>,
503}
504
505#[derive(Debug, Deserialize)]
506struct CopilotModelSupports {
507    #[serde(default)]
508    tool_calls: Option<bool>,
509    #[serde(default)]
510    vision: Option<bool>,
511    #[serde(default)]
512    streaming: Option<bool>,
513}
514
515#[async_trait]
516impl Provider for CopilotProvider {
517    fn name(&self) -> &str {
518        &self.provider_name
519    }
520
521    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
522        let mut models = self.discover_models_from_api().await;
523
524        // If API discovery returned nothing, fall back to known models
525        if models.is_empty() {
526            tracing::info!(provider = %self.provider_name, "No models from API, using known model catalog");
527            models = self.known_models();
528        }
529
530        // Enrich with pricing metadata from known premium request multipliers
531        self.enrich_with_pricing(&mut models);
532
533        // Filter out non-chat models (embeddings, etc.) and legacy dated variants
534        models.retain(|m| {
535            !m.id.starts_with("text-embedding")
536                && !m.id.contains("-embedding-")
537                && !is_dated_model_variant(&m.id)
538        });
539
540        // Deduplicate by id (API sometimes returns duplicates)
541        let mut seen = std::collections::HashSet::new();
542        models.retain(|m| seen.insert(m.id.clone()));
543
544        Ok(models)
545    }
546
547    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
548        let messages = Self::convert_messages(&request.messages);
549        let tools = Self::convert_tools(&request.tools);
550        let is_agent = Self::is_agent_initiated(&request.messages);
551        let has_vision = Self::has_vision_input(&request.messages);
552
553        let mut body = json!({
554            "model": request.model,
555            "messages": messages,
556        });
557
558        if !tools.is_empty() {
559            body["tools"] = json!(tools);
560        }
561        if let Some(temp) = request.temperature {
562            body["temperature"] = json!(temp);
563        }
564        if let Some(top_p) = request.top_p {
565            body["top_p"] = json!(top_p);
566        }
567        if let Some(max) = request.max_tokens {
568            body["max_tokens"] = json!(max);
569        }
570        if !request.stop.is_empty() {
571            body["stop"] = json!(request.stop);
572        }
573
574        let mut req = self
575            .client
576            .post(format!("{}/chat/completions", self.base_url))
577            .header("Authorization", format!("Bearer {}", self.token))
578            .header("Content-Type", "application/json")
579            .header("Openai-Intent", "conversation-edits")
580            .header("User-Agent", Self::user_agent())
581            .header("X-Initiator", if is_agent { "agent" } else { "user" });
582
583        if has_vision {
584            req = req.header("Copilot-Vision-Request", "true");
585        }
586
587        let response = req
588            .json(&body)
589            .send()
590            .await
591            .context("Failed to send Copilot request")?;
592
593        let status = response.status();
594        let text = response
595            .text()
596            .await
597            .context("Failed to read Copilot response")?;
598
599        if !status.is_success() {
600            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
601                let message = err
602                    .error
603                    .and_then(|detail| {
604                        detail.message.map(|msg| {
605                            if let Some(code) = detail.code {
606                                format!("{} ({})", msg, code)
607                            } else {
608                                msg
609                            }
610                        })
611                    })
612                    .or(err.message)
613                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
614                anyhow::bail!("Copilot API error: {}", message);
615            }
616            anyhow::bail!("Copilot API error: {} {}", status, text);
617        }
618
619        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
620            "Failed to parse Copilot response: {}",
621            util::truncate_bytes_safe(&text, 200)
622        ))?;
623
624        let choice = response
625            .choices
626            .first()
627            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
628
629        let mut content = Vec::new();
630        let mut has_tool_calls = false;
631
632        if let Some(text) = &choice.message.content
633            && !text.is_empty()
634        {
635            content.push(ContentPart::Text { text: text.clone() });
636        }
637
638        if let Some(tool_calls) = &choice.message.tool_calls {
639            has_tool_calls = !tool_calls.is_empty();
640            for tc in tool_calls {
641                content.push(ContentPart::ToolCall {
642                    id: tc.id.clone(),
643                    name: tc.function.name.clone(),
644                    arguments: tc.function.arguments.clone(),
645                    thought_signature: None,
646                });
647            }
648        }
649
650        let finish_reason = if has_tool_calls {
651            FinishReason::ToolCalls
652        } else {
653            match choice.finish_reason.as_deref() {
654                Some("stop") => FinishReason::Stop,
655                Some("length") => FinishReason::Length,
656                Some("tool_calls") => FinishReason::ToolCalls,
657                Some("content_filter") => FinishReason::ContentFilter,
658                _ => FinishReason::Stop,
659            }
660        };
661
662        Ok(CompletionResponse {
663            message: Message {
664                role: Role::Assistant,
665                content,
666            },
667            usage: Usage {
668                prompt_tokens: response
669                    .usage
670                    .as_ref()
671                    .map(|u| u.prompt_tokens)
672                    .unwrap_or(0),
673                completion_tokens: response
674                    .usage
675                    .as_ref()
676                    .map(|u| u.completion_tokens)
677                    .unwrap_or(0),
678                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
679                ..Default::default()
680            },
681            finish_reason,
682        })
683    }
684
685    async fn complete_stream(
686        &self,
687        request: CompletionRequest,
688    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
689        // For now, keep behavior aligned with other non-streaming providers.
690        let response = self.complete(request).await?;
691        let text = response
692            .message
693            .content
694            .iter()
695            .filter_map(|p| match p {
696                ContentPart::Text { text } => Some(text.clone()),
697                _ => None,
698            })
699            .collect::<Vec<_>>()
700            .join("");
701
702        Ok(Box::pin(futures::stream::once(async move {
703            StreamChunk::Text(text)
704        })))
705    }
706}
707
708/// Check if a model ID is a dated variant (e.g. "gpt-4o-2024-05-13") that should
709/// be filtered out in favor of the canonical alias (e.g. "gpt-4o").
710fn is_dated_model_variant(id: &str) -> bool {
711    // Match IDs ending in a YYYY-MM-DD date suffix
712    let bytes = id.as_bytes();
713    if bytes.len() < 11 {
714        return false;
715    }
716    // Check for "-YYYY-MM-DD" at end
717    let tail = &id[id.len() - 11..];
718    tail.starts_with('-')
719        && tail[1..5].bytes().all(|b| b.is_ascii_digit())
720        && tail.as_bytes()[5] == b'-'
721        && tail[6..8].bytes().all(|b| b.is_ascii_digit())
722        && tail.as_bytes()[8] == b'-'
723        && tail[9..11].bytes().all(|b| b.is_ascii_digit())
724}
725
726pub fn normalize_enterprise_domain(input: &str) -> String {
727    input
728        .trim()
729        .trim_start_matches("https://")
730        .trim_start_matches("http://")
731        .trim_end_matches('/')
732        .to_string()
733}
734
735pub fn enterprise_base_url(enterprise_url: &str) -> String {
736    format!(
737        "https://copilot-api.{}",
738        normalize_enterprise_domain(enterprise_url)
739    )
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745
746    #[test]
747    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
748        assert_eq!(
749            normalize_enterprise_domain("https://company.ghe.com/"),
750            "company.ghe.com"
751        );
752        assert_eq!(
753            normalize_enterprise_domain("http://company.ghe.com"),
754            "company.ghe.com"
755        );
756        assert_eq!(
757            normalize_enterprise_domain("company.ghe.com"),
758            "company.ghe.com"
759        );
760    }
761
762    #[test]
763    fn enterprise_base_url_uses_copilot_api_subdomain() {
764        assert_eq!(
765            enterprise_base_url("https://company.ghe.com/"),
766            "https://copilot-api.company.ghe.com"
767        );
768    }
769
770    #[test]
771    fn is_dated_model_variant_detects_date_suffix() {
772        assert!(is_dated_model_variant("gpt-4o-2024-05-13"));
773        assert!(is_dated_model_variant("gpt-4o-2024-08-06"));
774        assert!(is_dated_model_variant("gpt-4.1-2025-04-14"));
775        assert!(is_dated_model_variant("gpt-4o-mini-2024-07-18"));
776        assert!(!is_dated_model_variant("gpt-4o"));
777        assert!(!is_dated_model_variant("gpt-5"));
778        assert!(!is_dated_model_variant("claude-sonnet-4"));
779        assert!(!is_dated_model_variant("gemini-2.5-pro"));
780    }
781
782    #[test]
783    fn known_models_fallback_is_non_empty() {
784        let provider = CopilotProvider::new("test-token".to_string()).unwrap();
785        let models = provider.known_models();
786        assert!(!models.is_empty());
787        // All fallback models should support tools
788        assert!(models.iter().all(|m| m.supports_tools));
789    }
790
791    #[test]
792    fn enrich_with_pricing_sets_costs() {
793        let provider = CopilotProvider::new("test-token".to_string()).unwrap();
794        let mut models = vec![ModelInfo {
795            id: "gpt-4o".to_string(),
796            name: "gpt-4o".to_string(),
797            provider: "github-copilot".to_string(),
798            context_window: 128_000,
799            max_output_tokens: Some(16_384),
800            supports_vision: true,
801            supports_tools: true,
802            supports_streaming: true,
803            input_cost_per_million: None,
804            output_cost_per_million: None,
805        }];
806        provider.enrich_with_pricing(&mut models);
807        // GPT-4o is 0x premium (free), so cost = 0.0
808        assert_eq!(models[0].input_cost_per_million, Some(0.0));
809        assert_eq!(models[0].output_cost_per_million, Some(0.0));
810        // Name should be enriched
811        assert_eq!(models[0].name, "GPT-4o");
812    }
813}