Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18    client: Client,
19    token: String,
20    base_url: String,
21    provider_name: String,
22}
23
24impl CopilotProvider {
25    pub fn new(token: String) -> Result<Self> {
26        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27    }
28
29    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30        let base_url = enterprise_base_url(&enterprise_url);
31        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32    }
33
34    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35        Ok(Self {
36            client: Client::new(),
37            token,
38            base_url: base_url.trim_end_matches('/').to_string(),
39            provider_name: provider_name.to_string(),
40        })
41    }
42
43    fn user_agent() -> String {
44        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45    }
46
47    fn convert_messages(messages: &[Message]) -> Vec<Value> {
48        messages
49            .iter()
50            .map(|msg| {
51                let role = match msg.role {
52                    Role::System => "system",
53                    Role::User => "user",
54                    Role::Assistant => "assistant",
55                    Role::Tool => "tool",
56                };
57
58                match msg.role {
59                    Role::Tool => {
60                        if let Some(ContentPart::ToolResult {
61                            tool_call_id,
62                            content,
63                        }) = msg.content.first()
64                        {
65                            json!({
66                                "role": "tool",
67                                "tool_call_id": tool_call_id,
68                                "content": content
69                            })
70                        } else {
71                            json!({ "role": role, "content": "" })
72                        }
73                    }
74                    Role::Assistant => {
75                        let text: String = msg
76                            .content
77                            .iter()
78                            .filter_map(|p| match p {
79                                ContentPart::Text { text } => Some(text.clone()),
80                                _ => None,
81                            })
82                            .collect::<Vec<_>>()
83                            .join("");
84
85                        let tool_calls: Vec<Value> = msg
86                            .content
87                            .iter()
88                            .filter_map(|p| match p {
89                                ContentPart::ToolCall {
90                                    id,
91                                    name,
92                                    arguments,
93                                } => Some(json!({
94                                    "id": id,
95                                    "type": "function",
96                                    "function": {
97                                        "name": name,
98                                        "arguments": arguments
99                                    }
100                                })),
101                                _ => None,
102                            })
103                            .collect();
104
105                        if tool_calls.is_empty() {
106                            json!({ "role": "assistant", "content": text })
107                        } else {
108                            json!({
109                                "role": "assistant",
110                                "content": if text.is_empty() { "".to_string() } else { text },
111                                "tool_calls": tool_calls
112                            })
113                        }
114                    }
115                    _ => {
116                        let text: String = msg
117                            .content
118                            .iter()
119                            .filter_map(|p| match p {
120                                ContentPart::Text { text } => Some(text.clone()),
121                                _ => None,
122                            })
123                            .collect::<Vec<_>>()
124                            .join("\n");
125                        json!({ "role": role, "content": text })
126                    }
127                }
128            })
129            .collect()
130    }
131
132    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133        tools
134            .iter()
135            .map(|t| {
136                json!({
137                    "type": "function",
138                    "function": {
139                        "name": t.name,
140                        "description": t.description,
141                        "parameters": t.parameters
142                    }
143                })
144            })
145            .collect()
146    }
147
148    fn is_agent_initiated(messages: &[Message]) -> bool {
149        messages
150            .iter()
151            .rev()
152            .find(|msg| msg.role != Role::System)
153            .map(|msg| msg.role != Role::User)
154            .unwrap_or(false)
155    }
156
157    fn has_vision_input(messages: &[Message]) -> bool {
158        messages.iter().any(|msg| {
159            msg.content
160                .iter()
161                .any(|part| matches!(part, ContentPart::Image { .. }))
162        })
163    }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168    choices: Vec<CopilotChoice>,
169    #[serde(default)]
170    usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175    message: CopilotMessage,
176    #[serde(default)]
177    finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182    #[serde(default)]
183    content: Option<String>,
184    #[serde(default)]
185    tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190    id: String,
191    #[serde(rename = "type")]
192    #[allow(dead_code)]
193    call_type: String,
194    function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199    name: String,
200    arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205    #[serde(default)]
206    prompt_tokens: usize,
207    #[serde(default)]
208    completion_tokens: usize,
209    #[serde(default)]
210    total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215    error: Option<CopilotErrorDetail>,
216    message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221    message: Option<String>,
222    code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227    data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232    id: String,
233    #[serde(default)]
234    name: Option<String>,
235    #[serde(default)]
236    model_picker_enabled: Option<bool>,
237    #[serde(default)]
238    policy: Option<CopilotModelPolicy>,
239    #[serde(default)]
240    capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245    #[serde(default)]
246    state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251    #[serde(default)]
252    limits: Option<CopilotModelLimits>,
253    #[serde(default)]
254    supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259    #[serde(default)]
260    max_context_window_tokens: Option<usize>,
261    #[serde(default)]
262    max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267    #[serde(default)]
268    tool_calls: Option<bool>,
269    #[serde(default)]
270    vision: Option<bool>,
271    #[serde(default)]
272    streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277    fn name(&self) -> &str {
278        &self.provider_name
279    }
280
281    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282        let response = self
283            .client
284            .get(format!("{}/models", self.base_url))
285            .header("Authorization", format!("Bearer {}", self.token))
286            .header("Openai-Intent", "conversation-edits")
287            .header("User-Agent", Self::user_agent())
288            .send()
289            .await
290            .context("Failed to fetch Copilot models")?;
291
292        let mut models: Vec<ModelInfo> = if response.status().is_success() {
293            let parsed: CopilotModelsResponse = response
294                .json()
295                .await
296                .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298            parsed
299                .data
300                .into_iter()
301                .map(|model| {
302                    let caps = model.capabilities.as_ref();
303                    let limits = caps.and_then(|c| c.limits.as_ref());
304                    let supports = caps.and_then(|c| c.supports.as_ref());
305
306                    ModelInfo {
307                        id: model.id.clone(),
308                        name: model.name.unwrap_or_else(|| model.id.clone()),
309                        provider: self.provider_name.clone(),
310                        context_window: limits
311                            .and_then(|l| l.max_context_window_tokens)
312                            .unwrap_or(128_000),
313                        max_output_tokens: limits
314                            .and_then(|l| l.max_output_tokens)
315                            .or(Some(16_384)),
316                        supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
317                        supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
318                        supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
319                        input_cost_per_million: Some(0.0),
320                        output_cost_per_million: Some(0.0),
321                    }
322                })
323                .collect()
324        } else {
325            Vec::new()
326        };
327
328        // Enrich API-returned models with known metadata (better names, accurate limits).
329        // We do NOT inject models the API doesn't list — those aren't usable yet.
330        // Source: https://docs.github.com/en/copilot/using-github-copilot/ai-models
331        // NOTE: Model IDs use dots for versions, same as the /models API returns
332        // (e.g. claude-opus-4.5, gpt-5.1, gpt-4.1).
333        let known_metadata: std::collections::HashMap<&str, (&str, usize, usize)> = [
334            ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000)),
335            ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000)),
336            ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 200_000, 64_000)),
337            ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000)),
338            ("claude-haiku-4.5", ("Claude Haiku 4.5", 200_000, 64_000)),
339            ("gpt-5.2", ("GPT-5.2", 400_000, 128_000)),
340            ("gpt-5.1", ("GPT-5.1", 400_000, 128_000)),
341            ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000)),
342            (
343                "gpt-5.1-codex-mini",
344                ("GPT-5.1-Codex-Mini", 264_000, 64_000),
345            ),
346            ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 264_000, 64_000)),
347            ("gpt-5", ("GPT-5", 400_000, 128_000)),
348            ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000)),
349            ("gpt-4.1", ("GPT-4.1", 128_000, 32_768)),
350            ("gpt-4o", ("GPT-4o", 128_000, 16_384)),
351            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000)),
352            ("grok-code-fast-1", ("Grok Code Fast 1", 128_000, 32_768)),
353        ]
354        .into_iter()
355        .collect();
356
357        // Apply known metadata to enrich API models that had sparse info
358        for model in &mut models {
359            if let Some((name, ctx, max_out)) = known_metadata.get(model.id.as_str()) {
360                if model.name == model.id {
361                    model.name = name.to_string();
362                }
363                if model.context_window == 128_000 {
364                    model.context_window = *ctx;
365                }
366                if model.max_output_tokens == Some(16_384) {
367                    model.max_output_tokens = Some(*max_out);
368                }
369            }
370        }
371
372        // Filter out legacy/deprecated models that clutter the picker
373        // (embedding models, old GPT-3.5/4/4o variants without picker flag)
374        models.retain(|m| {
375            !m.id.starts_with("text-embedding")
376                && m.id != "gpt-3.5-turbo"
377                && m.id != "gpt-3.5-turbo-0613"
378                && m.id != "gpt-4-0613"
379                && m.id != "gpt-4o-2024-05-13"
380                && m.id != "gpt-4o-2024-08-06"
381                && m.id != "gpt-4o-2024-11-20"
382                && m.id != "gpt-4o-mini-2024-07-18"
383                && m.id != "gpt-4-o-preview"
384                && m.id != "gpt-4.1-2025-04-14"
385        });
386
387        // Deduplicate by id (API sometimes returns duplicates)
388        let mut seen = std::collections::HashSet::new();
389        models.retain(|m| seen.insert(m.id.clone()));
390
391        Ok(models)
392    }
393
394    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
395        let messages = Self::convert_messages(&request.messages);
396        let tools = Self::convert_tools(&request.tools);
397        let is_agent = Self::is_agent_initiated(&request.messages);
398        let has_vision = Self::has_vision_input(&request.messages);
399
400        let mut body = json!({
401            "model": request.model,
402            "messages": messages,
403        });
404
405        if !tools.is_empty() {
406            body["tools"] = json!(tools);
407        }
408        if let Some(temp) = request.temperature {
409            body["temperature"] = json!(temp);
410        }
411        if let Some(top_p) = request.top_p {
412            body["top_p"] = json!(top_p);
413        }
414        if let Some(max) = request.max_tokens {
415            body["max_tokens"] = json!(max);
416        }
417        if !request.stop.is_empty() {
418            body["stop"] = json!(request.stop);
419        }
420
421        let mut req = self
422            .client
423            .post(format!("{}/chat/completions", self.base_url))
424            .header("Authorization", format!("Bearer {}", self.token))
425            .header("Content-Type", "application/json")
426            .header("Openai-Intent", "conversation-edits")
427            .header("User-Agent", Self::user_agent())
428            .header("X-Initiator", if is_agent { "agent" } else { "user" });
429
430        if has_vision {
431            req = req.header("Copilot-Vision-Request", "true");
432        }
433
434        let response = req
435            .json(&body)
436            .send()
437            .await
438            .context("Failed to send Copilot request")?;
439
440        let status = response.status();
441        let text = response
442            .text()
443            .await
444            .context("Failed to read Copilot response")?;
445
446        if !status.is_success() {
447            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
448                let message = err
449                    .error
450                    .and_then(|detail| {
451                        detail.message.map(|msg| {
452                            if let Some(code) = detail.code {
453                                format!("{} ({})", msg, code)
454                            } else {
455                                msg
456                            }
457                        })
458                    })
459                    .or(err.message)
460                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
461                anyhow::bail!("Copilot API error: {}", message);
462            }
463            anyhow::bail!("Copilot API error: {} {}", status, text);
464        }
465
466        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
467            "Failed to parse Copilot response: {}",
468            &text[..text.len().min(200)]
469        ))?;
470
471        let choice = response
472            .choices
473            .first()
474            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
475
476        let mut content = Vec::new();
477        let mut has_tool_calls = false;
478
479        if let Some(text) = &choice.message.content {
480            if !text.is_empty() {
481                content.push(ContentPart::Text { text: text.clone() });
482            }
483        }
484
485        if let Some(tool_calls) = &choice.message.tool_calls {
486            has_tool_calls = !tool_calls.is_empty();
487            for tc in tool_calls {
488                content.push(ContentPart::ToolCall {
489                    id: tc.id.clone(),
490                    name: tc.function.name.clone(),
491                    arguments: tc.function.arguments.clone(),
492                });
493            }
494        }
495
496        let finish_reason = if has_tool_calls {
497            FinishReason::ToolCalls
498        } else {
499            match choice.finish_reason.as_deref() {
500                Some("stop") => FinishReason::Stop,
501                Some("length") => FinishReason::Length,
502                Some("tool_calls") => FinishReason::ToolCalls,
503                Some("content_filter") => FinishReason::ContentFilter,
504                _ => FinishReason::Stop,
505            }
506        };
507
508        Ok(CompletionResponse {
509            message: Message {
510                role: Role::Assistant,
511                content,
512            },
513            usage: Usage {
514                prompt_tokens: response
515                    .usage
516                    .as_ref()
517                    .map(|u| u.prompt_tokens)
518                    .unwrap_or(0),
519                completion_tokens: response
520                    .usage
521                    .as_ref()
522                    .map(|u| u.completion_tokens)
523                    .unwrap_or(0),
524                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
525                ..Default::default()
526            },
527            finish_reason,
528        })
529    }
530
531    async fn complete_stream(
532        &self,
533        request: CompletionRequest,
534    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
535        // For now, keep behavior aligned with other non-streaming providers.
536        let response = self.complete(request).await?;
537        let text = response
538            .message
539            .content
540            .iter()
541            .filter_map(|p| match p {
542                ContentPart::Text { text } => Some(text.clone()),
543                _ => None,
544            })
545            .collect::<Vec<_>>()
546            .join("");
547
548        Ok(Box::pin(futures::stream::once(async move {
549            StreamChunk::Text(text)
550        })))
551    }
552}
553
554pub fn normalize_enterprise_domain(input: &str) -> String {
555    input
556        .trim()
557        .trim_start_matches("https://")
558        .trim_start_matches("http://")
559        .trim_end_matches('/')
560        .to_string()
561}
562
563pub fn enterprise_base_url(enterprise_url: &str) -> String {
564    format!(
565        "https://copilot-api.{}",
566        normalize_enterprise_domain(enterprise_url)
567    )
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573
574    #[test]
575    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
576        assert_eq!(
577            normalize_enterprise_domain("https://company.ghe.com/"),
578            "company.ghe.com"
579        );
580        assert_eq!(
581            normalize_enterprise_domain("http://company.ghe.com"),
582            "company.ghe.com"
583        );
584        assert_eq!(
585            normalize_enterprise_domain("company.ghe.com"),
586            "company.ghe.com"
587        );
588    }
589
590    #[test]
591    fn enterprise_base_url_uses_copilot_api_subdomain() {
592        assert_eq!(
593            enterprise_base_url("https://company.ghe.com/"),
594            "https://copilot-api.company.ghe.com"
595        );
596    }
597}