Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18    client: Client,
19    token: String,
20    base_url: String,
21    provider_name: String,
22}
23
24impl CopilotProvider {
25    pub fn new(token: String) -> Result<Self> {
26        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27    }
28
29    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30        let base_url = enterprise_base_url(&enterprise_url);
31        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32    }
33
34    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35        Ok(Self {
36            client: Client::new(),
37            token,
38            base_url: base_url.trim_end_matches('/').to_string(),
39            provider_name: provider_name.to_string(),
40        })
41    }
42
43    fn user_agent() -> String {
44        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45    }
46
47    fn convert_messages(messages: &[Message]) -> Vec<Value> {
48        messages
49            .iter()
50            .map(|msg| {
51                let role = match msg.role {
52                    Role::System => "system",
53                    Role::User => "user",
54                    Role::Assistant => "assistant",
55                    Role::Tool => "tool",
56                };
57
58                match msg.role {
59                    Role::Tool => {
60                        if let Some(ContentPart::ToolResult {
61                            tool_call_id,
62                            content,
63                        }) = msg.content.first()
64                        {
65                            json!({
66                                "role": "tool",
67                                "tool_call_id": tool_call_id,
68                                "content": content
69                            })
70                        } else {
71                            json!({ "role": role, "content": "" })
72                        }
73                    }
74                    Role::Assistant => {
75                        let text: String = msg
76                            .content
77                            .iter()
78                            .filter_map(|p| match p {
79                                ContentPart::Text { text } => Some(text.clone()),
80                                _ => None,
81                            })
82                            .collect::<Vec<_>>()
83                            .join("");
84
85                        let tool_calls: Vec<Value> = msg
86                            .content
87                            .iter()
88                            .filter_map(|p| match p {
89                                ContentPart::ToolCall {
90                                    id,
91                                    name,
92                                    arguments,
93                                } => Some(json!({
94                                    "id": id,
95                                    "type": "function",
96                                    "function": {
97                                        "name": name,
98                                        "arguments": arguments
99                                    }
100                                })),
101                                _ => None,
102                            })
103                            .collect();
104
105                        if tool_calls.is_empty() {
106                            json!({ "role": "assistant", "content": text })
107                        } else {
108                            json!({
109                                "role": "assistant",
110                                "content": if text.is_empty() { "".to_string() } else { text },
111                                "tool_calls": tool_calls
112                            })
113                        }
114                    }
115                    _ => {
116                        let text: String = msg
117                            .content
118                            .iter()
119                            .filter_map(|p| match p {
120                                ContentPart::Text { text } => Some(text.clone()),
121                                _ => None,
122                            })
123                            .collect::<Vec<_>>()
124                            .join("\n");
125                        json!({ "role": role, "content": text })
126                    }
127                }
128            })
129            .collect()
130    }
131
132    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133        tools
134            .iter()
135            .map(|t| {
136                json!({
137                    "type": "function",
138                    "function": {
139                        "name": t.name,
140                        "description": t.description,
141                        "parameters": t.parameters
142                    }
143                })
144            })
145            .collect()
146    }
147
148    fn is_agent_initiated(messages: &[Message]) -> bool {
149        messages
150            .iter()
151            .rev()
152            .find(|msg| msg.role != Role::System)
153            .map(|msg| msg.role != Role::User)
154            .unwrap_or(false)
155    }
156
157    fn has_vision_input(messages: &[Message]) -> bool {
158        messages.iter().any(|msg| {
159            msg.content
160                .iter()
161                .any(|part| matches!(part, ContentPart::Image { .. }))
162        })
163    }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168    choices: Vec<CopilotChoice>,
169    #[serde(default)]
170    usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175    message: CopilotMessage,
176    #[serde(default)]
177    finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182    #[serde(default)]
183    content: Option<String>,
184    #[serde(default)]
185    tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190    id: String,
191    #[serde(rename = "type")]
192    #[allow(dead_code)]
193    call_type: String,
194    function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199    name: String,
200    arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205    #[serde(default)]
206    prompt_tokens: usize,
207    #[serde(default)]
208    completion_tokens: usize,
209    #[serde(default)]
210    total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215    error: Option<CopilotErrorDetail>,
216    message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221    message: Option<String>,
222    code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227    data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232    id: String,
233    #[serde(default)]
234    name: Option<String>,
235    #[serde(default)]
236    model_picker_enabled: Option<bool>,
237    #[serde(default)]
238    policy: Option<CopilotModelPolicy>,
239    #[serde(default)]
240    capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245    #[serde(default)]
246    state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251    #[serde(default)]
252    limits: Option<CopilotModelLimits>,
253    #[serde(default)]
254    supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259    #[serde(default)]
260    max_context_window_tokens: Option<usize>,
261    #[serde(default)]
262    max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267    #[serde(default)]
268    tool_calls: Option<bool>,
269    #[serde(default)]
270    vision: Option<bool>,
271    #[serde(default)]
272    streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277    fn name(&self) -> &str {
278        &self.provider_name
279    }
280
281    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282        let response = self
283            .client
284            .get(format!("{}/models", self.base_url))
285            .header("Authorization", format!("Bearer {}", self.token))
286            .header("Openai-Intent", "conversation-edits")
287            .header("User-Agent", Self::user_agent())
288            .send()
289            .await
290            .context("Failed to fetch Copilot models")?;
291
292        let mut models: Vec<ModelInfo> = if response.status().is_success() {
293            let parsed: CopilotModelsResponse = response
294                .json()
295                .await
296                .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298            parsed
299                .data
300                .into_iter()
301                .filter(|model| {
302                    // Skip models that are disabled in the picker
303                    if model.model_picker_enabled == Some(false) {
304                        return false;
305                    }
306                    // Skip models with a disabled policy state
307                    if let Some(ref policy) = model.policy {
308                        if policy.state.as_deref() == Some("disabled") {
309                            return false;
310                        }
311                    }
312                    true
313                })
314                .map(|model| {
315                    let caps = model.capabilities.as_ref();
316                    let limits = caps.and_then(|c| c.limits.as_ref());
317                    let supports = caps.and_then(|c| c.supports.as_ref());
318
319                    ModelInfo {
320                        id: model.id.clone(),
321                        name: model.name.unwrap_or_else(|| model.id.clone()),
322                        provider: self.provider_name.clone(),
323                        context_window: limits
324                            .and_then(|l| l.max_context_window_tokens)
325                            .unwrap_or(128_000),
326                        max_output_tokens: limits
327                            .and_then(|l| l.max_output_tokens)
328                            .or(Some(16_384)),
329                        supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
330                        supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
331                        supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
332                        input_cost_per_million: None, // Set below per-model
333                        output_cost_per_million: None,
334                    }
335                })
336                .collect()
337        } else {
338            Vec::new()
339        };
340
341        // Enrich API-returned models with known metadata (better names, accurate limits)
342        // AND per-model premium request costs from GitHub Copilot pricing.
343        // Source: https://docs.github.com/en/copilot/concepts/billing/copilot-requests
344        //
345        // Cost model: Premium requests at $0.04/request overflow rate.
346        // We convert multiplier to approximate $/M tokens using ~4K tokens/request avg.
347        // Formula: multiplier * $0.04 / 4K tokens * 1M = multiplier * $10/M tokens.
348        //
349        // Premium request multipliers (Feb 2026):
350        //   0x (included): GPT-4.1, GPT-4o, GPT-5 mini, Raptor mini
351        //   0.25x: Grok Code Fast 1
352        //   0.33x: Claude Haiku 4.5, Gemini 3 Flash, GPT-5.1-Codex-Mini
353        //   1x: Claude Sonnet 4/4.5, Gemini 2.5/3 Pro, GPT-5, GPT-5.x-Codex variants
354        //   3x: Claude Opus 4.5, Claude Opus 4.6
355        //   10x: Claude Opus 4.1
356        //
357        // Tuple: (display_name, context_window, max_output, premium_multiplier)
358        let known_metadata: std::collections::HashMap<&str, (&str, usize, usize, f64)> = [
359            ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000, 3.0)),
360            ("claude-opus-4.6", ("Claude Opus 4.6", 200_000, 64_000, 3.0)),
361            ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000, 10.0)),
362            (
363                "claude-sonnet-4.5",
364                ("Claude Sonnet 4.5", 200_000, 64_000, 1.0),
365            ),
366            ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000, 1.0)),
367            (
368                "claude-haiku-4.5",
369                ("Claude Haiku 4.5", 200_000, 64_000, 0.33),
370            ),
371            ("gpt-5.3-codex", ("GPT-5.3-Codex", 264_000, 64_000, 1.0)),
372            ("gpt-5.2", ("GPT-5.2", 400_000, 128_000, 1.0)),
373            ("gpt-5.2-codex", ("GPT-5.2-Codex", 264_000, 64_000, 1.0)),
374            ("gpt-5.1", ("GPT-5.1", 400_000, 128_000, 1.0)),
375            ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000, 1.0)),
376            (
377                "gpt-5.1-codex-mini",
378                ("GPT-5.1-Codex-Mini", 264_000, 64_000, 0.33),
379            ),
380            (
381                "gpt-5.1-codex-max",
382                ("GPT-5.1-Codex-Max", 264_000, 64_000, 1.0),
383            ),
384            ("gpt-5", ("GPT-5", 400_000, 128_000, 1.0)),
385            ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000, 0.0)),
386            ("gpt-5-codex", ("GPT-5-Codex", 264_000, 64_000, 1.0)),
387            ("gpt-4.1", ("GPT-4.1", 128_000, 32_768, 0.0)),
388            ("gpt-4o", ("GPT-4o", 128_000, 16_384, 0.0)),
389            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000, 1.0)),
390            (
391                "gemini-3-flash-preview",
392                ("Gemini 3 Flash", 1_000_000, 64_000, 0.33),
393            ),
394            (
395                "gemini-3-pro-preview",
396                ("Gemini 3 Pro", 1_000_000, 64_000, 1.0),
397            ),
398            (
399                "grok-code-fast-1",
400                ("Grok Code Fast 1", 128_000, 32_768, 0.25),
401            ),
402        ]
403        .into_iter()
404        .collect();
405
406        // Apply known metadata to enrich API models that had sparse info,
407        // and set per-model premium request costs.
408        for model in &mut models {
409            if let Some((name, ctx, max_out, premium_mult)) = known_metadata.get(model.id.as_str())
410            {
411                if model.name == model.id {
412                    model.name = name.to_string();
413                }
414                if model.context_window == 128_000 {
415                    model.context_window = *ctx;
416                }
417                if model.max_output_tokens == Some(16_384) {
418                    model.max_output_tokens = Some(*max_out);
419                }
420                // Convert premium request multiplier to approximate $/M tokens.
421                // $0.04/request overflow rate, ~4K tokens/request avg = multiplier * $10/M.
422                // Models at 0.0x are included free on paid plans.
423                let approx_cost = premium_mult * 10.0;
424                model.input_cost_per_million = Some(approx_cost);
425                model.output_cost_per_million = Some(approx_cost);
426            } else {
427                // Unknown Copilot model — assume 1x premium request ($10/M approx)
428                if model.input_cost_per_million.is_none() {
429                    model.input_cost_per_million = Some(10.0);
430                }
431                if model.output_cost_per_million.is_none() {
432                    model.output_cost_per_million = Some(10.0);
433                }
434            }
435        }
436
437        // Filter out legacy/deprecated models that clutter the picker
438        // (embedding models, old GPT-3.5/4/4o variants without picker flag)
439        models.retain(|m| {
440            !m.id.starts_with("text-embedding")
441                && m.id != "gpt-3.5-turbo"
442                && m.id != "gpt-3.5-turbo-0613"
443                && m.id != "gpt-4-0613"
444                && m.id != "gpt-4o-2024-05-13"
445                && m.id != "gpt-4o-2024-08-06"
446                && m.id != "gpt-4o-2024-11-20"
447                && m.id != "gpt-4o-mini-2024-07-18"
448                && m.id != "gpt-4-o-preview"
449                && m.id != "gpt-4.1-2025-04-14"
450        });
451
452        // Deduplicate by id (API sometimes returns duplicates)
453        let mut seen = std::collections::HashSet::new();
454        models.retain(|m| seen.insert(m.id.clone()));
455
456        Ok(models)
457    }
458
459    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
460        let messages = Self::convert_messages(&request.messages);
461        let tools = Self::convert_tools(&request.tools);
462        let is_agent = Self::is_agent_initiated(&request.messages);
463        let has_vision = Self::has_vision_input(&request.messages);
464
465        let mut body = json!({
466            "model": request.model,
467            "messages": messages,
468        });
469
470        if !tools.is_empty() {
471            body["tools"] = json!(tools);
472        }
473        if let Some(temp) = request.temperature {
474            body["temperature"] = json!(temp);
475        }
476        if let Some(top_p) = request.top_p {
477            body["top_p"] = json!(top_p);
478        }
479        if let Some(max) = request.max_tokens {
480            body["max_tokens"] = json!(max);
481        }
482        if !request.stop.is_empty() {
483            body["stop"] = json!(request.stop);
484        }
485
486        let mut req = self
487            .client
488            .post(format!("{}/chat/completions", self.base_url))
489            .header("Authorization", format!("Bearer {}", self.token))
490            .header("Content-Type", "application/json")
491            .header("Openai-Intent", "conversation-edits")
492            .header("User-Agent", Self::user_agent())
493            .header("X-Initiator", if is_agent { "agent" } else { "user" });
494
495        if has_vision {
496            req = req.header("Copilot-Vision-Request", "true");
497        }
498
499        let response = req
500            .json(&body)
501            .send()
502            .await
503            .context("Failed to send Copilot request")?;
504
505        let status = response.status();
506        let text = response
507            .text()
508            .await
509            .context("Failed to read Copilot response")?;
510
511        if !status.is_success() {
512            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
513                let message = err
514                    .error
515                    .and_then(|detail| {
516                        detail.message.map(|msg| {
517                            if let Some(code) = detail.code {
518                                format!("{} ({})", msg, code)
519                            } else {
520                                msg
521                            }
522                        })
523                    })
524                    .or(err.message)
525                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
526                anyhow::bail!("Copilot API error: {}", message);
527            }
528            anyhow::bail!("Copilot API error: {} {}", status, text);
529        }
530
531        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
532            "Failed to parse Copilot response: {}",
533            &text[..text.len().min(200)]
534        ))?;
535
536        let choice = response
537            .choices
538            .first()
539            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
540
541        let mut content = Vec::new();
542        let mut has_tool_calls = false;
543
544        if let Some(text) = &choice.message.content {
545            if !text.is_empty() {
546                content.push(ContentPart::Text { text: text.clone() });
547            }
548        }
549
550        if let Some(tool_calls) = &choice.message.tool_calls {
551            has_tool_calls = !tool_calls.is_empty();
552            for tc in tool_calls {
553                content.push(ContentPart::ToolCall {
554                    id: tc.id.clone(),
555                    name: tc.function.name.clone(),
556                    arguments: tc.function.arguments.clone(),
557                });
558            }
559        }
560
561        let finish_reason = if has_tool_calls {
562            FinishReason::ToolCalls
563        } else {
564            match choice.finish_reason.as_deref() {
565                Some("stop") => FinishReason::Stop,
566                Some("length") => FinishReason::Length,
567                Some("tool_calls") => FinishReason::ToolCalls,
568                Some("content_filter") => FinishReason::ContentFilter,
569                _ => FinishReason::Stop,
570            }
571        };
572
573        Ok(CompletionResponse {
574            message: Message {
575                role: Role::Assistant,
576                content,
577            },
578            usage: Usage {
579                prompt_tokens: response
580                    .usage
581                    .as_ref()
582                    .map(|u| u.prompt_tokens)
583                    .unwrap_or(0),
584                completion_tokens: response
585                    .usage
586                    .as_ref()
587                    .map(|u| u.completion_tokens)
588                    .unwrap_or(0),
589                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
590                ..Default::default()
591            },
592            finish_reason,
593        })
594    }
595
596    async fn complete_stream(
597        &self,
598        request: CompletionRequest,
599    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
600        // For now, keep behavior aligned with other non-streaming providers.
601        let response = self.complete(request).await?;
602        let text = response
603            .message
604            .content
605            .iter()
606            .filter_map(|p| match p {
607                ContentPart::Text { text } => Some(text.clone()),
608                _ => None,
609            })
610            .collect::<Vec<_>>()
611            .join("");
612
613        Ok(Box::pin(futures::stream::once(async move {
614            StreamChunk::Text(text)
615        })))
616    }
617}
618
619pub fn normalize_enterprise_domain(input: &str) -> String {
620    input
621        .trim()
622        .trim_start_matches("https://")
623        .trim_start_matches("http://")
624        .trim_end_matches('/')
625        .to_string()
626}
627
628pub fn enterprise_base_url(enterprise_url: &str) -> String {
629    format!(
630        "https://copilot-api.{}",
631        normalize_enterprise_domain(enterprise_url)
632    )
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    #[test]
640    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
641        assert_eq!(
642            normalize_enterprise_domain("https://company.ghe.com/"),
643            "company.ghe.com"
644        );
645        assert_eq!(
646            normalize_enterprise_domain("http://company.ghe.com"),
647            "company.ghe.com"
648        );
649        assert_eq!(
650            normalize_enterprise_domain("company.ghe.com"),
651            "company.ghe.com"
652        );
653    }
654
655    #[test]
656    fn enterprise_base_url_uses_copilot_api_subdomain() {
657        assert_eq!(
658            enterprise_base_url("https://company.ghe.com/"),
659            "https://copilot-api.company.ghe.com"
660        );
661    }
662}