1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18 client: Client,
19 token: String,
20 base_url: String,
21 provider_name: String,
22}
23
24impl CopilotProvider {
25 pub fn new(token: String) -> Result<Self> {
26 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27 }
28
29 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30 let base_url = enterprise_base_url(&enterprise_url);
31 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32 }
33
34 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35 Ok(Self {
36 client: Client::new(),
37 token,
38 base_url: base_url.trim_end_matches('/').to_string(),
39 provider_name: provider_name.to_string(),
40 })
41 }
42
43 fn user_agent() -> String {
44 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45 }
46
47 fn convert_messages(messages: &[Message]) -> Vec<Value> {
48 messages
49 .iter()
50 .map(|msg| {
51 let role = match msg.role {
52 Role::System => "system",
53 Role::User => "user",
54 Role::Assistant => "assistant",
55 Role::Tool => "tool",
56 };
57
58 match msg.role {
59 Role::Tool => {
60 if let Some(ContentPart::ToolResult {
61 tool_call_id,
62 content,
63 }) = msg.content.first()
64 {
65 json!({
66 "role": "tool",
67 "tool_call_id": tool_call_id,
68 "content": content
69 })
70 } else {
71 json!({ "role": role, "content": "" })
72 }
73 }
74 Role::Assistant => {
75 let text: String = msg
76 .content
77 .iter()
78 .filter_map(|p| match p {
79 ContentPart::Text { text } => Some(text.clone()),
80 _ => None,
81 })
82 .collect::<Vec<_>>()
83 .join("");
84
85 let tool_calls: Vec<Value> = msg
86 .content
87 .iter()
88 .filter_map(|p| match p {
89 ContentPart::ToolCall {
90 id,
91 name,
92 arguments,
93 } => Some(json!({
94 "id": id,
95 "type": "function",
96 "function": {
97 "name": name,
98 "arguments": arguments
99 }
100 })),
101 _ => None,
102 })
103 .collect();
104
105 if tool_calls.is_empty() {
106 json!({ "role": "assistant", "content": text })
107 } else {
108 json!({
109 "role": "assistant",
110 "content": if text.is_empty() { "".to_string() } else { text },
111 "tool_calls": tool_calls
112 })
113 }
114 }
115 _ => {
116 let text: String = msg
117 .content
118 .iter()
119 .filter_map(|p| match p {
120 ContentPart::Text { text } => Some(text.clone()),
121 _ => None,
122 })
123 .collect::<Vec<_>>()
124 .join("\n");
125 json!({ "role": role, "content": text })
126 }
127 }
128 })
129 .collect()
130 }
131
132 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133 tools
134 .iter()
135 .map(|t| {
136 json!({
137 "type": "function",
138 "function": {
139 "name": t.name,
140 "description": t.description,
141 "parameters": t.parameters
142 }
143 })
144 })
145 .collect()
146 }
147
148 fn is_agent_initiated(messages: &[Message]) -> bool {
149 messages
150 .iter()
151 .rev()
152 .find(|msg| msg.role != Role::System)
153 .map(|msg| msg.role != Role::User)
154 .unwrap_or(false)
155 }
156
157 fn has_vision_input(messages: &[Message]) -> bool {
158 messages.iter().any(|msg| {
159 msg.content
160 .iter()
161 .any(|part| matches!(part, ContentPart::Image { .. }))
162 })
163 }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168 choices: Vec<CopilotChoice>,
169 #[serde(default)]
170 usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175 message: CopilotMessage,
176 #[serde(default)]
177 finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182 #[serde(default)]
183 content: Option<String>,
184 #[serde(default)]
185 tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190 id: String,
191 #[serde(rename = "type")]
192 #[allow(dead_code)]
193 call_type: String,
194 function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199 name: String,
200 arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205 #[serde(default)]
206 prompt_tokens: usize,
207 #[serde(default)]
208 completion_tokens: usize,
209 #[serde(default)]
210 total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215 error: Option<CopilotErrorDetail>,
216 message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221 message: Option<String>,
222 code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227 data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232 id: String,
233 #[serde(default)]
234 name: Option<String>,
235 #[serde(default)]
236 model_picker_enabled: Option<bool>,
237 #[serde(default)]
238 policy: Option<CopilotModelPolicy>,
239 #[serde(default)]
240 capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245 #[serde(default)]
246 state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251 #[serde(default)]
252 limits: Option<CopilotModelLimits>,
253 #[serde(default)]
254 supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259 #[serde(default)]
260 max_context_window_tokens: Option<usize>,
261 #[serde(default)]
262 max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267 #[serde(default)]
268 tool_calls: Option<bool>,
269 #[serde(default)]
270 vision: Option<bool>,
271 #[serde(default)]
272 streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277 fn name(&self) -> &str {
278 &self.provider_name
279 }
280
281 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282 let response = self
283 .client
284 .get(format!("{}/models", self.base_url))
285 .header("Authorization", format!("Bearer {}", self.token))
286 .header("Openai-Intent", "conversation-edits")
287 .header("User-Agent", Self::user_agent())
288 .send()
289 .await
290 .context("Failed to fetch Copilot models")?;
291
292 let mut models: Vec<ModelInfo> = if response.status().is_success() {
293 let parsed: CopilotModelsResponse = response
294 .json()
295 .await
296 .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298 parsed
299 .data
300 .into_iter()
301 .map(|model| {
302 let caps = model.capabilities.as_ref();
303 let limits = caps.and_then(|c| c.limits.as_ref());
304 let supports = caps.and_then(|c| c.supports.as_ref());
305
306 ModelInfo {
307 id: model.id.clone(),
308 name: model.name.unwrap_or_else(|| model.id.clone()),
309 provider: self.provider_name.clone(),
310 context_window: limits
311 .and_then(|l| l.max_context_window_tokens)
312 .unwrap_or(128_000),
313 max_output_tokens: limits
314 .and_then(|l| l.max_output_tokens)
315 .or(Some(16_384)),
316 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
317 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
318 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
319 input_cost_per_million: None, output_cost_per_million: None,
321 }
322 })
323 .collect()
324 } else {
325 Vec::new()
326 };
327
328 let known_metadata: std::collections::HashMap<&str, (&str, usize, usize, f64)> = [
346 ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000, 3.0)),
347 ("claude-opus-4.6", ("Claude Opus 4.6", 200_000, 64_000, 3.0)),
348 ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000, 10.0)),
349 (
350 "claude-sonnet-4.5",
351 ("Claude Sonnet 4.5", 200_000, 64_000, 1.0),
352 ),
353 ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000, 1.0)),
354 (
355 "claude-haiku-4.5",
356 ("Claude Haiku 4.5", 200_000, 64_000, 0.33),
357 ),
358 ("gpt-5.2", ("GPT-5.2", 400_000, 128_000, 1.0)),
359 ("gpt-5.1", ("GPT-5.1", 400_000, 128_000, 1.0)),
360 ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000, 1.0)),
361 (
362 "gpt-5.1-codex-mini",
363 ("GPT-5.1-Codex-Mini", 264_000, 64_000, 0.33),
364 ),
365 (
366 "gpt-5.1-codex-max",
367 ("GPT-5.1-Codex-Max", 264_000, 64_000, 1.0),
368 ),
369 ("gpt-5", ("GPT-5", 400_000, 128_000, 1.0)),
370 ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000, 0.0)),
371 ("gpt-5-codex", ("GPT-5-Codex", 264_000, 64_000, 1.0)),
372 ("gpt-4.1", ("GPT-4.1", 128_000, 32_768, 0.0)),
373 ("gpt-4o", ("GPT-4o", 128_000, 16_384, 0.0)),
374 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000, 1.0)),
375 (
376 "gemini-3-flash-preview",
377 ("Gemini 3 Flash", 1_000_000, 64_000, 0.33),
378 ),
379 (
380 "gemini-3-pro-preview",
381 ("Gemini 3 Pro", 1_000_000, 64_000, 1.0),
382 ),
383 (
384 "grok-code-fast-1",
385 ("Grok Code Fast 1", 128_000, 32_768, 0.25),
386 ),
387 ]
388 .into_iter()
389 .collect();
390
391 for model in &mut models {
394 if let Some((name, ctx, max_out, premium_mult)) = known_metadata.get(model.id.as_str())
395 {
396 if model.name == model.id {
397 model.name = name.to_string();
398 }
399 if model.context_window == 128_000 {
400 model.context_window = *ctx;
401 }
402 if model.max_output_tokens == Some(16_384) {
403 model.max_output_tokens = Some(*max_out);
404 }
405 let approx_cost = premium_mult * 10.0;
409 model.input_cost_per_million = Some(approx_cost);
410 model.output_cost_per_million = Some(approx_cost);
411 } else {
412 if model.input_cost_per_million.is_none() {
414 model.input_cost_per_million = Some(10.0);
415 }
416 if model.output_cost_per_million.is_none() {
417 model.output_cost_per_million = Some(10.0);
418 }
419 }
420 }
421
422 models.retain(|m| {
425 !m.id.starts_with("text-embedding")
426 && m.id != "gpt-3.5-turbo"
427 && m.id != "gpt-3.5-turbo-0613"
428 && m.id != "gpt-4-0613"
429 && m.id != "gpt-4o-2024-05-13"
430 && m.id != "gpt-4o-2024-08-06"
431 && m.id != "gpt-4o-2024-11-20"
432 && m.id != "gpt-4o-mini-2024-07-18"
433 && m.id != "gpt-4-o-preview"
434 && m.id != "gpt-4.1-2025-04-14"
435 });
436
437 let mut seen = std::collections::HashSet::new();
439 models.retain(|m| seen.insert(m.id.clone()));
440
441 Ok(models)
442 }
443
444 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
445 let messages = Self::convert_messages(&request.messages);
446 let tools = Self::convert_tools(&request.tools);
447 let is_agent = Self::is_agent_initiated(&request.messages);
448 let has_vision = Self::has_vision_input(&request.messages);
449
450 let mut body = json!({
451 "model": request.model,
452 "messages": messages,
453 });
454
455 if !tools.is_empty() {
456 body["tools"] = json!(tools);
457 }
458 if let Some(temp) = request.temperature {
459 body["temperature"] = json!(temp);
460 }
461 if let Some(top_p) = request.top_p {
462 body["top_p"] = json!(top_p);
463 }
464 if let Some(max) = request.max_tokens {
465 body["max_tokens"] = json!(max);
466 }
467 if !request.stop.is_empty() {
468 body["stop"] = json!(request.stop);
469 }
470
471 let mut req = self
472 .client
473 .post(format!("{}/chat/completions", self.base_url))
474 .header("Authorization", format!("Bearer {}", self.token))
475 .header("Content-Type", "application/json")
476 .header("Openai-Intent", "conversation-edits")
477 .header("User-Agent", Self::user_agent())
478 .header("X-Initiator", if is_agent { "agent" } else { "user" });
479
480 if has_vision {
481 req = req.header("Copilot-Vision-Request", "true");
482 }
483
484 let response = req
485 .json(&body)
486 .send()
487 .await
488 .context("Failed to send Copilot request")?;
489
490 let status = response.status();
491 let text = response
492 .text()
493 .await
494 .context("Failed to read Copilot response")?;
495
496 if !status.is_success() {
497 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
498 let message = err
499 .error
500 .and_then(|detail| {
501 detail.message.map(|msg| {
502 if let Some(code) = detail.code {
503 format!("{} ({})", msg, code)
504 } else {
505 msg
506 }
507 })
508 })
509 .or(err.message)
510 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
511 anyhow::bail!("Copilot API error: {}", message);
512 }
513 anyhow::bail!("Copilot API error: {} {}", status, text);
514 }
515
516 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
517 "Failed to parse Copilot response: {}",
518 &text[..text.len().min(200)]
519 ))?;
520
521 let choice = response
522 .choices
523 .first()
524 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
525
526 let mut content = Vec::new();
527 let mut has_tool_calls = false;
528
529 if let Some(text) = &choice.message.content {
530 if !text.is_empty() {
531 content.push(ContentPart::Text { text: text.clone() });
532 }
533 }
534
535 if let Some(tool_calls) = &choice.message.tool_calls {
536 has_tool_calls = !tool_calls.is_empty();
537 for tc in tool_calls {
538 content.push(ContentPart::ToolCall {
539 id: tc.id.clone(),
540 name: tc.function.name.clone(),
541 arguments: tc.function.arguments.clone(),
542 });
543 }
544 }
545
546 let finish_reason = if has_tool_calls {
547 FinishReason::ToolCalls
548 } else {
549 match choice.finish_reason.as_deref() {
550 Some("stop") => FinishReason::Stop,
551 Some("length") => FinishReason::Length,
552 Some("tool_calls") => FinishReason::ToolCalls,
553 Some("content_filter") => FinishReason::ContentFilter,
554 _ => FinishReason::Stop,
555 }
556 };
557
558 Ok(CompletionResponse {
559 message: Message {
560 role: Role::Assistant,
561 content,
562 },
563 usage: Usage {
564 prompt_tokens: response
565 .usage
566 .as_ref()
567 .map(|u| u.prompt_tokens)
568 .unwrap_or(0),
569 completion_tokens: response
570 .usage
571 .as_ref()
572 .map(|u| u.completion_tokens)
573 .unwrap_or(0),
574 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
575 ..Default::default()
576 },
577 finish_reason,
578 })
579 }
580
581 async fn complete_stream(
582 &self,
583 request: CompletionRequest,
584 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
585 let response = self.complete(request).await?;
587 let text = response
588 .message
589 .content
590 .iter()
591 .filter_map(|p| match p {
592 ContentPart::Text { text } => Some(text.clone()),
593 _ => None,
594 })
595 .collect::<Vec<_>>()
596 .join("");
597
598 Ok(Box::pin(futures::stream::once(async move {
599 StreamChunk::Text(text)
600 })))
601 }
602}
603
604pub fn normalize_enterprise_domain(input: &str) -> String {
605 input
606 .trim()
607 .trim_start_matches("https://")
608 .trim_start_matches("http://")
609 .trim_end_matches('/')
610 .to_string()
611}
612
613pub fn enterprise_base_url(enterprise_url: &str) -> String {
614 format!(
615 "https://copilot-api.{}",
616 normalize_enterprise_domain(enterprise_url)
617 )
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[test]
625 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
626 assert_eq!(
627 normalize_enterprise_domain("https://company.ghe.com/"),
628 "company.ghe.com"
629 );
630 assert_eq!(
631 normalize_enterprise_domain("http://company.ghe.com"),
632 "company.ghe.com"
633 );
634 assert_eq!(
635 normalize_enterprise_domain("company.ghe.com"),
636 "company.ghe.com"
637 );
638 }
639
640 #[test]
641 fn enterprise_base_url_uses_copilot_api_subdomain() {
642 assert_eq!(
643 enterprise_base_url("https://company.ghe.com/"),
644 "https://copilot-api.company.ghe.com"
645 );
646 }
647}