1use crate::http_util::ProviderHttp;
7use crate::registry;
8use aigw::openai::translate::OpenAIResponseTranslator;
9use aigw::openai::{HttpTransportConfig, OpenAIAuthConfig};
10use aigw::openai_compat::translate::OpenAICompatRequestTranslator;
11use aigw::openai_compat::{OpenAICompatConfig, OpenAICompatProvider, Quirks};
12use aigw_core::translate::{RequestTranslator as _, ResponseTranslator as _};
13use async_trait::async_trait;
14use byokey_auth::AuthManager;
15use byokey_types::{
16 AccountInfo, ByokError, ChatRequest, ProviderId, RateLimitStore,
17 traits::{ProviderExecutor, ProviderResponse, Result},
18};
19use secrecy::SecretString;
20use serde_json::Value;
21use std::{
22 cmp::Ordering as CmpOrdering,
23 collections::{BTreeMap, HashMap},
24 sync::{Arc, LazyLock, Mutex},
25 time::{Duration, Instant},
26};
27
28struct CachedQuota {
30 percent_remaining: f64,
31 unlimited: bool,
32 fetched_at: Instant,
33}
34
35struct AccountTracker {
37 current: Option<String>,
39 last_rebalance: Option<Instant>,
41 quotas: HashMap<String, CachedQuota>,
43}
44
45static ACCOUNT_TRACKER: LazyLock<Mutex<AccountTracker>> = LazyLock::new(|| {
47 Mutex::new(AccountTracker {
48 current: None,
49 last_rebalance: None,
50 quotas: HashMap::new(),
51 })
52});
53
54#[allow(clippy::duration_suboptimal_units)]
57const REBALANCE_INTERVAL: Duration = Duration::from_secs(5 * 60);
58
59#[allow(clippy::duration_suboptimal_units)]
61const QUOTA_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
62
63const DEFAULT_BASE_URL: &str = "https://api.githubcopilot.com";
65
66const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
68
69const COPILOT_USER_URL: &str = "https://api.github.com/copilot_internal/user";
71
72const USER_AGENT: &str = "GitHubCopilotChat/0.35.0";
74const EDITOR_VERSION: &str = "vscode/1.107.0";
75const PLUGIN_VERSION: &str = "copilot-chat/0.35.0";
76const INTEGRATION_ID: &str = "vscode-chat";
77const OPENAI_INTENT: &str = "conversation-panel";
78const GITHUB_API_VERSION: &str = "2025-04-01";
79
80struct CachedToken {
82 token: String,
83 api_endpoint: String,
84 expires_at: Instant,
85 is_pro: bool,
87}
88
89fn quota_score(q: Option<&CachedQuota>) -> f64 {
93 match q {
94 Some(q) if q.unlimited => 100.0,
95 Some(q) => q.percent_remaining,
96 None => 50.0,
97 }
98}
99
100pub struct CopilotExecutor {
102 ph: ProviderHttp,
103 api_key: Option<String>,
104 base_url: Option<String>,
105 auth: Arc<AuthManager>,
106 cache: Mutex<HashMap<String, CachedToken>>,
108 user_agent: String,
109 editor_version: String,
110 plugin_version: String,
111}
112
113#[bon::bon]
114impl CopilotExecutor {
115 #[builder]
117 pub fn new(
118 http: rquest::Client,
119 auth: Arc<AuthManager>,
120 api_key: Option<String>,
121 base_url: Option<String>,
122 ratelimit: Option<Arc<RateLimitStore>>,
123 user_agent: Option<String>,
124 editor_version: Option<String>,
125 plugin_version: Option<String>,
126 ) -> Self {
127 let mut ph = ProviderHttp::new(http);
128 if let Some(store) = ratelimit {
129 ph = ph.with_ratelimit(store, ProviderId::Copilot);
130 }
131 Self {
132 ph,
133 api_key,
134 base_url,
135 auth,
136 cache: Mutex::new(HashMap::new()),
137 user_agent: user_agent.unwrap_or_else(|| USER_AGENT.to_string()),
138 editor_version: editor_version.unwrap_or_else(|| EDITOR_VERSION.to_string()),
139 plugin_version: plugin_version.unwrap_or_else(|| PLUGIN_VERSION.to_string()),
140 }
141 }
142
143 async fn exchange_and_cache(&self, github_token: &str) -> Result<(String, String)> {
147 {
149 let cache = self.cache.lock().unwrap();
150 if let Some(cached) = cache.get(github_token)
151 && cached.expires_at > Instant::now()
152 {
153 return Ok((cached.token.clone(), cached.api_endpoint.clone()));
154 }
155 }
156
157 let resp = self
159 .ph
160 .client()
161 .get(COPILOT_TOKEN_URL)
162 .header("authorization", format!("token {github_token}"))
163 .header("accept", "application/json")
164 .header("user-agent", self.user_agent.as_str())
165 .header("editor-version", self.editor_version.as_str())
166 .header("editor-plugin-version", self.plugin_version.as_str())
167 .send()
168 .await?;
169
170 let status = resp.status();
171 if !status.is_success() {
172 let text = resp.text().await.unwrap_or_default();
173 return Err(ByokError::Auth(format!(
174 "Copilot token exchange {status}: {text}"
175 )));
176 }
177
178 let json: Value = resp.json().await?;
179
180 let api_token = json
181 .get("token")
182 .and_then(Value::as_str)
183 .ok_or_else(|| ByokError::Auth("missing token in Copilot response".into()))?
184 .to_string();
185
186 let expires_at_unix = json.get("expires_at").and_then(Value::as_i64).unwrap_or(0);
187
188 let ttl = if expires_at_unix > 0 {
189 let now_unix = std::time::SystemTime::now()
190 .duration_since(std::time::UNIX_EPOCH)
191 .unwrap_or_default()
192 .as_secs()
193 .cast_signed();
194 let secs = (expires_at_unix - now_unix).max(0).cast_unsigned();
195 Duration::from_secs(secs)
196 } else {
197 Duration::from_mins(25) };
199
200 let default_base = self.base_url.as_deref().unwrap_or(DEFAULT_BASE_URL);
201 let api_endpoint = json
202 .pointer("/endpoints/api")
203 .and_then(Value::as_str)
204 .unwrap_or(default_base)
205 .trim_end_matches('/')
206 .to_string();
207
208 let is_pro = json
210 .get("copilot_plan")
211 .and_then(Value::as_str)
212 .is_none_or(|plan| plan != "copilot_free");
213
214 {
216 let mut cache = self.cache.lock().unwrap();
217 cache.insert(
218 github_token.to_string(),
219 CachedToken {
220 token: api_token.clone(),
221 api_endpoint: api_endpoint.clone(),
222 expires_at: Instant::now() + ttl,
223 is_pro,
224 },
225 );
226 }
227
228 Ok((api_token, api_endpoint))
229 }
230
231 async fn copilot_token_for_account(&self, account_id: &str) -> Result<(String, String)> {
233 let github_token = self
234 .auth
235 .get_token_for(&ProviderId::Copilot, account_id)
236 .await?
237 .access_token;
238 self.exchange_and_cache(&github_token).await
239 }
240
241 async fn fetch_quota(&self, github_token: &str) -> Option<(f64, bool)> {
245 let resp = self
246 .ph
247 .client()
248 .get(COPILOT_USER_URL)
249 .header("authorization", format!("token {github_token}"))
250 .header("accept", "application/json")
251 .header("user-agent", self.user_agent.as_str())
252 .send()
253 .await
254 .ok()?;
255
256 if !resp.status().is_success() {
257 return None;
258 }
259
260 let json: Value = resp.json().await.ok()?;
261 let pi = json.pointer("/quota_snapshots/premium_interactions")?;
262 let unlimited = pi
263 .get("unlimited")
264 .and_then(Value::as_bool)
265 .unwrap_or(false);
266 let percent = pi
267 .get("percent_remaining")
268 .and_then(Value::as_f64)
269 .unwrap_or(0.0);
270 Some((percent, unlimited))
271 }
272
273 async fn refresh_quota_if_stale(&self, account_id: &str) {
275 {
277 let tracker = ACCOUNT_TRACKER.lock().unwrap();
278 if let Some(q) = tracker.quotas.get(account_id)
279 && q.fetched_at.elapsed() < QUOTA_CACHE_TTL
280 {
281 return;
282 }
283 }
284
285 let github_token = match self
287 .auth
288 .get_token_for(&ProviderId::Copilot, account_id)
289 .await
290 {
291 Ok(t) => t.access_token,
292 Err(e) => {
293 tracing::warn!(account_id, error = %e, "failed to get token for quota fetch");
294 return;
295 }
296 };
297
298 if let Some((percent, unlimited)) = self.fetch_quota(&github_token).await {
299 tracing::info!(
300 account_id,
301 percent_remaining = percent,
302 unlimited,
303 "fetched copilot quota"
304 );
305 let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
306 tracker.quotas.insert(
307 account_id.to_string(),
308 CachedQuota {
309 percent_remaining: percent,
310 unlimited,
311 fetched_at: Instant::now(),
312 },
313 );
314 } else {
315 tracing::warn!(account_id, "failed to fetch copilot quota, skipping");
316 }
317 }
318
319 async fn select_account(&self, accounts: &[AccountInfo]) -> Result<String> {
324 {
325 let tracker = ACCOUNT_TRACKER.lock().unwrap();
326
327 if let Some(ref current) = tracker.current
329 && accounts.iter().any(|a| a.account_id == *current)
330 && tracker
331 .last_rebalance
332 .is_some_and(|t| t.elapsed() < REBALANCE_INTERVAL)
333 {
334 return Ok(current.clone());
335 }
336 }
337
338 for account in accounts {
340 self.refresh_quota_if_stale(&account.account_id).await;
341 }
342
343 let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
345 let best = accounts
346 .iter()
347 .max_by(|a, b| {
348 let qa = tracker.quotas.get(&a.account_id);
349 let qb = tracker.quotas.get(&b.account_id);
350 quota_score(qa)
351 .partial_cmp("a_score(qb))
352 .unwrap_or(CmpOrdering::Equal)
353 })
354 .ok_or_else(|| ByokError::Auth("no copilot accounts available".into()))?;
355
356 tracing::info!(
357 account_id = %best.account_id,
358 score = quota_score(tracker.quotas.get(&best.account_id)),
359 "selected copilot account"
360 );
361
362 tracker.current = Some(best.account_id.clone());
363 tracker.last_rebalance = Some(Instant::now());
364 Ok(best.account_id.clone())
365 }
366
367 pub fn invalidate_current_account() {
373 let mut tracker = ACCOUNT_TRACKER.lock().unwrap();
374 tracker.last_rebalance = None;
375 }
376
377 pub async fn copilot_token(&self) -> Result<(String, String)> {
391 if let Some(key) = &self.api_key {
392 let base = self
393 .base_url
394 .as_deref()
395 .unwrap_or(DEFAULT_BASE_URL)
396 .trim_end_matches('/')
397 .to_string();
398 return Ok((key.clone(), base));
399 }
400
401 let accounts = self.auth.list_accounts(&ProviderId::Copilot).await?;
402
403 if accounts.len() > 1 {
404 let account_id = self.select_account(&accounts).await?;
405 return self.copilot_token_for_account(&account_id).await;
406 }
407
408 let github_token = self
410 .auth
411 .get_token(&ProviderId::Copilot)
412 .await?
413 .access_token;
414 self.exchange_and_cache(&github_token).await
415 }
416
417 async fn copilot_creds(&self) -> Result<(String, String)> {
419 self.copilot_token().await
420 }
421
422 fn build_provider(&self, token: &str, base_url: &str) -> Result<OpenAICompatProvider> {
429 let mut default_headers = BTreeMap::new();
430 default_headers.insert("user-agent".to_owned(), self.user_agent.clone());
431 default_headers.insert("editor-version".to_owned(), self.editor_version.clone());
432 default_headers.insert(
433 "editor-plugin-version".to_owned(),
434 self.plugin_version.clone(),
435 );
436 default_headers.insert("openai-intent".to_owned(), OPENAI_INTENT.to_owned());
437 default_headers.insert(
438 "copilot-integration-id".to_owned(),
439 INTEGRATION_ID.to_owned(),
440 );
441 default_headers.insert(
442 "x-github-api-version".to_owned(),
443 GITHUB_API_VERSION.to_owned(),
444 );
445 default_headers.insert("content-type".to_owned(), "application/json".to_owned());
446
447 OpenAICompatProvider::new(OpenAICompatConfig {
448 name: "copilot".to_owned(),
449 http: HttpTransportConfig {
450 base_url: base_url.to_owned(),
451 timeout_seconds: 600,
452 default_headers,
453 },
454 auth: OpenAIAuthConfig {
455 api_key: SecretString::from(token.to_owned()),
456 organization: None,
457 project: None,
458 },
459 quirks: Quirks::default(),
460 })
461 .map_err(|e| ByokError::Config(e.to_string()))
462 }
463
464 pub async fn is_pro(&self) -> bool {
474 let accounts = self
475 .auth
476 .list_accounts(&ProviderId::Copilot)
477 .await
478 .unwrap_or_default();
479
480 if accounts.len() > 1 {
481 let cache = self.cache.lock().unwrap();
483 let now = Instant::now();
484 let mut found_any = false;
485 for cached in cache.values() {
486 if cached.expires_at > now {
487 found_any = true;
488 if cached.is_pro {
489 return true;
490 }
491 }
492 }
493 if found_any {
495 return false;
496 }
497 return true;
499 }
500
501 if let Ok(github_token) = self
503 .auth
504 .get_token(&ProviderId::Copilot)
505 .await
506 .map(|t| t.access_token)
507 {
508 let cache = self.cache.lock().unwrap();
509 if let Some(cached) = cache.get(&github_token)
510 && cached.expires_at > Instant::now()
511 {
512 return cached.is_pro;
513 }
514 }
515 true }
517
518 fn initiator(request: &ChatRequest) -> &'static str {
521 let is_agent = request.messages.iter().any(|m| {
522 matches!(
523 m.get("role").and_then(Value::as_str),
524 Some("assistant" | "tool")
525 )
526 });
527 if is_agent { "agent" } else { "user" }
528 }
529}
530
531#[async_trait]
532impl ProviderExecutor for CopilotExecutor {
533 async fn chat_completion(&self, request: ChatRequest) -> Result<ProviderResponse> {
534 let stream = request.stream;
535 let initiator = Self::initiator(&request);
537
538 let aigw_request: aigw_core::model::ChatRequest =
540 serde_json::from_value(request.into_body())
541 .map_err(|e| ByokError::Translation(e.to_string()))?;
542
543 let accounts = self
544 .auth
545 .list_accounts(&ProviderId::Copilot)
546 .await
547 .unwrap_or_default();
548 let max_attempts = if accounts.len() > 1 {
549 accounts.len().min(3)
550 } else {
551 1
552 };
553
554 let mut last_err = None;
555 for attempt in 0..max_attempts {
556 let creds = self.copilot_creds().await;
557 let (token, endpoint) = match creds {
558 Ok(c) => c,
559 Err(e) => {
560 if max_attempts > 1 {
561 tracing::warn!(attempt, error = %e, "copilot creds failed, trying next account");
562 Self::invalidate_current_account();
563 last_err = Some(e);
564 continue;
565 }
566 return Err(e);
567 }
568 };
569
570 let provider = match self.build_provider(&token, &endpoint) {
572 Ok(p) => p,
573 Err(e) => return Err(e),
574 };
575 let translator = OpenAICompatRequestTranslator::new(&provider)
576 .map_err(|e| ByokError::Config(e.to_string()))?;
577
578 let translated = if stream {
582 translator.translate_stream_request(&aigw_request)
583 } else {
584 translator.translate_request(&aigw_request)
585 }
586 .map_err(|e| ByokError::Translation(e.to_string()))?;
587
588 let mut builder = self.ph.client().post(&translated.url);
590 for (name, value) in &translated.headers {
591 if let Ok(v) = value.to_str() {
592 builder = builder.header(name.as_str(), v);
593 }
594 }
595 builder = builder.header("x-initiator", initiator);
598 builder = builder.header("accept-encoding", "identity");
600 let builder = builder.body(translated.body.to_vec());
602
603 if stream {
604 match self.ph.send_passthrough(builder, true).await {
607 Ok(resp) => return Ok(resp),
608 Err(e) => {
609 if !e.is_retryable() || attempt + 1 >= max_attempts {
610 return Err(e);
611 }
612 tracing::warn!(attempt, error = %e, "copilot stream request failed, trying next account");
613 Self::invalidate_current_account();
614 last_err = Some(e);
615 }
616 }
617 } else {
618 let resp = match self.ph.send(builder).await {
620 Ok(r) => r,
621 Err(e) => {
622 if !e.is_retryable() || attempt + 1 >= max_attempts {
623 return Err(e);
624 }
625 tracing::warn!(attempt, error = %e, "copilot request failed, trying next account");
626 Self::invalidate_current_account();
627 last_err = Some(e);
628 continue;
629 }
630 };
631 let resp_bytes = resp.bytes().await.map_err(ByokError::from)?;
632 let aigw_response = OpenAIResponseTranslator
633 .translate_response(http::StatusCode::OK, &resp_bytes)
634 .map_err(|e: aigw_core::error::TranslateError| {
635 ByokError::Translation(e.to_string())
636 })?;
637 let value = serde_json::to_value(aigw_response)
638 .map_err(|e| ByokError::Translation(e.to_string()))?;
639 return Ok(ProviderResponse::Complete(value));
640 }
641 }
642
643 tracing::error!(
644 attempts = max_attempts,
645 "all copilot accounts exhausted for chat request"
646 );
647 Err(last_err.unwrap_or_else(|| ByokError::Auth("no copilot accounts available".into())))
648 }
649
650 fn supported_models(&self) -> Vec<String> {
651 registry::models_for_provider(&ProviderId::Copilot)
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 fn make_executor() -> CopilotExecutor {
660 let (client, auth) = crate::http_util::test_auth();
661 CopilotExecutor::builder().http(client).auth(auth).build()
662 }
663
664 #[test]
665 fn test_supported_models_non_empty() {
666 let ex = make_executor();
667 assert!(!ex.supported_models().is_empty());
668 }
669
670 #[test]
671 fn test_initiator_user() {
672 let req: ChatRequest = serde_json::from_value(serde_json::json!({
673 "model": "gpt-4o",
674 "messages": [{"role": "user", "content": "hi"}]
675 }))
676 .unwrap();
677 assert_eq!(CopilotExecutor::initiator(&req), "user");
678 }
679
680 #[test]
681 fn test_initiator_agent() {
682 let req: ChatRequest = serde_json::from_value(serde_json::json!({
683 "model": "gpt-4o",
684 "messages": [
685 {"role": "user", "content": "hi"},
686 {"role": "assistant", "content": "hello"}
687 ]
688 }))
689 .unwrap();
690 assert_eq!(CopilotExecutor::initiator(&req), "agent");
691 }
692}