Skip to main content

pi/providers/
copilot.rs

1//! GitHub Copilot provider implementation.
2//!
3//! Copilot uses a two-step authentication flow:
4//! 1. Exchange a GitHub OAuth/PAT token for a short-lived Copilot session token
5//!    via `https://api.github.com/copilot_internal/v2/token`.
6//! 2. Use the session token to make OpenAI-compatible chat completion requests
7//!    to the Copilot proxy endpoint.
8//!
9//! The session token is cached and automatically refreshed when it expires.
10//! GitHub Enterprise Server is supported via a configurable base URL.
11
12use crate::error::{Error, Result};
13use crate::http::client::Client;
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
16use async_trait::async_trait;
17use futures::Stream;
18use serde::Deserialize;
19use std::pin::Pin;
20use std::sync::Mutex;
21
22use super::openai::OpenAIProvider;
23
24// ── Constants ────────────────────────────────────────────────────
25
26/// Default GitHub API base for token exchange.
27const GITHUB_API_BASE: &str = "https://api.github.com";
28
29/// Editor version header value (required by Copilot API).
30const EDITOR_VERSION: &str = "vscode/1.96.2";
31
32/// User-Agent header value (required by Copilot API).
33const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
34
35/// GitHub API version header.
36const GITHUB_API_VERSION: &str = "2025-04-01";
37
38/// Safety margin: refresh the session token this many seconds before expiry.
39const TOKEN_REFRESH_MARGIN_SECS: i64 = 60;
40
41// ── Token exchange types ─────────────────────────────────────────
42
43/// Response from the Copilot token exchange endpoint.
44#[derive(Debug, Deserialize)]
45struct CopilotTokenResponse {
46    /// The short-lived session token.
47    token: String,
48    /// Unix timestamp (seconds) when the token expires.
49    expires_at: i64,
50    /// Endpoints returned by the API.
51    #[serde(default)]
52    endpoints: CopilotEndpoints,
53}
54
55/// Endpoint URLs returned alongside the session token.
56#[derive(Debug, Default, Deserialize)]
57struct CopilotEndpoints {
58    /// The API endpoint for chat completions.
59    #[serde(default)]
60    api: String,
61}
62
63/// Cached session token with expiry.
64#[derive(Debug, Clone)]
65struct CachedToken {
66    token: String,
67    expires_at: i64,
68    api_endpoint: String,
69}
70
71// ── Provider ─────────────────────────────────────────────────────
72
73/// GitHub Copilot provider that wraps OpenAI-compatible streaming.
74pub struct CopilotProvider {
75    /// HTTP client for token exchange and API requests.
76    client: Client,
77    /// The GitHub OAuth token or PAT used for token exchange.
78    github_token: String,
79    /// The model ID to request (e.g., "gpt-4o", "claude-3.5-sonnet").
80    model: String,
81    /// GitHub API base URL (supports Enterprise: `https://github.example.com/api/v3`).
82    github_api_base: String,
83    /// Provider name for event attribution.
84    provider_name: String,
85    /// Compatibility overrides passed to the underlying OpenAI provider.
86    compat: Option<CompatConfig>,
87    /// Cached session token (refreshed automatically).
88    cached_token: Mutex<Option<CachedToken>>,
89}
90
91impl CopilotProvider {
92    /// Create a new Copilot provider.
93    pub fn new(model: impl Into<String>, github_token: impl Into<String>) -> Self {
94        Self {
95            client: Client::new(),
96            github_token: github_token.into(),
97            model: model.into(),
98            github_api_base: GITHUB_API_BASE.to_string(),
99            provider_name: "github-copilot".to_string(),
100            compat: None,
101            cached_token: Mutex::new(None),
102        }
103    }
104
105    /// Set the GitHub API base URL (for Enterprise).
106    #[must_use]
107    pub fn with_github_api_base(mut self, base: impl Into<String>) -> Self {
108        self.github_api_base = base.into();
109        self
110    }
111
112    /// Set the provider name for event attribution.
113    #[must_use]
114    pub fn with_provider_name(mut self, name: impl Into<String>) -> Self {
115        self.provider_name = name.into();
116        self
117    }
118
119    /// Attach compatibility overrides.
120    #[must_use]
121    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
122        self.compat = compat;
123        self
124    }
125
126    /// Inject a custom HTTP client (for testing / VCR).
127    #[must_use]
128    pub fn with_client(mut self, client: Client) -> Self {
129        self.client = client;
130        self
131    }
132
133    /// Get a valid session token, refreshing if necessary.
134    async fn ensure_session_token(&self) -> Result<CachedToken> {
135        // Check cache first.
136        {
137            let guard = self
138                .cached_token
139                .lock()
140                .unwrap_or_else(std::sync::PoisonError::into_inner);
141            if let Some(cached) = &*guard {
142                let now = chrono::Utc::now().timestamp();
143                if cached.expires_at > now + TOKEN_REFRESH_MARGIN_SECS {
144                    return Ok(cached.clone());
145                }
146            }
147        }
148
149        // Exchange GitHub token for a Copilot session token.
150        let token_url = format!(
151            "{}/copilot_internal/v2/token",
152            self.github_api_base.trim_end_matches('/')
153        );
154
155        let request = self
156            .client
157            .get(&token_url)
158            .header("Authorization", format!("token {}", self.github_token))
159            .header("Accept", "application/json")
160            .header("Editor-Version", EDITOR_VERSION)
161            .header("User-Agent", COPILOT_USER_AGENT)
162            .header("X-Github-Api-Version", GITHUB_API_VERSION);
163
164        let response = Box::pin(request.send())
165            .await
166            .map_err(|e| Error::auth(format!("Copilot token exchange failed: {e}")))?;
167
168        let status = response.status();
169        let text = response
170            .text()
171            .await
172            .unwrap_or_else(|_| "<failed to read body>".to_string());
173
174        if !(200..300).contains(&status) {
175            return Err(Error::auth(format!(
176                "Copilot token exchange failed (HTTP {status}). \
177                 Verify your GitHub token has Copilot access. Response: {text}"
178            )));
179        }
180
181        let token_response: CopilotTokenResponse = serde_json::from_str(&text)
182            .map_err(|e| Error::auth(format!("Invalid Copilot token response: {e}")))?;
183
184        // Determine the API endpoint.
185        let api_endpoint = if token_response.endpoints.api.is_empty() {
186            // Fallback: use the standard Copilot proxy URL.
187            "https://api.githubcopilot.com/chat/completions".to_string()
188        } else {
189            let base = token_response.endpoints.api.trim_end_matches('/');
190            if base.ends_with("/chat/completions") {
191                base.to_string()
192            } else {
193                format!("{base}/chat/completions")
194            }
195        };
196
197        let cached = CachedToken {
198            token: token_response.token,
199            expires_at: token_response.expires_at,
200            api_endpoint,
201        };
202
203        // Store in cache.
204        {
205            let mut guard = self
206                .cached_token
207                .lock()
208                .unwrap_or_else(std::sync::PoisonError::into_inner);
209            *guard = Some(cached.clone());
210        }
211
212        Ok(cached)
213    }
214}
215
216#[async_trait]
217impl Provider for CopilotProvider {
218    fn name(&self) -> &str {
219        &self.provider_name
220    }
221
222    fn api(&self) -> &'static str {
223        "openai-completions"
224    }
225
226    fn model_id(&self) -> &str {
227        &self.model
228    }
229
230    #[allow(clippy::too_many_lines)]
231    async fn stream(
232        &self,
233        context: &Context<'_>,
234        options: &StreamOptions,
235    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
236        // Get a valid session token.
237        let session = self.ensure_session_token().await?;
238
239        // Build an OpenAI provider pointed at the Copilot endpoint.
240        let inner = OpenAIProvider::new(&self.model)
241            .with_provider_name(&self.provider_name)
242            .with_base_url(&session.api_endpoint)
243            .with_compat(self.compat.clone())
244            .with_client(self.client.clone());
245
246        // Override the authorization: Copilot uses the session token,
247        // not the GitHub OAuth token.
248        let mut copilot_options = options.clone();
249        copilot_options.api_key = Some(session.token);
250
251        // Add Copilot-specific headers.
252        copilot_options
253            .headers
254            .insert("Editor-Version".to_string(), EDITOR_VERSION.to_string());
255        copilot_options
256            .headers
257            .insert("User-Agent".to_string(), COPILOT_USER_AGENT.to_string());
258        copilot_options.headers.insert(
259            "X-Github-Api-Version".to_string(),
260            GITHUB_API_VERSION.to_string(),
261        );
262        copilot_options.headers.insert(
263            "Copilot-Integration-Id".to_string(),
264            "vscode-chat".to_string(),
265        );
266
267        inner.stream(context, &copilot_options).await
268    }
269}
270
271// ── Tests ────────────────────────────────────────────────────────
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::vcr::{
277        Cassette, Interaction, RecordedRequest, RecordedResponse, VcrMode, VcrRecorder,
278    };
279
280    #[test]
281    fn test_copilot_provider_defaults() {
282        let p = CopilotProvider::new("gpt-4o", "ghp_test123");
283        assert_eq!(p.name(), "github-copilot");
284        assert_eq!(p.api(), "openai-completions");
285        assert_eq!(p.model_id(), "gpt-4o");
286        assert_eq!(p.github_api_base, GITHUB_API_BASE);
287    }
288
289    #[test]
290    fn test_copilot_provider_builder() {
291        let p = CopilotProvider::new("gpt-4o", "ghp_test")
292            .with_provider_name("copilot-enterprise")
293            .with_github_api_base("https://github.example.com/api/v3");
294
295        assert_eq!(p.name(), "copilot-enterprise");
296        assert_eq!(p.github_api_base, "https://github.example.com/api/v3");
297    }
298
299    #[test]
300    fn test_copilot_token_response_deserialization() {
301        let json = r#"{
302            "token": "ghu_session_abc123",
303            "expires_at": 1700000000,
304            "endpoints": {
305                "api": "https://copilot-proxy.githubusercontent.com/v1",
306                "proxy": "https://copilot-proxy.githubusercontent.com"
307            }
308        }"#;
309
310        let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
311        assert_eq!(resp.token, "ghu_session_abc123");
312        assert_eq!(resp.expires_at, 1_700_000_000);
313        assert_eq!(
314            resp.endpoints.api,
315            "https://copilot-proxy.githubusercontent.com/v1"
316        );
317    }
318
319    #[test]
320    fn test_copilot_token_response_missing_endpoints() {
321        let json = r#"{"token": "ghu_abc", "expires_at": 1700000000}"#;
322
323        let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
324        assert_eq!(resp.token, "ghu_abc");
325        assert!(resp.endpoints.api.is_empty());
326    }
327
328    #[test]
329    fn test_copilot_token_exchange_url_construction() {
330        // Standard GitHub
331        let p = CopilotProvider::new("gpt-4o", "ghp_test");
332        let expected = "https://api.github.com/copilot_internal/v2/token";
333        let actual = format!(
334            "{}/copilot_internal/v2/token",
335            p.github_api_base.trim_end_matches('/')
336        );
337        assert_eq!(actual, expected);
338
339        // Enterprise with trailing slash
340        let p = CopilotProvider::new("gpt-4o", "ghp_test")
341            .with_github_api_base("https://github.example.com/api/v3/");
342        let actual = format!(
343            "{}/copilot_internal/v2/token",
344            p.github_api_base.trim_end_matches('/')
345        );
346        assert_eq!(
347            actual,
348            "https://github.example.com/api/v3/copilot_internal/v2/token"
349        );
350    }
351
352    #[test]
353    fn test_cached_token_clone() {
354        let cloned = CachedToken {
355            token: "session-tok".to_string(),
356            expires_at: 99999,
357            api_endpoint: "https://example.com/chat/completions".to_string(),
358        };
359        assert_eq!(cloned.token, "session-tok");
360        assert_eq!(cloned.expires_at, 99999);
361    }
362
363    /// Build a VCR client that returns a successful token exchange response.
364    fn vcr_token_exchange_client(
365        test_name: &str,
366        token: &str,
367        expires_at: i64,
368        api_endpoint: &str,
369    ) -> (Client, tempfile::TempDir) {
370        let temp = tempfile::tempdir().expect("tempdir");
371        let response_body = serde_json::json!({
372            "token": token,
373            "expires_at": expires_at,
374            "endpoints": {
375                "api": api_endpoint
376            }
377        })
378        .to_string();
379        let cassette = Cassette {
380            version: "1".to_string(),
381            test_name: test_name.to_string(),
382            recorded_at: "2025-01-01T00:00:00Z".to_string(),
383            interactions: vec![Interaction {
384                request: RecordedRequest {
385                    method: "GET".to_string(),
386                    url: "https://api.github.com/copilot_internal/v2/token".to_string(),
387                    headers: vec![],
388                    body: None,
389                    body_text: None,
390                },
391                response: RecordedResponse {
392                    status: 200,
393                    headers: vec![],
394                    body_chunks: vec![response_body],
395                    body_chunks_base64: None,
396                },
397            }],
398        };
399        let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
400        std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
401            .expect("write cassette");
402        let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
403        let client = Client::new().with_vcr(recorder);
404        (client, temp)
405    }
406
407    #[test]
408    fn test_token_exchange_success_via_vcr() {
409        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
410            .build()
411            .expect("rt");
412        rt.block_on(async {
413            let far_future = chrono::Utc::now().timestamp() + 3600;
414            let (client, _temp) = vcr_token_exchange_client(
415                "copilot_token_success",
416                "ghu_session_test",
417                far_future,
418                "https://copilot-proxy.example.com/v1",
419            );
420            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy_token").with_client(client);
421            let cached = provider
422                .ensure_session_token()
423                .await
424                .expect("token exchange");
425            assert_eq!(cached.token, "ghu_session_test");
426            assert_eq!(cached.expires_at, far_future);
427            assert_eq!(
428                cached.api_endpoint,
429                "https://copilot-proxy.example.com/v1/chat/completions"
430            );
431        });
432    }
433
434    #[test]
435    fn test_token_exchange_caches_on_second_call() {
436        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
437            .build()
438            .expect("rt");
439        rt.block_on(async {
440            let far_future = chrono::Utc::now().timestamp() + 3600;
441            let (client, _temp) =
442                vcr_token_exchange_client("copilot_token_cache", "ghu_cached", far_future, "");
443            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
444            // First call populates the cache.
445            let first = provider.ensure_session_token().await.expect("first call");
446            assert_eq!(first.token, "ghu_cached");
447            // Second call should use the cache (no VCR interaction needed).
448            let second = provider.ensure_session_token().await.expect("second call");
449            assert_eq!(second.token, "ghu_cached");
450        });
451    }
452
453    #[test]
454    fn test_token_exchange_error_returns_auth_error() {
455        let temp = tempfile::tempdir().expect("tempdir");
456        let test_name = "copilot_token_error";
457        let cassette = Cassette {
458            version: "1".to_string(),
459            test_name: test_name.to_string(),
460            recorded_at: "2025-01-01T00:00:00Z".to_string(),
461            interactions: vec![Interaction {
462                request: RecordedRequest {
463                    method: "GET".to_string(),
464                    url: "https://api.github.com/copilot_internal/v2/token".to_string(),
465                    headers: vec![],
466                    body: None,
467                    body_text: None,
468                },
469                response: RecordedResponse {
470                    status: 401,
471                    headers: vec![],
472                    body_chunks: vec![r#"{"message":"Bad credentials"}"#.to_string()],
473                    body_chunks_base64: None,
474                },
475            }],
476        };
477        let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
478        std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
479            .expect("write cassette");
480        let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
481        let client = Client::new().with_vcr(recorder);
482
483        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
484            .build()
485            .expect("rt");
486        rt.block_on(async {
487            let provider = CopilotProvider::new("gpt-4o", "ghp_bad_token").with_client(client);
488            let result = provider.ensure_session_token().await;
489            assert!(result.is_err());
490            let msg = result.unwrap_err().to_string();
491            assert!(
492                msg.contains("401") || msg.contains("Bad credentials"),
493                "expected auth error, got: {msg}"
494            );
495        });
496    }
497
498    #[test]
499    fn test_token_exchange_fallback_endpoint() {
500        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
501            .build()
502            .expect("rt");
503        rt.block_on(async {
504            let far_future = chrono::Utc::now().timestamp() + 3600;
505            // Empty api endpoint → should fall back to default.
506            let (client, _temp) =
507                vcr_token_exchange_client("copilot_token_fallback", "ghu_fallback", far_future, "");
508            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
509            let cached = provider.ensure_session_token().await.expect("fallback");
510            assert_eq!(
511                cached.api_endpoint,
512                "https://api.githubcopilot.com/chat/completions"
513            );
514        });
515    }
516
517    #[test]
518    fn test_token_exchange_endpoint_already_has_path() {
519        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
520            .build()
521            .expect("rt");
522        rt.block_on(async {
523            let far_future = chrono::Utc::now().timestamp() + 3600;
524            let (client, _temp) = vcr_token_exchange_client(
525                "copilot_token_full_endpoint",
526                "ghu_full",
527                far_future,
528                "https://custom.proxy.com/chat/completions",
529            );
530            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
531            let cached = provider
532                .ensure_session_token()
533                .await
534                .expect("full endpoint");
535            // Endpoint already includes /chat/completions; should not be duplicated.
536            assert_eq!(
537                cached.api_endpoint,
538                "https://custom.proxy.com/chat/completions"
539            );
540        });
541    }
542}