Skip to main content

byokey_provider/executor/
copilot.rs

1//! GitHub Copilot executor — OpenAI-compatible API.
2//!
3//! Auth: device code flow → GitHub token → exchange for short-lived Copilot API token.
4//! Format: `OpenAI` passthrough via `aigw::openai_compat` for URL/header/request building.
5//!         Streaming: raw byte passthrough (Option P). Non-streaming: aigw response translator.
6use crate::http_util::ProviderHttp;
7use crate::registry;
8use aigw::openai::translate::OpenAIResponseTranslator;
9use aigw::openai::{HttpTransportConfig, OpenAIAuthConfig};
10use aigw::openai_compat::translate::OpenAICompatRequestTranslator;
11use aigw::openai_compat::{OpenAICompatConfig, OpenAICompatProvider, Quirks};
12use aigw_core::translate::{RequestTranslator as _, ResponseTranslator as _};
13use async_trait::async_trait;
14use byokey_auth::AuthManager;
15use byokey_types::{
16    AccountInfo, ByokError, ChatRequest, ProviderId, RateLimitStore,
17    traits::{ProviderExecutor, ProviderResponse, Result},
18};
19use secrecy::SecretString;
20use serde_json::Value;
21use std::{
22    cmp::Ordering as CmpOrdering,
23    collections::{BTreeMap, HashMap},
24    sync::{Arc, LazyLock, Mutex},
25    time::{Duration, Instant},
26};
27
28/// Cached quota snapshot for a single Copilot account.
29struct CachedQuota {
30    percent_remaining: f64,
31    unlimited: bool,
32    fetched_at: Instant,
33}
34
35/// Tracks the currently selected account and per-account quota snapshots.
36struct AccountTracker {
37    /// Currently sticky account id.
38    current: Option<String>,
39    /// When the last rebalance comparison happened.
40    last_rebalance: Option<Instant>,
41    /// Per-account cached quota data.
42    quotas: HashMap<String, CachedQuota>,
43}
44
45/// Global account tracker for quota-aware multi-account routing.
46static ACCOUNT_TRACKER: LazyLock<Mutex<AccountTracker>> = LazyLock::new(|| {
47    Mutex::new(AccountTracker {
48        current: None,
49        last_rebalance: None,
50        quotas: HashMap::new(),
51    })
52});
53
54// `Duration::from_mins` is not yet a const fn on stable.
55/// How often to re-compare quotas across accounts.
56#[allow(clippy::duration_suboptimal_units)]
57const REBALANCE_INTERVAL: Duration = Duration::from_secs(5 * 60);
58
59/// Quota cache TTL — avoid re-fetching within this window.
60#[allow(clippy::duration_suboptimal_units)]
61const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
62
63/// Default GitHub Copilot Chat Completions API base URL.
64const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
65
66/// Endpoint to exchange a GitHub OAuth token for a short-lived Copilot API token.
67const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
68
69/// Copilot usage/quota endpoint (returns `quota_snapshots`).
70const COPILOT_USER_URL: &str = "https://api.github.com/copilot_internal/user";
71
72// Header values matching the VS Code Copilot Chat extension.
73const USER_AGENT: &str = "GitHubCopilotChat/0.35.0";
74const EDITOR_VERSION: &str = "vscode/1.107.0";
75const PLUGIN_VERSION: &str = "copilot-chat/0.35.0";
76const INTEGRATION_ID: &str = "vscode-chat";
77const OPENAI_INTENT: &str = "conversation-panel";
78const GITHUB_API_VERSION: &str = "2025-04-01";
79
80/// A cached Copilot API token with its expiry time.
81struct CachedToken {
82    token: String,
83    api_endpoint: String,
84    expires_at: Instant,
85    /// `true` = Pro/Business/Enterprise, `false` = Free tier.
86    is_pro: bool,
87}
88
89/// Score a cached quota for account comparison.
90///
91/// `unlimited` → 100, known quota → `percent_remaining`, unknown → 50 (neutral).
92fn quota_score(q: Option<&CachedQuota>) -> f64 {
93    match q {
94        Some(q) if q.unlimited => 100.0,
95        Some(q) => q.percent_remaining,
96        None => 50.0,
97    }
98}
99
100/// Executor for the GitHub Copilot API.
101pub struct CopilotExecutor {
102    ph: ProviderHttp,
103    api_key: Option<String>,
104    base_url: Option<String>,
105    auth: Arc<AuthManager>,
106    /// Cache: GitHub token → short-lived Copilot API token.
107    cache: Mutex<HashMap<String, CachedToken>>,
108    user_agent: String,
109    editor_version: String,
110    plugin_version: String,
111}
112
113#[bon::bon]
114impl CopilotExecutor {
115    /// Creates a new Copilot executor.
116    #[builder]
117    pub fn new(
118        http: rquest::Client,
119        auth: Arc<AuthManager>,
120        api_key: Option<String>,
121        base_url: Option<String>,
122        ratelimit: Option<Arc<RateLimitStore>>,
123        user_agent: Option<String>,
124        editor_version: Option<String>,
125        plugin_version: Option<String>,
126    ) -> Self {
127        let mut ph = ProviderHttp::new(http);
128        if let Some(store) = ratelimit {
129            ph = ph.with_ratelimit(store, ProviderId::Copilot);
130        }
131        Self {
132            ph,
133            api_key,
134            base_url,
135            auth,
136            cache: Mutex::new(HashMap::new()),
137            user_agent: user_agent.unwrap_or_else(|| USER_AGENT.to_string()),
138            editor_version: editor_version.unwrap_or_else(|| EDITOR_VERSION.to_string()),
139            plugin_version: plugin_version.unwrap_or_else(|| PLUGIN_VERSION.to_string()),
140        }
141    }
142
143    /// Exchange a GitHub token for a Copilot API token and cache the result.
144    ///
145    /// Returns `(copilot_api_token, api_endpoint)`.
146    async fn exchange_and_cache(&self, github_token: &str) -> Result<(String, String)> {
147        // Check cache first
148        {
149            let cache = self.cache.lock().unwrap();
150            if let Some(cached) = cache.get(github_token)
151                && cached.expires_at > Instant::now()
152            {
153                return Ok((cached.token.clone(), cached.api_endpoint.clone()));
154            }
155        }
156
157        // Exchange GitHub token for Copilot API token
158        let resp = self
159            .ph
160            .client()
161            .get(COPILOT_TOKEN_URL)
162            .header("authorization", format!("token {github_token}"))
163            .header("accept", "application/json")
164            .header("user-agent", self.user_agent.as_str())
165            .header("editor-version", self.editor_version.as_str())
166            .header("editor-plugin-version", self.plugin_version.as_str())
167            .send()
168            .await?;
169
170        let status = resp.status();
171        if !status.is_success() {
172            let text = resp.text().await.unwrap_or_default();
173            return Err(ByokError::Auth(format!(
174                "Copilot token exchange {status}: {text}"
175            )));
176        }
177
178        let json: Value = resp.json().await?;
179
180        let api_token = json
181            .get("token")
182            .and_then(Value::as_str)
183            .ok_or_else(|| ByokError::Auth("missing token in Copilot response".into()))?
184            .to_string();
185
186        let expires_at_unix = json.get("expires_at").and_then(Value::as_i64).unwrap_or(0);
187
188        let ttl = if expires_at_unix > 0 {
189            let now_unix = std::time::SystemTime::now()
190                .duration_since(std::time::UNIX_EPOCH)
191                .unwrap_or_default()
192                .as_secs()
193                .cast_signed();
194            let secs = (expires_at_unix - now_unix).max(0).cast_unsigned();
195            Duration::from_secs(secs)
196        } else {
197            Duration::from_mins(25) // default TTL
198        };
199
200        let default_base = self.base_url.as_deref().unwrap_or(DEFAULT_BASE_URL);
201        let api_endpoint = json
202            .pointer("/endpoints/api")
203            .and_then(Value::as_str)
204            .unwrap_or(default_base)
205            .trim_end_matches('/')
206            .to_string();
207
208        // If "copilot_plan" is absent or not "copilot_free", assume Pro+.
209        let is_pro = json
210            .get("copilot_plan")
211            .and_then(Value::as_str)
212            .is_none_or(|plan| plan != "copilot_free");
213
214        // Cache the new token
215        {
216            let mut cache = self.cache.lock().unwrap();
217            cache.insert(
218                github_token.to_string(),
219                CachedToken {
220                    token: api_token.clone(),
221                    api_endpoint: api_endpoint.clone(),
222                    expires_at: Instant::now() + ttl,
223                    is_pro,
224                },
225            );
226        }
227
228        Ok((api_token, api_endpoint))
229    }
230
231    /// Obtain a Copilot API token for a specific account.
232    async fn copilot_token_for_account(&self, account_id: &str) -> Result<(String, String)> {
233        let github_token = self
234            .auth
235            .get_token_for(&ProviderId::Copilot, account_id)
236            .await?
237            .access_token;
238        self.exchange_and_cache(&github_token).await
239    }
240
241    /// Fetch quota snapshot for a single GitHub account.
242    ///
243    /// Returns `(percent_remaining, unlimited)` on success, `None` on any failure.
244    async fn fetch_quota(&self, github_token: &str) -> Option<(f64, bool)> {
245        let resp = self
246            .ph
247            .client()
248            .get(COPILOT_USER_URL)
249            .header("authorization", format!("token {github_token}"))
250            .header("accept", "application/json")
251            .header("user-agent", self.user_agent.as_str())
252            .send()
253            .await
254            .ok()?;
255
256        if !resp.status().is_success() {
257            return None;
258        }
259
260        let json: Value = resp.json().await.ok()?;
261        let pi = json.pointer("/quota_snapshots/premium_interactions")?;
262        let unlimited = pi
263            .get("unlimited")
264            .and_then(Value::as_bool)
265            .unwrap_or(false);
266        let percent = pi
267            .get("percent_remaining")
268            .and_then(Value::as_f64)
269            .unwrap_or(0.0);
270        Some((percent, unlimited))
271    }
272
273    /// Refresh quota for an account if the cached value is stale or missing.
274    async fn refresh_quota_if_stale(&self, account_id: &str) {
275        // Check if we already have a fresh cache entry.
276        {
277            let tracker = ACCOUNT_TRACKER.lock().unwrap();
278            if let Some(q) = tracker.quotas.get(account_id)
279                && q.fetched_at.elapsed() < QUOTA_CACHE_TTL
280            {
281                return;
282            }
283        }
284
285        // Fetch the GitHub token for this account.
286        let github_token = match self
287            .auth
288            .get_token_for(&ProviderId::Copilot, account_id)
289            .await
290        {
291            Ok(t) => t.access_token,
292            Err(e) => {
293                tracing::warn!(account_id, error = %e, "failed to get token for quota fetch");
294                return;
295            }
296        };
297
298        if let Some((percent, unlimited)) = self.fetch_quota(&github_token).await {
299            tracing::info!(
300                account_id,
301                percent_remaining = percent,
302                unlimited,
303                "fetched copilot quota"
304            );
305            let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
306            tracker.quotas.insert(
307                account_id.to_string(),
308                CachedQuota {
309                    percent_remaining: percent,
310                    unlimited,
311                    fetched_at: Instant::now(),
312                },
313            );
314        } else {
315            tracing::warn!(account_id, "failed to fetch copilot quota, skipping");
316        }
317    }
318
319    /// Select the best account based on cached quota data.
320    ///
321    /// Uses sticky selection: keeps the current account until the rebalance
322    /// interval elapses, then re-compares all accounts' quotas.
323    async fn select_account(&self, accounts: &[AccountInfo]) -> Result<String> {
324        {
325            let tracker = ACCOUNT_TRACKER.lock().unwrap();
326
327            // Sticky: current is still valid and rebalance interval hasn't elapsed.
328            if let Some(ref current) = tracker.current
329                && accounts.iter().any(|a| a.account_id == *current)
330                && tracker
331                    .last_rebalance
332                    .is_some_and(|t| t.elapsed() < REBALANCE_INTERVAL)
333            {
334                return Ok(current.clone());
335            }
336        }
337
338        // Fetch quotas (skips accounts with fresh cache).
339        for account in accounts {
340            self.refresh_quota_if_stale(&account.account_id).await;
341        }
342
343        // Pick the account with the highest remaining quota.
344        let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
345        let best = accounts
346            .iter()
347            .max_by(|a, b| {
348                let qa = tracker.quotas.get(&a.account_id);
349                let qb = tracker.quotas.get(&b.account_id);
350                quota_score(qa)
351                    .partial_cmp(&quota_score(qb))
352                    .unwrap_or(CmpOrdering::Equal)
353            })
354            .ok_or_else(|| ByokError::Auth("no copilot accounts available".into()))?;
355
356        tracing::info!(
357            account_id = %best.account_id,
358            score = quota_score(tracker.quotas.get(&best.account_id)),
359            "selected copilot account"
360        );
361
362        tracker.current = Some(best.account_id.clone());
363        tracker.last_rebalance = Some(Instant::now());
364        Ok(best.account_id.clone())
365    }
366
367    /// Force the next `copilot_token()` call to re-evaluate account selection.
368    ///
369    /// # Panics
370    ///
371    /// Panics if the account tracker mutex is poisoned.
372    pub fn invalidate_current_account() {
373        let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
374        tracker.last_rebalance = None;
375    }
376
377    /// Returns the Copilot API token and base endpoint URL (without path suffix).
378    ///
379    /// When `api_key` is set it is used directly (skip token exchange).
380    /// With multiple accounts, selects the account with the most remaining quota.
381    /// Otherwise falls back to the active account.
382    ///
383    /// # Errors
384    ///
385    /// Returns [`ByokError::Auth`] if the token exchange fails.
386    ///
387    /// # Panics
388    ///
389    /// Panics if the internal token cache mutex is poisoned.
390    pub async fn copilot_token(&self) -> Result<(String, String)> {
391        if let Some(key) = &self.api_key {
392            let base = self
393                .base_url
394                .as_deref()
395                .unwrap_or(DEFAULT_BASE_URL)
396                .trim_end_matches('/')
397                .to_string();
398            return Ok((key.clone(), base));
399        }
400
401        let accounts = self.auth.list_accounts(&ProviderId::Copilot).await?;
402
403        if accounts.len() > 1 {
404            let account_id = self.select_account(&accounts).await?;
405            return self.copilot_token_for_account(&account_id).await;
406        }
407
408        // Single or no account: use active account (original behavior).
409        let github_token = self
410            .auth
411            .get_token(&ProviderId::Copilot)
412            .await?
413            .access_token;
414        self.exchange_and_cache(&github_token).await
415    }
416
417    /// Obtains the Copilot API token and base endpoint URL.
418    async fn copilot_creds(&self) -> Result<(String, String)> {
419        self.copilot_token().await
420    }
421
422    /// Builds an [`OpenAICompatProvider`] for a single request, given the resolved
423    /// Copilot API token and base endpoint URL.
424    ///
425    /// Static Copilot-specific headers are placed in `default_headers` so aigw
426    /// includes them in every request it builds. The `x-initiator` header is
427    /// **per-request** and must be added separately after translation.
428    fn build_provider(&self, token: &str, base_url: &str) -> Result<OpenAICompatProvider> {
429        let mut default_headers = BTreeMap::new();
430        default_headers.insert("user-agent".to_owned(), self.user_agent.clone());
431        default_headers.insert("editor-version".to_owned(), self.editor_version.clone());
432        default_headers.insert(
433            "editor-plugin-version".to_owned(),
434            self.plugin_version.clone(),
435        );
436        default_headers.insert("openai-intent".to_owned(), OPENAI_INTENT.to_owned());
437        default_headers.insert(
438            "copilot-integration-id".to_owned(),
439            INTEGRATION_ID.to_owned(),
440        );
441        default_headers.insert(
442            "x-github-api-version".to_owned(),
443            GITHUB_API_VERSION.to_owned(),
444        );
445        default_headers.insert("content-type".to_owned(), "application/json".to_owned());
446
447        OpenAICompatProvider::new(OpenAICompatConfig {
448            name: "copilot".to_owned(),
449            http: HttpTransportConfig {
450                base_url: base_url.to_owned(),
451                timeout_seconds: 600,
452                default_headers,
453            },
454            auth: OpenAIAuthConfig {
455                api_key: SecretString::from(token.to_owned()),
456                organization: None,
457                project: None,
458            },
459            quirks: Quirks::default(),
460        })
461        .map_err(|e| ByokError::Config(e.to_string()))
462    }
463
464    /// Returns `true` if any cached Copilot token belongs to a Pro/Business/Enterprise plan.
465    ///
466    /// With multiple accounts, returns `true` if **any** account is Pro+.
467    /// Defaults to `true` (Pro) if the plan cannot be determined (e.g. no cached token yet
468    /// or the `copilot_plan` field was absent in the token exchange response).
469    ///
470    /// # Panics
471    ///
472    /// Panics if the internal token cache mutex is poisoned.
473    pub async fn is_pro(&self) -> bool {
474        let accounts = self
475            .auth
476            .list_accounts(&ProviderId::Copilot)
477            .await
478            .unwrap_or_default();
479
480        if accounts.len() > 1 {
481            // Check all cached tokens: any Pro → true.
482            let cache = self.cache.lock().unwrap();
483            let now = Instant::now();
484            let mut found_any = false;
485            for cached in cache.values() {
486                if cached.expires_at > now {
487                    found_any = true;
488                    if cached.is_pro {
489                        return true;
490                    }
491                }
492            }
493            // If we found cached tokens but none are Pro, return false.
494            if found_any {
495                return false;
496            }
497            // No cached tokens yet: conservative default.
498            return true;
499        }
500
501        // Single account: original behavior.
502        if let Ok(github_token) = self
503            .auth
504            .get_token(&ProviderId::Copilot)
505            .await
506            .map(|t| t.access_token)
507        {
508            let cache = self.cache.lock().unwrap();
509            if let Some(cached) = cache.get(&github_token)
510                && cached.expires_at > Instant::now()
511            {
512                return cached.is_pro;
513            }
514        }
515        true // conservative default: assume Pro
516    }
517
518    /// Returns the `X-Initiator` header value based on whether the request
519    /// contains any assistant/tool messages (agent) or only user messages.
520    fn initiator(request: &ChatRequest) -> &'static str {
521        let is_agent = request.messages.iter().any(|m| {
522            matches!(
523                m.get("role").and_then(Value::as_str),
524                Some("assistant" | "tool")
525            )
526        });
527        if is_agent { "agent" } else { "user" }
528    }
529}
530
531#[async_trait]
532impl ProviderExecutor for CopilotExecutor {
533    async fn chat_completion(&self, request: ChatRequest) -> Result<ProviderResponse> {
534        let stream = request.stream;
535        // `x-initiator` is derived from the request message roles before consuming it.
536        let initiator = Self::initiator(&request);
537
538        // Translate: BYOKEY ChatRequest → aigw ChatRequest.
539        let aigw_request: aigw_core::model::ChatRequest =
540            serde_json::from_value(request.into_body())
541                .map_err(|e| ByokError::Translation(e.to_string()))?;
542
543        let accounts = self
544            .auth
545            .list_accounts(&ProviderId::Copilot)
546            .await
547            .unwrap_or_default();
548        let max_attempts = if accounts.len() > 1 {
549            accounts.len().min(3)
550        } else {
551            1
552        };
553
554        let mut last_err = None;
555        for attempt in 0..max_attempts {
556            let creds = self.copilot_creds().await;
557            let (token, endpoint) = match creds {
558                Ok(c) => c,
559                Err(e) => {
560                    if max_attempts > 1 {
561                        tracing::warn!(attempt, error = %e, "copilot creds failed, trying next account");
562                        Self::invalidate_current_account();
563                        last_err = Some(e);
564                        continue;
565                    }
566                    return Err(e);
567                }
568            };
569
570            // Build aigw provider + translator for this token/endpoint combination.
571            let provider = match self.build_provider(&token, &endpoint) {
572                Ok(p) => p,
573                Err(e) => return Err(e),
574            };
575            let translator = OpenAICompatRequestTranslator::new(&provider)
576                .map_err(|e| ByokError::Config(e.to_string()))?;
577
578            // Translate the canonical request to a Copilot HTTP request.
579            // aigw handles: URL (`{endpoint}/chat/completions`), static headers,
580            // `Authorization: Bearer <token>`, content-type, and body serialization.
581            let translated = if stream {
582                translator.translate_stream_request(&aigw_request)
583            } else {
584                translator.translate_request(&aigw_request)
585            }
586            .map_err(|e| ByokError::Translation(e.to_string()))?;
587
588            // Build rquest from aigw's translated URL and headers.
589            let mut builder = self.ph.client().post(&translated.url);
590            for (name, value) in &translated.headers {
591                if let Ok(v) = value.to_str() {
592                    builder = builder.header(name.as_str(), v);
593                }
594            }
595            // x-initiator is per-request (depends on message roles) so aigw can't
596            // include it in default_headers. Append it manually after translation.
597            builder = builder.header("x-initiator", initiator);
598            // Prevent compressed SSE streams from breaking the line scanner.
599            builder = builder.header("accept-encoding", "identity");
600            // Attach the translated body (already serialized JSON bytes by aigw).
601            let builder = builder.body(translated.body.to_vec());
602
603            if stream {
604                // Option P: raw byte passthrough — stream Copilot SSE bytes to caller
605                // unchanged. aigw is used only for URL/header/body building.
606                match self.ph.send_passthrough(builder, true).await {
607                    Ok(resp) => return Ok(resp),
608                    Err(e) => {
609                        if !e.is_retryable() || attempt + 1 >= max_attempts {
610                            return Err(e);
611                        }
612                        tracing::warn!(attempt, error = %e, "copilot stream request failed, trying next account");
613                        Self::invalidate_current_account();
614                        last_err = Some(e);
615                    }
616                }
617            } else {
618                // Non-streaming: use aigw's OpenAICompatResponseTranslator.
619                let resp = match self.ph.send(builder).await {
620                    Ok(r) => r,
621                    Err(e) => {
622                        if !e.is_retryable() || attempt + 1 >= max_attempts {
623                            return Err(e);
624                        }
625                        tracing::warn!(attempt, error = %e, "copilot request failed, trying next account");
626                        Self::invalidate_current_account();
627                        last_err = Some(e);
628                        continue;
629                    }
630                };
631                let resp_bytes = resp.bytes().await.map_err(ByokError::from)?;
632                let aigw_response = OpenAIResponseTranslator
633                    .translate_response(http::StatusCode::OK, &resp_bytes)
634                    .map_err(|e: aigw_core::error::TranslateError| {
635                        ByokError::Translation(e.to_string())
636                    })?;
637                let value = serde_json::to_value(aigw_response)
638                    .map_err(|e| ByokError::Translation(e.to_string()))?;
639                return Ok(ProviderResponse::Complete(value));
640            }
641        }
642
643        tracing::error!(
644            attempts = max_attempts,
645            "all copilot accounts exhausted for chat request"
646        );
647        Err(last_err.unwrap_or_else(|| ByokError::Auth("no copilot accounts available".into())))
648    }
649
650    fn supported_models(&self) -> Vec<String> {
651        registry::models_for_provider(&ProviderId::Copilot)
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    fn make_executor() -> CopilotExecutor {
660        let (client, auth) = crate::http_util::test_auth();
661        CopilotExecutor::builder().http(client).auth(auth).build()
662    }
663
664    #[test]
665    fn test_supported_models_non_empty() {
666        let ex = make_executor();
667        assert!(!ex.supported_models().is_empty());
668    }
669
670    #[test]
671    fn test_initiator_user() {
672        let req: ChatRequest = serde_json::from_value(serde_json::json!({
673            "model": "gpt-4o",
674            "messages": [{"role": "user", "content": "hi"}]
675        }))
676        .unwrap();
677        assert_eq!(CopilotExecutor::initiator(&req), "user");
678    }
679
680    #[test]
681    fn test_initiator_agent() {
682        let req: ChatRequest = serde_json::from_value(serde_json::json!({
683            "model": "gpt-4o",
684            "messages": [
685                {"role": "user", "content": "hi"},
686                {"role": "assistant", "content": "hello"}
687            ]
688        }))
689        .unwrap();
690        assert_eq!(CopilotExecutor::initiator(&req), "agent");
691    }
692}