1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18 client: Client,
19 token: String,
20 base_url: String,
21 provider_name: String,
22}
23
24impl CopilotProvider {
25 pub fn new(token: String) -> Result<Self> {
26 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27 }
28
29 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30 let base_url = enterprise_base_url(&enterprise_url);
31 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32 }
33
34 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35 Ok(Self {
36 client: Client::new(),
37 token,
38 base_url: base_url.trim_end_matches('/').to_string(),
39 provider_name: provider_name.to_string(),
40 })
41 }
42
43 fn user_agent() -> String {
44 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45 }
46
47 fn convert_messages(messages: &[Message]) -> Vec<Value> {
48 messages
49 .iter()
50 .map(|msg| {
51 let role = match msg.role {
52 Role::System => "system",
53 Role::User => "user",
54 Role::Assistant => "assistant",
55 Role::Tool => "tool",
56 };
57
58 match msg.role {
59 Role::Tool => {
60 if let Some(ContentPart::ToolResult {
61 tool_call_id,
62 content,
63 }) = msg.content.first()
64 {
65 json!({
66 "role": "tool",
67 "tool_call_id": tool_call_id,
68 "content": content
69 })
70 } else {
71 json!({ "role": role, "content": "" })
72 }
73 }
74 Role::Assistant => {
75 let text: String = msg
76 .content
77 .iter()
78 .filter_map(|p| match p {
79 ContentPart::Text { text } => Some(text.clone()),
80 _ => None,
81 })
82 .collect::<Vec<_>>()
83 .join("");
84
85 let tool_calls: Vec<Value> = msg
86 .content
87 .iter()
88 .filter_map(|p| match p {
89 ContentPart::ToolCall {
90 id,
91 name,
92 arguments,
93 } => Some(json!({
94 "id": id,
95 "type": "function",
96 "function": {
97 "name": name,
98 "arguments": arguments
99 }
100 })),
101 _ => None,
102 })
103 .collect();
104
105 if tool_calls.is_empty() {
106 json!({ "role": "assistant", "content": text })
107 } else {
108 json!({
109 "role": "assistant",
110 "content": if text.is_empty() { "".to_string() } else { text },
111 "tool_calls": tool_calls
112 })
113 }
114 }
115 _ => {
116 let text: String = msg
117 .content
118 .iter()
119 .filter_map(|p| match p {
120 ContentPart::Text { text } => Some(text.clone()),
121 _ => None,
122 })
123 .collect::<Vec<_>>()
124 .join("\n");
125 json!({ "role": role, "content": text })
126 }
127 }
128 })
129 .collect()
130 }
131
132 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133 tools
134 .iter()
135 .map(|t| {
136 json!({
137 "type": "function",
138 "function": {
139 "name": t.name,
140 "description": t.description,
141 "parameters": t.parameters
142 }
143 })
144 })
145 .collect()
146 }
147
148 fn is_agent_initiated(messages: &[Message]) -> bool {
149 messages
150 .iter()
151 .rev()
152 .find(|msg| msg.role != Role::System)
153 .map(|msg| msg.role != Role::User)
154 .unwrap_or(false)
155 }
156
157 fn has_vision_input(messages: &[Message]) -> bool {
158 messages.iter().any(|msg| {
159 msg.content
160 .iter()
161 .any(|part| matches!(part, ContentPart::Image { .. }))
162 })
163 }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168 choices: Vec<CopilotChoice>,
169 #[serde(default)]
170 usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175 message: CopilotMessage,
176 #[serde(default)]
177 finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182 #[serde(default)]
183 content: Option<String>,
184 #[serde(default)]
185 tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190 id: String,
191 #[serde(rename = "type")]
192 #[allow(dead_code)]
193 call_type: String,
194 function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199 name: String,
200 arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205 #[serde(default)]
206 prompt_tokens: usize,
207 #[serde(default)]
208 completion_tokens: usize,
209 #[serde(default)]
210 total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215 error: Option<CopilotErrorDetail>,
216 message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221 message: Option<String>,
222 code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227 data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232 id: String,
233 #[serde(default)]
234 name: Option<String>,
235 #[serde(default)]
236 model_picker_enabled: Option<bool>,
237 #[serde(default)]
238 policy: Option<CopilotModelPolicy>,
239 #[serde(default)]
240 capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245 #[serde(default)]
246 state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251 #[serde(default)]
252 limits: Option<CopilotModelLimits>,
253 #[serde(default)]
254 supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259 #[serde(default)]
260 max_context_window_tokens: Option<usize>,
261 #[serde(default)]
262 max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267 #[serde(default)]
268 tool_calls: Option<bool>,
269 #[serde(default)]
270 vision: Option<bool>,
271 #[serde(default)]
272 streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277 fn name(&self) -> &str {
278 &self.provider_name
279 }
280
281 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282 let response = self
283 .client
284 .get(format!("{}/models", self.base_url))
285 .header("Authorization", format!("Bearer {}", self.token))
286 .header("Openai-Intent", "conversation-edits")
287 .header("User-Agent", Self::user_agent())
288 .send()
289 .await
290 .context("Failed to fetch Copilot models")?;
291
292 let mut models: Vec<ModelInfo> = if response.status().is_success() {
293 let parsed: CopilotModelsResponse = response
294 .json()
295 .await
296 .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298 parsed
299 .data
300 .into_iter()
301 .filter(|model| {
302 if model.model_picker_enabled == Some(false) {
304 return false;
305 }
306 if let Some(ref policy) = model.policy {
308 if policy.state.as_deref() == Some("disabled") {
309 return false;
310 }
311 }
312 true
313 })
314 .map(|model| {
315 let caps = model.capabilities.as_ref();
316 let limits = caps.and_then(|c| c.limits.as_ref());
317 let supports = caps.and_then(|c| c.supports.as_ref());
318
319 ModelInfo {
320 id: model.id.clone(),
321 name: model.name.unwrap_or_else(|| model.id.clone()),
322 provider: self.provider_name.clone(),
323 context_window: limits
324 .and_then(|l| l.max_context_window_tokens)
325 .unwrap_or(128_000),
326 max_output_tokens: limits
327 .and_then(|l| l.max_output_tokens)
328 .or(Some(16_384)),
329 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
330 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
331 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
332 input_cost_per_million: None, output_cost_per_million: None,
334 }
335 })
336 .collect()
337 } else {
338 Vec::new()
339 };
340
341 let known_metadata: std::collections::HashMap<&str, (&str, usize, usize, f64)> = [
359 ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000, 3.0)),
360 ("claude-opus-4.6", ("Claude Opus 4.6", 200_000, 64_000, 3.0)),
361 ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000, 10.0)),
362 (
363 "claude-sonnet-4.5",
364 ("Claude Sonnet 4.5", 200_000, 64_000, 1.0),
365 ),
366 ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000, 1.0)),
367 (
368 "claude-haiku-4.5",
369 ("Claude Haiku 4.5", 200_000, 64_000, 0.33),
370 ),
371 ("gpt-5.3-codex", ("GPT-5.3-Codex", 264_000, 64_000, 1.0)),
372 ("gpt-5.2", ("GPT-5.2", 400_000, 128_000, 1.0)),
373 ("gpt-5.2-codex", ("GPT-5.2-Codex", 264_000, 64_000, 1.0)),
374 ("gpt-5.1", ("GPT-5.1", 400_000, 128_000, 1.0)),
375 ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000, 1.0)),
376 (
377 "gpt-5.1-codex-mini",
378 ("GPT-5.1-Codex-Mini", 264_000, 64_000, 0.33),
379 ),
380 (
381 "gpt-5.1-codex-max",
382 ("GPT-5.1-Codex-Max", 264_000, 64_000, 1.0),
383 ),
384 ("gpt-5", ("GPT-5", 400_000, 128_000, 1.0)),
385 ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000, 0.0)),
386 ("gpt-5-codex", ("GPT-5-Codex", 264_000, 64_000, 1.0)),
387 ("gpt-4.1", ("GPT-4.1", 128_000, 32_768, 0.0)),
388 ("gpt-4o", ("GPT-4o", 128_000, 16_384, 0.0)),
389 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000, 1.0)),
390 (
391 "gemini-3-flash-preview",
392 ("Gemini 3 Flash", 1_000_000, 64_000, 0.33),
393 ),
394 (
395 "gemini-3-pro-preview",
396 ("Gemini 3 Pro", 1_000_000, 64_000, 1.0),
397 ),
398 (
399 "grok-code-fast-1",
400 ("Grok Code Fast 1", 128_000, 32_768, 0.25),
401 ),
402 ]
403 .into_iter()
404 .collect();
405
406 for model in &mut models {
409 if let Some((name, ctx, max_out, premium_mult)) = known_metadata.get(model.id.as_str())
410 {
411 if model.name == model.id {
412 model.name = name.to_string();
413 }
414 if model.context_window == 128_000 {
415 model.context_window = *ctx;
416 }
417 if model.max_output_tokens == Some(16_384) {
418 model.max_output_tokens = Some(*max_out);
419 }
420 let approx_cost = premium_mult * 10.0;
424 model.input_cost_per_million = Some(approx_cost);
425 model.output_cost_per_million = Some(approx_cost);
426 } else {
427 if model.input_cost_per_million.is_none() {
429 model.input_cost_per_million = Some(10.0);
430 }
431 if model.output_cost_per_million.is_none() {
432 model.output_cost_per_million = Some(10.0);
433 }
434 }
435 }
436
437 models.retain(|m| {
440 !m.id.starts_with("text-embedding")
441 && m.id != "gpt-3.5-turbo"
442 && m.id != "gpt-3.5-turbo-0613"
443 && m.id != "gpt-4-0613"
444 && m.id != "gpt-4o-2024-05-13"
445 && m.id != "gpt-4o-2024-08-06"
446 && m.id != "gpt-4o-2024-11-20"
447 && m.id != "gpt-4o-mini-2024-07-18"
448 && m.id != "gpt-4-o-preview"
449 && m.id != "gpt-4.1-2025-04-14"
450 });
451
452 let mut seen = std::collections::HashSet::new();
454 models.retain(|m| seen.insert(m.id.clone()));
455
456 Ok(models)
457 }
458
459 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
460 let messages = Self::convert_messages(&request.messages);
461 let tools = Self::convert_tools(&request.tools);
462 let is_agent = Self::is_agent_initiated(&request.messages);
463 let has_vision = Self::has_vision_input(&request.messages);
464
465 let mut body = json!({
466 "model": request.model,
467 "messages": messages,
468 });
469
470 if !tools.is_empty() {
471 body["tools"] = json!(tools);
472 }
473 if let Some(temp) = request.temperature {
474 body["temperature"] = json!(temp);
475 }
476 if let Some(top_p) = request.top_p {
477 body["top_p"] = json!(top_p);
478 }
479 if let Some(max) = request.max_tokens {
480 body["max_tokens"] = json!(max);
481 }
482 if !request.stop.is_empty() {
483 body["stop"] = json!(request.stop);
484 }
485
486 let mut req = self
487 .client
488 .post(format!("{}/chat/completions", self.base_url))
489 .header("Authorization", format!("Bearer {}", self.token))
490 .header("Content-Type", "application/json")
491 .header("Openai-Intent", "conversation-edits")
492 .header("User-Agent", Self::user_agent())
493 .header("X-Initiator", if is_agent { "agent" } else { "user" });
494
495 if has_vision {
496 req = req.header("Copilot-Vision-Request", "true");
497 }
498
499 let response = req
500 .json(&body)
501 .send()
502 .await
503 .context("Failed to send Copilot request")?;
504
505 let status = response.status();
506 let text = response
507 .text()
508 .await
509 .context("Failed to read Copilot response")?;
510
511 if !status.is_success() {
512 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
513 let message = err
514 .error
515 .and_then(|detail| {
516 detail.message.map(|msg| {
517 if let Some(code) = detail.code {
518 format!("{} ({})", msg, code)
519 } else {
520 msg
521 }
522 })
523 })
524 .or(err.message)
525 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
526 anyhow::bail!("Copilot API error: {}", message);
527 }
528 anyhow::bail!("Copilot API error: {} {}", status, text);
529 }
530
531 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
532 "Failed to parse Copilot response: {}",
533 &text[..text.len().min(200)]
534 ))?;
535
536 let choice = response
537 .choices
538 .first()
539 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
540
541 let mut content = Vec::new();
542 let mut has_tool_calls = false;
543
544 if let Some(text) = &choice.message.content {
545 if !text.is_empty() {
546 content.push(ContentPart::Text { text: text.clone() });
547 }
548 }
549
550 if let Some(tool_calls) = &choice.message.tool_calls {
551 has_tool_calls = !tool_calls.is_empty();
552 for tc in tool_calls {
553 content.push(ContentPart::ToolCall {
554 id: tc.id.clone(),
555 name: tc.function.name.clone(),
556 arguments: tc.function.arguments.clone(),
557 });
558 }
559 }
560
561 let finish_reason = if has_tool_calls {
562 FinishReason::ToolCalls
563 } else {
564 match choice.finish_reason.as_deref() {
565 Some("stop") => FinishReason::Stop,
566 Some("length") => FinishReason::Length,
567 Some("tool_calls") => FinishReason::ToolCalls,
568 Some("content_filter") => FinishReason::ContentFilter,
569 _ => FinishReason::Stop,
570 }
571 };
572
573 Ok(CompletionResponse {
574 message: Message {
575 role: Role::Assistant,
576 content,
577 },
578 usage: Usage {
579 prompt_tokens: response
580 .usage
581 .as_ref()
582 .map(|u| u.prompt_tokens)
583 .unwrap_or(0),
584 completion_tokens: response
585 .usage
586 .as_ref()
587 .map(|u| u.completion_tokens)
588 .unwrap_or(0),
589 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
590 ..Default::default()
591 },
592 finish_reason,
593 })
594 }
595
596 async fn complete_stream(
597 &self,
598 request: CompletionRequest,
599 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
600 let response = self.complete(request).await?;
602 let text = response
603 .message
604 .content
605 .iter()
606 .filter_map(|p| match p {
607 ContentPart::Text { text } => Some(text.clone()),
608 _ => None,
609 })
610 .collect::<Vec<_>>()
611 .join("");
612
613 Ok(Box::pin(futures::stream::once(async move {
614 StreamChunk::Text(text)
615 })))
616 }
617}
618
619pub fn normalize_enterprise_domain(input: &str) -> String {
620 input
621 .trim()
622 .trim_start_matches("https://")
623 .trim_start_matches("http://")
624 .trim_end_matches('/')
625 .to_string()
626}
627
628pub fn enterprise_base_url(enterprise_url: &str) -> String {
629 format!(
630 "https://copilot-api.{}",
631 normalize_enterprise_domain(enterprise_url)
632 )
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
641 assert_eq!(
642 normalize_enterprise_domain("https://company.ghe.com/"),
643 "company.ghe.com"
644 );
645 assert_eq!(
646 normalize_enterprise_domain("http://company.ghe.com"),
647 "company.ghe.com"
648 );
649 assert_eq!(
650 normalize_enterprise_domain("company.ghe.com"),
651 "company.ghe.com"
652 );
653 }
654
655 #[test]
656 fn enterprise_base_url_uses_copilot_api_subdomain() {
657 assert_eq!(
658 enterprise_base_url("https://company.ghe.com/"),
659 "https://copilot-api.company.ghe.com"
660 );
661 }
662}