1use super::{
4 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
5 Role, StreamChunk, ToolDefinition, Usage,
6};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use reqwest::Client;
10use serde::Deserialize;
11use serde_json::{Value, json};
12
13const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
14const COPILOT_PROVIDER: &str = "github-copilot";
15const COPILOT_ENTERPRISE_PROVIDER: &str = "github-copilot-enterprise";
16
17pub struct CopilotProvider {
18 client: Client,
19 token: String,
20 base_url: String,
21 provider_name: String,
22}
23
24impl CopilotProvider {
25 pub fn new(token: String) -> Result<Self> {
26 Self::with_base_url(token, DEFAULT_BASE_URL.to_string(), COPILOT_PROVIDER)
27 }
28
29 pub fn enterprise(token: String, enterprise_url: String) -> Result<Self> {
30 let base_url = enterprise_base_url(&enterprise_url);
31 Self::with_base_url(token, base_url, COPILOT_ENTERPRISE_PROVIDER)
32 }
33
34 pub fn with_base_url(token: String, base_url: String, provider_name: &str) -> Result<Self> {
35 Ok(Self {
36 client: Client::new(),
37 token,
38 base_url: base_url.trim_end_matches('/').to_string(),
39 provider_name: provider_name.to_string(),
40 })
41 }
42
43 fn user_agent() -> String {
44 format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"))
45 }
46
47 fn convert_messages(messages: &[Message]) -> Vec<Value> {
48 messages
49 .iter()
50 .map(|msg| {
51 let role = match msg.role {
52 Role::System => "system",
53 Role::User => "user",
54 Role::Assistant => "assistant",
55 Role::Tool => "tool",
56 };
57
58 match msg.role {
59 Role::Tool => {
60 if let Some(ContentPart::ToolResult {
61 tool_call_id,
62 content,
63 }) = msg.content.first()
64 {
65 json!({
66 "role": "tool",
67 "tool_call_id": tool_call_id,
68 "content": content
69 })
70 } else {
71 json!({ "role": role, "content": "" })
72 }
73 }
74 Role::Assistant => {
75 let text: String = msg
76 .content
77 .iter()
78 .filter_map(|p| match p {
79 ContentPart::Text { text } => Some(text.clone()),
80 _ => None,
81 })
82 .collect::<Vec<_>>()
83 .join("");
84
85 let tool_calls: Vec<Value> = msg
86 .content
87 .iter()
88 .filter_map(|p| match p {
89 ContentPart::ToolCall {
90 id,
91 name,
92 arguments,
93 } => Some(json!({
94 "id": id,
95 "type": "function",
96 "function": {
97 "name": name,
98 "arguments": arguments
99 }
100 })),
101 _ => None,
102 })
103 .collect();
104
105 if tool_calls.is_empty() {
106 json!({ "role": "assistant", "content": text })
107 } else {
108 json!({
109 "role": "assistant",
110 "content": if text.is_empty() { "".to_string() } else { text },
111 "tool_calls": tool_calls
112 })
113 }
114 }
115 _ => {
116 let text: String = msg
117 .content
118 .iter()
119 .filter_map(|p| match p {
120 ContentPart::Text { text } => Some(text.clone()),
121 _ => None,
122 })
123 .collect::<Vec<_>>()
124 .join("\n");
125 json!({ "role": role, "content": text })
126 }
127 }
128 })
129 .collect()
130 }
131
132 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
133 tools
134 .iter()
135 .map(|t| {
136 json!({
137 "type": "function",
138 "function": {
139 "name": t.name,
140 "description": t.description,
141 "parameters": t.parameters
142 }
143 })
144 })
145 .collect()
146 }
147
148 fn is_agent_initiated(messages: &[Message]) -> bool {
149 messages
150 .iter()
151 .rev()
152 .find(|msg| msg.role != Role::System)
153 .map(|msg| msg.role != Role::User)
154 .unwrap_or(false)
155 }
156
157 fn has_vision_input(messages: &[Message]) -> bool {
158 messages.iter().any(|msg| {
159 msg.content
160 .iter()
161 .any(|part| matches!(part, ContentPart::Image { .. }))
162 })
163 }
164}
165
166#[derive(Debug, Deserialize)]
167struct CopilotResponse {
168 choices: Vec<CopilotChoice>,
169 #[serde(default)]
170 usage: Option<CopilotUsage>,
171}
172
173#[derive(Debug, Deserialize)]
174struct CopilotChoice {
175 message: CopilotMessage,
176 #[serde(default)]
177 finish_reason: Option<String>,
178}
179
180#[derive(Debug, Deserialize)]
181struct CopilotMessage {
182 #[serde(default)]
183 content: Option<String>,
184 #[serde(default)]
185 tool_calls: Option<Vec<CopilotToolCall>>,
186}
187
188#[derive(Debug, Deserialize)]
189struct CopilotToolCall {
190 id: String,
191 #[serde(rename = "type")]
192 #[allow(dead_code)]
193 call_type: String,
194 function: CopilotFunction,
195}
196
197#[derive(Debug, Deserialize)]
198struct CopilotFunction {
199 name: String,
200 arguments: String,
201}
202
203#[derive(Debug, Deserialize)]
204struct CopilotUsage {
205 #[serde(default)]
206 prompt_tokens: usize,
207 #[serde(default)]
208 completion_tokens: usize,
209 #[serde(default)]
210 total_tokens: usize,
211}
212
213#[derive(Debug, Deserialize)]
214struct CopilotErrorResponse {
215 error: Option<CopilotErrorDetail>,
216 message: Option<String>,
217}
218
219#[derive(Debug, Deserialize)]
220struct CopilotErrorDetail {
221 message: Option<String>,
222 code: Option<String>,
223}
224
225#[derive(Debug, Deserialize)]
226struct CopilotModelsResponse {
227 data: Vec<CopilotModelInfo>,
228}
229
230#[derive(Debug, Deserialize)]
231struct CopilotModelInfo {
232 id: String,
233 #[serde(default)]
234 name: Option<String>,
235 #[serde(default)]
236 model_picker_enabled: Option<bool>,
237 #[serde(default)]
238 policy: Option<CopilotModelPolicy>,
239 #[serde(default)]
240 capabilities: Option<CopilotModelCapabilities>,
241}
242
243#[derive(Debug, Deserialize)]
244struct CopilotModelPolicy {
245 #[serde(default)]
246 state: Option<String>,
247}
248
249#[derive(Debug, Deserialize)]
250struct CopilotModelCapabilities {
251 #[serde(default)]
252 limits: Option<CopilotModelLimits>,
253 #[serde(default)]
254 supports: Option<CopilotModelSupports>,
255}
256
257#[derive(Debug, Deserialize)]
258struct CopilotModelLimits {
259 #[serde(default)]
260 max_context_window_tokens: Option<usize>,
261 #[serde(default)]
262 max_output_tokens: Option<usize>,
263}
264
265#[derive(Debug, Deserialize)]
266struct CopilotModelSupports {
267 #[serde(default)]
268 tool_calls: Option<bool>,
269 #[serde(default)]
270 vision: Option<bool>,
271 #[serde(default)]
272 streaming: Option<bool>,
273}
274
275#[async_trait]
276impl Provider for CopilotProvider {
277 fn name(&self) -> &str {
278 &self.provider_name
279 }
280
281 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
282 let response = self
283 .client
284 .get(format!("{}/models", self.base_url))
285 .header("Authorization", format!("Bearer {}", self.token))
286 .header("Openai-Intent", "conversation-edits")
287 .header("User-Agent", Self::user_agent())
288 .send()
289 .await
290 .context("Failed to fetch Copilot models")?;
291
292 let mut models: Vec<ModelInfo> = if response.status().is_success() {
293 let parsed: CopilotModelsResponse = response
294 .json()
295 .await
296 .unwrap_or(CopilotModelsResponse { data: vec![] });
297
298 parsed
299 .data
300 .into_iter()
301 .map(|model| {
302 let caps = model.capabilities.as_ref();
303 let limits = caps.and_then(|c| c.limits.as_ref());
304 let supports = caps.and_then(|c| c.supports.as_ref());
305
306 ModelInfo {
307 id: model.id.clone(),
308 name: model.name.unwrap_or_else(|| model.id.clone()),
309 provider: self.provider_name.clone(),
310 context_window: limits
311 .and_then(|l| l.max_context_window_tokens)
312 .unwrap_or(128_000),
313 max_output_tokens: limits
314 .and_then(|l| l.max_output_tokens)
315 .or(Some(16_384)),
316 supports_vision: supports.and_then(|s| s.vision).unwrap_or(false),
317 supports_tools: supports.and_then(|s| s.tool_calls).unwrap_or(true),
318 supports_streaming: supports.and_then(|s| s.streaming).unwrap_or(true),
319 input_cost_per_million: Some(0.0),
320 output_cost_per_million: Some(0.0),
321 }
322 })
323 .collect()
324 } else {
325 Vec::new()
326 };
327
328 let known_metadata: std::collections::HashMap<&str, (&str, usize, usize)> = [
334 ("claude-opus-4.5", ("Claude Opus 4.5", 200_000, 64_000)),
335 ("claude-opus-41", ("Claude Opus 4.1", 200_000, 64_000)),
336 ("claude-sonnet-4.5", ("Claude Sonnet 4.5", 200_000, 64_000)),
337 ("claude-sonnet-4", ("Claude Sonnet 4", 200_000, 64_000)),
338 ("claude-haiku-4.5", ("Claude Haiku 4.5", 200_000, 64_000)),
339 ("gpt-5.2", ("GPT-5.2", 400_000, 128_000)),
340 ("gpt-5.1", ("GPT-5.1", 400_000, 128_000)),
341 ("gpt-5.1-codex", ("GPT-5.1-Codex", 264_000, 64_000)),
342 ("gpt-5.1-codex-mini", ("GPT-5.1-Codex-Mini", 264_000, 64_000)),
343 ("gpt-5.1-codex-max", ("GPT-5.1-Codex-Max", 264_000, 64_000)),
344 ("gpt-5", ("GPT-5", 400_000, 128_000)),
345 ("gpt-5-mini", ("GPT-5 mini", 264_000, 64_000)),
346 ("gpt-4.1", ("GPT-4.1", 128_000, 32_768)),
347 ("gpt-4o", ("GPT-4o", 128_000, 16_384)),
348 ("gemini-2.5-pro", ("Gemini 2.5 Pro", 1_000_000, 64_000)),
349 ("grok-code-fast-1", ("Grok Code Fast 1", 128_000, 32_768)),
350 ]
351 .into_iter()
352 .collect();
353
354 for model in &mut models {
356 if let Some((name, ctx, max_out)) = known_metadata.get(model.id.as_str()) {
357 if model.name == model.id {
358 model.name = name.to_string();
359 }
360 if model.context_window == 128_000 {
361 model.context_window = *ctx;
362 }
363 if model.max_output_tokens == Some(16_384) {
364 model.max_output_tokens = Some(*max_out);
365 }
366 }
367 }
368
369 models.retain(|m| {
372 !m.id.starts_with("text-embedding")
373 && m.id != "gpt-3.5-turbo"
374 && m.id != "gpt-3.5-turbo-0613"
375 && m.id != "gpt-4-0613"
376 && m.id != "gpt-4o-2024-05-13"
377 && m.id != "gpt-4o-2024-08-06"
378 && m.id != "gpt-4o-2024-11-20"
379 && m.id != "gpt-4o-mini-2024-07-18"
380 && m.id != "gpt-4-o-preview"
381 && m.id != "gpt-4.1-2025-04-14"
382 });
383
384 let mut seen = std::collections::HashSet::new();
386 models.retain(|m| seen.insert(m.id.clone()));
387
388 Ok(models)
389 }
390
391 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
392 let messages = Self::convert_messages(&request.messages);
393 let tools = Self::convert_tools(&request.tools);
394 let is_agent = Self::is_agent_initiated(&request.messages);
395 let has_vision = Self::has_vision_input(&request.messages);
396
397 let mut body = json!({
398 "model": request.model,
399 "messages": messages,
400 });
401
402 if !tools.is_empty() {
403 body["tools"] = json!(tools);
404 }
405 if let Some(temp) = request.temperature {
406 body["temperature"] = json!(temp);
407 }
408 if let Some(top_p) = request.top_p {
409 body["top_p"] = json!(top_p);
410 }
411 if let Some(max) = request.max_tokens {
412 body["max_tokens"] = json!(max);
413 }
414 if !request.stop.is_empty() {
415 body["stop"] = json!(request.stop);
416 }
417
418 let mut req = self
419 .client
420 .post(format!("{}/chat/completions", self.base_url))
421 .header("Authorization", format!("Bearer {}", self.token))
422 .header("Content-Type", "application/json")
423 .header("Openai-Intent", "conversation-edits")
424 .header("User-Agent", Self::user_agent())
425 .header("X-Initiator", if is_agent { "agent" } else { "user" });
426
427 if has_vision {
428 req = req.header("Copilot-Vision-Request", "true");
429 }
430
431 let response = req
432 .json(&body)
433 .send()
434 .await
435 .context("Failed to send Copilot request")?;
436
437 let status = response.status();
438 let text = response
439 .text()
440 .await
441 .context("Failed to read Copilot response")?;
442
443 if !status.is_success() {
444 if let Ok(err) = serde_json::from_str::<CopilotErrorResponse>(&text) {
445 let message = err
446 .error
447 .and_then(|detail| {
448 detail.message.map(|msg| {
449 if let Some(code) = detail.code {
450 format!("{} ({})", msg, code)
451 } else {
452 msg
453 }
454 })
455 })
456 .or(err.message)
457 .unwrap_or_else(|| "Unknown Copilot API error".to_string());
458 anyhow::bail!("Copilot API error: {}", message);
459 }
460 anyhow::bail!("Copilot API error: {} {}", status, text);
461 }
462
463 let response: CopilotResponse = serde_json::from_str(&text).context(format!(
464 "Failed to parse Copilot response: {}",
465 &text[..text.len().min(200)]
466 ))?;
467
468 let choice = response
469 .choices
470 .first()
471 .ok_or_else(|| anyhow::anyhow!("No choices"))?;
472
473 let mut content = Vec::new();
474 let mut has_tool_calls = false;
475
476 if let Some(text) = &choice.message.content {
477 if !text.is_empty() {
478 content.push(ContentPart::Text { text: text.clone() });
479 }
480 }
481
482 if let Some(tool_calls) = &choice.message.tool_calls {
483 has_tool_calls = !tool_calls.is_empty();
484 for tc in tool_calls {
485 content.push(ContentPart::ToolCall {
486 id: tc.id.clone(),
487 name: tc.function.name.clone(),
488 arguments: tc.function.arguments.clone(),
489 });
490 }
491 }
492
493 let finish_reason = if has_tool_calls {
494 FinishReason::ToolCalls
495 } else {
496 match choice.finish_reason.as_deref() {
497 Some("stop") => FinishReason::Stop,
498 Some("length") => FinishReason::Length,
499 Some("tool_calls") => FinishReason::ToolCalls,
500 Some("content_filter") => FinishReason::ContentFilter,
501 _ => FinishReason::Stop,
502 }
503 };
504
505 Ok(CompletionResponse {
506 message: Message {
507 role: Role::Assistant,
508 content,
509 },
510 usage: Usage {
511 prompt_tokens: response
512 .usage
513 .as_ref()
514 .map(|u| u.prompt_tokens)
515 .unwrap_or(0),
516 completion_tokens: response
517 .usage
518 .as_ref()
519 .map(|u| u.completion_tokens)
520 .unwrap_or(0),
521 total_tokens: response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0),
522 ..Default::default()
523 },
524 finish_reason,
525 })
526 }
527
528 async fn complete_stream(
529 &self,
530 request: CompletionRequest,
531 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
532 let response = self.complete(request).await?;
534 let text = response
535 .message
536 .content
537 .iter()
538 .filter_map(|p| match p {
539 ContentPart::Text { text } => Some(text.clone()),
540 _ => None,
541 })
542 .collect::<Vec<_>>()
543 .join("");
544
545 Ok(Box::pin(futures::stream::once(async move {
546 StreamChunk::Text(text)
547 })))
548 }
549}
550
551pub fn normalize_enterprise_domain(input: &str) -> String {
552 input
553 .trim()
554 .trim_start_matches("https://")
555 .trim_start_matches("http://")
556 .trim_end_matches('/')
557 .to_string()
558}
559
560pub fn enterprise_base_url(enterprise_url: &str) -> String {
561 format!(
562 "https://copilot-api.{}",
563 normalize_enterprise_domain(enterprise_url)
564 )
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
572 fn normalize_enterprise_domain_handles_scheme_and_trailing_slash() {
573 assert_eq!(
574 normalize_enterprise_domain("https://company.ghe.com/"),
575 "company.ghe.com"
576 );
577 assert_eq!(
578 normalize_enterprise_domain("http://company.ghe.com"),
579 "company.ghe.com"
580 );
581 assert_eq!(
582 normalize_enterprise_domain("company.ghe.com"),
583 "company.ghe.com"
584 );
585 }
586
587 #[test]
588 fn enterprise_base_url_uses_copilot_api_subdomain() {
589 assert_eq!(
590 enterprise_base_url("https://company.ghe.com/"),
591 "https://copilot-api.company.ghe.com"
592 );
593 }
594}