Skip to main content

pi/providers/
copilot.rs

1//! GitHub Copilot provider implementation.
2//!
3//! Copilot uses a two-step authentication flow:
4//! 1. Exchange a GitHub OAuth/PAT token for a short-lived Copilot session token
5//!    via `https://api.github.com/copilot_internal/v2/token`.
6//! 2. Use the session token to make OpenAI-compatible chat completion requests
7//!    to the Copilot proxy endpoint.
8//!
9//! The session token is cached and automatically refreshed when it expires.
10//! GitHub Enterprise Server is supported via a configurable base URL.
11
12use crate::error::{Error, Result};
13use crate::http::client::Client;
14use crate::models::CompatConfig;
15use crate::provider::{Context, Provider, StreamEvent, StreamOptions};
16use async_trait::async_trait;
17use futures::Stream;
18use serde::Deserialize;
19use std::pin::Pin;
20use std::sync::Mutex;
21
22use super::openai::OpenAIProvider;
23
24// ── Constants ────────────────────────────────────────────────────
25
26/// Default GitHub API base for token exchange.
27const GITHUB_API_BASE: &str = "https://api.github.com";
28
29/// Editor version header value (required by Copilot API).
30/// Override via `PI_COPILOT_EDITOR_VERSION`.
31const EDITOR_VERSION: &str = "vscode/1.96.2";
32
33/// User-Agent header value (required by Copilot API).
34/// Override via `PI_COPILOT_USER_AGENT`.
35const COPILOT_USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
36
37/// GitHub API version header.
38/// Override via `PI_GITHUB_API_VERSION`.
39const GITHUB_API_VERSION: &str = "2025-04-01";
40
41/// Safety margin: refresh the session token this many seconds before expiry.
42const TOKEN_REFRESH_MARGIN_SECS: i64 = 60;
43
44fn copilot_editor_version() -> String {
45    std::env::var("PI_COPILOT_EDITOR_VERSION")
46        .ok()
47        .filter(|v| !v.is_empty())
48        .unwrap_or_else(|| EDITOR_VERSION.to_string())
49}
50
51fn copilot_user_agent() -> String {
52    std::env::var("PI_COPILOT_USER_AGENT")
53        .ok()
54        .filter(|v| !v.is_empty())
55        .unwrap_or_else(|| COPILOT_USER_AGENT.to_string())
56}
57
58fn github_api_version() -> String {
59    std::env::var("PI_GITHUB_API_VERSION")
60        .ok()
61        .filter(|v| !v.is_empty())
62        .unwrap_or_else(|| GITHUB_API_VERSION.to_string())
63}
64
65// ── Token exchange types ─────────────────────────────────────────
66
67/// Response from the Copilot token exchange endpoint.
68#[derive(Debug, Deserialize)]
69struct CopilotTokenResponse {
70    /// The short-lived session token.
71    token: String,
72    /// Unix timestamp (seconds) when the token expires.
73    expires_at: i64,
74    /// Endpoints returned by the API.
75    #[serde(default)]
76    endpoints: CopilotEndpoints,
77}
78
79/// Endpoint URLs returned alongside the session token.
80#[derive(Debug, Default, Deserialize)]
81struct CopilotEndpoints {
82    /// The API endpoint for chat completions.
83    #[serde(default)]
84    api: String,
85}
86
87/// Cached session token with expiry.
88#[derive(Debug, Clone)]
89struct CachedToken {
90    token: String,
91    expires_at: i64,
92    api_endpoint: String,
93}
94
95// ── Provider ─────────────────────────────────────────────────────
96
97/// GitHub Copilot provider that wraps OpenAI-compatible streaming.
98pub struct CopilotProvider {
99    /// HTTP client for token exchange and API requests.
100    client: Client,
101    /// The GitHub OAuth token or PAT used for token exchange.
102    github_token: String,
103    /// The model ID to request (e.g., "gpt-4o", "claude-3.5-sonnet").
104    model: String,
105    /// GitHub API base URL (supports Enterprise: `https://github.example.com/api/v3`).
106    github_api_base: String,
107    /// Provider name for event attribution.
108    provider_name: String,
109    /// Compatibility overrides passed to the underlying OpenAI provider.
110    compat: Option<CompatConfig>,
111    /// Cached session token (refreshed automatically).
112    cached_token: Mutex<Option<CachedToken>>,
113}
114
115impl CopilotProvider {
116    /// Create a new Copilot provider.
117    pub fn new(model: impl Into<String>, github_token: impl Into<String>) -> Self {
118        Self {
119            client: Client::new(),
120            github_token: github_token.into(),
121            model: model.into(),
122            github_api_base: GITHUB_API_BASE.to_string(),
123            provider_name: "github-copilot".to_string(),
124            compat: None,
125            cached_token: Mutex::new(None),
126        }
127    }
128
129    /// Set the GitHub API base URL (for Enterprise).
130    #[must_use]
131    pub fn with_github_api_base(mut self, base: impl Into<String>) -> Self {
132        self.github_api_base = base.into();
133        self
134    }
135
136    /// Set the provider name for event attribution.
137    #[must_use]
138    pub fn with_provider_name(mut self, name: impl Into<String>) -> Self {
139        self.provider_name = name.into();
140        self
141    }
142
143    /// Attach compatibility overrides.
144    #[must_use]
145    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
146        self.compat = compat;
147        self
148    }
149
150    /// Inject a custom HTTP client (for testing / VCR).
151    #[must_use]
152    pub fn with_client(mut self, client: Client) -> Self {
153        self.client = client;
154        self
155    }
156
157    /// Get a valid session token, refreshing if necessary.
158    async fn ensure_session_token(&self) -> Result<CachedToken> {
159        // Check cache first.
160        {
161            let guard = self
162                .cached_token
163                .lock()
164                .unwrap_or_else(std::sync::PoisonError::into_inner);
165            if let Some(cached) = &*guard {
166                let now = chrono::Utc::now().timestamp();
167                if cached.expires_at > now + TOKEN_REFRESH_MARGIN_SECS {
168                    return Ok(cached.clone());
169                }
170            }
171        }
172
173        // Exchange GitHub token for a Copilot session token.
174        let token_url = format!(
175            "{}/copilot_internal/v2/token",
176            self.github_api_base.trim_end_matches('/')
177        );
178
179        let request = self
180            .client
181            .get(&token_url)
182            .header("Authorization", format!("token {}", self.github_token))
183            .header("Accept", "application/json")
184            .header("Editor-Version", copilot_editor_version())
185            .header("User-Agent", copilot_user_agent())
186            .header("X-Github-Api-Version", github_api_version());
187
188        let response = Box::pin(request.send())
189            .await
190            .map_err(|e| Error::auth(format!("Copilot token exchange failed: {e}")))?;
191
192        let status = response.status();
193        let text = response
194            .text()
195            .await
196            .unwrap_or_else(|_| "<failed to read body>".to_string());
197
198        if !(200..300).contains(&status) {
199            return Err(Error::auth(format!(
200                "Copilot token exchange failed (HTTP {status}). \
201                 Verify your GitHub token has Copilot access. Response: {text}"
202            )));
203        }
204
205        let token_response: CopilotTokenResponse = serde_json::from_str(&text)
206            .map_err(|e| Error::auth(format!("Invalid Copilot token response: {e}")))?;
207
208        // Determine the API endpoint.
209        let api_endpoint = if token_response.endpoints.api.is_empty() {
210            // Fallback: use the standard Copilot proxy URL.
211            "https://api.githubcopilot.com/chat/completions".to_string()
212        } else {
213            let base = token_response.endpoints.api.trim_end_matches('/');
214            if base.ends_with("/chat/completions") {
215                base.to_string()
216            } else {
217                format!("{base}/chat/completions")
218            }
219        };
220
221        let cached = CachedToken {
222            token: token_response.token,
223            expires_at: token_response.expires_at,
224            api_endpoint,
225        };
226
227        // Store in cache.
228        {
229            let mut guard = self
230                .cached_token
231                .lock()
232                .unwrap_or_else(std::sync::PoisonError::into_inner);
233            *guard = Some(cached.clone());
234        }
235
236        Ok(cached)
237    }
238}
239
240#[async_trait]
241impl Provider for CopilotProvider {
242    fn name(&self) -> &str {
243        &self.provider_name
244    }
245
246    fn api(&self) -> &'static str {
247        "openai-completions"
248    }
249
250    fn model_id(&self) -> &str {
251        &self.model
252    }
253
254    #[allow(clippy::too_many_lines)]
255    async fn stream(
256        &self,
257        context: &Context<'_>,
258        options: &StreamOptions,
259    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
260        // Get a valid session token.
261        let session = self.ensure_session_token().await?;
262
263        // Build an OpenAI provider pointed at the Copilot endpoint.
264        let inner = OpenAIProvider::new(&self.model)
265            .with_provider_name(&self.provider_name)
266            .with_base_url(&session.api_endpoint)
267            .with_compat(self.compat.clone())
268            .with_client(self.client.clone());
269
270        // Override the authorization: Copilot uses the session token,
271        // not the GitHub OAuth token.
272        let mut copilot_options = options.clone();
273        copilot_options.api_key = Some(session.token);
274
275        // Add Copilot-specific headers.
276        copilot_options
277            .headers
278            .insert("Editor-Version".to_string(), copilot_editor_version());
279        copilot_options
280            .headers
281            .insert("User-Agent".to_string(), copilot_user_agent());
282        copilot_options
283            .headers
284            .insert("X-Github-Api-Version".to_string(), github_api_version());
285        copilot_options.headers.insert(
286            "Copilot-Integration-Id".to_string(),
287            "vscode-chat".to_string(),
288        );
289
290        inner.stream(context, &copilot_options).await
291    }
292}
293
294// ── Tests ────────────────────────────────────────────────────────
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use crate::vcr::{
300        Cassette, Interaction, RecordedRequest, RecordedResponse, VcrMode, VcrRecorder,
301    };
302
303    #[test]
304    fn test_copilot_provider_defaults() {
305        let p = CopilotProvider::new("gpt-4o", "ghp_test123");
306        assert_eq!(p.name(), "github-copilot");
307        assert_eq!(p.api(), "openai-completions");
308        assert_eq!(p.model_id(), "gpt-4o");
309        assert_eq!(p.github_api_base, GITHUB_API_BASE);
310    }
311
312    #[test]
313    fn test_copilot_provider_builder() {
314        let p = CopilotProvider::new("gpt-4o", "ghp_test")
315            .with_provider_name("copilot-enterprise")
316            .with_github_api_base("https://github.example.com/api/v3");
317
318        assert_eq!(p.name(), "copilot-enterprise");
319        assert_eq!(p.github_api_base, "https://github.example.com/api/v3");
320    }
321
322    #[test]
323    fn test_copilot_token_response_deserialization() {
324        let json = r#"{
325            "token": "ghu_session_abc123",
326            "expires_at": 1700000000,
327            "endpoints": {
328                "api": "https://copilot-proxy.githubusercontent.com/v1",
329                "proxy": "https://copilot-proxy.githubusercontent.com"
330            }
331        }"#;
332
333        let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
334        assert_eq!(resp.token, "ghu_session_abc123");
335        assert_eq!(resp.expires_at, 1_700_000_000);
336        assert_eq!(
337            resp.endpoints.api,
338            "https://copilot-proxy.githubusercontent.com/v1"
339        );
340    }
341
342    #[test]
343    fn test_copilot_token_response_missing_endpoints() {
344        let json = r#"{"token": "ghu_abc", "expires_at": 1700000000}"#;
345
346        let resp: CopilotTokenResponse = serde_json::from_str(json).expect("parse");
347        assert_eq!(resp.token, "ghu_abc");
348        assert!(resp.endpoints.api.is_empty());
349    }
350
351    #[test]
352    fn test_copilot_token_exchange_url_construction() {
353        // Standard GitHub
354        let p = CopilotProvider::new("gpt-4o", "ghp_test");
355        let expected = "https://api.github.com/copilot_internal/v2/token";
356        let actual = format!(
357            "{}/copilot_internal/v2/token",
358            p.github_api_base.trim_end_matches('/')
359        );
360        assert_eq!(actual, expected);
361
362        // Enterprise with trailing slash
363        let p = CopilotProvider::new("gpt-4o", "ghp_test")
364            .with_github_api_base("https://github.example.com/api/v3/");
365        let actual = format!(
366            "{}/copilot_internal/v2/token",
367            p.github_api_base.trim_end_matches('/')
368        );
369        assert_eq!(
370            actual,
371            "https://github.example.com/api/v3/copilot_internal/v2/token"
372        );
373    }
374
375    #[test]
376    fn test_cached_token_clone() {
377        let cloned = CachedToken {
378            token: "session-tok".to_string(),
379            expires_at: 99999,
380            api_endpoint: "https://example.com/chat/completions".to_string(),
381        };
382        assert_eq!(cloned.token, "session-tok");
383        assert_eq!(cloned.expires_at, 99999);
384    }
385
386    /// Build a VCR client that returns a successful token exchange response.
387    fn vcr_token_exchange_client(
388        test_name: &str,
389        token: &str,
390        expires_at: i64,
391        api_endpoint: &str,
392    ) -> (Client, tempfile::TempDir) {
393        let temp = tempfile::tempdir().expect("tempdir");
394        let response_body = serde_json::json!({
395            "token": token,
396            "expires_at": expires_at,
397            "endpoints": {
398                "api": api_endpoint
399            }
400        })
401        .to_string();
402        let cassette = Cassette {
403            version: "1.0".to_string(),
404            test_name: test_name.to_string(),
405            recorded_at: "2025-01-01T00:00:00Z".to_string(),
406            interactions: vec![Interaction {
407                request: RecordedRequest {
408                    method: "GET".to_string(),
409                    url: "https://api.github.com/copilot_internal/v2/token".to_string(),
410                    headers: vec![],
411                    body: None,
412                    body_text: None,
413                },
414                response: RecordedResponse {
415                    status: 200,
416                    headers: vec![],
417                    body_chunks: vec![response_body],
418                    body_chunks_base64: None,
419                },
420            }],
421        };
422        let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
423        std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
424            .expect("write cassette");
425        let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
426        let client = Client::new().with_vcr(recorder);
427        (client, temp)
428    }
429
430    #[test]
431    fn test_token_exchange_success_via_vcr() {
432        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
433            .build()
434            .expect("rt");
435        rt.block_on(async {
436            let far_future = chrono::Utc::now().timestamp() + 3600;
437            let (client, _temp) = vcr_token_exchange_client(
438                "copilot_token_success",
439                "ghu_session_test",
440                far_future,
441                "https://copilot-proxy.example.com/v1",
442            );
443            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy_token").with_client(client);
444            let cached = provider
445                .ensure_session_token()
446                .await
447                .expect("token exchange");
448            assert_eq!(cached.token, "ghu_session_test");
449            assert_eq!(cached.expires_at, far_future);
450            assert_eq!(
451                cached.api_endpoint,
452                "https://copilot-proxy.example.com/v1/chat/completions"
453            );
454        });
455    }
456
457    #[test]
458    fn test_token_exchange_caches_on_second_call() {
459        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
460            .build()
461            .expect("rt");
462        rt.block_on(async {
463            let far_future = chrono::Utc::now().timestamp() + 3600;
464            let (client, _temp) =
465                vcr_token_exchange_client("copilot_token_cache", "ghu_cached", far_future, "");
466            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
467            // First call populates the cache.
468            let first = provider.ensure_session_token().await.expect("first call");
469            assert_eq!(first.token, "ghu_cached");
470            // Second call should use the cache (no VCR interaction needed).
471            let second = provider.ensure_session_token().await.expect("second call");
472            assert_eq!(second.token, "ghu_cached");
473        });
474    }
475
476    #[test]
477    fn test_token_exchange_error_returns_auth_error() {
478        let temp = tempfile::tempdir().expect("tempdir");
479        let test_name = "copilot_token_error";
480        let cassette = Cassette {
481            version: "1.0".to_string(),
482            test_name: test_name.to_string(),
483            recorded_at: "2025-01-01T00:00:00Z".to_string(),
484            interactions: vec![Interaction {
485                request: RecordedRequest {
486                    method: "GET".to_string(),
487                    url: "https://api.github.com/copilot_internal/v2/token".to_string(),
488                    headers: vec![],
489                    body: None,
490                    body_text: None,
491                },
492                response: RecordedResponse {
493                    status: 401,
494                    headers: vec![],
495                    body_chunks: vec![r#"{"message":"Bad credentials"}"#.to_string()],
496                    body_chunks_base64: None,
497                },
498            }],
499        };
500        let serialized = serde_json::to_string_pretty(&cassette).expect("serialize");
501        std::fs::write(temp.path().join(format!("{test_name}.json")), serialized)
502            .expect("write cassette");
503        let recorder = VcrRecorder::new_with(test_name, VcrMode::Playback, temp.path());
504        let client = Client::new().with_vcr(recorder);
505
506        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
507            .build()
508            .expect("rt");
509        rt.block_on(async {
510            let provider = CopilotProvider::new("gpt-4o", "ghp_bad_token").with_client(client);
511            let result = provider.ensure_session_token().await;
512            assert!(result.is_err());
513            let msg = result.unwrap_err().to_string();
514            assert!(
515                msg.contains("401") || msg.contains("Bad credentials"),
516                "expected auth error, got: {msg}"
517            );
518        });
519    }
520
521    #[test]
522    fn test_token_exchange_fallback_endpoint() {
523        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
524            .build()
525            .expect("rt");
526        rt.block_on(async {
527            let far_future = chrono::Utc::now().timestamp() + 3600;
528            // Empty api endpoint → should fall back to default.
529            let (client, _temp) =
530                vcr_token_exchange_client("copilot_token_fallback", "ghu_fallback", far_future, "");
531            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
532            let cached = provider.ensure_session_token().await.expect("fallback");
533            assert_eq!(
534                cached.api_endpoint,
535                "https://api.githubcopilot.com/chat/completions"
536            );
537        });
538    }
539
540    #[test]
541    fn test_token_exchange_endpoint_already_has_path() {
542        let rt = asupersync::runtime::RuntimeBuilder::current_thread()
543            .build()
544            .expect("rt");
545        rt.block_on(async {
546            let far_future = chrono::Utc::now().timestamp() + 3600;
547            let (client, _temp) = vcr_token_exchange_client(
548                "copilot_token_full_endpoint",
549                "ghu_full",
550                far_future,
551                "https://custom.proxy.com/chat/completions",
552            );
553            let provider = CopilotProvider::new("gpt-4o", "ghp_dummy").with_client(client);
554            let cached = provider
555                .ensure_session_token()
556                .await
557                .expect("full endpoint");
558            // Endpoint already includes /chat/completions; should not be duplicated.
559            assert_eq!(
560                cached.api_endpoint,
561                "https://custom.proxy.com/chat/completions"
562            );
563        });
564    }
565}