1use crate::providers::traits::{
15 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
16 Provider, TokenUsage, ToolCall as ProviderToolCall,
17};
18use crate::tools::ToolSpec;
19use async_trait::async_trait;
20use reqwest::Client;
21use serde::{Deserialize, Serialize};
22use std::path::{Path, PathBuf};
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::sync::Mutex;
26use tracing::warn;
27
28const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
30const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
31const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
32const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
33const DEFAULT_API: &str = "https://api.githubcopilot.com";
34
35#[derive(Debug, Deserialize)]
38struct DeviceCodeResponse {
39 device_code: String,
40 user_code: String,
41 verification_uri: String,
42 #[serde(default = "default_interval")]
43 interval: u64,
44 #[serde(default = "default_expires_in")]
45 expires_in: u64,
46}
47
48fn default_interval() -> u64 {
49 5
50}
51
52fn default_expires_in() -> u64 {
53 900
54}
55
56#[derive(Debug, Deserialize)]
57struct AccessTokenResponse {
58 access_token: Option<String>,
59 error: Option<String>,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
63struct ApiKeyInfo {
64 token: String,
65 expires_at: i64,
66 #[serde(default)]
67 endpoints: Option<ApiEndpoints>,
68}
69
70#[derive(Debug, Serialize, Deserialize)]
71struct ApiEndpoints {
72 api: Option<String>,
73}
74
75struct CachedApiKey {
76 token: String,
77 api_endpoint: String,
78 expires_at: i64,
79}
80
81#[derive(Debug, Serialize)]
84struct ApiChatRequest<'a> {
85 model: String,
86 messages: Vec<ApiMessage>,
87 temperature: f64,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 tools: Option<Vec<NativeToolSpec<'a>>>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 tool_choice: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
95struct ApiMessage {
96 role: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 content: Option<ApiContent>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 tool_call_id: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 tool_calls: Option<Vec<NativeToolCall>>,
103}
104
105#[derive(Debug, Serialize)]
106struct NativeToolSpec<'a> {
107 #[serde(rename = "type")]
108 kind: &'static str,
109 function: NativeToolFunctionSpec<'a>,
110}
111
112#[derive(Debug, Serialize)]
113struct NativeToolFunctionSpec<'a> {
114 name: &'a str,
115 description: &'a str,
116 parameters: &'a serde_json::Value,
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct NativeToolCall {
121 #[serde(skip_serializing_if = "Option::is_none")]
122 id: Option<String>,
123 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
124 kind: Option<String>,
125 function: NativeFunctionCall,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129struct NativeFunctionCall {
130 name: String,
131 arguments: String,
132}
133
134#[derive(Debug, Clone, Serialize)]
136#[serde(untagged)]
137enum ApiContent {
138 Text(String),
139 Parts(Vec<ContentPart>),
140}
141
142#[derive(Debug, Clone, Serialize)]
143#[serde(tag = "type")]
144enum ContentPart {
145 #[serde(rename = "text")]
146 Text { text: String },
147 #[serde(rename = "image_url")]
148 ImageUrl { image_url: ImageUrlDetail },
149}
150
151#[derive(Debug, Clone, Serialize)]
152struct ImageUrlDetail {
153 url: String,
154}
155
156#[derive(Debug, Deserialize)]
157struct ApiChatResponse {
158 choices: Vec<Choice>,
159 #[serde(default)]
160 usage: Option<UsageInfo>,
161}
162
163#[derive(Debug, Deserialize)]
164struct UsageInfo {
165 #[serde(default)]
166 prompt_tokens: Option<u64>,
167 #[serde(default)]
168 completion_tokens: Option<u64>,
169}
170
171#[derive(Debug, Deserialize)]
172struct Choice {
173 message: ResponseMessage,
174}
175
176#[derive(Debug, Deserialize)]
177struct ResponseMessage {
178 #[serde(default)]
179 content: Option<String>,
180 #[serde(default)]
181 tool_calls: Option<Vec<NativeToolCall>>,
182}
183
184pub struct CopilotProvider {
192 github_token: Option<String>,
193 refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
196 token_dir: PathBuf,
197}
198
199impl CopilotProvider {
200 pub fn new(github_token: Option<&str>) -> Self {
201 let token_dir = directories::ProjectDirs::from("", "", "construct")
202 .map(|dir| dir.config_dir().join("copilot"))
203 .unwrap_or_else(|| {
204 let user = std::env::var("USER")
207 .or_else(|_| std::env::var("USERNAME"))
208 .unwrap_or_else(|_| "unknown".to_string());
209 std::env::temp_dir().join(format!("construct-copilot-{user}"))
210 });
211
212 if let Err(err) = std::fs::create_dir_all(&token_dir) {
213 warn!(
214 "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
215 token_dir
216 );
217 } else {
218 #[cfg(unix)]
219 {
220 use std::os::unix::fs::PermissionsExt;
221
222 if let Err(err) =
223 std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
224 {
225 warn!(
226 "Failed to set Copilot token directory permissions on {:?}: {err}",
227 token_dir
228 );
229 }
230 }
231 }
232
233 Self {
234 github_token: github_token
235 .filter(|token| !token.is_empty())
236 .map(String::from),
237 refresh_lock: Arc::new(Mutex::new(None)),
238 token_dir,
239 }
240 }
241
242 fn http_client(&self) -> Client {
243 crate::config::build_runtime_proxy_client_with_timeouts("provider.copilot", 120, 10)
244 }
245
246 const COPILOT_HEADERS: [(&str, &str); 4] = [
248 ("Editor-Version", "vscode/1.85.1"),
249 ("Editor-Plugin-Version", "copilot/1.155.0"),
250 ("User-Agent", "GithubCopilot/1.155.0"),
251 ("Accept", "application/json"),
252 ];
253
254 fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
255 tools.map(|items| {
256 items
257 .iter()
258 .map(|tool| NativeToolSpec {
259 kind: "function",
260 function: NativeToolFunctionSpec {
261 name: &tool.name,
262 description: &tool.description,
263 parameters: &tool.parameters,
264 },
265 })
266 .collect()
267 })
268 }
269
270 fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
273 if role != "user" {
274 return Some(ApiContent::Text(content.to_string()));
275 }
276
277 let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
278 if image_refs.is_empty() {
279 return Some(ApiContent::Text(content.to_string()));
280 }
281
282 let mut parts = Vec::with_capacity(image_refs.len() + 1);
283 let trimmed = cleaned_text.trim();
284 if !trimmed.is_empty() {
285 parts.push(ContentPart::Text {
286 text: trimmed.to_string(),
287 });
288 }
289 for image_ref in image_refs {
290 parts.push(ContentPart::ImageUrl {
291 image_url: ImageUrlDetail { url: image_ref },
292 });
293 }
294
295 Some(ApiContent::Parts(parts))
296 }
297
298 fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
299 messages
300 .iter()
301 .map(|message| {
302 if message.role == "assistant" {
303 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
304 if let Some(tool_calls_value) = value.get("tool_calls") {
305 if let Ok(parsed_calls) =
306 serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
307 {
308 let tool_calls = parsed_calls
309 .into_iter()
310 .map(|tool_call| NativeToolCall {
311 id: Some(tool_call.id),
312 kind: Some("function".to_string()),
313 function: NativeFunctionCall {
314 name: tool_call.name,
315 arguments: tool_call.arguments,
316 },
317 })
318 .collect::<Vec<_>>();
319
320 let content = value
321 .get("content")
322 .and_then(serde_json::Value::as_str)
323 .map(|s| ApiContent::Text(s.to_string()));
324
325 return ApiMessage {
326 role: "assistant".to_string(),
327 content,
328 tool_call_id: None,
329 tool_calls: Some(tool_calls),
330 };
331 }
332 }
333 }
334 }
335
336 if message.role == "tool" {
337 if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
338 let tool_call_id = value
339 .get("tool_call_id")
340 .and_then(serde_json::Value::as_str)
341 .map(ToString::to_string);
342 let content = value
343 .get("content")
344 .and_then(serde_json::Value::as_str)
345 .map(|s| ApiContent::Text(s.to_string()));
346
347 return ApiMessage {
348 role: "tool".to_string(),
349 content,
350 tool_call_id,
351 tool_calls: None,
352 };
353 }
354 }
355
356 ApiMessage {
357 role: message.role.clone(),
358 content: Self::to_api_content(&message.role, &message.content),
359 tool_call_id: None,
360 tool_calls: None,
361 }
362 })
363 .collect()
364 }
365
366 async fn send_chat_request(
368 &self,
369 messages: Vec<ApiMessage>,
370 tools: Option<&[ToolSpec]>,
371 model: &str,
372 temperature: f64,
373 ) -> anyhow::Result<ProviderChatResponse> {
374 let (token, endpoint) = self.get_api_key().await?;
375 let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
376
377 let native_tools = Self::convert_tools(tools);
378 let request = ApiChatRequest {
379 model: model.to_string(),
380 messages,
381 temperature,
382 tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
383 tools: native_tools,
384 };
385
386 let mut req = self
387 .http_client()
388 .post(&url)
389 .header("Authorization", format!("Bearer {token}"))
390 .json(&request);
391
392 for (header, value) in &Self::COPILOT_HEADERS {
393 req = req.header(*header, *value);
394 }
395
396 let response = req.send().await?;
397
398 if !response.status().is_success() {
399 return Err(super::api_error("GitHub Copilot", response).await);
400 }
401
402 let api_response: ApiChatResponse = response.json().await?;
403 let usage = api_response.usage.map(|u| TokenUsage {
404 input_tokens: u.prompt_tokens,
405 output_tokens: u.completion_tokens,
406 cached_input_tokens: None,
407 });
408 let choice = api_response
409 .choices
410 .into_iter()
411 .next()
412 .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
413
414 let tool_calls = choice
415 .message
416 .tool_calls
417 .unwrap_or_default()
418 .into_iter()
419 .map(|tool_call| ProviderToolCall {
420 id: tool_call
421 .id
422 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
423 name: tool_call.function.name,
424 arguments: tool_call.function.arguments,
425 })
426 .collect();
427
428 Ok(ProviderChatResponse {
429 text: choice.message.content,
430 tool_calls,
431 usage,
432 reasoning_content: None,
433 })
434 }
435
436 async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
439 let mut cached = self.refresh_lock.lock().await;
440
441 if let Some(cached_key) = cached.as_ref() {
442 if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
443 return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
444 }
445 }
446
447 if let Some(info) = self.load_api_key_from_disk().await {
448 if chrono::Utc::now().timestamp() + 120 < info.expires_at {
449 let endpoint = info
450 .endpoints
451 .as_ref()
452 .and_then(|e| e.api.clone())
453 .unwrap_or_else(|| DEFAULT_API.to_string());
454 let token = info.token;
455
456 *cached = Some(CachedApiKey {
457 token: token.clone(),
458 api_endpoint: endpoint.clone(),
459 expires_at: info.expires_at,
460 });
461 return Ok((token, endpoint));
462 }
463 }
464
465 let access_token = self.get_github_access_token().await?;
466 let api_key_info = self.exchange_for_api_key(&access_token).await?;
467 self.save_api_key_to_disk(&api_key_info).await;
468
469 let endpoint = api_key_info
470 .endpoints
471 .as_ref()
472 .and_then(|e| e.api.clone())
473 .unwrap_or_else(|| DEFAULT_API.to_string());
474
475 *cached = Some(CachedApiKey {
476 token: api_key_info.token.clone(),
477 api_endpoint: endpoint.clone(),
478 expires_at: api_key_info.expires_at,
479 });
480
481 Ok((api_key_info.token, endpoint))
482 }
483
484 async fn get_github_access_token(&self) -> anyhow::Result<String> {
486 if let Some(token) = &self.github_token {
487 return Ok(token.clone());
488 }
489
490 let access_token_path = self.token_dir.join("access-token");
491 if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
492 let token = cached.trim();
493 if !token.is_empty() {
494 return Ok(token.to_string());
495 }
496 }
497
498 let token = self.device_code_login().await?;
499 write_file_secure(&access_token_path, &token).await;
500 Ok(token)
501 }
502
503 async fn device_code_login(&self) -> anyhow::Result<String> {
505 let response: DeviceCodeResponse = self
506 .http_client()
507 .post(GITHUB_DEVICE_CODE_URL)
508 .header("Accept", "application/json")
509 .json(&serde_json::json!({
510 "client_id": GITHUB_CLIENT_ID,
511 "scope": "read:user"
512 }))
513 .send()
514 .await?
515 .error_for_status()?
516 .json()
517 .await?;
518
519 let mut poll_interval = Duration::from_secs(response.interval.max(5));
520 let expires_in = response.expires_in.max(1);
521 let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
522
523 eprintln!(
524 "\nGitHub Copilot authentication is required.\n\
525 Visit: {}\n\
526 Code: {}\n\
527 Waiting for authorization...\n",
528 response.verification_uri, response.user_code
529 );
530
531 while tokio::time::Instant::now() < expires_at {
532 tokio::time::sleep(poll_interval).await;
533
534 let token_response: AccessTokenResponse = self
535 .http_client()
536 .post(GITHUB_ACCESS_TOKEN_URL)
537 .header("Accept", "application/json")
538 .json(&serde_json::json!({
539 "client_id": GITHUB_CLIENT_ID,
540 "device_code": response.device_code,
541 "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
542 }))
543 .send()
544 .await?
545 .json()
546 .await?;
547
548 if let Some(token) = token_response.access_token {
549 eprintln!("Authentication succeeded.\n");
550 return Ok(token);
551 }
552
553 match token_response.error.as_deref() {
554 Some("slow_down") => {
555 poll_interval += Duration::from_secs(5);
556 }
557 Some("authorization_pending") | None => {}
558 Some("expired_token") => {
559 anyhow::bail!("GitHub device authorization expired")
560 }
561 Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
562 }
563 }
564
565 anyhow::bail!("Timed out waiting for GitHub authorization")
566 }
567
568 async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
570 let mut request = self.http_client().get(GITHUB_API_KEY_URL);
571 for (header, value) in &Self::COPILOT_HEADERS {
572 request = request.header(*header, *value);
573 }
574 request = request.header("Authorization", format!("token {access_token}"));
575
576 let response = request.send().await?;
577
578 if !response.status().is_success() {
579 let status = response.status();
580 let body = response.text().await.unwrap_or_default();
581 let sanitized = super::sanitize_api_error(&body);
582
583 if status.as_u16() == 401 || status.as_u16() == 403 {
584 let access_token_path = self.token_dir.join("access-token");
585 tokio::fs::remove_file(&access_token_path).await.ok();
586 }
587
588 anyhow::bail!(
589 "Failed to get Copilot API key ({status}): {sanitized}. \
590 Ensure your GitHub account has an active Copilot subscription."
591 );
592 }
593
594 let info: ApiKeyInfo = response.json().await?;
595 Ok(info)
596 }
597
598 async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
599 let path = self.token_dir.join("api-key.json");
600 let data = tokio::fs::read_to_string(&path).await.ok()?;
601 serde_json::from_str(&data).ok()
602 }
603
604 async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
605 let path = self.token_dir.join("api-key.json");
606 if let Ok(json) = serde_json::to_string_pretty(info) {
607 write_file_secure(&path, &json).await;
608 }
609 }
610}
611
612async fn write_file_secure(path: &Path, content: &str) {
615 let path = path.to_path_buf();
616 let content = content.to_string();
617
618 let result = tokio::task::spawn_blocking(move || {
619 #[cfg(unix)]
620 {
621 use std::io::Write;
622 use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
623
624 let mut file = std::fs::OpenOptions::new()
625 .write(true)
626 .create(true)
627 .truncate(true)
628 .mode(0o600)
629 .open(&path)?;
630 file.write_all(content.as_bytes())?;
631
632 std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
633 Ok::<(), std::io::Error>(())
634 }
635 #[cfg(not(unix))]
636 {
637 std::fs::write(&path, &content)?;
638 Ok::<(), std::io::Error>(())
639 }
640 })
641 .await;
642
643 match result {
644 Ok(Ok(())) => {}
645 Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
646 Err(err) => warn!("Failed to spawn blocking write: {err}"),
647 }
648}
649
650#[async_trait]
651impl Provider for CopilotProvider {
652 async fn chat_with_system(
653 &self,
654 system_prompt: Option<&str>,
655 message: &str,
656 model: &str,
657 temperature: f64,
658 ) -> anyhow::Result<String> {
659 let mut messages = Vec::new();
660 if let Some(system) = system_prompt {
661 messages.push(ApiMessage {
662 role: "system".to_string(),
663 content: Some(ApiContent::Text(system.to_string())),
664 tool_call_id: None,
665 tool_calls: None,
666 });
667 }
668 messages.push(ApiMessage {
669 role: "user".to_string(),
670 content: Self::to_api_content("user", message),
671 tool_call_id: None,
672 tool_calls: None,
673 });
674
675 let response = self
676 .send_chat_request(messages, None, model, temperature)
677 .await?;
678 Ok(response.text.unwrap_or_default())
679 }
680
681 async fn chat_with_history(
682 &self,
683 messages: &[ChatMessage],
684 model: &str,
685 temperature: f64,
686 ) -> anyhow::Result<String> {
687 let response = self
688 .send_chat_request(Self::convert_messages(messages), None, model, temperature)
689 .await?;
690 Ok(response.text.unwrap_or_default())
691 }
692
693 async fn chat(
694 &self,
695 request: ProviderChatRequest<'_>,
696 model: &str,
697 temperature: f64,
698 ) -> anyhow::Result<ProviderChatResponse> {
699 self.send_chat_request(
700 Self::convert_messages(request.messages),
701 request.tools,
702 model,
703 temperature,
704 )
705 .await
706 }
707
708 fn supports_native_tools(&self) -> bool {
709 true
710 }
711
712 async fn warmup(&self) -> anyhow::Result<()> {
713 let _ = self.get_api_key().await?;
714 Ok(())
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
723 fn new_without_token() {
724 let provider = CopilotProvider::new(None);
725 assert!(provider.github_token.is_none());
726 }
727
728 #[test]
729 fn new_with_token() {
730 let provider = CopilotProvider::new(Some("ghp_test"));
731 assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
732 }
733
734 #[test]
735 fn empty_token_treated_as_none() {
736 let provider = CopilotProvider::new(Some(""));
737 assert!(provider.github_token.is_none());
738 }
739
740 #[tokio::test]
741 async fn cache_starts_empty() {
742 let provider = CopilotProvider::new(None);
743 let cached = provider.refresh_lock.lock().await;
744 assert!(cached.is_none());
745 }
746
747 #[test]
748 fn copilot_headers_include_required_fields() {
749 let headers = CopilotProvider::COPILOT_HEADERS;
750 assert!(
751 headers
752 .iter()
753 .any(|(header, _)| *header == "Editor-Version")
754 );
755 assert!(
756 headers
757 .iter()
758 .any(|(header, _)| *header == "Editor-Plugin-Version")
759 );
760 assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
761 }
762
763 #[test]
764 fn default_interval_and_expiry() {
765 assert_eq!(default_interval(), 5);
766 assert_eq!(default_expires_in(), 900);
767 }
768
769 #[test]
770 fn supports_native_tools() {
771 let provider = CopilotProvider::new(None);
772 assert!(provider.supports_native_tools());
773 }
774
775 #[test]
776 fn api_response_parses_usage() {
777 let json = r#"{
778 "choices": [{"message": {"content": "Hello"}}],
779 "usage": {"prompt_tokens": 200, "completion_tokens": 80}
780 }"#;
781 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
782 let usage = resp.usage.unwrap();
783 assert_eq!(usage.prompt_tokens, Some(200));
784 assert_eq!(usage.completion_tokens, Some(80));
785 }
786
787 #[test]
788 fn api_response_parses_without_usage() {
789 let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
790 let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
791 assert!(resp.usage.is_none());
792 }
793
794 #[test]
795 fn to_api_content_user_with_image_returns_parts() {
796 let content = "describe this [IMAGE:data:image/png;base64,abc123]";
797 let result = CopilotProvider::to_api_content("user", content).unwrap();
798 match result {
799 ApiContent::Parts(parts) => {
800 assert_eq!(parts.len(), 2);
801 assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
802 assert!(
803 matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
804 );
805 }
806 ApiContent::Text(_) => {
807 panic!("expected ApiContent::Parts for user message with image marker")
808 }
809 }
810 }
811
812 #[test]
813 fn to_api_content_user_plain_returns_text() {
814 let result = CopilotProvider::to_api_content("user", "hello world").unwrap();
815 assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
816 }
817
818 #[test]
819 fn to_api_content_non_user_returns_text() {
820 let result = CopilotProvider::to_api_content("system", "you are helpful").unwrap();
821 assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
822
823 let result = CopilotProvider::to_api_content("assistant", "sure").unwrap();
824 assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
825 }
826}