1use super::util;
4use super::{
5 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
6 Role, StreamChunk, ToolDefinition, Usage,
7};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::{Value, json};
13
14const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
15const COPILOT_PROVIDER: &str = "github-copilot";
16const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
17
18pub struct CopilotProvider {
19 client: Client,
20 token: String,
21 base_url: String,
22 provider_name: String,
23}
24
25impl CopilotProvider {
26 pub fn new(token: String) -> Result<Self> {
27 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
28 }
29
30 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
31 let base_url = enterprise_base_url(&enterprise_url);
32 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
33 }
34
35 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
36 Ok(Self {
37 client: crate::provider::shared_http::shared_client().clone(),
38 token,
39 base_url: base_url.trim_end_matches('/').to_string(),
40 provider_name: provider_name.to_string(),
41 })
42 }
43
44 fn user_agent() -> String {
45 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
46 }
47
48 fn convert_messages(messages: &[Message]) -> Vec<Value> {
49 messages
50 .iter()
51 .map(|msg| {
52 let role = match msg.role {
53 Role::System => "system",
54 Role::User => "user",
55 Role::Assistant => "assistant",
56 Role::Tool => "tool",
57 };
58
59 match msg.role {
60 Role::Tool => {
61 if let Some(ContentPart::ToolResult {
62 tool_call_id,
63 content,
64 }) = msg.content.first()
65 {
66 json!({
67 "role": "tool",
68 "tool_call_id": tool_call_id,
69 "content": content
70 })
71 } else {
72 json!({ "role": role, "content": "" })
73 }
74 }
75 Role::Assistant => {
76 let text: String = msg
77 .content
78 .iter()
79 .filter_map(|p| match p {
80 ContentPart::Text { text } => Some(text.clone()),
81 _ => None,
82 })
83 .collect::<Vec<_>>()
84 .join("");
85
86 let tool_calls: Vec<Value> = msg
87 .content
88 .iter()
89 .filter_map(|p| match p {
90 ContentPart::ToolCall {
91 id,
92 name,
93 arguments,
94 ..
95 } => Some(json!({
96 "id": id,
97 "type": "function",
98 "function": {
99 "name": name,
100 "arguments": arguments
101 }
102 })),
103 _ => None,
104 })
105 .collect();
106
107 if tool_calls.is_empty() {
108 json!({ "role": "assistant", "content": text })
109 } else {
110 json!({
111 "role": "assistant",
112 "content": if text.is_empty() { "".to_string() } else { text },
113 "tool_calls": tool_calls
114 })
115 }
116 }
117 _ => {
118 let text: String = msg
119 .content
120 .iter()
121 .filter_map(|p| match p {
122 ContentPart::Text { text } => Some(text.clone()),
123 _ => None,
124 })
125 .collect::<Vec<_>>()
126 .join("\n");
127 json!({ "role": role, "content": text })
128 }
129 }
130 })
131 .collect()
132 }
133
134 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
135 tools
136 .iter()
137 .map(|t| {
138 json!({
139 "type": "function",
140 "function": {
141 "name": t.name,
142 "description": t.description,
143 "parameters": t.parameters
144 }
145 })
146 })
147 .collect()
148 }
149
150 fn is_agent_initiated(messages: &[Message]) -> bool {
151 messages
152 .iter()
153 .rev()
154 .find(|msg| msg.role != Role::System)
155 .map(|msg| msg.role != Role::User)
156 .unwrap_or(false)
157 }
158
159 fn has_vision_input(messages: &[Message]) -> bool {
160 messages.iter().any(|msg| {
161 msg.content
162 .iter()
163 .any(|part| matches!(part, ContentPart::Image { .. }))
164 })
165 }
166
167 async fn discover_models_from_api(&self) -> Vec<ModelInfo> {
169 let response = match self
170 .client
171 .get(format!("{}/models", self.base_url))
172 .header("Authorization", format!("Bearer {}", self.token))
173 .header("User-Agent", Self::user_agent())
174 .send()
175 .await
176 {
177 Ok(r) => r,
178 Err(e) => {
179 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to fetch Copilot models endpoint");
180 return Vec::new();
181 }
182 };
183
184 let status = response.status();
185 if !status.is_success() {
186 let body = response.text().await.unwrap_or_default();
187 tracing::warn!(
188 provider = %self.provider_name,
189 status = %status,
190 body = %body.chars().take(200).collect::<String>(),
191 "Copilot /models endpoint returned non-success"
192 );
193 return Vec::new();
194 }
195
196 let parsed: CopilotModelsResponse = match crate::provider::body_cap::json_capped(
197 response,
198 crate::provider::body_cap::PROVIDER_METADATA_BODY_CAP,
199 )
200 .await
201 {
202 Ok(p) => p,
203 Err(e) => {
204 tracing::warn!(provider = %self.provider_name, error = %e, "Failed to parse Copilot models response (or exceeded body cap)");
205 return Vec::new();
206 }
207 };
208
209 let models: Vec<ModelInfo> = parsed
210 .data
211 .into_iter()
212 .filter(|model| {
213 if model.model_picker_enabled == Some(false) {
215 return false;
216 }
217 if let Some(ref policy) = model.policy
219 && policy.state.as_deref() == Some("disabled")
220 {
221 return false;
222 }
223 true
224 })
225 .map(|model| {
226 let caps = model.capabilities.as_ref();
227 let limits = caps.and_then(|c| c.limits.as_ref());
228 let supports = caps.and_then(|c| c.supports.as_ref());
229
230 ModelInfo {
231 id: model.id.clone(),
232 name: model.name.unwrap_or_else(|| model.id.clone()),
233 provider: self.provider_name.clone(),
234 context_window: limits
235 .and_then(|l| l.max_context_window_tokens)
236 .unwrap_or(128_000),
237 max_output_tokens: limits.and_then(|l| l.max_output_tokens).or(Some(16_384)),
238 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
239 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
240 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
241 input_cost_per_million: None,
242 output_cost_per_million: None,
243 }
244 })
245 .collect();
246
247 tracing::info!(
248 provider = %self.provider_name,
249 count = models.len(),
250 "Discovered models from Copilot API"
251 );
252 models
253 }
254
255 fn enrich_with_pricing(&self, models: &mut [ModelInfo]) {
262 let pricing: std::collections::HashMap<&str, (&str, f64)> = [
264 ("claude-opus-4.5", ("Claude Opus 4.5", 3.0)),
265 ("claude-opus-4.6", ("Claude Opus 4.6", 3.0)),
266 ("claude-opus-41", ("Claude Opus 4.1", 10.0)),
267 ("claude-sonnet-4-6", ("Claude Sonnet 4.6", 1.0)),
268 ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 1.0)),
269 ("claude-sonnet-4", ("Claude Sonnet 4", 1.0)),
270 ("claude-haiku-4.5", ("Claude Haiku 4.5", 0.33)),
271 ("gpt-5.3-codex", ("GPT-5.3-Codex", 1.0)),
272 ("gpt-5.2", ("GPT-5.2", 1.0)),
273 ("gpt-5.2-codex", ("GPT-5.2-Codex", 1.0)),
274 ("gpt-5.1", ("GPT-5.1", 1.0)),
275 ("gpt-5.1-codex", ("GPT-5.1-Codex", 1.0)),
276 ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 0.33)),
277 ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 1.0)),
278 ("gpt-5", ("GPT-5", 1.0)),
279 ("gpt-5-mini", ("GPT-5 mini", 0.0)),
280 ("gpt-5-codex", ("GPT-5-Codex", 1.0)),
281 ("gpt-4.1", ("GPT-4.1", 0.0)),
282 ("gpt-4o", ("GPT-4o", 0.0)),
283 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1.0)),
284 ("gemini-3.1-pro-preview", ("Gemini 3.1 Pro Preview", 1.0)),
285 (
286 "gemini-3.1-pro-preview-customtools",
287 ("Gemini 3.1 Pro Preview (Custom Tools)", 1.0),
288 ),
289 ("gemini-3-flash-preview", ("Gemini 3 Flash Preview", 0.33)),
290 ("gemini-3-pro-preview", ("Gemini 3 Pro Preview", 1.0)),
291 (
292 "gemini-3-pro-image-preview",
293 ("Gemini 3 Pro Image Preview", 1.0),
294 ),
295 ("grok-code-fast-1", ("Grok Code Fast 1", 0.25)),
296 ]
297 .into_iter()
298 .collect();
299
300 for model in models.iter_mut() {
301 if let Some((display_name, premium_mult)) = pricing.get(model.id.as_str()) {
302 if model.name == model.id {
304 model.name = display_name.to_string();
305 }
306 let approx_cost = premium_mult * 10.0;
307 model.input_cost_per_million = Some(approx_cost);
308 model.output_cost_per_million = Some(approx_cost);
309 } else {
310 if model.input_cost_per_million.is_none() {
312 model.input_cost_per_million = Some(10.0);
313 }
314 if model.output_cost_per_million.is_none() {
315 model.output_cost_per_million = Some(10.0);
316 }
317 }
318 }
319 }
320
321 fn known_models(&self) -> Vec<ModelInfo> {
323 let entries: &[(&str, &str, usize, usize, bool)] = &[
324 ("gpt-4o", "GPT-4o", 128_000, 16_384, true),
325 ("gpt-4.1", "GPT-4.1", 128_000, 32_768, false),
326 ("gpt-5", "GPT-5", 400_000, 128_000, false),
327 ("gpt-5-mini", "GPT-5 mini", 264_000, 64_000, false),
328 ("claude-sonnet-4", "Claude Sonnet 4", 200_000, 64_000, false),
329 (
330 "claude-sonnet-4.5",
331 "Claude Sonnet 4.5",
332 200_000,
333 64_000,
334 false,
335 ),
336 (
337 "claude-sonnet-4-6",
338 "Claude Sonnet 4.6",
339 200_000,
340 128_000,
341 false,
342 ),
343 (
344 "claude-haiku-4.5",
345 "Claude Haiku 4.5",
346 200_000,
347 64_000,
348 false,
349 ),
350 ("gemini-2.5-pro", "Gemini 2.5 Pro", 1_000_000, 64_000, false),
351 (
352 "gemini-3.1-pro-preview",
353 "Gemini 3.1 Pro Preview",
354 1_048_576,
355 65_536,
356 false,
357 ),
358 (
359 "gemini-3.1-pro-preview-customtools",
360 "Gemini 3.1 Pro Preview (Custom Tools)",
361 1_048_576,
362 65_536,
363 false,
364 ),
365 (
366 "gemini-3-pro-preview",
367 "Gemini 3 Pro Preview",
368 1_048_576,
369 65_536,
370 false,
371 ),
372 (
373 "gemini-3-flash-preview",
374 "Gemini 3 Flash Preview",
375 1_048_576,
376 65_536,
377 false,
378 ),
379 (
380 "gemini-3-pro-image-preview",
381 "Gemini 3 Pro Image Preview",
382 65_536,
383 32_768,
384 false,
385 ),
386 ];
387
388 entries
389 .iter()
390 .map(|(id, name, ctx, max_out, vision)| ModelInfo {
391 id: id.to_string(),
392 name: name.to_string(),
393 provider: self.provider_name.clone(),
394 context_window: *ctx,
395 max_output_tokens: Some(*max_out),
396 supports_vision: *vision,
397 supports_tools: true,
398 supports_streaming: true,
399 input_cost_per_million: None,
400 output_cost_per_million: None,
401 })
402 .collect()
403 }
404}
405
406#[derive(Debug, Deserialize)]
407struct CopilotResponse {
408 choices: Vec<CopilotChoice>,
409 #[serde(default)]
410 usage: Option<CopilotUsage>,
411}
412
413#[derive(Debug, Deserialize)]
414struct CopilotChoice {
415 message: CopilotMessage,
416 #[serde(default)]
417 finish_reason: Option<String>,
418}
419
420#[derive(Debug, Deserialize)]
421struct CopilotMessage {
422 #[serde(default)]
423 content: Option<String>,
424 #[serde(default)]
425 tool_calls: Option<Vec<CopilotToolCall>>,
426}
427
428#[derive(Debug, Deserialize)]
429struct CopilotToolCall {
430 id: String,
431 #[serde(rename = "type")]
432 #[allow(dead_code)]
433 call_type: String,
434 function: CopilotFunction,
435}
436
437#[derive(Debug, Deserialize)]
438struct CopilotFunction {
439 name: String,
440 arguments: String,
441}
442
443#[derive(Debug, Deserialize)]
444struct CopilotUsage {
445 #[serde(default)]
446 prompt_tokens: usize,
447 #[serde(default)]
448 completion_tokens: usize,
449 #[serde(default)]
450 total_tokens: usize,
451}
452
453#[derive(Debug, Deserialize)]
454struct CopilotErrorResponse {
455 error: Option<CopilotErrorDetail>,
456 message: Option<String>,
457}
458
459#[derive(Debug, Deserialize)]
460struct CopilotErrorDetail {
461 message: Option<String>,
462 code: Option<String>,
463}
464
465#[derive(Debug, Deserialize)]
466struct CopilotModelsResponse {
467 data: Vec<CopilotModelInfo>,
468}
469
470#[derive(Debug, Deserialize)]
471struct CopilotModelInfo {
472 id: String,
473 #[serde(default)]
474 name: Option<String>,
475 #[serde(default)]
476 model_picker_enabled: Option<bool>,
477 #[serde(default)]
478 policy: Option<CopilotModelPolicy>,
479 #[serde(default)]
480 capabilities: Option<CopilotModelCapabilities>,
481}
482
483#[derive(Debug, Deserialize)]
484struct CopilotModelPolicy {
485 #[serde(default)]
486 state: Option<String>,
487}
488
489#[derive(Debug, Deserialize)]
490struct CopilotModelCapabilities {
491 #[serde(default)]
492 limits: Option<CopilotModelLimits>,
493 #[serde(default)]
494 supports: Option<CopilotModelSupports>,
495}
496
497#[derive(Debug, Deserialize)]
498struct CopilotModelLimits {
499 #[serde(default)]
500 max_context_window_tokens: Option<usize>,
501 #[serde(default)]
502 max_output_tokens: Option<usize>,
503}
504
505#[derive(Debug, Deserialize)]
506struct CopilotModelSupports {
507 #[serde(default)]
508 tool_calls: Option<bool>,
509 #[serde(default)]
510 vision: Option<bool>,
511 #[serde(default)]
512 streaming: Option<bool>,
513}
514
515#[async_trait]
516impl Provider for CopilotProvider {
517 fn name(&self) -> &str {
518 &self.provider_name
519 }
520
521 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
522 let mut models = self.discover_models_from_api().await;
523
524 if models.is_empty() {
526 tracing::info!(provider = %self.provider_name, "No models from API, using known model catalog");
527 models = self.known_models();
528 }
529
530 self.enrich_with_pricing(&mut models);
532
533 models.retain(|m| {
535 !m.id.starts_with("text-embedding")
536 && !m.id.contains("-embedding-")
537 && !is_dated_model_variant(&m.id)
538 });
539
540 let mut seen = std::collections::HashSet::new();
542 models.retain(|m| seen.insert(m.id.clone()));
543
544 Ok(models)
545 }
546
547 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
548 let messages = Self::convert_messages(&request.messages);
549 let tools = Self::convert_tools(&request.tools);
550 let is_agent = Self::is_agent_initiated(&request.messages);
551 let has_vision = Self::has_vision_input(&request.messages);
552
553 let mut body = json!({
554 "model": request.model,
555 "messages": messages,
556 });
557
558 if !tools.is_empty() {
559 body["tools"] = json!(tools);
560 }
561 if let Some(temp) = request.temperature {
562 body["temperature"] = json!(temp);
563 }
564 if let Some(top_p) = request.top_p {
565 body["top_p"] = json!(top_p);
566 }
567 if let Some(max) = request.max_tokens {
568 let key = if request.model.starts_with("gpt-5") {
569 "max_completion_tokens"
570 } else {
571 "max_tokens"
572 };
573 body[key] = json!(max);
574 }
575 if !request.stop.is_empty() {
576 body["stop"] = json!(request.stop);
577 }
578
579 let mut req = self
580 .client
581 .post(format!("{}/chat/completions", self.base_url))
582 .header("Authorization", format!("Bearer {}", self.token))
583 .header("Content-Type", "application/json")
584 .header("Openai-Intent", "conversation-edits")
585 .header("User-Agent", Self::user_agent())
586 .header("X-Initiator", if is_agent { "agent" } else { "user" });
587
588 if has_vision {
589 req = req.header("Copilot-Vision-Request", "true");
590 }
591
592 let response = req
593 .json(&body)
594 .send()
595 .await
596 .context("Failed to send Copilot request")?;
597
598 let status = response.status();
599 let text = response
600 .text()
601 .await
602 .context("Failed to read Copilot response")?;
603
604 if !status.is_success() {
605 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
606 let message = err
607 .error
608 .and_then(|detail| {
609 detail.message.map(|msg| {
610 if let Some(code) = detail.code {
611 format!("{} ({})", msg, code)
612 } else {
613 msg
614 }
615 })
616 })
617 .or(err.message)
618 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
619 anyhow::bail!("Copilot API error: {}", message);
620 }
621 anyhow::bail!("Copilot API error: {} {}", status, text);
622 }
623
624 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
625 "Failed to parse Copilot response: {}",
626 util::truncate_bytes_safe(&text, 200)
627 ))?;
628
629 let choice = response
630 .choices
631 .first()
632 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
633
634 let mut content = Vec::new();
635 let mut has_tool_calls = false;
636
637 if let Some(text) = &choice.message.content
638 && !text.is_empty()
639 {
640 content.push(ContentPart::Text { text: text.clone() });
641 }
642
643 if let Some(tool_calls) = &choice.message.tool_calls {
644 has_tool_calls = !tool_calls.is_empty();
645 for tc in tool_calls {
646 content.push(ContentPart::ToolCall {
647 id: tc.id.clone(),
648 name: tc.function.name.clone(),
649 arguments: tc.function.arguments.clone(),
650 thought_signature: None,
651 });
652 }
653 }
654
655 let finish_reason = if has_tool_calls {
656 FinishReason::ToolCalls
657 } else {
658 match choice.finish_reason.as_deref() {
659 Some("stop") => FinishReason::Stop,
660 Some("length") => FinishReason::Length,
661 Some("tool_calls") => FinishReason::ToolCalls,
662 Some("content_filter") => FinishReason::ContentFilter,
663 _ => FinishReason::Stop,
664 }
665 };
666
667 Ok(CompletionResponse {
668 message: Message {
669 role: Role::Assistant,
670 content,
671 },
672 usage: Usage {
673 prompt_tokens: response
674 .usage
675 .as_ref()
676 .map(|u| u.prompt_tokens)
677 .unwrap_or(0),
678 completion_tokens: response
679 .usage
680 .as_ref()
681 .map(|u| u.completion_tokens)
682 .unwrap_or(0),
683 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
684 ..Default::default()
685 },
686 finish_reason,
687 })
688 }
689
690 async fn complete_stream(
691 &self,
692 request: CompletionRequest,
693 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
694 let response = self.complete(request).await?;
696 let text = response
697 .message
698 .content
699 .iter()
700 .filter_map(|p| match p {
701 ContentPart::Text { text } => Some(text.clone()),
702 _ => None,
703 })
704 .collect::<Vec<_>>()
705 .join("");
706
707 Ok(Box::pin(futures::stream::once(async move {
708 StreamChunk::Text(text)
709 })))
710 }
711}
712
713fn is_dated_model_variant(id: &str) -> bool {
716 let bytes = id.as_bytes();
718 if bytes.len() < 11 {
719 return false;
720 }
721 let tail = &id[id.len() - 11..];
723 tail.starts_with('-')
724 && tail[1..5].bytes().all(|b| b.is_ascii_digit())
725 && tail.as_bytes()[5] == b'-'
726 && tail[6..8].bytes().all(|b| b.is_ascii_digit())
727 && tail.as_bytes()[8] == b'-'
728 && tail[9..11].bytes().all(|b| b.is_ascii_digit())
729}
730
731pub fn normalize_enterprise_domain(input: &str) -> String {
732 input
733 .trim()
734 .trim_start_matches("https://")
735 .trim_start_matches("http://")
736 .trim_end_matches('/')
737 .to_string()
738}
739
740pub fn enterprise_base_url(enterprise_url: &str) -> String {
741 format!(
742 "https://copilot-api.{}",
743 normalize_enterprise_domain(enterprise_url)
744 )
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[test]
752 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
753 for input in [
754 "https://company.ghe.com/",
755 "http://company.ghe.com",
756 "company.ghe.com",
757 ] {
758 assert_eq!(normalize_enterprise_domain(input), "company.ghe.com");
759 }
760 }
761
762 #[test]
763 fn enterprise_base_url_uses_copilot_api_subdomain() {
764 let url = enterprise_base_url("https://company.ghe.com/");
765 assert_eq!(url, "https://copilot-api.company.ghe.com");
766 }
767
768 #[test]
769 fn is_dated_model_variant_detects_date_suffix() {
770 assert!(is_dated_model_variant("gpt-4o-2024-05-13"));
771 assert!(is_dated_model_variant("gpt-4o-2024-08-06"));
772 assert!(is_dated_model_variant("gpt-4.1-2025-04-14"));
773 assert!(is_dated_model_variant("gpt-4o-mini-2024-07-18"));
774 assert!(!is_dated_model_variant("gpt-4o"));
775 assert!(!is_dated_model_variant("gpt-5"));
776 assert!(!is_dated_model_variant("claude-sonnet-4"));
777 assert!(!is_dated_model_variant("gemini-2.5-pro"));
778 }
779
780 #[test]
781 fn known_models_fallback_is_non_empty() {
782 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
783 let models = provider.known_models();
784 assert!(!models.is_empty());
785 assert!(models.iter().all(|m| m.supports_tools));
787 }
788
789 #[test]
790 fn enrich_with_pricing_sets_costs() {
791 let provider = CopilotProvider::new("test-token".to_string()).unwrap();
792 let mut models = vec![ModelInfo {
793 id: "gpt-4o".to_string(),
794 name: "gpt-4o".to_string(),
795 provider: "github-copilot".to_string(),
796 context_window: 128_000,
797 max_output_tokens: Some(16_384),
798 supports_vision: true,
799 supports_tools: true,
800 supports_streaming: true,
801 input_cost_per_million: None,
802 output_cost_per_million: None,
803 }];
804 provider.enrich_with_pricing(&mut models);
805 assert_eq!(models[0].input_cost_per_million, Some(0.0));
807 assert_eq!(models[0].output_cost_per_million, Some(0.0));
808 assert_eq!(models[0].name, "GPT-4o");
810 }
811}