Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18    client: Client,
19    token: String,
20    base_url: String,
21    provider_name: String,
22}
23
24impl CopilotProvider {
25    pub fn new(token: String) -> Result<Self> {
26        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27    }
28
29    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30        let base_url = enterprise_base_url(&enterprise_url);
31        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32    }
33
34    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35        Ok(Self {
36            client: Client::new(),
37            token,
38            base_url: base_url.trim_end_matches('/').to_string(),
39            provider_name: provider_name.to_string(),
40        })
41    }
42
43    fn user_agent() -> String {
44        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45    }
46
47    fn convert_messages(messages: &[Message]) -> Vec<Value> {
48        messages
49            .iter()
50            .map(|msg| {
51                let role = match msg.role {
52                    Role::System => "system",
53                    Role::User => "user",
54                    Role::Assistant => "assistant",
55                    Role::Tool => "tool",
56                };
57
58                match msg.role {
59                    Role::Tool => {
60                        if let Some(ContentPart::ToolResult {
61                            tool_call_id,
62                            content,
63                        }) = msg.content.first()
64                        {
65                            json!({
66                                "role": "tool",
67                                "tool_call_id": tool_call_id,
68                                "content": content
69                            })
70                        } else {
71                            json!({ "role": role, "content": "" })
72                        }
73                    }
74                    Role::Assistant => {
75                        let text: String = msg
76                            .content
77                            .iter()
78                            .filter_map(|p| match p {
79                                ContentPart::Text { text } => Some(text.clone()),
80                                _ => None,
81                            })
82                            .collect::<Vec<_>>()
83                            .join("");
84
85                        let tool_calls: Vec<Value> = msg
86                            .content
87                            .iter()
88                            .filter_map(|p| match p {
89                                ContentPart::ToolCall {
90                                    id,
91                                    name,
92                                    arguments,
93                                } => Some(json!({
94                                    "id": id,
95                                    "type": "function",
96                                    "function": {
97                                        "name": name,
98                                        "arguments": arguments
99                                    }
100                                })),
101                                _ => None,
102                            })
103                            .collect();
104
105                        if tool_calls.is_empty() {
106                            json!({ "role": "assistant", "content": text })
107                        } else {
108                            json!({
109                                "role": "assistant",
110                                "content": if text.is_empty() { "".to_string() } else { text },
111                                "tool_calls": tool_calls
112                            })
113                        }
114                    }
115                    _ => {
116                        let text: String = msg
117                            .content
118                            .iter()
119                            .filter_map(|p| match p {
120                                ContentPart::Text { text } => Some(text.clone()),
121                                _ => None,
122                            })
123                            .collect::<Vec<_>>()
124                            .join("\n");
125                        json!({ "role": role, "content": text })
126                    }
127                }
128            })
129            .collect()
130    }
131
132    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133        tools
134            .iter()
135            .map(|t| {
136                json!({
137                    "type": "function",
138                    "function": {
139                        "name": t.name,
140                        "description": t.description,
141                        "parameters": t.parameters
142                    }
143                })
144            })
145            .collect()
146    }
147
148    fn is_agent_initiated(messages: &[Message]) -> bool {
149        messages
150            .iter()
151            .rev()
152            .find(|msg| msg.role != Role::System)
153            .map(|msg| msg.role != Role::User)
154            .unwrap_or(false)
155    }
156
157    fn has_vision_input(messages: &[Message]) -> bool {
158        messages.iter().any(|msg| {
159            msg.content
160                .iter()
161                .any(|part| matches!(part, ContentPart::Image { .. }))
162        })
163    }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168    choices: Vec<CopilotChoice>,
169    #[serde(default)]
170    usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175    message: CopilotMessage,
176    #[serde(default)]
177    finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182    #[serde(default)]
183    content: Option<String>,
184    #[serde(default)]
185    tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190    id: String,
191    #[serde(rename = "type")]
192    #[allow(dead_code)]
193    call_type: String,
194    function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199    name: String,
200    arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205    #[serde(default)]
206    prompt_tokens: usize,
207    #[serde(default)]
208    completion_tokens: usize,
209    #[serde(default)]
210    total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215    error: Option<CopilotErrorDetail>,
216    message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221    message: Option<String>,
222    code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227    data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232    id: String,
233    #[serde(default)]
234    name: Option<String>,
235    #[serde(default)]
236    model_picker_enabled: Option<bool>,
237    #[serde(default)]
238    policy: Option<CopilotModelPolicy>,
239    #[serde(default)]
240    capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245    #[serde(default)]
246    state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251    #[serde(default)]
252    limits: Option<CopilotModelLimits>,
253    #[serde(default)]
254    supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259    #[serde(default)]
260    max_context_window_tokens: Option<usize>,
261    #[serde(default)]
262    max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267    #[serde(default)]
268    tool_calls: Option<bool>,
269    #[serde(default)]
270    vision: Option<bool>,
271    #[serde(default)]
272    streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277    fn name(&self) -> &str {
278        &self.provider_name
279    }
280
281    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282        let response = self
283            .client
284            .get(format!("{}/models", self.base_url))
285            .header("Authorization", format!("Bearer {}", self.token))
286            .header("Openai-Intent", "conversation-edits")
287            .header("User-Agent", Self::user_agent())
288            .send()
289            .await
290            .context("Failed to fetch Copilot models")?;
291
292        let mut models: Vec<ModelInfo> = if response.status().is_success() {
293            let parsed: CopilotModelsResponse = response
294                .json()
295                .await
296                .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298            parsed
299                .data
300                .into_iter()
301                .map(|model| {
302                    let caps = model.capabilities.as_ref();
303                    let limits = caps.and_then(|c| c.limits.as_ref());
304                    let supports = caps.and_then(|c| c.supports.as_ref());
305
306                    ModelInfo {
307                        id: model.id.clone(),
308                        name: model.name.unwrap_or_else(|| model.id.clone()),
309                        provider: self.provider_name.clone(),
310                        context_window: limits
311                            .and_then(|l| l.max_context_window_tokens)
312                            .unwrap_or(128_000),
313                        max_output_tokens: limits
314                            .and_then(|l| l.max_output_tokens)
315                            .or(Some(16_384)),
316                        supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
317                        supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
318                        supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
319                        input_cost_per_million: None, // Set below per-model
320                        output_cost_per_million: None,
321                    }
322                })
323                .collect()
324        } else {
325            Vec::new()
326        };
327
328        // Enrich API-returned models with known metadata (better names, accurate limits)
329        // AND per-model premium request costs from GitHub Copilot pricing.
330        // Source: https://docs.github.com/en/copilot/concepts/billing/copilot-requests
331        //
332        // Cost model: Premium requests at $0.04/request overflow rate.
333        // We convert multiplier to approximate $/M tokens using ~4K tokens/request avg.
334        // Formula: multiplier * $0.04 / 4K tokens * 1M = multiplier * $10/M tokens.
335        //
336        // Premium request multipliers (Feb 2026):
337        //   0x (included): GPT-4.1, GPT-4o, GPT-5 mini, Raptor mini
338        //   0.25x: Grok Code Fast 1
339        //   0.33x: Claude Haiku 4.5, Gemini 3 Flash, GPT-5.1-Codex-Mini
340        //   1x: Claude Sonnet 4/4.5, Gemini 2.5/3 Pro, GPT-5, GPT-5.x-Codex variants
341        //   3x: Claude Opus 4.5, Claude Opus 4.6
342        //   10x: Claude Opus 4.1
343        //
344        // Tuple: (display_name, context_window, max_output, premium_multiplier)
345        let known_metadata: std::collections::HashMap<&str, (&str, usize, usize, f64)> = [
346            ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000, 3.0)),
347            ("claude-opus-4.6", ("Claude Opus 4.6", 200_000, 64_000, 3.0)),
348            ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000, 10.0)),
349            (
350                "claude-sonnet-4.5",
351                ("Claude Sonnet 4.5", 200_000, 64_000, 1.0),
352            ),
353            ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000, 1.0)),
354            (
355                "claude-haiku-4.5",
356                ("Claude Haiku 4.5", 200_000, 64_000, 0.33),
357            ),
358            ("gpt-5.2", ("GPT-5.2", 400_000, 128_000, 1.0)),
359            ("gpt-5.1", ("GPT-5.1", 400_000, 128_000, 1.0)),
360            ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000, 1.0)),
361            (
362                "gpt-5.1-codex-mini",
363                ("GPT-5.1-Codex-Mini", 264_000, 64_000, 0.33),
364            ),
365            (
366                "gpt-5.1-codex-max",
367                ("GPT-5.1-Codex-Max", 264_000, 64_000, 1.0),
368            ),
369            ("gpt-5", ("GPT-5", 400_000, 128_000, 1.0)),
370            ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000, 0.0)),
371            ("gpt-5-codex", ("GPT-5-Codex", 264_000, 64_000, 1.0)),
372            ("gpt-4.1", ("GPT-4.1", 128_000, 32_768, 0.0)),
373            ("gpt-4o", ("GPT-4o", 128_000, 16_384, 0.0)),
374            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000, 1.0)),
375            (
376                "gemini-3-flash-preview",
377                ("Gemini 3 Flash", 1_000_000, 64_000, 0.33),
378            ),
379            (
380                "gemini-3-pro-preview",
381                ("Gemini 3 Pro", 1_000_000, 64_000, 1.0),
382            ),
383            (
384                "grok-code-fast-1",
385                ("Grok Code Fast 1", 128_000, 32_768, 0.25),
386            ),
387        ]
388        .into_iter()
389        .collect();
390
391        // Apply known metadata to enrich API models that had sparse info,
392        // and set per-model premium request costs.
393        for model in &mut models {
394            if let Some((name, ctx, max_out, premium_mult)) = known_metadata.get(model.id.as_str())
395            {
396                if model.name == model.id {
397                    model.name = name.to_string();
398                }
399                if model.context_window == 128_000 {
400                    model.context_window = *ctx;
401                }
402                if model.max_output_tokens == Some(16_384) {
403                    model.max_output_tokens = Some(*max_out);
404                }
405                // Convert premium request multiplier to approximate $/M tokens.
406                // $0.04/request overflow rate, ~4K tokens/request avg = multiplier * $10/M.
407                // Models at 0.0x are included free on paid plans.
408                let approx_cost = premium_mult * 10.0;
409                model.input_cost_per_million = Some(approx_cost);
410                model.output_cost_per_million = Some(approx_cost);
411            } else {
412                // Unknown Copilot model — assume 1x premium request ($10/M approx)
413                if model.input_cost_per_million.is_none() {
414                    model.input_cost_per_million = Some(10.0);
415                }
416                if model.output_cost_per_million.is_none() {
417                    model.output_cost_per_million = Some(10.0);
418                }
419            }
420        }
421
422        // Filter out legacy/deprecated models that clutter the picker
423        // (embedding models, old GPT-3.5/4/4o variants without picker flag)
424        models.retain(|m| {
425            !m.id.starts_with("text-embedding")
426                && m.id != "gpt-3.5-turbo"
427                && m.id != "gpt-3.5-turbo-0613"
428                && m.id != "gpt-4-0613"
429                && m.id != "gpt-4o-2024-05-13"
430                && m.id != "gpt-4o-2024-08-06"
431                && m.id != "gpt-4o-2024-11-20"
432                && m.id != "gpt-4o-mini-2024-07-18"
433                && m.id != "gpt-4-o-preview"
434                && m.id != "gpt-4.1-2025-04-14"
435        });
436
437        // Deduplicate by id (API sometimes returns duplicates)
438        let mut seen = std::collections::HashSet::new();
439        models.retain(|m| seen.insert(m.id.clone()));
440
441        Ok(models)
442    }
443
444    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
445        let messages = Self::convert_messages(&request.messages);
446        let tools = Self::convert_tools(&request.tools);
447        let is_agent = Self::is_agent_initiated(&request.messages);
448        let has_vision = Self::has_vision_input(&request.messages);
449
450        let mut body = json!({
451            "model": request.model,
452            "messages": messages,
453        });
454
455        if !tools.is_empty() {
456            body["tools"] = json!(tools);
457        }
458        if let Some(temp) = request.temperature {
459            body["temperature"] = json!(temp);
460        }
461        if let Some(top_p) = request.top_p {
462            body["top_p"] = json!(top_p);
463        }
464        if let Some(max) = request.max_tokens {
465            body["max_tokens"] = json!(max);
466        }
467        if !request.stop.is_empty() {
468            body["stop"] = json!(request.stop);
469        }
470
471        let mut req = self
472            .client
473            .post(format!("{}/chat/completions", self.base_url))
474            .header("Authorization", format!("Bearer {}", self.token))
475            .header("Content-Type", "application/json")
476            .header("Openai-Intent", "conversation-edits")
477            .header("User-Agent", Self::user_agent())
478            .header("X-Initiator", if is_agent { "agent" } else { "user" });
479
480        if has_vision {
481            req = req.header("Copilot-Vision-Request", "true");
482        }
483
484        let response = req
485            .json(&body)
486            .send()
487            .await
488            .context("Failed to send Copilot request")?;
489
490        let status = response.status();
491        let text = response
492            .text()
493            .await
494            .context("Failed to read Copilot response")?;
495
496        if !status.is_success() {
497            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
498                let message = err
499                    .error
500                    .and_then(|detail| {
501                        detail.message.map(|msg| {
502                            if let Some(code) = detail.code {
503                                format!("{} ({})", msg, code)
504                            } else {
505                                msg
506                            }
507                        })
508                    })
509                    .or(err.message)
510                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
511                anyhow::bail!("Copilot API error: {}", message);
512            }
513            anyhow::bail!("Copilot API error: {} {}", status, text);
514        }
515
516        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
517            "Failed to parse Copilot response: {}",
518            &text[..text.len().min(200)]
519        ))?;
520
521        let choice = response
522            .choices
523            .first()
524            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
525
526        let mut content = Vec::new();
527        let mut has_tool_calls = false;
528
529        if let Some(text) = &choice.message.content {
530            if !text.is_empty() {
531                content.push(ContentPart::Text { text: text.clone() });
532            }
533        }
534
535        if let Some(tool_calls) = &choice.message.tool_calls {
536            has_tool_calls = !tool_calls.is_empty();
537            for tc in tool_calls {
538                content.push(ContentPart::ToolCall {
539                    id: tc.id.clone(),
540                    name: tc.function.name.clone(),
541                    arguments: tc.function.arguments.clone(),
542                });
543            }
544        }
545
546        let finish_reason = if has_tool_calls {
547            FinishReason::ToolCalls
548        } else {
549            match choice.finish_reason.as_deref() {
550                Some("stop") => FinishReason::Stop,
551                Some("length") => FinishReason::Length,
552                Some("tool_calls") => FinishReason::ToolCalls,
553                Some("content_filter") => FinishReason::ContentFilter,
554                _ => FinishReason::Stop,
555            }
556        };
557
558        Ok(CompletionResponse {
559            message: Message {
560                role: Role::Assistant,
561                content,
562            },
563            usage: Usage {
564                prompt_tokens: response
565                    .usage
566                    .as_ref()
567                    .map(|u| u.prompt_tokens)
568                    .unwrap_or(0),
569                completion_tokens: response
570                    .usage
571                    .as_ref()
572                    .map(|u| u.completion_tokens)
573                    .unwrap_or(0),
574                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
575                ..Default::default()
576            },
577            finish_reason,
578        })
579    }
580
581    async fn complete_stream(
582        &self,
583        request: CompletionRequest,
584    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
585        // For now, keep behavior aligned with other non-streaming providers.
586        let response = self.complete(request).await?;
587        let text = response
588            .message
589            .content
590            .iter()
591            .filter_map(|p| match p {
592                ContentPart::Text { text } => Some(text.clone()),
593                _ => None,
594            })
595            .collect::<Vec<_>>()
596            .join("");
597
598        Ok(Box::pin(futures::stream::once(async move {
599            StreamChunk::Text(text)
600        })))
601    }
602}
603
604pub fn normalize_enterprise_domain(input: &str) -> String {
605    input
606        .trim()
607        .trim_start_matches("https://")
608        .trim_start_matches("http://")
609        .trim_end_matches('/')
610        .to_string()
611}
612
613pub fn enterprise_base_url(enterprise_url: &str) -> String {
614    format!(
615        "https://copilot-api.{}",
616        normalize_enterprise_domain(enterprise_url)
617    )
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    #[test]
625    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
626        assert_eq!(
627            normalize_enterprise_domain("https://company.ghe.com/"),
628            "company.ghe.com"
629        );
630        assert_eq!(
631            normalize_enterprise_domain("http://company.ghe.com"),
632            "company.ghe.com"
633        );
634        assert_eq!(
635            normalize_enterprise_domain("company.ghe.com"),
636            "company.ghe.com"
637        );
638    }
639
640    #[test]
641    fn enterprise_base_url_uses_copilot_api_subdomain() {
642        assert_eq!(
643            enterprise_base_url("https://company.ghe.com/"),
644            "https://copilot-api.company.ghe.com"
645        );
646    }
647}