1use crate::error::{Error, Result};
13use crate::http::client::Client;
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
16use async_trait::async_trait;
17use futures::Stream;
18use serde::Deserialize;
19use std::pin::Pin;
20use std::sync::Mutex;
21
22use super::openai::OpenAIProvider;
23
24const GITHUB_API_BASE: &str = "https://api.github.com";
28
29const EDITOR_VERSION: &str = "vscode/1.96.2";
31
32const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
34
35const GITHUB_API_VERSION: &str = "2025-04-01";
37
38const TOKEN_REFRESH_MARGIN_SECS: i64 = 60;
40
41#[derive(Debug, Deserialize)]
45struct CopilotTokenResponse {
46 token: String,
48 expires_at: i64,
50 #[serde(default)]
52 endpoints: CopilotEndpoints,
53}
54
55#[derive(Debug, Default, Deserialize)]
57struct CopilotEndpoints {
58 #[serde(default)]
60 api: String,
61}
62
63#[derive(Debug, Clone)]
65struct CachedToken {
66 token: String,
67 expires_at: i64,
68 api_endpoint: String,
69}
70
71pub struct CopilotProvider {
75 client: Client,
77 github_token: String,
79 model: String,
81 github_api_base: String,
83 provider_name: String,
85 compat: Option<CompatConfig>,
87 cached_token: Mutex<Option<CachedToken>>,
89}
90
91impl CopilotProvider {
92 pub fn new(model: impl Into<String>, github_token: impl Into<String>) -> Self {
94 Self {
95 client: Client::new(),
96 github_token: github_token.into(),
97 model: model.into(),
98 github_api_base: GITHUB_API_BASE.to_string(),
99 provider_name: "github-copilot".to_string(),
100 compat: None,
101 cached_token: Mutex::new(None),
102 }
103 }
104
105 #[must_use]
107 pub fn with_github_api_base(mut self, base: impl Into<String>) -> Self {
108 self.github_api_base = base.into();
109 self
110 }
111
112 #[must_use]
114 pub fn with_provider_name(mut self, name: impl Into<String>) -> Self {
115 self.provider_name = name.into();
116 self
117 }
118
119 #[must_use]
121 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
122 self.compat = compat;
123 self
124 }
125
126 #[must_use]
128 pub fn with_client(mut self, client: Client) -> Self {
129 self.client = client;
130 self
131 }
132
133 async fn ensure_session_token(&self) -> Result<CachedToken> {
135 {
137 let guard = self
138 .cached_token
139 .lock()
140 .unwrap_or_else(std::sync::PoisonError::into_inner);
141 if let Some(cached) = &*guard {
142 let now = chrono::Utc::now().timestamp();
143 if cached.expires_at > now + TOKEN_REFRESH_MARGIN_SECS {
144 return Ok(cached.clone());
145 }
146 }
147 }
148
149 let token_url = format!(
151 "{}/copilot_internal/v2/token",
152 self.github_api_base.trim_end_matches('/')
153 );
154
155 let request = self
156 .client
157 .get(&token_url)
158 .header("Authorization", format!("token {}", self.github_token))
159 .header("Accept", "application/json")
160 .header("Editor-Version", EDITOR_VERSION)
161 .header("User-Agent", COPILOT_USER_AGENT)
162 .header("X-Github-Api-Version", GITHUB_API_VERSION);
163
164 let response = Box::pin(request.send())
165 .await
166 .map_err(|e| Error::auth(format!("Copilot token exchange failed: {e}")))?;
167
168 let status = response.status();
169 let text = response
170 .text()
171 .await
172 .unwrap_or_else(|_| "<failed to read body>".to_string());
173
174 if !(200..300).contains(&status) {
175 return Err(Error::auth(format!(
176 "Copilot token exchange failed (HTTP {status}). \
177 Verify your GitHub token has Copilot access. Response: {text}"
178 )));
179 }
180
181 let token_response: CopilotTokenResponse = serde_json::from_str(&text)
182 .map_err(|e| Error::auth(format!("Invalid Copilot token response: {e}")))?;
183
184 let api_endpoint = if token_response.endpoints.api.is_empty() {
186 "https://api.githubcopilot.com/chat/completions".to_string()
188 } else {
189 let base = token_response.endpoints.api.trim_end_matches('/');
190 if base.ends_with("/chat/completions") {
191 base.to_string()
192 } else {
193 format!("{base}/chat/completions")
194 }
195 };
196
197 let cached = CachedToken {
198 token: token_response.token,
199 expires_at: token_response.expires_at,
200 api_endpoint,
201 };
202
203 {
205 let mut guard = self
206 .cached_token
207 .lock()
208 .unwrap_or_else(std::sync::PoisonError::into_inner);
209 *guard = Some(cached.clone());
210 }
211
212 Ok(cached)
213 }
214}
215
216#[async_trait]
217impl Provider for CopilotProvider {
218 fn name(&self) -> &str {
219 &self.provider_name
220 }
221
222 fn api(&self) -> &'static str {
223 "openai-completions"
224 }
225
226 fn model_id(&self) -> &str {
227 &self.model
228 }
229
230 #[allow(clippy::too_many_lines)]
231 async fn stream(
232 &self,
233 context: &Context<'_>,
234 options: &StreamOptions,
235 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
236 let session = self.ensure_session_token().await?;
238
239 let inner = OpenAIProvider::new(&self.model)
241 .with_provider_name(&self.provider_name)
242 .with_base_url(&session.api_endpoint)
243 .with_compat(self.compat.clone())
244 .with_client(self.client.clone());
245
246 let mut copilot_options = options.clone();
249 copilot_options.api_key = Some(session.token);
250
251 copilot_options
253 .headers
254 .insert("Editor-Version".to_string(), EDITOR_VERSION.to_string());
255 copilot_options
256 .headers
257 .insert("User-Agent".to_string(), COPILOT_USER_AGENT.to_string());
258 copilot_options.headers.insert(
259 "X-Github-Api-Version".to_string(),
260 GITHUB_API_VERSION.to_string(),
261 );
262 copilot_options.headers.insert(
263 "Copilot-Integration-Id".to_string(),
264 "vscode-chat".to_string(),
265 );
266
267 inner.stream(context, &copilot_options).await
268 }
269}
270
271#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::vcr::{
277 Cassette, Interaction, RecordedRequest, RecordedResponse, VcrMode, VcrRecorder,
278 };
279
280 #[test]
281 fn test_copilot_provider_defaults() {
282 let p = CopilotProvider::new("gpt-4o", "ghp_test123");
283 assert_eq!(p.name(), "github-copilot");
284 assert_eq!(p.api(), "openai-completions");
285 assert_eq!(p.model_id(), "gpt-4o");
286 assert_eq!(p.github_api_base, GITHUB_API_BASE);
287 }
288
289 #[test]
290 fn test_copilot_provider_builder() {
291 let p = CopilotProvider::new("gpt-4o", "ghp_test")
292 .with_provider_name("copilot-enterprise")
293 .with_github_api_base("https://github.example.com/api/v3");
294
295 assert_eq!(p.name(), "copilot-enterprise");
296 assert_eq!(p.github_api_base, "https://github.example.com/api/v3");
297 }
298
299 #[test]
300 fn test_copilot_token_response_deserialization() {
301 let json = r#"{
302 "token": "ghu_session_abc123",
303 "expires_at": 1700000000,
304 "endpoints": {
305 "api": "https://copilot-proxy.githubusercontent.com/v1",
306 "proxy": "https://copilot-proxy.githubusercontent.com"
307 }
308 }"#;
309
310 let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
311 assert_eq!(resp.token, "ghu_session_abc123");
312 assert_eq!(resp.expires_at, 1_700_000_000);
313 assert_eq!(
314 resp.endpoints.api,
315 "https://copilot-proxy.githubusercontent.com/v1"
316 );
317 }
318
319 #[test]
320 fn test_copilot_token_response_missing_endpoints() {
321 let json = r#"{"token": "ghu_abc", "expires_at": 1700000000}"#;
322
323 let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
324 assert_eq!(resp.token, "ghu_abc");
325 assert!(resp.endpoints.api.is_empty());
326 }
327
328 #[test]
329 fn test_copilot_token_exchange_url_construction() {
330 let p = CopilotProvider::new("gpt-4o", "ghp_test");
332 let expected = "https://api.github.com/copilot_internal/v2/token";
333 let actual = format!(
334 "{}/copilot_internal/v2/token",
335 p.github_api_base.trim_end_matches('/')
336 );
337 assert_eq!(actual, expected);
338
339 let p = CopilotProvider::new("gpt-4o", "ghp_test")
341 .with_github_api_base("https://github.example.com/api/v3/");
342 let actual = format!(
343 "{}/copilot_internal/v2/token",
344 p.github_api_base.trim_end_matches('/')
345 );
346 assert_eq!(
347 actual,
348 "https://github.example.com/api/v3/copilot_internal/v2/token"
349 );
350 }
351
352 #[test]
353 fn test_cached_token_clone() {
354 let cloned = CachedToken {
355 token: "session-tok".to_string(),
356 expires_at: 99999,
357 api_endpoint: "https://example.com/chat/completions".to_string(),
358 };
359 assert_eq!(cloned.token, "session-tok");
360 assert_eq!(cloned.expires_at, 99999);
361 }
362
363 fn vcr_token_exchange_client(
365 test_name: &str,
366 token: &str,
367 expires_at: i64,
368 api_endpoint: &str,
369 ) -> (Client, tempfile::TempDir) {
370 let temp = tempfile::tempdir().expect("tempdir");
371 let response_body = serde_json::json!({
372 "token": token,
373 "expires_at": expires_at,
374 "endpoints": {
375 "api": api_endpoint
376 }
377 })
378 .to_string();
379 let cassette = Cassette {
380 version: "1".to_string(),
381 test_name: test_name.to_string(),
382 recorded_at: "2025-01-01T00:00:00Z".to_string(),
383 interactions: vec![Interaction {
384 request: RecordedRequest {
385 method: "GET".to_string(),
386 url: "https://api.github.com/copilot_internal/v2/token".to_string(),
387 headers: vec![],
388 body: None,
389 body_text: None,
390 },
391 response: RecordedResponse {
392 status: 200,
393 headers: vec![],
394 body_chunks: vec![response_body],
395 body_chunks_base64: None,
396 },
397 }],
398 };
399 let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
400 std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
401 .expect("write cassette");
402 let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
403 let client = Client::new().with_vcr(recorder);
404 (client, temp)
405 }
406
407 #[test]
408 fn test_token_exchange_success_via_vcr() {
409 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
410 .build()
411 .expect("rt");
412 rt.block_on(async {
413 let far_future = chrono::Utc::now().timestamp() + 3600;
414 let (client, _temp) = vcr_token_exchange_client(
415 "copilot_token_success",
416 "ghu_session_test",
417 far_future,
418 "https://copilot-proxy.example.com/v1",
419 );
420 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy_token").with_client(client);
421 let cached = provider
422 .ensure_session_token()
423 .await
424 .expect("token exchange");
425 assert_eq!(cached.token, "ghu_session_test");
426 assert_eq!(cached.expires_at, far_future);
427 assert_eq!(
428 cached.api_endpoint,
429 "https://copilot-proxy.example.com/v1/chat/completions"
430 );
431 });
432 }
433
434 #[test]
435 fn test_token_exchange_caches_on_second_call() {
436 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
437 .build()
438 .expect("rt");
439 rt.block_on(async {
440 let far_future = chrono::Utc::now().timestamp() + 3600;
441 let (client, _temp) =
442 vcr_token_exchange_client("copilot_token_cache", "ghu_cached", far_future, "");
443 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
444 let first = provider.ensure_session_token().await.expect("first call");
446 assert_eq!(first.token, "ghu_cached");
447 let second = provider.ensure_session_token().await.expect("second call");
449 assert_eq!(second.token, "ghu_cached");
450 });
451 }
452
453 #[test]
454 fn test_token_exchange_error_returns_auth_error() {
455 let temp = tempfile::tempdir().expect("tempdir");
456 let test_name = "copilot_token_error";
457 let cassette = Cassette {
458 version: "1".to_string(),
459 test_name: test_name.to_string(),
460 recorded_at: "2025-01-01T00:00:00Z".to_string(),
461 interactions: vec![Interaction {
462 request: RecordedRequest {
463 method: "GET".to_string(),
464 url: "https://api.github.com/copilot_internal/v2/token".to_string(),
465 headers: vec![],
466 body: None,
467 body_text: None,
468 },
469 response: RecordedResponse {
470 status: 401,
471 headers: vec![],
472 body_chunks: vec![r#"{"message":"Bad credentials"}"#.to_string()],
473 body_chunks_base64: None,
474 },
475 }],
476 };
477 let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
478 std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
479 .expect("write cassette");
480 let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
481 let client = Client::new().with_vcr(recorder);
482
483 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
484 .build()
485 .expect("rt");
486 rt.block_on(async {
487 let provider = CopilotProvider::new("gpt-4o", "ghp_bad_token").with_client(client);
488 let result = provider.ensure_session_token().await;
489 assert!(result.is_err());
490 let msg = result.unwrap_err().to_string();
491 assert!(
492 msg.contains("401") || msg.contains("Bad credentials"),
493 "expected auth error, got: {msg}"
494 );
495 });
496 }
497
498 #[test]
499 fn test_token_exchange_fallback_endpoint() {
500 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
501 .build()
502 .expect("rt");
503 rt.block_on(async {
504 let far_future = chrono::Utc::now().timestamp() + 3600;
505 let (client, _temp) =
507 vcr_token_exchange_client("copilot_token_fallback", "ghu_fallback", far_future, "");
508 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
509 let cached = provider.ensure_session_token().await.expect("fallback");
510 assert_eq!(
511 cached.api_endpoint,
512 "https://api.githubcopilot.com/chat/completions"
513 );
514 });
515 }
516
517 #[test]
518 fn test_token_exchange_endpoint_already_has_path() {
519 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
520 .build()
521 .expect("rt");
522 rt.block_on(async {
523 let far_future = chrono::Utc::now().timestamp() + 3600;
524 let (client, _temp) = vcr_token_exchange_client(
525 "copilot_token_full_endpoint",
526 "ghu_full",
527 far_future,
528 "https://custom.proxy.com/chat/completions",
529 );
530 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
531 let cached = provider
532 .ensure_session_token()
533 .await
534 .expect("full endpoint");
535 assert_eq!(
537 cached.api_endpoint,
538 "https://custom.proxy.com/chat/completions"
539 );
540 });
541 }
542}