1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18 client: Client,
19 token: String,
20 base_url: String,
21 provider_name: String,
22}
23
24impl CopilotProvider {
25 pub fn new(token: String) -> Result<Self> {
26 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27 }
28
29 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30 let base_url = enterprise_base_url(&enterprise_url);
31 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32 }
33
34 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35 Ok(Self {
36 client: Client::new(),
37 token,
38 base_url: base_url.trim_end_matches('/').to_string(),
39 provider_name: provider_name.to_string(),
40 })
41 }
42
43 fn user_agent() -> String {
44 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45 }
46
47 fn convert_messages(messages: &[Message]) -> Vec<Value> {
48 messages
49 .iter()
50 .map(|msg| {
51 let role = match msg.role {
52 Role::System => "system",
53 Role::User => "user",
54 Role::Assistant => "assistant",
55 Role::Tool => "tool",
56 };
57
58 match msg.role {
59 Role::Tool => {
60 if let Some(ContentPart::ToolResult {
61 tool_call_id,
62 content,
63 }) = msg.content.first()
64 {
65 json!({
66 "role": "tool",
67 "tool_call_id": tool_call_id,
68 "content": content
69 })
70 } else {
71 json!({ "role": role, "content": "" })
72 }
73 }
74 Role::Assistant => {
75 let text: String = msg
76 .content
77 .iter()
78 .filter_map(|p| match p {
79 ContentPart::Text { text } => Some(text.clone()),
80 _ => None,
81 })
82 .collect::<Vec<_>>()
83 .join("");
84
85 let tool_calls: Vec<Value> = msg
86 .content
87 .iter()
88 .filter_map(|p| match p {
89 ContentPart::ToolCall {
90 id,
91 name,
92 arguments,
93 ..
94 } => Some(json!({
95 "id": id,
96 "type": "function",
97 "function": {
98 "name": name,
99 "arguments": arguments
100 }
101 })),
102 _ => None,
103 })
104 .collect();
105
106 if tool_calls.is_empty() {
107 json!({ "role": "assistant", "content": text })
108 } else {
109 json!({
110 "role": "assistant",
111 "content": if text.is_empty() { "".to_string() } else { text },
112 "tool_calls": tool_calls
113 })
114 }
115 }
116 _ => {
117 let text: String = msg
118 .content
119 .iter()
120 .filter_map(|p| match p {
121 ContentPart::Text { text } => Some(text.clone()),
122 _ => None,
123 })
124 .collect::<Vec<_>>()
125 .join("\n");
126 json!({ "role": role, "content": text })
127 }
128 }
129 })
130 .collect()
131 }
132
133 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
134 tools
135 .iter()
136 .map(|t| {
137 json!({
138 "type": "function",
139 "function": {
140 "name": t.name,
141 "description": t.description,
142 "parameters": t.parameters
143 }
144 })
145 })
146 .collect()
147 }
148
149 fn is_agent_initiated(messages: &[Message]) -> bool {
150 messages
151 .iter()
152 .rev()
153 .find(|msg| msg.role != Role::System)
154 .map(|msg| msg.role != Role::User)
155 .unwrap_or(false)
156 }
157
158 fn has_vision_input(messages: &[Message]) -> bool {
159 messages.iter().any(|msg| {
160 msg.content
161 .iter()
162 .any(|part| matches!(part, ContentPart::Image { .. }))
163 })
164 }
165
166 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
168 let response = match self
169 .client
170 .get(format!("{}/models", self.base_url))
171 .header("Authorization", format!("Bearer {}", self.token))
172 .header("User-Agent", Self::user_agent())
173 .send()
174 .await
175 {
176 Ok(r) => r,
177 Err(e) => {
178 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to fetch Copilot models endpoint");
179 return Vec::new();
180 }
181 };
182
183 let status = response.status();
184 if !status.is_success() {
185 let body = response.text().await.unwrap_or_default();
186 tracing::warn!(
187 provider = %self.provider_name,
188 status = %status,
189 body = %body.chars().take(200).collect::<String>(),
190 "Copilot /models endpoint returned non-success"
191 );
192 return Vec::new();
193 }
194
195 let parsed: CopilotModelsResponse = match response.json().await {
196 Ok(p) => p,
197 Err(e) => {
198 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to parse Copilot models response");
199 return Vec::new();
200 }
201 };
202
203 let models: Vec<ModelInfo> = parsed
204 .data
205 .into_iter()
206 .filter(|model| {
207 if model.model_picker_enabled == Some(false) {
209 return false;
210 }
211 if let Some(ref policy) = model.policy {
213 if policy.state.as_deref() == Some("disabled") {
214 return false;
215 }
216 }
217 true
218 })
219 .map(|model| {
220 let caps = model.capabilities.as_ref();
221 let limits = caps.and_then(|c| c.limits.as_ref());
222 let supports = caps.and_then(|c| c.supports.as_ref());
223
224 ModelInfo {
225 id: model.id.clone(),
226 name: model.name.unwrap_or_else(|| model.id.clone()),
227 provider: self.provider_name.clone(),
228 context_window: limits
229 .and_then(|l| l.max_context_window_tokens)
230 .unwrap_or(128_000),
231 max_output_tokens: limits.and_then(|l| l.max_output_tokens).or(Some(16_384)),
232 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
233 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
234 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
235 input_cost_per_million: None,
236 output_cost_per_million: None,
237 }
238 })
239 .collect();
240
241 tracing::info!(
242 provider = %self.provider_name,
243 count = models.len(),
244 "Discovered models from Copilot API"
245 );
246 models
247 }
248
249 fn enrich_with_pricing(&self, models: &mut [ModelInfo]) {
256 let pricing: std::collections::HashMap<&str, (&str, f64)> = [
258 ("claude-opus-4.5", ("Claude Opus 4.5", 3.0)),
259 ("claude-opus-4.6", ("Claude Opus 4.6", 3.0)),
260 ("claude-opus-41", ("Claude Opus 4.1", 10.0)),
261 ("claude-sonnet-4-6", ("Claude Sonnet 4.6", 1.0)),
262 ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 1.0)),
263 ("claude-sonnet-4", ("Claude Sonnet 4", 1.0)),
264 ("claude-haiku-4.5", ("Claude Haiku 4.5", 0.33)),
265 ("gpt-5.3-codex", ("GPT-5.3-Codex", 1.0)),
266 ("gpt-5.2", ("GPT-5.2", 1.0)),
267 ("gpt-5.2-codex", ("GPT-5.2-Codex", 1.0)),
268 ("gpt-5.1", ("GPT-5.1", 1.0)),
269 ("gpt-5.1-codex", ("GPT-5.1-Codex", 1.0)),
270 ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 0.33)),
271 ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 1.0)),
272 ("gpt-5", ("GPT-5", 1.0)),
273 ("gpt-5-mini", ("GPT-5 mini", 0.0)),
274 ("gpt-5-codex", ("GPT-5-Codex", 1.0)),
275 ("gpt-4.1", ("GPT-4.1", 0.0)),
276 ("gpt-4o", ("GPT-4o", 0.0)),
277 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1.0)),
278 ("gemini-3.1-pro-preview", ("Gemini 3.1 Pro Preview", 1.0)),
279 (
280 "gemini-3.1-pro-preview-customtools",
281 ("Gemini 3.1 Pro Preview (Custom Tools)", 1.0),
282 ),
283 ("gemini-3-flash-preview", ("Gemini 3 Flash Preview", 0.33)),
284 ("gemini-3-pro-preview", ("Gemini 3 Pro Preview", 1.0)),
285 (
286 "gemini-3-pro-image-preview",
287 ("Gemini 3 Pro Image Preview", 1.0),
288 ),
289 ("grok-code-fast-1", ("Grok Code Fast 1", 0.25)),
290 ]
291 .into_iter()
292 .collect();
293
294 for model in models.iter_mut() {
295 if let Some((display_name, premium_mult)) = pricing.get(model.id.as_str()) {
296 if model.name == model.id {
298 model.name = display_name.to_string();
299 }
300 let approx_cost = premium_mult * 10.0;
301 model.input_cost_per_million = Some(approx_cost);
302 model.output_cost_per_million = Some(approx_cost);
303 } else {
304 if model.input_cost_per_million.is_none() {
306 model.input_cost_per_million = Some(10.0);
307 }
308 if model.output_cost_per_million.is_none() {
309 model.output_cost_per_million = Some(10.0);
310 }
311 }
312 }
313 }
314
315 fn known_models(&self) -> Vec<ModelInfo> {
317 let entries: &[(&str, &str, usize, usize, bool)] = &[
318 ("gpt-4o", "GPT-4o", 128_000, 16_384, true),
319 ("gpt-4.1", "GPT-4.1", 128_000, 32_768, false),
320 ("gpt-5", "GPT-5", 400_000, 128_000, false),
321 ("gpt-5-mini", "GPT-5 mini", 264_000, 64_000, false),
322 ("claude-sonnet-4", "Claude Sonnet 4", 200_000, 64_000, false),
323 (
324 "claude-sonnet-4.5",
325 "Claude Sonnet 4.5",
326 200_000,
327 64_000,
328 false,
329 ),
330 (
331 "claude-sonnet-4-6",
332 "Claude Sonnet 4.6",
333 200_000,
334 128_000,
335 false,
336 ),
337 (
338 "claude-haiku-4.5",
339 "Claude Haiku 4.5",
340 200_000,
341 64_000,
342 false,
343 ),
344 ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_000_000, 64_000, false),
345 (
346 "gemini-3.1-pro-preview",
347 "Gemini 3.1 Pro Preview",
348 1_048_576,
349 65_536,
350 false,
351 ),
352 (
353 "gemini-3.1-pro-preview-customtools",
354 "Gemini 3.1 Pro Preview (Custom Tools)",
355 1_048_576,
356 65_536,
357 false,
358 ),
359 (
360 "gemini-3-pro-preview",
361 "Gemini 3 Pro Preview",
362 1_048_576,
363 65_536,
364 false,
365 ),
366 (
367 "gemini-3-flash-preview",
368 "Gemini 3 Flash Preview",
369 1_048_576,
370 65_536,
371 false,
372 ),
373 (
374 "gemini-3-pro-image-preview",
375 "Gemini 3 Pro Image Preview",
376 65_536,
377 32_768,
378 false,
379 ),
380 ];
381
382 entries
383 .iter()
384 .map(|(id, name, ctx, max_out, vision)| ModelInfo {
385 id: id.to_string(),
386 name: name.to_string(),
387 provider: self.provider_name.clone(),
388 context_window: *ctx,
389 max_output_tokens: Some(*max_out),
390 supports_vision: *vision,
391 supports_tools: true,
392 supports_streaming: true,
393 input_cost_per_million: None,
394 output_cost_per_million: None,
395 })
396 .collect()
397 }
398}
399
400#[derive(Debug, Deserialize)]
401struct CopilotResponse {
402 choices: Vec<CopilotChoice>,
403 #[serde(default)]
404 usage: Option<CopilotUsage>,
405}
406
407#[derive(Debug, Deserialize)]
408struct CopilotChoice {
409 message: CopilotMessage,
410 #[serde(default)]
411 finish_reason: Option<String>,
412}
413
414#[derive(Debug, Deserialize)]
415struct CopilotMessage {
416 #[serde(default)]
417 content: Option<String>,
418 #[serde(default)]
419 tool_calls: Option<Vec<CopilotToolCall>>,
420}
421
422#[derive(Debug, Deserialize)]
423struct CopilotToolCall {
424 id: String,
425 #[serde(rename = "type")]
426 #[allow(dead_code)]
427 call_type: String,
428 function: CopilotFunction,
429}
430
431#[derive(Debug, Deserialize)]
432struct CopilotFunction {
433 name: String,
434 arguments: String,
435}
436
437#[derive(Debug, Deserialize)]
438struct CopilotUsage {
439 #[serde(default)]
440 prompt_tokens: usize,
441 #[serde(default)]
442 completion_tokens: usize,
443 #[serde(default)]
444 total_tokens: usize,
445}
446
447#[derive(Debug, Deserialize)]
448struct CopilotErrorResponse {
449 error: Option<CopilotErrorDetail>,
450 message: Option<String>,
451}
452
453#[derive(Debug, Deserialize)]
454struct CopilotErrorDetail {
455 message: Option<String>,
456 code: Option<String>,
457}
458
459#[derive(Debug, Deserialize)]
460struct CopilotModelsResponse {
461 data: Vec<CopilotModelInfo>,
462}
463
464#[derive(Debug, Deserialize)]
465struct CopilotModelInfo {
466 id: String,
467 #[serde(default)]
468 name: Option<String>,
469 #[serde(default)]
470 model_picker_enabled: Option<bool>,
471 #[serde(default)]
472 policy: Option<CopilotModelPolicy>,
473 #[serde(default)]
474 capabilities: Option<CopilotModelCapabilities>,
475}
476
477#[derive(Debug, Deserialize)]
478struct CopilotModelPolicy {
479 #[serde(default)]
480 state: Option<String>,
481}
482
483#[derive(Debug, Deserialize)]
484struct CopilotModelCapabilities {
485 #[serde(default)]
486 limits: Option<CopilotModelLimits>,
487 #[serde(default)]
488 supports: Option<CopilotModelSupports>,
489}
490
491#[derive(Debug, Deserialize)]
492struct CopilotModelLimits {
493 #[serde(default)]
494 max_context_window_tokens: Option<usize>,
495 #[serde(default)]
496 max_output_tokens: Option<usize>,
497}
498
499#[derive(Debug, Deserialize)]
500struct CopilotModelSupports {
501 #[serde(default)]
502 tool_calls: Option<bool>,
503 #[serde(default)]
504 vision: Option<bool>,
505 #[serde(default)]
506 streaming: Option<bool>,
507}
508
509#[async_trait]
510impl Provider for CopilotProvider {
511 fn name(&self) -> &str {
512 &self.provider_name
513 }
514
515 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
516 let mut models = self.discover_models_from_api().await;
517
518 if models.is_empty() {
520 tracing::info!(provider = %self.provider_name, "No models from API, using known model catalog");
521 models = self.known_models();
522 }
523
524 self.enrich_with_pricing(&mut models);
526
527 models.retain(|m| {
529 !m.id.starts_with("text-embedding")
530 && !m.id.contains("-embedding-")
531 && !is_dated_model_variant(&m.id)
532 });
533
534 let mut seen = std::collections::HashSet::new();
536 models.retain(|m| seen.insert(m.id.clone()));
537
538 Ok(models)
539 }
540
541 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
542 let messages = Self::convert_messages(&request.messages);
543 let tools = Self::convert_tools(&request.tools);
544 let is_agent = Self::is_agent_initiated(&request.messages);
545 let has_vision = Self::has_vision_input(&request.messages);
546
547 let mut body = json!({
548 "model": request.model,
549 "messages": messages,
550 });
551
552 if !tools.is_empty() {
553 body["tools"] = json!(tools);
554 }
555 if let Some(temp) = request.temperature {
556 body["temperature"] = json!(temp);
557 }
558 if let Some(top_p) = request.top_p {
559 body["top_p"] = json!(top_p);
560 }
561 if let Some(max) = request.max_tokens {
562 body["max_tokens"] = json!(max);
563 }
564 if !request.stop.is_empty() {
565 body["stop"] = json!(request.stop);
566 }
567
568 let mut req = self
569 .client
570 .post(format!("{}/chat/completions", self.base_url))
571 .header("Authorization", format!("Bearer {}", self.token))
572 .header("Content-Type", "application/json")
573 .header("Openai-Intent", "conversation-edits")
574 .header("User-Agent", Self::user_agent())
575 .header("X-Initiator", if is_agent { "agent" } else { "user" });
576
577 if has_vision {
578 req = req.header("Copilot-Vision-Request", "true");
579 }
580
581 let response = req
582 .json(&body)
583 .send()
584 .await
585 .context("Failed to send Copilot request")?;
586
587 let status = response.status();
588 let text = response
589 .text()
590 .await
591 .context("Failed to read Copilot response")?;
592
593 if !status.is_success() {
594 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
595 let message = err
596 .error
597 .and_then(|detail| {
598 detail.message.map(|msg| {
599 if let Some(code) = detail.code {
600 format!("{} ({})", msg, code)
601 } else {
602 msg
603 }
604 })
605 })
606 .or(err.message)
607 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
608 anyhow::bail!("Copilot API error: {}", message);
609 }
610 anyhow::bail!("Copilot API error: {} {}", status, text);
611 }
612
613 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
614 "Failed to parse Copilot response: {}",
615 &text[..text.len().min(200)]
616 ))?;
617
618 let choice = response
619 .choices
620 .first()
621 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
622
623 let mut content = Vec::new();
624 let mut has_tool_calls = false;
625
626 if let Some(text) = &choice.message.content {
627 if !text.is_empty() {
628 content.push(ContentPart::Text { text: text.clone() });
629 }
630 }
631
632 if let Some(tool_calls) = &choice.message.tool_calls {
633 has_tool_calls = !tool_calls.is_empty();
634 for tc in tool_calls {
635 content.push(ContentPart::ToolCall {
636 id: tc.id.clone(),
637 name: tc.function.name.clone(),
638 arguments: tc.function.arguments.clone(),
639 thought_signature: None,
640 });
641 }
642 }
643
644 let finish_reason = if has_tool_calls {
645 FinishReason::ToolCalls
646 } else {
647 match choice.finish_reason.as_deref() {
648 Some("stop") => FinishReason::Stop,
649 Some("length") => FinishReason::Length,
650 Some("tool_calls") => FinishReason::ToolCalls,
651 Some("content_filter") => FinishReason::ContentFilter,
652 _ => FinishReason::Stop,
653 }
654 };
655
656 Ok(CompletionResponse {
657 message: Message {
658 role: Role::Assistant,
659 content,
660 },
661 usage: Usage {
662 prompt_tokens: response
663 .usage
664 .as_ref()
665 .map(|u| u.prompt_tokens)
666 .unwrap_or(0),
667 completion_tokens: response
668 .usage
669 .as_ref()
670 .map(|u| u.completion_tokens)
671 .unwrap_or(0),
672 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
673 ..Default::default()
674 },
675 finish_reason,
676 })
677 }
678
679 async fn complete_stream(
680 &self,
681 request: CompletionRequest,
682 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
683 let response = self.complete(request).await?;
685 let text = response
686 .message
687 .content
688 .iter()
689 .filter_map(|p| match p {
690 ContentPart::Text { text } => Some(text.clone()),
691 _ => None,
692 })
693 .collect::<Vec<_>>()
694 .join("");
695
696 Ok(Box::pin(futures::stream::once(async move {
697 StreamChunk::Text(text)
698 })))
699 }
700}
701
702fn is_dated_model_variant(id: &str) -> bool {
705 let bytes = id.as_bytes();
707 if bytes.len() < 11 {
708 return false;
709 }
710 let tail = &id[id.len() - 11..];
712 tail.starts_with('-')
713 && tail[1..5].bytes().all(|b| b.is_ascii_digit())
714 && tail.as_bytes()[5] == b'-'
715 && tail[6..8].bytes().all(|b| b.is_ascii_digit())
716 && tail.as_bytes()[8] == b'-'
717 && tail[9..11].bytes().all(|b| b.is_ascii_digit())
718}
719
720pub fn normalize_enterprise_domain(input: &str) -> String {
721 input
722 .trim()
723 .trim_start_matches("https://")
724 .trim_start_matches("http://")
725 .trim_end_matches('/')
726 .to_string()
727}
728
729pub fn enterprise_base_url(enterprise_url: &str) -> String {
730 format!(
731 "https://copilot-api.{}",
732 normalize_enterprise_domain(enterprise_url)
733 )
734}
735
736#[cfg(test)]
737mod tests {
738 use super::*;
739
740 #[test]
741 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
742 assert_eq!(
743 normalize_enterprise_domain("https://company.ghe.com/"),
744 "company.ghe.com"
745 );
746 assert_eq!(
747 normalize_enterprise_domain("http://company.ghe.com"),
748 "company.ghe.com"
749 );
750 assert_eq!(
751 normalize_enterprise_domain("company.ghe.com"),
752 "company.ghe.com"
753 );
754 }
755
756 #[test]
757 fn enterprise_base_url_uses_copilot_api_subdomain() {
758 assert_eq!(
759 enterprise_base_url("https://company.ghe.com/"),
760 "https://copilot-api.company.ghe.com"
761 );
762 }
763
764 #[test]
765 fn is_dated_model_variant_detects_date_suffix() {
766 assert!(is_dated_model_variant("gpt-4o-2024-05-13"));
767 assert!(is_dated_model_variant("gpt-4o-2024-08-06"));
768 assert!(is_dated_model_variant("gpt-4.1-2025-04-14"));
769 assert!(is_dated_model_variant("gpt-4o-mini-2024-07-18"));
770 assert!(!is_dated_model_variant("gpt-4o"));
771 assert!(!is_dated_model_variant("gpt-5"));
772 assert!(!is_dated_model_variant("claude-sonnet-4"));
773 assert!(!is_dated_model_variant("gemini-2.5-pro"));
774 }
775
776 #[test]
777 fn known_models_fallback_is_non_empty() {
778 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
779 let models = provider.known_models();
780 assert!(!models.is_empty());
781 assert!(models.iter().all(|m| m.supports_tools));
783 }
784
785 #[test]
786 fn enrich_with_pricing_sets_costs() {
787 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
788 let mut models = vec![ModelInfo {
789 id: "gpt-4o".to_string(),
790 name: "gpt-4o".to_string(),
791 provider: "github-copilot".to_string(),
792 context_window: 128_000,
793 max_output_tokens: Some(16_384),
794 supports_vision: true,
795 supports_tools: true,
796 supports_streaming: true,
797 input_cost_per_million: None,
798 output_cost_per_million: None,
799 }];
800 provider.enrich_with_pricing(&mut models);
801 assert_eq!(models[0].input_cost_per_million, Some(0.0));
803 assert_eq!(models[0].output_cost_per_million, Some(0.0));
804 assert_eq!(models[0].name, "GPT-4o");
806 }
807}