1use crate::error::{Error, Result};
13use crate::http::client::Client;
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
16use async_trait::async_trait;
17use futures::Stream;
18use serde::Deserialize;
19use std::pin::Pin;
20use std::sync::Mutex;
21
22use super::openai::OpenAIProvider;
23
24const GITHUB_API_BASE: &str = "https://api.github.com";
28
29const EDITOR_VERSION: &str = "vscode/1.96.2";
32
33const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
36
37const GITHUB_API_VERSION: &str = "2025-04-01";
40
41const TOKEN_REFRESH_MARGIN_SECS: i64 = 60;
43
44fn copilot_editor_version() -> String {
45 std::env::var("PI_COPILOT_EDITOR_VERSION")
46 .ok()
47 .filter(|v| !v.is_empty())
48 .unwrap_or_else(|| EDITOR_VERSION.to_string())
49}
50
51fn copilot_user_agent() -> String {
52 std::env::var("PI_COPILOT_USER_AGENT")
53 .ok()
54 .filter(|v| !v.is_empty())
55 .unwrap_or_else(|| COPILOT_USER_AGENT.to_string())
56}
57
58fn github_api_version() -> String {
59 std::env::var("PI_GITHUB_API_VERSION")
60 .ok()
61 .filter(|v| !v.is_empty())
62 .unwrap_or_else(|| GITHUB_API_VERSION.to_string())
63}
64
65#[derive(Debug, Deserialize)]
69struct CopilotTokenResponse {
70 token: String,
72 expires_at: i64,
74 #[serde(default)]
76 endpoints: CopilotEndpoints,
77}
78
79#[derive(Debug, Default, Deserialize)]
81struct CopilotEndpoints {
82 #[serde(default)]
84 api: String,
85}
86
87#[derive(Debug, Clone)]
89struct CachedToken {
90 token: String,
91 expires_at: i64,
92 api_endpoint: String,
93}
94
95pub struct CopilotProvider {
99 client: Client,
101 github_token: String,
103 model: String,
105 github_api_base: String,
107 provider_name: String,
109 compat: Option<CompatConfig>,
111 cached_token: Mutex<Option<CachedToken>>,
113}
114
115impl CopilotProvider {
116 pub fn new(model: impl Into<String>, github_token: impl Into<String>) -> Self {
118 Self {
119 client: Client::new(),
120 github_token: github_token.into(),
121 model: model.into(),
122 github_api_base: GITHUB_API_BASE.to_string(),
123 provider_name: "github-copilot".to_string(),
124 compat: None,
125 cached_token: Mutex::new(None),
126 }
127 }
128
129 #[must_use]
131 pub fn with_github_api_base(mut self, base: impl Into<String>) -> Self {
132 self.github_api_base = base.into();
133 self
134 }
135
136 #[must_use]
138 pub fn with_provider_name(mut self, name: impl Into<String>) -> Self {
139 self.provider_name = name.into();
140 self
141 }
142
143 #[must_use]
145 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
146 self.compat = compat;
147 self
148 }
149
150 #[must_use]
152 pub fn with_client(mut self, client: Client) -> Self {
153 self.client = client;
154 self
155 }
156
157 async fn ensure_session_token(&self) -> Result<CachedToken> {
159 {
161 let guard = self
162 .cached_token
163 .lock()
164 .unwrap_or_else(std::sync::PoisonError::into_inner);
165 if let Some(cached) = &*guard {
166 let now = chrono::Utc::now().timestamp();
167 if cached.expires_at > now + TOKEN_REFRESH_MARGIN_SECS {
168 return Ok(cached.clone());
169 }
170 }
171 }
172
173 let token_url = format!(
175 "{}/copilot_internal/v2/token",
176 self.github_api_base.trim_end_matches('/')
177 );
178
179 let request = self
180 .client
181 .get(&token_url)
182 .header("Authorization", format!("token {}", self.github_token))
183 .header("Accept", "application/json")
184 .header("Editor-Version", copilot_editor_version())
185 .header("User-Agent", copilot_user_agent())
186 .header("X-Github-Api-Version", github_api_version());
187
188 let response = Box::pin(request.send())
189 .await
190 .map_err(|e| Error::auth(format!("Copilot token exchange failed: {e}")))?;
191
192 let status = response.status();
193 let text = response
194 .text()
195 .await
196 .unwrap_or_else(|_| "<failed to read body>".to_string());
197
198 if !(200..300).contains(&status) {
199 return Err(Error::auth(format!(
200 "Copilot token exchange failed (HTTP {status}). \
201 Verify your GitHub token has Copilot access. Response: {text}"
202 )));
203 }
204
205 let token_response: CopilotTokenResponse = serde_json::from_str(&text)
206 .map_err(|e| Error::auth(format!("Invalid Copilot token response: {e}")))?;
207
208 let api_endpoint = if token_response.endpoints.api.is_empty() {
210 "https://api.githubcopilot.com/chat/completions".to_string()
212 } else {
213 let base = token_response.endpoints.api.trim_end_matches('/');
214 if base.ends_with("/chat/completions") {
215 base.to_string()
216 } else {
217 format!("{base}/chat/completions")
218 }
219 };
220
221 let cached = CachedToken {
222 token: token_response.token,
223 expires_at: token_response.expires_at,
224 api_endpoint,
225 };
226
227 {
229 let mut guard = self
230 .cached_token
231 .lock()
232 .unwrap_or_else(std::sync::PoisonError::into_inner);
233 *guard = Some(cached.clone());
234 }
235
236 Ok(cached)
237 }
238}
239
240#[async_trait]
241impl Provider for CopilotProvider {
242 fn name(&self) -> &str {
243 &self.provider_name
244 }
245
246 fn api(&self) -> &'static str {
247 "openai-completions"
248 }
249
250 fn model_id(&self) -> &str {
251 &self.model
252 }
253
254 #[allow(clippy::too_many_lines)]
255 async fn stream(
256 &self,
257 context: &Context<'_>,
258 options: &StreamOptions,
259 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
260 let session = self.ensure_session_token().await?;
262
263 let inner = OpenAIProvider::new(&self.model)
265 .with_provider_name(&self.provider_name)
266 .with_base_url(&session.api_endpoint)
267 .with_compat(self.compat.clone())
268 .with_client(self.client.clone());
269
270 let mut copilot_options = options.clone();
273 copilot_options.api_key = Some(session.token);
274
275 copilot_options
277 .headers
278 .insert("Editor-Version".to_string(), copilot_editor_version());
279 copilot_options
280 .headers
281 .insert("User-Agent".to_string(), copilot_user_agent());
282 copilot_options
283 .headers
284 .insert("X-Github-Api-Version".to_string(), github_api_version());
285 copilot_options.headers.insert(
286 "Copilot-Integration-Id".to_string(),
287 "vscode-chat".to_string(),
288 );
289
290 inner.stream(context, &copilot_options).await
291 }
292}
293
294#[cfg(test)]
297mod tests {
298 use super::*;
299 use crate::vcr::{
300 Cassette, Interaction, RecordedRequest, RecordedResponse, VcrMode, VcrRecorder,
301 };
302
303 #[test]
304 fn test_copilot_provider_defaults() {
305 let p = CopilotProvider::new("gpt-4o", "ghp_test123");
306 assert_eq!(p.name(), "github-copilot");
307 assert_eq!(p.api(), "openai-completions");
308 assert_eq!(p.model_id(), "gpt-4o");
309 assert_eq!(p.github_api_base, GITHUB_API_BASE);
310 }
311
312 #[test]
313 fn test_copilot_provider_builder() {
314 let p = CopilotProvider::new("gpt-4o", "ghp_test")
315 .with_provider_name("copilot-enterprise")
316 .with_github_api_base("https://github.example.com/api/v3");
317
318 assert_eq!(p.name(), "copilot-enterprise");
319 assert_eq!(p.github_api_base, "https://github.example.com/api/v3");
320 }
321
322 #[test]
323 fn test_copilot_token_response_deserialization() {
324 let json = r#"{
325 "token": "ghu_session_abc123",
326 "expires_at": 1700000000,
327 "endpoints": {
328 "api": "https://copilot-proxy.githubusercontent.com/v1",
329 "proxy": "https://copilot-proxy.githubusercontent.com"
330 }
331 }"#;
332
333 let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
334 assert_eq!(resp.token, "ghu_session_abc123");
335 assert_eq!(resp.expires_at, 1_700_000_000);
336 assert_eq!(
337 resp.endpoints.api,
338 "https://copilot-proxy.githubusercontent.com/v1"
339 );
340 }
341
342 #[test]
343 fn test_copilot_token_response_missing_endpoints() {
344 let json = r#"{"token": "ghu_abc", "expires_at": 1700000000}"#;
345
346 let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
347 assert_eq!(resp.token, "ghu_abc");
348 assert!(resp.endpoints.api.is_empty());
349 }
350
351 #[test]
352 fn test_copilot_token_exchange_url_construction() {
353 let p = CopilotProvider::new("gpt-4o", "ghp_test");
355 let expected = "https://api.github.com/copilot_internal/v2/token";
356 let actual = format!(
357 "{}/copilot_internal/v2/token",
358 p.github_api_base.trim_end_matches('/')
359 );
360 assert_eq!(actual, expected);
361
362 let p = CopilotProvider::new("gpt-4o", "ghp_test")
364 .with_github_api_base("https://github.example.com/api/v3/");
365 let actual = format!(
366 "{}/copilot_internal/v2/token",
367 p.github_api_base.trim_end_matches('/')
368 );
369 assert_eq!(
370 actual,
371 "https://github.example.com/api/v3/copilot_internal/v2/token"
372 );
373 }
374
375 #[test]
376 fn test_cached_token_clone() {
377 let cloned = CachedToken {
378 token: "session-tok".to_string(),
379 expires_at: 99999,
380 api_endpoint: "https://example.com/chat/completions".to_string(),
381 };
382 assert_eq!(cloned.token, "session-tok");
383 assert_eq!(cloned.expires_at, 99999);
384 }
385
386 fn vcr_token_exchange_client(
388 test_name: &str,
389 token: &str,
390 expires_at: i64,
391 api_endpoint: &str,
392 ) -> (Client, tempfile::TempDir) {
393 let temp = tempfile::tempdir().expect("tempdir");
394 let response_body = serde_json::json!({
395 "token": token,
396 "expires_at": expires_at,
397 "endpoints": {
398 "api": api_endpoint
399 }
400 })
401 .to_string();
402 let cassette = Cassette {
403 version: "1.0".to_string(),
404 test_name: test_name.to_string(),
405 recorded_at: "2025-01-01T00:00:00Z".to_string(),
406 interactions: vec![Interaction {
407 request: RecordedRequest {
408 method: "GET".to_string(),
409 url: "https://api.github.com/copilot_internal/v2/token".to_string(),
410 headers: vec![],
411 body: None,
412 body_text: None,
413 },
414 response: RecordedResponse {
415 status: 200,
416 headers: vec![],
417 body_chunks: vec![response_body],
418 body_chunks_base64: None,
419 },
420 }],
421 };
422 let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
423 std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
424 .expect("write cassette");
425 let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
426 let client = Client::new().with_vcr(recorder);
427 (client, temp)
428 }
429
430 #[test]
431 fn test_token_exchange_success_via_vcr() {
432 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
433 .build()
434 .expect("rt");
435 rt.block_on(async {
436 let far_future = chrono::Utc::now().timestamp() + 3600;
437 let (client, _temp) = vcr_token_exchange_client(
438 "copilot_token_success",
439 "ghu_session_test",
440 far_future,
441 "https://copilot-proxy.example.com/v1",
442 );
443 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy_token").with_client(client);
444 let cached = provider
445 .ensure_session_token()
446 .await
447 .expect("token exchange");
448 assert_eq!(cached.token, "ghu_session_test");
449 assert_eq!(cached.expires_at, far_future);
450 assert_eq!(
451 cached.api_endpoint,
452 "https://copilot-proxy.example.com/v1/chat/completions"
453 );
454 });
455 }
456
457 #[test]
458 fn test_token_exchange_caches_on_second_call() {
459 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
460 .build()
461 .expect("rt");
462 rt.block_on(async {
463 let far_future = chrono::Utc::now().timestamp() + 3600;
464 let (client, _temp) =
465 vcr_token_exchange_client("copilot_token_cache", "ghu_cached", far_future, "");
466 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
467 let first = provider.ensure_session_token().await.expect("first call");
469 assert_eq!(first.token, "ghu_cached");
470 let second = provider.ensure_session_token().await.expect("second call");
472 assert_eq!(second.token, "ghu_cached");
473 });
474 }
475
476 #[test]
477 fn test_token_exchange_error_returns_auth_error() {
478 let temp = tempfile::tempdir().expect("tempdir");
479 let test_name = "copilot_token_error";
480 let cassette = Cassette {
481 version: "1.0".to_string(),
482 test_name: test_name.to_string(),
483 recorded_at: "2025-01-01T00:00:00Z".to_string(),
484 interactions: vec![Interaction {
485 request: RecordedRequest {
486 method: "GET".to_string(),
487 url: "https://api.github.com/copilot_internal/v2/token".to_string(),
488 headers: vec![],
489 body: None,
490 body_text: None,
491 },
492 response: RecordedResponse {
493 status: 401,
494 headers: vec![],
495 body_chunks: vec![r#"{"message":"Bad credentials"}"#.to_string()],
496 body_chunks_base64: None,
497 },
498 }],
499 };
500 let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
501 std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
502 .expect("write cassette");
503 let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
504 let client = Client::new().with_vcr(recorder);
505
506 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
507 .build()
508 .expect("rt");
509 rt.block_on(async {
510 let provider = CopilotProvider::new("gpt-4o", "ghp_bad_token").with_client(client);
511 let result = provider.ensure_session_token().await;
512 assert!(result.is_err());
513 let msg = result.unwrap_err().to_string();
514 assert!(
515 msg.contains("401") || msg.contains("Bad credentials"),
516 "expected auth error, got: {msg}"
517 );
518 });
519 }
520
521 #[test]
522 fn test_token_exchange_fallback_endpoint() {
523 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
524 .build()
525 .expect("rt");
526 rt.block_on(async {
527 let far_future = chrono::Utc::now().timestamp() + 3600;
528 let (client, _temp) =
530 vcr_token_exchange_client("copilot_token_fallback", "ghu_fallback", far_future, "");
531 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
532 let cached = provider.ensure_session_token().await.expect("fallback");
533 assert_eq!(
534 cached.api_endpoint,
535 "https://api.githubcopilot.com/chat/completions"
536 );
537 });
538 }
539
540 #[test]
541 fn test_token_exchange_endpoint_already_has_path() {
542 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
543 .build()
544 .expect("rt");
545 rt.block_on(async {
546 let far_future = chrono::Utc::now().timestamp() + 3600;
547 let (client, _temp) = vcr_token_exchange_client(
548 "copilot_token_full_endpoint",
549 "ghu_full",
550 far_future,
551 "https://custom.proxy.com/chat/completions",
552 );
553 let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
554 let cached = provider
555 .ensure_session_token()
556 .await
557 .expect("full endpoint");
558 assert_eq!(
560 cached.api_endpoint,
561 "https://custom.proxy.com/chat/completions"
562 );
563 });
564 }
565}