1use super::util;
4use super::{
5 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
6 Role, StreamChunk, ToolDefinition, Usage,
7};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::{Value, json};
13
14const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
15const COPILOT_PROVIDER: &str = "github-copilot";
16const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
17
18pub struct CopilotProvider {
19 client: Client,
20 token: String,
21 base_url: String,
22 provider_name: String,
23}
24
25impl CopilotProvider {
26 pub fn new(token: String) -> Result<Self> {
27 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
28 }
29
30 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
31 let base_url = enterprise_base_url(&enterprise_url);
32 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
33 }
34
35 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
36 Ok(Self {
37 client: crate::provider::shared_http::shared_client().clone(),
38 token,
39 base_url: base_url.trim_end_matches('/').to_string(),
40 provider_name: provider_name.to_string(),
41 })
42 }
43
44 fn user_agent() -> String {
45 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
46 }
47
48 fn convert_messages(messages: &[Message]) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({ "role": role, "content": "" })
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let tool_calls: Vec<Value> = msg
87 .content
88 .iter()
89 .filter_map(|p| match p {
90 ContentPart::ToolCall {
91 id,
92 name,
93 arguments,
94 ..
95 } => Some(json!({
96 "id": id,
97 "type": "function",
98 "function": {
99 "name": name,
100 "arguments": arguments
101 }
102 })),
103 _ => None,
104 })
105 .collect();
106
107 if tool_calls.is_empty() {
108 json!({ "role": "assistant", "content": text })
109 } else {
110 json!({
111 "role": "assistant",
112 "content": if text.is_empty() { "".to_string() } else { text },
113 "tool_calls": tool_calls
114 })
115 }
116 }
117 _ => {
118 let text: String = msg
119 .content
120 .iter()
121 .filter_map(|p| match p {
122 ContentPart::Text { text } => Some(text.clone()),
123 _ => None,
124 })
125 .collect::<Vec<_>>()
126 .join("\n");
127 json!({ "role": role, "content": text })
128 }
129 }
130 })
131 .collect()
132 }
133
134 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
135 tools
136 .iter()
137 .map(|t| {
138 json!({
139 "type": "function",
140 "function": {
141 "name": t.name,
142 "description": t.description,
143 "parameters": t.parameters
144 }
145 })
146 })
147 .collect()
148 }
149
150 fn is_agent_initiated(messages: &[Message]) -> bool {
151 messages
152 .iter()
153 .rev()
154 .find(|msg| msg.role != Role::System)
155 .map(|msg| msg.role != Role::User)
156 .unwrap_or(false)
157 }
158
159 fn has_vision_input(messages: &[Message]) -> bool {
160 messages.iter().any(|msg| {
161 msg.content
162 .iter()
163 .any(|part| matches!(part, ContentPart::Image { .. }))
164 })
165 }
166
167 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
169 let response = match self
170 .client
171 .get(format!("{}/models", self.base_url))
172 .header("Authorization", format!("Bearer {}", self.token))
173 .header("User-Agent", Self::user_agent())
174 .send()
175 .await
176 {
177 Ok(r) => r,
178 Err(e) => {
179 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to fetch Copilot models endpoint");
180 return Vec::new();
181 }
182 };
183
184 let status = response.status();
185 if !status.is_success() {
186 let body = response.text().await.unwrap_or_default();
187 tracing::warn!(
188 provider = %self.provider_name,
189 status = %status,
190 body = %body.chars().take(200).collect::<String>(),
191 "Copilot /models endpoint returned non-success"
192 );
193 return Vec::new();
194 }
195
196 let parsed: CopilotModelsResponse = match crate::provider::body_cap::json_capped(
197 response,
198 crate::provider::body_cap::PROVIDER_METADATA_BODY_CAP,
199 )
200 .await
201 {
202 Ok(p) => p,
203 Err(e) => {
204 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to parse Copilot models response (or exceeded body cap)");
205 return Vec::new();
206 }
207 };
208
209 let models: Vec<ModelInfo> = parsed
210 .data
211 .into_iter()
212 .filter(|model| {
213 if model.model_picker_enabled == Some(false) {
215 return false;
216 }
217 if let Some(ref policy) = model.policy
219 && policy.state.as_deref() == Some("disabled")
220 {
221 return false;
222 }
223 true
224 })
225 .map(|model| {
226 let caps = model.capabilities.as_ref();
227 let limits = caps.and_then(|c| c.limits.as_ref());
228 let supports = caps.and_then(|c| c.supports.as_ref());
229
230 ModelInfo {
231 id: model.id.clone(),
232 name: model.name.unwrap_or_else(|| model.id.clone()),
233 provider: self.provider_name.clone(),
234 context_window: limits
235 .and_then(|l| l.max_context_window_tokens)
236 .unwrap_or(128_000),
237 max_output_tokens: limits.and_then(|l| l.max_output_tokens).or(Some(16_384)),
238 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
239 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
240 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
241 input_cost_per_million: None,
242 output_cost_per_million: None,
243 }
244 })
245 .collect();
246
247 tracing::info!(
248 provider = %self.provider_name,
249 count = models.len(),
250 "Discovered models from Copilot API"
251 );
252 models
253 }
254
255 fn enrich_with_pricing(&self, models: &mut [ModelInfo]) {
262 let pricing: std::collections::HashMap<&str, (&str, f64)> = [
264 ("claude-opus-4.5", ("Claude Opus 4.5", 3.0)),
265 ("claude-opus-4.6", ("Claude Opus 4.6", 3.0)),
266 ("claude-opus-41", ("Claude Opus 4.1", 10.0)),
267 ("claude-sonnet-4-6", ("Claude Sonnet 4.6", 1.0)),
268 ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 1.0)),
269 ("claude-sonnet-4", ("Claude Sonnet 4", 1.0)),
270 ("claude-haiku-4.5", ("Claude Haiku 4.5", 0.33)),
271 ("gpt-5.3-codex", ("GPT-5.3-Codex", 1.0)),
272 ("gpt-5.2", ("GPT-5.2", 1.0)),
273 ("gpt-5.2-codex", ("GPT-5.2-Codex", 1.0)),
274 ("gpt-5.1", ("GPT-5.1", 1.0)),
275 ("gpt-5.1-codex", ("GPT-5.1-Codex", 1.0)),
276 ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 0.33)),
277 ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 1.0)),
278 ("gpt-5", ("GPT-5", 1.0)),
279 ("gpt-5-mini", ("GPT-5 mini", 0.0)),
280 ("gpt-5-codex", ("GPT-5-Codex", 1.0)),
281 ("gpt-4.1", ("GPT-4.1", 0.0)),
282 ("gpt-4o", ("GPT-4o", 0.0)),
283 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1.0)),
284 ("gemini-3.1-pro-preview", ("Gemini 3.1 Pro Preview", 1.0)),
285 (
286 "gemini-3.1-pro-preview-customtools",
287 ("Gemini 3.1 Pro Preview (Custom Tools)", 1.0),
288 ),
289 ("gemini-3-flash-preview", ("Gemini 3 Flash Preview", 0.33)),
290 ("gemini-3-pro-preview", ("Gemini 3 Pro Preview", 1.0)),
291 (
292 "gemini-3-pro-image-preview",
293 ("Gemini 3 Pro Image Preview", 1.0),
294 ),
295 ("grok-code-fast-1", ("Grok Code Fast 1", 0.25)),
296 ]
297 .into_iter()
298 .collect();
299
300 for model in models.iter_mut() {
301 if let Some((display_name, premium_mult)) = pricing.get(model.id.as_str()) {
302 if model.name == model.id {
304 model.name = display_name.to_string();
305 }
306 let approx_cost = premium_mult * 10.0;
307 model.input_cost_per_million = Some(approx_cost);
308 model.output_cost_per_million = Some(approx_cost);
309 } else {
310 if model.input_cost_per_million.is_none() {
312 model.input_cost_per_million = Some(10.0);
313 }
314 if model.output_cost_per_million.is_none() {
315 model.output_cost_per_million = Some(10.0);
316 }
317 }
318 }
319 }
320
321 fn known_models(&self) -> Vec<ModelInfo> {
323 let entries: &[(&str, &str, usize, usize, bool)] = &[
324 ("gpt-4o", "GPT-4o", 128_000, 16_384, true),
325 ("gpt-4.1", "GPT-4.1", 128_000, 32_768, false),
326 ("gpt-5", "GPT-5", 400_000, 128_000, false),
327 ("gpt-5-mini", "GPT-5 mini", 264_000, 64_000, false),
328 ("claude-sonnet-4", "Claude Sonnet 4", 200_000, 64_000, false),
329 (
330 "claude-sonnet-4.5",
331 "Claude Sonnet 4.5",
332 200_000,
333 64_000,
334 false,
335 ),
336 (
337 "claude-sonnet-4-6",
338 "Claude Sonnet 4.6",
339 200_000,
340 128_000,
341 false,
342 ),
343 (
344 "claude-haiku-4.5",
345 "Claude Haiku 4.5",
346 200_000,
347 64_000,
348 false,
349 ),
350 ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_000_000, 64_000, false),
351 (
352 "gemini-3.1-pro-preview",
353 "Gemini 3.1 Pro Preview",
354 1_048_576,
355 65_536,
356 false,
357 ),
358 (
359 "gemini-3.1-pro-preview-customtools",
360 "Gemini 3.1 Pro Preview (Custom Tools)",
361 1_048_576,
362 65_536,
363 false,
364 ),
365 (
366 "gemini-3-pro-preview",
367 "Gemini 3 Pro Preview",
368 1_048_576,
369 65_536,
370 false,
371 ),
372 (
373 "gemini-3-flash-preview",
374 "Gemini 3 Flash Preview",
375 1_048_576,
376 65_536,
377 false,
378 ),
379 (
380 "gemini-3-pro-image-preview",
381 "Gemini 3 Pro Image Preview",
382 65_536,
383 32_768,
384 false,
385 ),
386 ];
387
388 entries
389 .iter()
390 .map(|(id, name, ctx, max_out, vision)| ModelInfo {
391 id: id.to_string(),
392 name: name.to_string(),
393 provider: self.provider_name.clone(),
394 context_window: *ctx,
395 max_output_tokens: Some(*max_out),
396 supports_vision: *vision,
397 supports_tools: true,
398 supports_streaming: true,
399 input_cost_per_million: None,
400 output_cost_per_million: None,
401 })
402 .collect()
403 }
404}
405
406#[derive(Debug, Deserialize)]
407struct CopilotResponse {
408 choices: Vec<CopilotChoice>,
409 #[serde(default)]
410 usage: Option<CopilotUsage>,
411}
412
413#[derive(Debug, Deserialize)]
414struct CopilotChoice {
415 message: CopilotMessage,
416 #[serde(default)]
417 finish_reason: Option<String>,
418}
419
420#[derive(Debug, Deserialize)]
421struct CopilotMessage {
422 #[serde(default)]
423 content: Option<String>,
424 #[serde(default)]
425 tool_calls: Option<Vec<CopilotToolCall>>,
426}
427
428#[derive(Debug, Deserialize)]
429struct CopilotToolCall {
430 id: String,
431 #[serde(rename = "type")]
432 #[allow(dead_code)]
433 call_type: String,
434 function: CopilotFunction,
435}
436
437#[derive(Debug, Deserialize)]
438struct CopilotFunction {
439 name: String,
440 arguments: String,
441}
442
443#[derive(Debug, Deserialize)]
444struct CopilotUsage {
445 #[serde(default)]
446 prompt_tokens: usize,
447 #[serde(default)]
448 completion_tokens: usize,
449 #[serde(default)]
450 total_tokens: usize,
451}
452
453#[derive(Debug, Deserialize)]
454struct CopilotErrorResponse {
455 error: Option<CopilotErrorDetail>,
456 message: Option<String>,
457}
458
459#[derive(Debug, Deserialize)]
460struct CopilotErrorDetail {
461 message: Option<String>,
462 code: Option<String>,
463}
464
465#[derive(Debug, Deserialize)]
466struct CopilotModelsResponse {
467 data: Vec<CopilotModelInfo>,
468}
469
470#[derive(Debug, Deserialize)]
471struct CopilotModelInfo {
472 id: String,
473 #[serde(default)]
474 name: Option<String>,
475 #[serde(default)]
476 model_picker_enabled: Option<bool>,
477 #[serde(default)]
478 policy: Option<CopilotModelPolicy>,
479 #[serde(default)]
480 capabilities: Option<CopilotModelCapabilities>,
481}
482
483#[derive(Debug, Deserialize)]
484struct CopilotModelPolicy {
485 #[serde(default)]
486 state: Option<String>,
487}
488
489#[derive(Debug, Deserialize)]
490struct CopilotModelCapabilities {
491 #[serde(default)]
492 limits: Option<CopilotModelLimits>,
493 #[serde(default)]
494 supports: Option<CopilotModelSupports>,
495}
496
497#[derive(Debug, Deserialize)]
498struct CopilotModelLimits {
499 #[serde(default)]
500 max_context_window_tokens: Option<usize>,
501 #[serde(default)]
502 max_output_tokens: Option<usize>,
503}
504
505#[derive(Debug, Deserialize)]
506struct CopilotModelSupports {
507 #[serde(default)]
508 tool_calls: Option<bool>,
509 #[serde(default)]
510 vision: Option<bool>,
511 #[serde(default)]
512 streaming: Option<bool>,
513}
514
515#[async_trait]
516impl Provider for CopilotProvider {
517 fn name(&self) -> &str {
518 &self.provider_name
519 }
520
521 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
522 let mut models = self.discover_models_from_api().await;
523
524 if models.is_empty() {
526 tracing::info!(provider = %self.provider_name, "No models from API, using known model catalog");
527 models = self.known_models();
528 }
529
530 self.enrich_with_pricing(&mut models);
532
533 models.retain(|m| {
535 !m.id.starts_with("text-embedding")
536 && !m.id.contains("-embedding-")
537 && !is_dated_model_variant(&m.id)
538 });
539
540 let mut seen = std::collections::HashSet::new();
542 models.retain(|m| seen.insert(m.id.clone()));
543
544 Ok(models)
545 }
546
547 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
548 let messages = Self::convert_messages(&request.messages);
549 let tools = Self::convert_tools(&request.tools);
550 let is_agent = Self::is_agent_initiated(&request.messages);
551 let has_vision = Self::has_vision_input(&request.messages);
552
553 let mut body = json!({
554 "model": request.model,
555 "messages": messages,
556 });
557
558 if !tools.is_empty() {
559 body["tools"] = json!(tools);
560 }
561 if let Some(temp) = request.temperature {
562 body["temperature"] = json!(temp);
563 }
564 if let Some(top_p) = request.top_p {
565 body["top_p"] = json!(top_p);
566 }
567 if let Some(max) = request.max_tokens {
568 body["max_tokens"] = json!(max);
569 }
570 if !request.stop.is_empty() {
571 body["stop"] = json!(request.stop);
572 }
573
574 let mut req = self
575 .client
576 .post(format!("{}/chat/completions", self.base_url))
577 .header("Authorization", format!("Bearer {}", self.token))
578 .header("Content-Type", "application/json")
579 .header("Openai-Intent", "conversation-edits")
580 .header("User-Agent", Self::user_agent())
581 .header("X-Initiator", if is_agent { "agent" } else { "user" });
582
583 if has_vision {
584 req = req.header("Copilot-Vision-Request", "true");
585 }
586
587 let response = req
588 .json(&body)
589 .send()
590 .await
591 .context("Failed to send Copilot request")?;
592
593 let status = response.status();
594 let text = response
595 .text()
596 .await
597 .context("Failed to read Copilot response")?;
598
599 if !status.is_success() {
600 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
601 let message = err
602 .error
603 .and_then(|detail| {
604 detail.message.map(|msg| {
605 if let Some(code) = detail.code {
606 format!("{} ({})", msg, code)
607 } else {
608 msg
609 }
610 })
611 })
612 .or(err.message)
613 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
614 anyhow::bail!("Copilot API error: {}", message);
615 }
616 anyhow::bail!("Copilot API error: {} {}", status, text);
617 }
618
619 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
620 "Failed to parse Copilot response: {}",
621 util::truncate_bytes_safe(&text, 200)
622 ))?;
623
624 let choice = response
625 .choices
626 .first()
627 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
628
629 let mut content = Vec::new();
630 let mut has_tool_calls = false;
631
632 if let Some(text) = &choice.message.content
633 && !text.is_empty()
634 {
635 content.push(ContentPart::Text { text: text.clone() });
636 }
637
638 if let Some(tool_calls) = &choice.message.tool_calls {
639 has_tool_calls = !tool_calls.is_empty();
640 for tc in tool_calls {
641 content.push(ContentPart::ToolCall {
642 id: tc.id.clone(),
643 name: tc.function.name.clone(),
644 arguments: tc.function.arguments.clone(),
645 thought_signature: None,
646 });
647 }
648 }
649
650 let finish_reason = if has_tool_calls {
651 FinishReason::ToolCalls
652 } else {
653 match choice.finish_reason.as_deref() {
654 Some("stop") => FinishReason::Stop,
655 Some("length") => FinishReason::Length,
656 Some("tool_calls") => FinishReason::ToolCalls,
657 Some("content_filter") => FinishReason::ContentFilter,
658 _ => FinishReason::Stop,
659 }
660 };
661
662 Ok(CompletionResponse {
663 message: Message {
664 role: Role::Assistant,
665 content,
666 },
667 usage: Usage {
668 prompt_tokens: response
669 .usage
670 .as_ref()
671 .map(|u| u.prompt_tokens)
672 .unwrap_or(0),
673 completion_tokens: response
674 .usage
675 .as_ref()
676 .map(|u| u.completion_tokens)
677 .unwrap_or(0),
678 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
679 ..Default::default()
680 },
681 finish_reason,
682 })
683 }
684
685 async fn complete_stream(
686 &self,
687 request: CompletionRequest,
688 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
689 let response = self.complete(request).await?;
691 let text = response
692 .message
693 .content
694 .iter()
695 .filter_map(|p| match p {
696 ContentPart::Text { text } => Some(text.clone()),
697 _ => None,
698 })
699 .collect::<Vec<_>>()
700 .join("");
701
702 Ok(Box::pin(futures::stream::once(async move {
703 StreamChunk::Text(text)
704 })))
705 }
706}
707
708fn is_dated_model_variant(id: &str) -> bool {
711 let bytes = id.as_bytes();
713 if bytes.len() < 11 {
714 return false;
715 }
716 let tail = &id[id.len() - 11..];
718 tail.starts_with('-')
719 && tail[1..5].bytes().all(|b| b.is_ascii_digit())
720 && tail.as_bytes()[5] == b'-'
721 && tail[6..8].bytes().all(|b| b.is_ascii_digit())
722 && tail.as_bytes()[8] == b'-'
723 && tail[9..11].bytes().all(|b| b.is_ascii_digit())
724}
725
726pub fn normalize_enterprise_domain(input: &str) -> String {
727 input
728 .trim()
729 .trim_start_matches("https://")
730 .trim_start_matches("http://")
731 .trim_end_matches('/')
732 .to_string()
733}
734
735pub fn enterprise_base_url(enterprise_url: &str) -> String {
736 format!(
737 "https://copilot-api.{}",
738 normalize_enterprise_domain(enterprise_url)
739 )
740}
741
742#[cfg(test)]
743mod tests {
744 use super::*;
745
746 #[test]
747 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
748 assert_eq!(
749 normalize_enterprise_domain("https://company.ghe.com/"),
750 "company.ghe.com"
751 );
752 assert_eq!(
753 normalize_enterprise_domain("http://company.ghe.com"),
754 "company.ghe.com"
755 );
756 assert_eq!(
757 normalize_enterprise_domain("company.ghe.com"),
758 "company.ghe.com"
759 );
760 }
761
762 #[test]
763 fn enterprise_base_url_uses_copilot_api_subdomain() {
764 assert_eq!(
765 enterprise_base_url("https://company.ghe.com/"),
766 "https://copilot-api.company.ghe.com"
767 );
768 }
769
770 #[test]
771 fn is_dated_model_variant_detects_date_suffix() {
772 assert!(is_dated_model_variant("gpt-4o-2024-05-13"));
773 assert!(is_dated_model_variant("gpt-4o-2024-08-06"));
774 assert!(is_dated_model_variant("gpt-4.1-2025-04-14"));
775 assert!(is_dated_model_variant("gpt-4o-mini-2024-07-18"));
776 assert!(!is_dated_model_variant("gpt-4o"));
777 assert!(!is_dated_model_variant("gpt-5"));
778 assert!(!is_dated_model_variant("claude-sonnet-4"));
779 assert!(!is_dated_model_variant("gemini-2.5-pro"));
780 }
781
782 #[test]
783 fn known_models_fallback_is_non_empty() {
784 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
785 let models = provider.known_models();
786 assert!(!models.is_empty());
787 assert!(models.iter().all(|m| m.supports_tools));
789 }
790
791 #[test]
792 fn enrich_with_pricing_sets_costs() {
793 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
794 let mut models = vec![ModelInfo {
795 id: "gpt-4o".to_string(),
796 name: "gpt-4o".to_string(),
797 provider: "github-copilot".to_string(),
798 context_window: 128_000,
799 max_output_tokens: Some(16_384),
800 supports_vision: true,
801 supports_tools: true,
802 supports_streaming: true,
803 input_cost_per_million: None,
804 output_cost_per_million: None,
805 }];
806 provider.enrich_with_pricing(&mut models);
807 assert_eq!(models[0].input_cost_per_million, Some(0.0));
809 assert_eq!(models[0].output_cost_per_million, Some(0.0));
810 assert_eq!(models[0].name, "GPT-4o");
812 }
813}