1use async_trait::async_trait;
2use futures::Stream;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::{SystemTime, UNIX_EPOCH};
10
11use super::{
12 FinishReason, GenerateOptions, GenerateResult, LanguageModel, Message, MessageContent,
13 MessagePart, MessageRole, StreamChunk, StreamOptions, ToolCall, ToolDefinition, Usage,
14};
15use crate::auth::{Auth, AuthCredentials};
16
17pub struct GitHubCopilotProvider {
19 auth: Box<dyn Auth>,
20 client: Client,
21 models: HashMap<String, GitHubCopilotModel>,
22}
23
24#[derive(Debug, Clone)]
25pub struct GitHubCopilotModel {
26 pub id: String,
27 pub name: String,
28 pub max_tokens: u32,
29 pub supports_tools: bool,
30 pub supports_vision: bool,
31 pub supports_caching: bool,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35struct DeviceCodeRequest {
36 client_id: String,
37 scope: String,
38}
39
40#[derive(Debug, Deserialize)]
41struct DeviceCodeResponse {
42 device_code: String,
43 user_code: String,
44 verification_uri: String,
45 expires_in: u32,
46 interval: u32,
47}
48
49#[derive(Debug, Serialize)]
50struct AccessTokenRequest {
51 client_id: String,
52 device_code: String,
53 grant_type: String,
54}
55
56#[derive(Debug, Deserialize)]
57struct AccessTokenResponse {
58 access_token: Option<String>,
59 error: Option<String>,
60 error_description: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct CopilotTokenResponse {
65 token: String,
66 expires_at: u64,
67 refresh_in: u64,
68 endpoints: CopilotEndpoints,
69}
70
71#[derive(Debug, Deserialize)]
72struct CopilotEndpoints {
73 api: String,
74}
75
76#[derive(Debug, Serialize)]
77struct CopilotRequest {
78 model: String,
79 messages: Vec<CopilotMessage>,
80 max_tokens: u32,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 temperature: Option<f32>,
83 #[serde(skip_serializing_if = "Vec::is_empty")]
84 tools: Vec<CopilotTool>,
85 #[serde(skip_serializing_if = "Vec::is_empty")]
86 stop: Vec<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 stream: Option<bool>,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct CopilotMessage {
93 role: String,
94 content: CopilotContent,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 name: Option<String>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 tool_calls: Option<Vec<CopilotToolCall>>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 tool_call_id: Option<String>,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
104#[serde(untagged)]
105enum CopilotContent {
106 Text(String),
107 Parts(Vec<CopilotContentPart>),
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111#[serde(tag = "type")]
112enum CopilotContentPart {
113 #[serde(rename = "text")]
114 Text { text: String },
115 #[serde(rename = "image_url")]
116 ImageUrl { image_url: CopilotImageUrl },
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct CopilotImageUrl {
121 url: String,
122 detail: Option<String>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126struct CopilotTool {
127 #[serde(rename = "type")]
128 tool_type: String,
129 function: CopilotFunction,
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133struct CopilotFunction {
134 name: String,
135 description: String,
136 parameters: Value,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct CopilotToolCall {
141 id: String,
142 #[serde(rename = "type")]
143 tool_type: String,
144 function: CopilotFunctionCall,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148struct CopilotFunctionCall {
149 name: String,
150 arguments: String,
151}
152
153#[derive(Debug, Deserialize)]
154struct CopilotResponse {
155 choices: Vec<CopilotChoice>,
156 usage: CopilotUsage,
157}
158
159#[derive(Debug, Deserialize)]
160struct CopilotChoice {
161 message: CopilotMessage,
162 finish_reason: Option<String>,
163}
164
165#[derive(Debug, Deserialize)]
166struct CopilotUsage {
167 prompt_tokens: u32,
168 completion_tokens: u32,
169 total_tokens: u32,
170}
171
172impl GitHubCopilotProvider {
173 const CLIENT_ID: &'static str = "Iv1.b507a08c87ecfe98";
174 const DEVICE_CODE_URL: &'static str = "https://github.com/login/device/code";
175 const ACCESS_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
176 const COPILOT_TOKEN_URL: &'static str = "https://api.github.com/copilot_internal/v2/token";
177 const API_BASE: &'static str = "https://api.githubcopilot.com";
178
179 pub fn new(auth: Box<dyn Auth>) -> Self {
180 let client = Client::new();
181 let models = Self::default_models();
182
183 Self {
184 auth,
185 client,
186 models,
187 }
188 }
189
190 fn default_models() -> HashMap<String, GitHubCopilotModel> {
191 let mut models = HashMap::new();
192
193 models.insert(
194 "gpt-4o".to_string(),
195 GitHubCopilotModel {
196 id: "gpt-4o".to_string(),
197 name: "GPT-4o".to_string(),
198 max_tokens: 4096,
199 supports_tools: true,
200 supports_vision: true,
201 supports_caching: false,
202 },
203 );
204
205 models.insert(
206 "gpt-4o-mini".to_string(),
207 GitHubCopilotModel {
208 id: "gpt-4o-mini".to_string(),
209 name: "GPT-4o Mini".to_string(),
210 max_tokens: 4096,
211 supports_tools: true,
212 supports_vision: true,
213 supports_caching: false,
214 },
215 );
216
217 models.insert(
218 "o1-preview".to_string(),
219 GitHubCopilotModel {
220 id: "o1-preview".to_string(),
221 name: "OpenAI o1 Preview".to_string(),
222 max_tokens: 32768,
223 supports_tools: false,
224 supports_vision: false,
225 supports_caching: false,
226 },
227 );
228
229 models
230 }
231
232 pub async fn start_device_flow() -> crate::Result<DeviceCodeResponse> {
234 let client = Client::new();
235 let request = DeviceCodeRequest {
236 client_id: Self::CLIENT_ID.to_string(),
237 scope: "read:user".to_string(),
238 };
239
240 let response = client
241 .post(Self::DEVICE_CODE_URL)
242 .header("Accept", "application/json")
243 .header("Content-Type", "application/json")
244 .header("User-Agent", "GitHubCopilotChat/0.26.7")
245 .json(&request)
246 .send()
247 .await
248 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Device code request failed: {}", e)))?;
249
250 if !response.status().is_success() {
251 return Err(crate::Error::Other(anyhow::anyhow!(
252 "Device code request failed with status: {}",
253 response.status()
254 )));
255 }
256
257 let device_response: DeviceCodeResponse = response
258 .json()
259 .await
260 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse device code response: {}", e)))?;
261
262 Ok(device_response)
263 }
264
265 pub async fn poll_for_token(device_code: &str) -> crate::Result<Option<String>> {
267 let client = Client::new();
268 let request = AccessTokenRequest {
269 client_id: Self::CLIENT_ID.to_string(),
270 device_code: device_code.to_string(),
271 grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
272 };
273
274 let response = client
275 .post(Self::ACCESS_TOKEN_URL)
276 .header("Accept", "application/json")
277 .header("Content-Type", "application/json")
278 .header("User-Agent", "GitHubCopilotChat/0.26.7")
279 .json(&request)
280 .send()
281 .await
282 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Token poll request failed: {}", e)))?;
283
284 if !response.status().is_success() {
285 return Ok(None);
286 }
287
288 let token_response: AccessTokenResponse = response
289 .json()
290 .await
291 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse token response: {}", e)))?;
292
293 if let Some(access_token) = token_response.access_token {
294 Ok(Some(access_token))
295 } else if token_response.error.as_deref() == Some("authorization_pending") {
296 Ok(None)
297 } else {
298 Err(crate::Error::Other(anyhow::anyhow!(
299 "Token exchange failed: {:?}",
300 token_response.error
301 )))
302 }
303 }
304
305 pub async fn get_copilot_token(github_token: &str) -> crate::Result<AuthCredentials> {
307 let client = Client::new();
308
309 let response = client
310 .get(Self::COPILOT_TOKEN_URL)
311 .header("Accept", "application/json")
312 .header("Authorization", format!("Bearer {}", github_token))
313 .header("User-Agent", "GitHubCopilotChat/0.26.7")
314 .header("Editor-Version", "vscode/1.99.3")
315 .header("Editor-Plugin-Version", "copilot-chat/0.26.7")
316 .send()
317 .await
318 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Copilot token request failed: {}", e)))?;
319
320 if !response.status().is_success() {
321 return Err(crate::Error::Other(anyhow::anyhow!(
322 "Copilot token request failed with status: {}",
323 response.status()
324 )));
325 }
326
327 let token_response: CopilotTokenResponse = response
328 .json()
329 .await
330 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse copilot token response: {}", e)))?;
331
332 Ok(AuthCredentials::OAuth {
333 access_token: token_response.token,
334 refresh_token: Some(github_token.to_string()), expires_at: Some(token_response.expires_at),
336 })
337 }
338
339 async fn get_auth_headers(&self) -> crate::Result<HashMap<String, String>> {
340 let credentials = self.auth.get_credentials().await?;
341
342 let mut headers = HashMap::new();
343 headers.insert("User-Agent".to_string(), "GitHubCopilotChat/0.26.7".to_string());
344 headers.insert("Editor-Version".to_string(), "vscode/1.99.3".to_string());
345 headers.insert("Editor-Plugin-Version".to_string(), "copilot-chat/0.26.7".to_string());
346 headers.insert("Openai-Intent".to_string(), "conversation-edits".to_string());
347
348 match credentials {
349 AuthCredentials::OAuth { access_token, refresh_token, expires_at } => {
350 if let Some(exp) = expires_at {
352 let now = SystemTime::now()
353 .duration_since(UNIX_EPOCH)
354 .unwrap()
355 .as_secs();
356
357 if now >= exp {
358 if let Some(github_token) = refresh_token {
359 let new_creds = Self::get_copilot_token(&github_token).await?;
360 self.auth.set_credentials(new_creds.clone()).await?;
361
362 if let AuthCredentials::OAuth { access_token, .. } = new_creds {
363 headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
364 }
365 } else {
366 return Err(crate::Error::Other(anyhow::anyhow!("Token expired and no refresh token available")));
367 }
368 } else {
369 headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
370 }
371 } else {
372 headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
373 }
374 }
375 _ => {
376 return Err(crate::Error::Other(anyhow::anyhow!(
377 "Invalid credentials for GitHub Copilot"
378 )));
379 }
380 }
381
382 Ok(headers)
383 }
384
385 fn convert_messages(&self, messages: Vec<Message>) -> Vec<CopilotMessage> {
386 messages
387 .into_iter()
388 .map(|msg| self.convert_message(msg))
389 .collect()
390 }
391
392 fn convert_message(&self, message: Message) -> CopilotMessage {
393 let role = match message.role {
394 MessageRole::System => "system",
395 MessageRole::User => "user",
396 MessageRole::Assistant => "assistant",
397 MessageRole::Tool => "tool",
398 }
399 .to_string();
400
401 let content = match message.content {
402 MessageContent::Text(text) => CopilotContent::Text(text),
403 MessageContent::Parts(parts) => {
404 let copilot_parts: Vec<CopilotContentPart> = parts
405 .into_iter()
406 .filter_map(|part| match part {
407 MessagePart::Text { text } => Some(CopilotContentPart::Text { text }),
408 MessagePart::Image { image } => {
409 if let Some(url) = image.url {
410 Some(CopilotContentPart::ImageUrl {
411 image_url: CopilotImageUrl {
412 url,
413 detail: Some("auto".to_string()),
414 },
415 })
416 } else if let Some(base64) = image.base64 {
417 Some(CopilotContentPart::ImageUrl {
418 image_url: CopilotImageUrl {
419 url: format!("data:{};base64,{}", image.mime_type, base64),
420 detail: Some("auto".to_string()),
421 },
422 })
423 } else {
424 None
425 }
426 }
427 })
428 .collect();
429 CopilotContent::Parts(copilot_parts)
430 }
431 };
432
433 let tool_calls = message.tool_calls.map(|calls| {
434 calls
435 .into_iter()
436 .map(|call| CopilotToolCall {
437 id: call.id,
438 tool_type: "function".to_string(),
439 function: CopilotFunctionCall {
440 name: call.name,
441 arguments: call.arguments.to_string(),
442 },
443 })
444 .collect()
445 });
446
447 CopilotMessage {
448 role,
449 content,
450 name: message.name,
451 tool_calls,
452 tool_call_id: message.tool_call_id,
453 }
454 }
455
456 fn convert_tools(&self, tools: Vec<ToolDefinition>) -> Vec<CopilotTool> {
457 tools
458 .into_iter()
459 .map(|tool| CopilotTool {
460 tool_type: "function".to_string(),
461 function: CopilotFunction {
462 name: tool.name,
463 description: tool.description,
464 parameters: tool.parameters,
465 },
466 })
467 .collect()
468 }
469
470 fn parse_finish_reason(&self, reason: Option<String>) -> FinishReason {
471 match reason.as_deref() {
472 Some("stop") => FinishReason::Stop,
473 Some("length") => FinishReason::Length,
474 Some("tool_calls") => FinishReason::ToolCalls,
475 Some("content_filter") => FinishReason::ContentFilter,
476 _ => FinishReason::Stop,
477 }
478 }
479}
480
481pub struct GitHubCopilotModelWithProvider {
482 model: GitHubCopilotModel,
483 provider: GitHubCopilotProvider,
484}
485
486impl GitHubCopilotModelWithProvider {
487 pub fn new(model: GitHubCopilotModel, provider: GitHubCopilotProvider) -> Self {
488 Self { model, provider }
489 }
490}
491
492#[async_trait]
493impl LanguageModel for GitHubCopilotModelWithProvider {
494 async fn generate(
495 &self,
496 messages: Vec<Message>,
497 options: GenerateOptions,
498 ) -> crate::Result<GenerateResult> {
499 let headers = self.provider.get_auth_headers().await?;
500 let copilot_messages = self.provider.convert_messages(messages);
501 let tools = self.provider.convert_tools(options.tools);
502
503 let request = CopilotRequest {
504 model: self.model.id.clone(),
505 messages: copilot_messages,
506 max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
507 temperature: options.temperature,
508 tools,
509 stop: options.stop_sequences,
510 stream: Some(false),
511 };
512
513 let mut req_builder = self
514 .provider
515 .client
516 .post(&format!("{}/v1/chat/completions", GitHubCopilotProvider::API_BASE))
517 .header("Content-Type", "application/json");
518
519 for (key, value) in headers {
520 req_builder = req_builder.header(&key, &value);
521 }
522
523 let response = req_builder
524 .json(&request)
525 .send()
526 .await
527 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
528
529 if !response.status().is_success() {
530 let status = response.status();
531 let body = response.text().await.unwrap_or_default();
532 return Err(crate::Error::Other(anyhow::anyhow!(
533 "API request failed with status {}: {}",
534 status,
535 body
536 )));
537 }
538
539 let copilot_response: CopilotResponse = response
540 .json()
541 .await
542 .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
543
544 let choice = copilot_response
545 .choices
546 .into_iter()
547 .next()
548 .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
549
550 let content = match choice.message.content {
551 CopilotContent::Text(text) => text,
552 CopilotContent::Parts(parts) => {
553 parts
554 .into_iter()
555 .filter_map(|part| match part {
556 CopilotContentPart::Text { text } => Some(text),
557 _ => None,
558 })
559 .collect::<Vec<_>>()
560 .join("")
561 }
562 };
563
564 let tool_calls = choice
565 .message
566 .tool_calls
567 .unwrap_or_default()
568 .into_iter()
569 .map(|call| ToolCall {
570 id: call.id,
571 name: call.function.name,
572 arguments: serde_json::from_str(&call.function.arguments)
573 .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
574 })
575 .collect();
576
577 Ok(GenerateResult {
578 content,
579 tool_calls,
580 usage: Usage {
581 prompt_tokens: copilot_response.usage.prompt_tokens,
582 completion_tokens: copilot_response.usage.completion_tokens,
583 total_tokens: copilot_response.usage.total_tokens,
584 },
585 finish_reason: self.provider.parse_finish_reason(choice.finish_reason),
586 })
587 }
588
589 async fn stream(
590 &self,
591 messages: Vec<Message>,
592 options: StreamOptions,
593 ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
594 Err(crate::Error::Other(anyhow::anyhow!(
597 "Streaming not yet implemented for GitHub Copilot"
598 )))
599 }
600
601 fn supports_tools(&self) -> bool {
602 self.model.supports_tools
603 }
604
605 fn supports_vision(&self) -> bool {
606 self.model.supports_vision
607 }
608
609 fn supports_caching(&self) -> bool {
610 self.model.supports_caching
611 }
612}
613
614#[derive(Debug, thiserror::Error)]
615pub enum GitHubCopilotError {
616 #[error("Device code flow failed")]
617 DeviceCodeFailed,
618
619 #[error("Token exchange failed")]
620 TokenExchangeFailed,
621
622 #[error("Authentication expired")]
623 AuthenticationExpired,
624
625 #[error("Copilot token request failed")]
626 CopilotTokenFailed,
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn test_default_models() {
635 let models = GitHubCopilotProvider::default_models();
636 assert!(!models.is_empty());
637 assert!(models.contains_key("gpt-4o"));
638 assert!(models.contains_key("gpt-4o-mini"));
639 assert!(models.contains_key("o1-preview"));
640 }
641
642 #[test]
643 fn test_model_capabilities() {
644 let models = GitHubCopilotProvider::default_models();
645 let gpt4o = models.get("gpt-4o").unwrap();
646 assert!(gpt4o.supports_tools);
647 assert!(gpt4o.supports_vision);
648
649 let o1 = models.get("o1-preview").unwrap();
650 assert!(!o1.supports_tools);
651 assert!(!o1.supports_vision);
652 }
653}