Skip to main content

codetether_agent/provider/
copilot.rs

1//! GitHub Copilot provider implementation using raw HTTP.
2
3use super::{
4    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5    Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18    client: Client,
19    token: String,
20    base_url: String,
21    provider_name: String,
22}
23
24impl CopilotProvider {
25    pub fn new(token: String) -> Result<Self> {
26        Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27    }
28
29    pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30        let base_url = enterprise_base_url(&enterprise_url);
31        Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32    }
33
34    pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35        Ok(Self {
36            client: Client::new(),
37            token,
38            base_url: base_url.trim_end_matches('/').to_string(),
39            provider_name: provider_name.to_string(),
40        })
41    }
42
43    fn user_agent() -> String {
44        format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45    }
46
47    fn convert_messages(messages: &[Message]) -> Vec<Value> {
48        messages
49            .iter()
50            .map(|msg| {
51                let role = match msg.role {
52                    Role::System => "system",
53                    Role::User => "user",
54                    Role::Assistant => "assistant",
55                    Role::Tool => "tool",
56                };
57
58                match msg.role {
59                    Role::Tool => {
60                        if let Some(ContentPart::ToolResult {
61                            tool_call_id,
62                            content,
63                        }) = msg.content.first()
64                        {
65                            json!({
66                                "role": "tool",
67                                "tool_call_id": tool_call_id,
68                                "content": content
69                            })
70                        } else {
71                            json!({ "role": role, "content": "" })
72                        }
73                    }
74                    Role::Assistant => {
75                        let text: String = msg
76                            .content
77                            .iter()
78                            .filter_map(|p| match p {
79                                ContentPart::Text { text } => Some(text.clone()),
80                                _ => None,
81                            })
82                            .collect::<Vec<_>>()
83                            .join("");
84
85                        let tool_calls: Vec<Value> = msg
86                            .content
87                            .iter()
88                            .filter_map(|p| match p {
89                                ContentPart::ToolCall {
90                                    id,
91                                    name,
92                                    arguments,
93                                } => Some(json!({
94                                    "id": id,
95                                    "type": "function",
96                                    "function": {
97                                        "name": name,
98                                        "arguments": arguments
99                                    }
100                                })),
101                                _ => None,
102                            })
103                            .collect();
104
105                        if tool_calls.is_empty() {
106                            json!({ "role": "assistant", "content": text })
107                        } else {
108                            json!({
109                                "role": "assistant",
110                                "content": if text.is_empty() { "".to_string() } else { text },
111                                "tool_calls": tool_calls
112                            })
113                        }
114                    }
115                    _ => {
116                        let text: String = msg
117                            .content
118                            .iter()
119                            .filter_map(|p| match p {
120                                ContentPart::Text { text } => Some(text.clone()),
121                                _ => None,
122                            })
123                            .collect::<Vec<_>>()
124                            .join("\n");
125                        json!({ "role": role, "content": text })
126                    }
127                }
128            })
129            .collect()
130    }
131
132    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133        tools
134            .iter()
135            .map(|t| {
136                json!({
137                    "type": "function",
138                    "function": {
139                        "name": t.name,
140                        "description": t.description,
141                        "parameters": t.parameters
142                    }
143                })
144            })
145            .collect()
146    }
147
148    fn is_agent_initiated(messages: &[Message]) -> bool {
149        messages
150            .iter()
151            .rev()
152            .find(|msg| msg.role != Role::System)
153            .map(|msg| msg.role != Role::User)
154            .unwrap_or(false)
155    }
156
157    fn has_vision_input(messages: &[Message]) -> bool {
158        messages.iter().any(|msg| {
159            msg.content
160                .iter()
161                .any(|part| matches!(part, ContentPart::Image { .. }))
162        })
163    }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168    choices: Vec<CopilotChoice>,
169    #[serde(default)]
170    usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175    message: CopilotMessage,
176    #[serde(default)]
177    finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182    #[serde(default)]
183    content: Option<String>,
184    #[serde(default)]
185    tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190    id: String,
191    #[serde(rename = "type")]
192    #[allow(dead_code)]
193    call_type: String,
194    function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199    name: String,
200    arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205    #[serde(default)]
206    prompt_tokens: usize,
207    #[serde(default)]
208    completion_tokens: usize,
209    #[serde(default)]
210    total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215    error: Option<CopilotErrorDetail>,
216    message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221    message: Option<String>,
222    code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227    data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232    id: String,
233    #[serde(default)]
234    name: Option<String>,
235    #[serde(default)]
236    model_picker_enabled: Option<bool>,
237    #[serde(default)]
238    policy: Option<CopilotModelPolicy>,
239    #[serde(default)]
240    capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245    #[serde(default)]
246    state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251    #[serde(default)]
252    limits: Option<CopilotModelLimits>,
253    #[serde(default)]
254    supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259    #[serde(default)]
260    max_context_window_tokens: Option<usize>,
261    #[serde(default)]
262    max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267    #[serde(default)]
268    tool_calls: Option<bool>,
269    #[serde(default)]
270    vision: Option<bool>,
271    #[serde(default)]
272    streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277    fn name(&self) -> &str {
278        &self.provider_name
279    }
280
281    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282        let response = self
283            .client
284            .get(format!("{}/models", self.base_url))
285            .header("Authorization", format!("Bearer {}", self.token))
286            .header("Openai-Intent", "conversation-edits")
287            .header("User-Agent", Self::user_agent())
288            .send()
289            .await
290            .context("Failed to fetch Copilot models")?;
291
292        let mut models: Vec<ModelInfo> = if response.status().is_success() {
293            let parsed: CopilotModelsResponse = response
294                .json()
295                .await
296                .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298            parsed
299                .data
300                .into_iter()
301                .map(|model| {
302                    let caps = model.capabilities.as_ref();
303                    let limits = caps.and_then(|c| c.limits.as_ref());
304                    let supports = caps.and_then(|c| c.supports.as_ref());
305
306                    ModelInfo {
307                        id: model.id.clone(),
308                        name: model.name.unwrap_or_else(|| model.id.clone()),
309                        provider: self.provider_name.clone(),
310                        context_window: limits
311                            .and_then(|l| l.max_context_window_tokens)
312                            .unwrap_or(128_000),
313                        max_output_tokens: limits
314                            .and_then(|l| l.max_output_tokens)
315                            .or(Some(16_384)),
316                        supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
317                        supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
318                        supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
319                        input_cost_per_million: Some(0.0),
320                        output_cost_per_million: Some(0.0),
321                    }
322                })
323                .collect()
324        } else {
325            Vec::new()
326        };
327
328        // Enrich API-returned models with known metadata (better names, accurate limits).
329        // We do NOT inject models the API doesn't list — those aren't usable yet.
330        // Source: https://docs.github.com/en/copilot/using-github-copilot/ai-models
331        // NOTE: Model IDs use dots for versions, same as the /models API returns
332        // (e.g. claude-opus-4.5, gpt-5.1, gpt-4.1).
333        let known_metadata: std::collections::HashMap<&str, (&str, usize, usize)> = [
334            ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000)),
335            ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000)),
336            ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 200_000, 64_000)),
337            ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000)),
338            ("claude-haiku-4.5", ("Claude Haiku 4.5", 200_000, 64_000)),
339            ("gpt-5.2", ("GPT-5.2", 400_000, 128_000)),
340            ("gpt-5.1", ("GPT-5.1", 400_000, 128_000)),
341            ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000)),
342            ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 264_000, 64_000)),
343            ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 264_000, 64_000)),
344            ("gpt-5", ("GPT-5", 400_000, 128_000)),
345            ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000)),
346            ("gpt-4.1", ("GPT-4.1", 128_000, 32_768)),
347            ("gpt-4o", ("GPT-4o", 128_000, 16_384)),
348            ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000)),
349            ("grok-code-fast-1", ("Grok Code Fast 1", 128_000, 32_768)),
350        ]
351        .into_iter()
352        .collect();
353
354        // Apply known metadata to enrich API models that had sparse info
355        for model in &mut models {
356            if let Some((name, ctx, max_out)) = known_metadata.get(model.id.as_str()) {
357                if model.name == model.id {
358                    model.name = name.to_string();
359                }
360                if model.context_window == 128_000 {
361                    model.context_window = *ctx;
362                }
363                if model.max_output_tokens == Some(16_384) {
364                    model.max_output_tokens = Some(*max_out);
365                }
366            }
367        }
368
369        // Filter out legacy/deprecated models that clutter the picker
370        // (embedding models, old GPT-3.5/4/4o variants without picker flag)
371        models.retain(|m| {
372            !m.id.starts_with("text-embedding")
373                && m.id != "gpt-3.5-turbo"
374                && m.id != "gpt-3.5-turbo-0613"
375                && m.id != "gpt-4-0613"
376                && m.id != "gpt-4o-2024-05-13"
377                && m.id != "gpt-4o-2024-08-06"
378                && m.id != "gpt-4o-2024-11-20"
379                && m.id != "gpt-4o-mini-2024-07-18"
380                && m.id != "gpt-4-o-preview"
381                && m.id != "gpt-4.1-2025-04-14"
382        });
383
384        // Deduplicate by id (API sometimes returns duplicates)
385        let mut seen = std::collections::HashSet::new();
386        models.retain(|m| seen.insert(m.id.clone()));
387
388        Ok(models)
389    }
390
391    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
392        let messages = Self::convert_messages(&request.messages);
393        let tools = Self::convert_tools(&request.tools);
394        let is_agent = Self::is_agent_initiated(&request.messages);
395        let has_vision = Self::has_vision_input(&request.messages);
396
397        let mut body = json!({
398            "model": request.model,
399            "messages": messages,
400        });
401
402        if !tools.is_empty() {
403            body["tools"] = json!(tools);
404        }
405        if let Some(temp) = request.temperature {
406            body["temperature"] = json!(temp);
407        }
408        if let Some(top_p) = request.top_p {
409            body["top_p"] = json!(top_p);
410        }
411        if let Some(max) = request.max_tokens {
412            body["max_tokens"] = json!(max);
413        }
414        if !request.stop.is_empty() {
415            body["stop"] = json!(request.stop);
416        }
417
418        let mut req = self
419            .client
420            .post(format!("{}/chat/completions", self.base_url))
421            .header("Authorization", format!("Bearer {}", self.token))
422            .header("Content-Type", "application/json")
423            .header("Openai-Intent", "conversation-edits")
424            .header("User-Agent", Self::user_agent())
425            .header("X-Initiator", if is_agent { "agent" } else { "user" });
426
427        if has_vision {
428            req = req.header("Copilot-Vision-Request", "true");
429        }
430
431        let response = req
432            .json(&body)
433            .send()
434            .await
435            .context("Failed to send Copilot request")?;
436
437        let status = response.status();
438        let text = response
439            .text()
440            .await
441            .context("Failed to read Copilot response")?;
442
443        if !status.is_success() {
444            if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
445                let message = err
446                    .error
447                    .and_then(|detail| {
448                        detail.message.map(|msg| {
449                            if let Some(code) = detail.code {
450                                format!("{} ({})", msg, code)
451                            } else {
452                                msg
453                            }
454                        })
455                    })
456                    .or(err.message)
457                    .unwrap_or_else(|| "Unknown Copilot API error".to_string());
458                anyhow::bail!("Copilot API error: {}", message);
459            }
460            anyhow::bail!("Copilot API error: {} {}", status, text);
461        }
462
463        let response: CopilotResponse = serde_json::from_str(&text).context(format!(
464            "Failed to parse Copilot response: {}",
465            &text[..text.len().min(200)]
466        ))?;
467
468        let choice = response
469            .choices
470            .first()
471            .ok_or_else(|| anyhow::anyhow!("No choices"))?;
472
473        let mut content = Vec::new();
474        let mut has_tool_calls = false;
475
476        if let Some(text) = &choice.message.content {
477            if !text.is_empty() {
478                content.push(ContentPart::Text { text: text.clone() });
479            }
480        }
481
482        if let Some(tool_calls) = &choice.message.tool_calls {
483            has_tool_calls = !tool_calls.is_empty();
484            for tc in tool_calls {
485                content.push(ContentPart::ToolCall {
486                    id: tc.id.clone(),
487                    name: tc.function.name.clone(),
488                    arguments: tc.function.arguments.clone(),
489                });
490            }
491        }
492
493        let finish_reason = if has_tool_calls {
494            FinishReason::ToolCalls
495        } else {
496            match choice.finish_reason.as_deref() {
497                Some("stop") => FinishReason::Stop,
498                Some("length") => FinishReason::Length,
499                Some("tool_calls") => FinishReason::ToolCalls,
500                Some("content_filter") => FinishReason::ContentFilter,
501                _ => FinishReason::Stop,
502            }
503        };
504
505        Ok(CompletionResponse {
506            message: Message {
507                role: Role::Assistant,
508                content,
509            },
510            usage: Usage {
511                prompt_tokens: response
512                    .usage
513                    .as_ref()
514                    .map(|u| u.prompt_tokens)
515                    .unwrap_or(0),
516                completion_tokens: response
517                    .usage
518                    .as_ref()
519                    .map(|u| u.completion_tokens)
520                    .unwrap_or(0),
521                total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
522                ..Default::default()
523            },
524            finish_reason,
525        })
526    }
527
528    async fn complete_stream(
529        &self,
530        request: CompletionRequest,
531    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
532        // For now, keep behavior aligned with other non-streaming providers.
533        let response = self.complete(request).await?;
534        let text = response
535            .message
536            .content
537            .iter()
538            .filter_map(|p| match p {
539                ContentPart::Text { text } => Some(text.clone()),
540                _ => None,
541            })
542            .collect::<Vec<_>>()
543            .join("");
544
545        Ok(Box::pin(futures::stream::once(async move {
546            StreamChunk::Text(text)
547        })))
548    }
549}
550
551pub fn normalize_enterprise_domain(input: &str) -> String {
552    input
553        .trim()
554        .trim_start_matches("https://")
555        .trim_start_matches("http://")
556        .trim_end_matches('/')
557        .to_string()
558}
559
560pub fn enterprise_base_url(enterprise_url: &str) -> String {
561    format!(
562        "https://copilot-api.{}",
563        normalize_enterprise_domain(enterprise_url)
564    )
565}
566
567#[cfg(test)]
568mod tests {
569    use super::*;
570
571    #[test]
572    fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
573        assert_eq!(
574            normalize_enterprise_domain("https://company.ghe.com/"),
575            "company.ghe.com"
576        );
577        assert_eq!(
578            normalize_enterprise_domain("http://company.ghe.com"),
579            "company.ghe.com"
580        );
581        assert_eq!(
582            normalize_enterprise_domain("company.ghe.com"),
583            "company.ghe.com"
584        );
585    }
586
587    #[test]
588    fn enterprise_base_url_uses_copilot_api_subdomain() {
589        assert_eq!(
590            enterprise_base_url("https://company.ghe.com/"),
591            "https://copilot-api.company.ghe.com"
592        );
593    }
594}