bamboo_agent/agent/llm/providers/copilot/auth/handler.rs
1//! GitHub Copilot authentication handler.
2//!
3//! This module provides authentication handling for GitHub Copilot,
4//! including device code flow, token caching, and automatic refresh.
5//!
6//! # Authentication Flow
7//!
8//! The authentication process follows GitHub's OAuth device flow:
9//!
10//! 1. **Device Code Request**: The client requests a device code from GitHub
11//! 2. **User Authorization**: The user visits the verification URL and enters the code
12//! 3. **Token Polling**: The client polls GitHub for an access token
13//! 4. **Copilot Token Exchange**: The access token is exchanged for a Copilot API token
14//! 5. **Token Caching**: Tokens are cached locally for future use
15//!
16//! # Example Usage
17//!
18//! ```rust,ignore
19//! use std::sync::Arc;
20//! use reqwest_middleware::ClientWithMiddleware;
21//! use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
22//!
23//! async fn authenticate() -> anyhow::Result<String> {
24//! // Create HTTP client with middleware
25//! let client = Arc::new(ClientWithMiddleware::new(/* ... */));
26//!
27//! // Create auth handler with data directory
28//! let handler = CopilotAuthHandler::new(
29//! client,
30//! std::path::PathBuf::from("~/.bamboo"),
31//! false, // Set to true for CLI mode
32//! );
33//!
34//! // Get token (will trigger device flow if needed)
35//! let token = handler.get_token().await?;
36//! Ok(token)
37//! }
38//! ```
39//!
40//! # Token Caching Strategy
41//!
42//! The handler implements a multi-level token caching strategy:
43//!
44//! 1. **Copilot Token Cache**: Checks `.copilot_token.json` for valid tokens
45//! 2. **Environment Variable**: Falls back to `COPILOT_API_KEY` if set
46//! 3. **Access Token Cache**: Uses cached GitHub access token to request new Copilot token
47//! 4. **Interactive Flow**: Only triggers device flow if all silent methods fail
48//!
49//! # Token Validation
50//!
51//! Tokens are validated with a 60-second buffer to ensure they don't expire
52//! during use. This proactive refresh ensures seamless operation.
53
54use crate::agent::llm::ProxyAuthRequiredError;
55use anyhow::anyhow;
56use lazy_static::lazy_static;
57use log::error;
58use reqwest::StatusCode;
59use reqwest_middleware::ClientWithMiddleware;
60use serde::{Deserialize, Serialize};
61use std::{
62 fs::{read_to_string, File},
63 io::Write,
64 path::PathBuf,
65 sync::Arc,
66 time::{Duration, SystemTime, UNIX_EPOCH},
67};
68use tokio::sync::Mutex;
69use tokio::time::sleep;
70
71use super::device_code::DeviceCodeResponse;
72
73/// Copilot API configuration returned from GitHub.
74///
75/// Contains the authentication token, feature flags, and endpoint URLs
76/// for the Copilot service.
77///
78/// This configuration is obtained by exchanging a GitHub access token
79/// for a Copilot-specific token via the `/copilot_internal/v2/token` endpoint.
80///
81/// # Fields
82///
83/// - `token`: The Copilot API token used for authentication
84/// - `expires_at`: Unix timestamp when the token expires
85/// - `refresh_in`: Suggested refresh interval in seconds
86/// - `endpoints`: API endpoints for Copilot services
87/// - Various feature flags controlling available functionality
88#[derive(Debug, Serialize, Deserialize, Clone)]
89pub struct CopilotConfig {
90 /// The Copilot API authentication token
91 pub token: String,
92 /// Whether code annotations are enabled for this account
93 pub annotations_enabled: bool,
94 /// Whether Copilot Chat is enabled for this account
95 pub chat_enabled: bool,
96 /// Whether Copilot Chat is enabled for JetBrains IDEs
97 pub chat_jetbrains_enabled: bool,
98 /// Whether code quote feature is enabled
99 pub code_quote_enabled: bool,
100 /// Whether code review feature is enabled
101 pub code_review_enabled: bool,
102 /// Whether code search is enabled
103 pub codesearch: bool,
104 /// Whether .copilotignore file support is enabled
105 pub copilotignore_enabled: bool,
106 /// API endpoints for various Copilot services
107 pub endpoints: Endpoints,
108 /// Unix timestamp when this token expires
109 pub expires_at: u64,
110 /// Whether this is an individual (vs enterprise) account
111 pub individual: bool,
112 /// User quota limit identifier, if applicable
113 pub limited_user_quotas: Option<String>,
114 /// Date when user quota resets, if applicable
115 pub limited_user_reset_date: Option<String>,
116 /// Whether 8k context prompts are enabled
117 pub prompt_8k: bool,
118 /// Public suggestions mode ("enabled", "disabled", etc.)
119 pub public_suggestions: String,
120 /// Recommended interval in seconds before refreshing token
121 pub refresh_in: u64,
122 /// Service SKU identifier
123 pub sku: String,
124 /// Internal load testing flag
125 pub snippy_load_test_enabled: bool,
126 /// Telemetry mode ("enabled", "disabled", etc.)
127 pub telemetry: String,
128 /// Unique tracking identifier for analytics
129 pub tracking_id: String,
130 /// Whether VS Code Electron fetcher v2 is enabled
131 pub vsc_electron_fetcher_v2: bool,
132 /// Whether Xcode support is enabled
133 pub xcode: bool,
134 /// Whether Xcode Chat is enabled
135 pub xcode_chat: bool,
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use tempfile::tempdir;
142
143 /// Creates a test HTTP client without proxy for unit tests.
144 fn test_http_client() -> Arc<ClientWithMiddleware> {
145 use reqwest::Client as ReqwestClient;
146 use reqwest_middleware::ClientBuilder;
147 let client = ReqwestClient::builder().no_proxy().build().expect("client");
148 Arc::new(ClientBuilder::new(client).build())
149 }
150
151 /// Creates a sample CopilotConfig for testing with specified expiration time.
152 fn sample_config(expires_at: u64) -> CopilotConfig {
153 CopilotConfig {
154 token: "cached-token".to_string(),
155 annotations_enabled: false,
156 chat_enabled: true,
157 chat_jetbrains_enabled: false,
158 code_quote_enabled: false,
159 code_review_enabled: false,
160 codesearch: false,
161 copilotignore_enabled: false,
162 endpoints: Endpoints {
163 api: Some("https://api.example.com".to_string()),
164 origin_tracker: None,
165 proxy: None,
166 telemetry: None,
167 },
168 expires_at,
169 individual: true,
170 limited_user_quotas: None,
171 limited_user_reset_date: None,
172 prompt_8k: false,
173 public_suggestions: "disabled".to_string(),
174 refresh_in: 300,
175 sku: "test".to_string(),
176 snippy_load_test_enabled: false,
177 telemetry: "disabled".to_string(),
178 tracking_id: "test".to_string(),
179 vsc_electron_fetcher_v2: false,
180 xcode: false,
181 xcode_chat: false,
182 }
183 }
184
185 /// Tests that read_access_token properly trims whitespace and newlines.
186 #[test]
187 fn read_access_token_trims() {
188 let dir = tempdir().expect("tempdir");
189 let token_path = dir.path().join(".token");
190 std::fs::write(&token_path, " token-value \n").expect("write token");
191
192 let token = CopilotAuthHandler::read_access_token(&token_path);
193 assert_eq!(token.as_deref(), Some("token-value"));
194 }
195
196 /// Tests that CopilotConfig can be written to and read from cache.
197 #[test]
198 fn cached_copilot_config_round_trip() {
199 let dir = tempdir().expect("tempdir");
200 let handler = CopilotAuthHandler::new(test_http_client(), dir.path().to_path_buf(), false);
201 let token_path = dir.path().join(".copilot_token.json");
202 let config = sample_config(1234567890);
203
204 handler
205 .write_cached_copilot_config(&token_path, &config)
206 .expect("write cache");
207 let loaded = handler
208 .read_cached_copilot_config(&token_path)
209 .expect("read cache");
210
211 assert_eq!(loaded.token, config.token);
212 assert_eq!(loaded.expires_at, config.expires_at);
213 }
214
215 /// Tests that token validation uses a 60-second buffer.
216 ///
217 /// Tokens expiring within 60 seconds should be considered invalid
218 /// to ensure proactive refresh.
219 #[test]
220 fn copilot_token_expiry_buffer() {
221 let dir = tempdir().expect("tempdir");
222 let handler = CopilotAuthHandler::new(test_http_client(), dir.path().to_path_buf(), false);
223 let now = SystemTime::now()
224 .duration_since(UNIX_EPOCH)
225 .map(|duration| duration.as_secs())
226 .unwrap_or(0);
227
228 let valid = sample_config(now + 120);
229 let stale = sample_config(now + 30);
230
231 assert!(handler.is_copilot_token_valid(&valid));
232 assert!(!handler.is_copilot_token_valid(&stale));
233 }
234
235 #[test]
236 fn access_token_should_only_be_discarded_on_auth_errors() {
237 let err_401 =
238 anyhow::Error::msg("Copilot token request failed: HTTP 401 - bad credentials");
239 assert!(CopilotAuthHandler::should_discard_access_token(&err_401));
240
241 let err_403 = anyhow::Error::msg("Copilot token request failed: HTTP 403 - forbidden");
242 assert!(CopilotAuthHandler::should_discard_access_token(&err_403));
243
244 let err_407 = anyhow::Error::new(ProxyAuthRequiredError);
245 assert!(!CopilotAuthHandler::should_discard_access_token(&err_407));
246
247 let err_503 =
248 anyhow::Error::msg("Copilot token request failed: HTTP 503 - service unavailable");
249 assert!(!CopilotAuthHandler::should_discard_access_token(&err_503));
250 }
251}
252
253/// API endpoint configuration for Copilot services.
254///
255/// Contains URLs for various Copilot API endpoints returned during
256/// the token exchange process.
257///
258/// # Fields
259///
260/// - `api`: Primary API endpoint for Copilot requests
261/// - `origin_tracker`: Endpoint for tracking request origins
262/// - `proxy`: Proxy endpoint for proxied requests
263/// - `telemetry`: Endpoint for sending telemetry data
264#[derive(Debug, Serialize, Deserialize, Clone)]
265pub struct Endpoints {
266 /// Primary Copilot API endpoint (e.g., "https://api.githubcopilot.com")
267 pub api: Option<String>,
268 /// Origin tracking service endpoint
269 pub origin_tracker: Option<String>,
270 /// Proxy service endpoint for proxied API calls
271 pub proxy: Option<String>,
272 /// Telemetry collection endpoint
273 pub telemetry: Option<String>,
274}
275
276/// Access token response from GitHub OAuth.
277///
278/// Contains the access token or error information from the OAuth device flow.
279/// This is the response from GitHub's `/login/oauth/access_token` endpoint
280/// when polling for authorization completion.
281///
282/// # Fields
283///
284/// - `access_token`: The OAuth access token on successful authorization
285/// - `token_type`: Token type (typically "bearer")
286/// - `scope`: OAuth scopes granted to the token
287/// - `error`: Error code if authorization failed or is pending
288/// - `error_description`: Human-readable error description
289#[derive(Debug, Deserialize)]
290pub(crate) struct AccessTokenResponse {
291 /// The OAuth access token (present on successful authorization)
292 pub access_token: Option<String>,
293 /// Token type (typically "bearer")
294 #[allow(dead_code)] // Needed for JSON deserialization from GitHub API
295 pub token_type: Option<String>,
296 /// OAuth scopes granted to this token
297 #[allow(dead_code)] // Needed for JSON deserialization from GitHub API
298 pub scope: Option<String>,
299 /// Error code (e.g., "authorization_pending", "slow_down", "expired_token")
300 pub error: Option<String>,
301 /// Human-readable error description
302 #[serde(rename = "error_description")]
303 pub error_description: Option<String>,
304}
305
306impl AccessTokenResponse {
307 /// Creates a new access token response from a token string.
308 ///
309 /// This is a convenience constructor for creating an `AccessTokenResponse`
310 /// from a previously cached token string.
311 ///
312 /// # Arguments
313 ///
314 /// * `token` - The access token string
315 ///
316 /// # Example
317 ///
318 /// ```ignore
319 /// use bamboo_agent::agent::llm::providers::copilot::auth::handler::AccessTokenResponse;
320 ///
321 /// let response = AccessTokenResponse::from_token("gho_xxxx".to_string());
322 /// assert_eq!(response.access_token, Some("gho_xxxx".to_string()));
323 /// ```
324 pub(crate) fn from_token(token: String) -> Self {
325 Self {
326 access_token: Some(token),
327 token_type: None,
328 scope: None,
329 error: None,
330 error_description: None,
331 }
332 }
333}
334
335// Global lock for chat token operations.
336//
337// This mutex ensures that only one token request can be in flight at a time
338// across the entire application. This prevents race conditions where multiple
339// concurrent requests could trigger separate authentication flows.
340//
341// The lock is acquired in `CopilotAuthHandler::get_chat_token` before
342// attempting silent authentication or starting a new device flow.
343lazy_static! {
344 static ref CHAT_TOKEN_LOCK: Mutex<()> = Mutex::new(());
345}
346
347/// Handler for GitHub Copilot authentication.
348///
349/// Manages the complete authentication lifecycle including:
350/// - Device code flow for initial authentication
351/// - Token caching and validation
352/// - Automatic token refresh
353/// - Silent authentication attempts
354///
355/// # Architecture
356///
357/// The handler implements a hierarchical token resolution strategy:
358///
359/// 1. **Cached Copilot Token**: Check local cache for valid token
360/// 2. **Environment Variable**: Check `COPILOT_API_KEY`
361/// 3. **Cached Access Token**: Use cached GitHub token to fetch new Copilot token
362/// 4. **Interactive Flow**: Prompt user via device code flow
363///
364/// # Thread Safety
365///
366/// The handler is thread-safe and can be cloned cheaply. Authentication
367/// operations are protected by a global lock to prevent concurrent flows.
368///
369/// # Example
370///
371/// ```rust,ignore
372/// use std::sync::Arc;
373/// use reqwest_middleware::ClientWithMiddleware;
374/// use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
375///
376/// async fn example() -> anyhow::Result<()> {
377/// let client = Arc::new(ClientWithMiddleware::new(/* ... */));
378/// let handler = CopilotAuthHandler::new(
379/// client,
380/// std::path::PathBuf::from("~/.bamboo"),
381/// false,
382/// );
383///
384/// // Will use cached token or trigger device flow
385/// let token = handler.get_token().await?;
386/// println!("Got token: {}", token);
387/// Ok(())
388/// }
389/// ```
390#[derive(Debug, Clone)]
391pub struct CopilotAuthHandler {
392 /// HTTP client with middleware for retry logic
393 client: Arc<ClientWithMiddleware>,
394 /// Directory for storing cached tokens
395 app_data_dir: PathBuf,
396 /// Whether to print authentication instructions to console
397 headless_auth: bool,
398 /// GitHub API base URL (customizable for testing)
399 github_api_base_url: String,
400 /// GitHub login base URL (customizable for testing)
401 github_login_base_url: String,
402}
403
404impl CopilotAuthHandler {
405 /// Creates a new authentication handler.
406 ///
407 /// # Arguments
408 ///
409 /// * `client` - HTTP client with middleware for retry logic and error handling
410 /// * `app_data_dir` - Directory for storing cached tokens (`.token` and `.copilot_token.json`)
411 /// * `headless_auth` - Whether to print authentication instructions to console.
412 /// Set to `true` for CLI applications, `false` for GUI applications.
413 ///
414 /// # Example
415 ///
416 /// ```rust,ignore
417 /// use std::sync::Arc;
418 /// use reqwest_middleware::ClientWithMiddleware;
419 /// use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
420 ///
421 /// let client = Arc::new(ClientWithMiddleware::new(/* ... */));
422 /// let handler = CopilotAuthHandler::new(
423 /// client,
424 /// std::path::PathBuf::from("~/.bamboo"),
425 /// true, // CLI mode
426 /// );
427 /// ```
428 pub fn new(
429 client: Arc<ClientWithMiddleware>,
430 app_data_dir: PathBuf,
431 headless_auth: bool,
432 ) -> Self {
433 CopilotAuthHandler {
434 client,
435 app_data_dir,
436 headless_auth,
437 github_api_base_url: "https://api.github.com".to_string(),
438 github_login_base_url: "https://github.com".to_string(),
439 }
440 }
441
442 /// Returns the application data directory path.
443 ///
444 /// This directory contains cached tokens:
445 /// - `.token`: GitHub OAuth access token
446 /// - `.copilot_token.json`: Copilot API configuration
447 pub fn app_data_dir(&self) -> &PathBuf {
448 &self.app_data_dir
449 }
450
451 /// Sets a custom GitHub API base URL for testing.
452 ///
453 /// This allows tests to mock GitHub's API without hitting production.
454 #[cfg(test)]
455 fn with_github_api_base_url(mut self, url: impl Into<String>) -> Self {
456 self.github_api_base_url = url.into();
457 self
458 }
459
460 /// Sets a custom GitHub login base URL for testing.
461 ///
462 /// This allows tests to mock GitHub's OAuth endpoints without hitting production.
463 #[cfg(test)]
464 fn with_github_login_base_url(mut self, url: impl Into<String>) -> Self {
465 self.github_login_base_url = url.into();
466 self
467 }
468
469 /// Performs authentication and returns an access token.
470 ///
471 /// This is the primary entry point for authentication. It will attempt
472 /// silent authentication first, then fall back to interactive device flow
473 /// if necessary.
474 ///
475 /// # Returns
476 ///
477 /// A Copilot API token on success.
478 ///
479 /// # Errors
480 ///
481 /// Returns an error if:
482 /// - All authentication methods fail
483 /// - User denies authorization during device flow
484 /// - Device code expires before authorization
485 /// - Network errors occur
486 pub async fn authenticate(&self) -> anyhow::Result<String> {
487 self.get_chat_token().await
488 }
489
490 /// Ensures the handler is authenticated, without returning the token.
491 ///
492 /// This is useful for pre-authenticating or verifying credentials
493 /// without needing the actual token value.
494 ///
495 /// # Example
496 ///
497 /// ```rust,ignore
498 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
499 /// # async fn example(handler: CopilotAuthHandler) -> anyhow::Result<()> {
500 /// // Pre-authenticate before starting the application
501 /// handler.ensure_authenticated().await?;
502 /// println!("Authentication successful!");
503 /// # Ok(())
504 /// # }
505 /// ```
506 pub async fn ensure_authenticated(&self) -> anyhow::Result<()> {
507 self.get_chat_token().await.map(|_| ())
508 }
509
510 /// Gets the current access token, authenticating if necessary.
511 ///
512 /// Alias for [`authenticate`](Self::authenticate).
513 pub async fn get_token(&self) -> anyhow::Result<String> {
514 self.get_chat_token().await
515 }
516
517 /// Gets a chat token, using cached credentials or triggering device flow.
518 ///
519 /// This method attempts silent authentication first, then falls back
520 /// to interactive device code flow if necessary.
521 ///
522 /// # Silent Authentication Priority
523 ///
524 /// 1. Check cached Copilot token (`.copilot_token.json`)
525 /// 2. Check `COPILOT_API_KEY` environment variable
526 /// 3. Check cached GitHub access token (`.token`) and exchange for new Copilot token
527 ///
528 /// # Thread Safety
529 ///
530 /// This method acquires a global lock to prevent concurrent authentication
531 /// flows. Only one authentication attempt can be in progress at a time.
532 ///
533 /// # Returns
534 ///
535 /// A valid Copilot API token.
536 ///
537 /// # Errors
538 ///
539 /// Returns an error if all authentication methods fail.
540 pub async fn get_chat_token(&self) -> anyhow::Result<String> {
541 // Acquire global lock to ensure sequential execution
542 let _guard = CHAT_TOKEN_LOCK.lock().await;
543
544 // Try silent authentication first
545 if let Some(token) = self.try_get_chat_token_silent().await? {
546 return Ok(token);
547 }
548
549 // Need interactive authentication
550 let device_code = self.start_authentication().await?;
551 let copilot_config = self.complete_authentication(&device_code).await?;
552 Ok(copilot_config.token)
553 }
554
555 /// Reads an access token from a file, trimming whitespace.
556 ///
557 /// This utility function reads a token from a file and trims any
558 /// leading/trailing whitespace or newlines.
559 ///
560 /// # Arguments
561 ///
562 /// * `token_path` - Path to the token file
563 ///
564 /// # Returns
565 ///
566 /// - `Some(token)` if the file exists and contains non-whitespace content
567 /// - `None` if the file doesn't exist or is empty/whitespace only
568 fn read_access_token(token_path: &PathBuf) -> Option<String> {
569 if !token_path.exists() {
570 return None;
571 }
572 let access_token_str = read_to_string(token_path).ok()?;
573 let trimmed = access_token_str.trim();
574 if trimmed.is_empty() {
575 None
576 } else {
577 Some(trimmed.to_string())
578 }
579 }
580
581 /// Reads a cached Copilot configuration from a file.
582 ///
583 /// Attempts to deserialize a JSON-formatted Copilot configuration
584 /// from the specified file.
585 ///
586 /// # Arguments
587 ///
588 /// * `token_path` - Path to the JSON cache file
589 ///
590 /// # Returns
591 ///
592 /// - `Some(config)` if the file exists and contains valid JSON
593 /// - `None` if the file doesn't exist or has invalid JSON
594 fn read_cached_copilot_config(&self, token_path: &PathBuf) -> Option<CopilotConfig> {
595 let cached_str = read_to_string(token_path).ok()?;
596 serde_json::from_str::<CopilotConfig>(&cached_str).ok()
597 }
598
599 /// Writes a Copilot configuration to a cache file.
600 ///
601 /// Serializes the configuration as JSON and writes it to the specified file.
602 ///
603 /// # Arguments
604 ///
605 /// * `token_path` - Path where the JSON should be written
606 /// * `copilot_config` - Configuration to cache
607 ///
608 /// # Errors
609 ///
610 /// Returns an error if:
611 /// - JSON serialization fails
612 /// - File creation fails
613 /// - Writing to file fails
614 fn write_cached_copilot_config(
615 &self,
616 token_path: &PathBuf,
617 copilot_config: &CopilotConfig,
618 ) -> anyhow::Result<()> {
619 let serialized = serde_json::to_string(copilot_config)?;
620 let mut file = File::create(token_path)?;
621 file.write_all(serialized.as_bytes())?;
622 Ok(())
623 }
624
625 /// Checks if a Copilot token is valid with a 60-second buffer.
626 ///
627 /// This method checks whether the token has expired, with a 60-second
628 /// buffer to ensure tokens are refreshed before they actually expire.
629 ///
630 /// # Arguments
631 ///
632 /// * `copilot_config` - Configuration containing the token expiration time
633 ///
634 /// # Returns
635 ///
636 /// - `true` if the token is valid for at least 60 more seconds
637 /// - `false` if the token has expired or will expire within 60 seconds
638 ///
639 /// # Example
640 ///
641 /// ```rust,ignore
642 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::{CopilotAuthHandler, CopilotConfig};
643 /// # fn example(handler: CopilotAuthHandler, config: CopilotConfig) {
644 /// if handler.is_copilot_token_valid(&config) {
645 /// println!("Token is valid");
646 /// } else {
647 /// println!("Token needs refresh");
648 /// }
649 /// # }
650 /// ```
651 fn is_copilot_token_valid(&self, copilot_config: &CopilotConfig) -> bool {
652 let now = SystemTime::now()
653 .duration_since(UNIX_EPOCH)
654 .map(|duration| duration.as_secs())
655 .unwrap_or(0);
656 copilot_config.expires_at.saturating_sub(60) > now
657 }
658
659 /// Requests a device code from GitHub for OAuth flow.
660 ///
661 /// This is the first step in the OAuth device flow. It requests a
662 /// device code and user code from GitHub that the user must enter
663 /// at the verification URL.
664 ///
665 /// # Returns
666 ///
667 /// A [`DeviceCodeResponse`] containing:
668 /// - `device_code`: Unique identifier for this authentication session
669 /// - `user_code`: Code the user must enter at the verification URL
670 /// - `verification_uri`: URL where user should enter the code
671 /// - `expires_in`: Seconds until the device code expires
672 /// - `interval`: Recommended polling interval in seconds
673 ///
674 /// # Errors
675 ///
676 /// Returns an error if:
677 /// - GitHub API is unreachable
678 /// - Proxy authentication is required
679 /// - API returns an error response
680 ///
681 /// # Example
682 ///
683 /// ```rust,ignore
684 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
685 /// # async fn example(handler: CopilotAuthHandler) -> anyhow::Result<()> {
686 /// let device_code = handler.get_device_code().await?;
687 /// println!("Visit: {}", device_code.verification_uri);
688 /// println!("Enter code: {}", device_code.user_code);
689 /// # Ok(())
690 /// # }
691 /// ```
692 pub(super) async fn get_device_code(&self) -> anyhow::Result<DeviceCodeResponse> {
693 let params = [
694 ("client_id", "Iv1.b507a08c87ecfe98"),
695 ("scope", "read:user"),
696 ];
697 let url = format!("{}/login/device/code", self.github_login_base_url);
698
699 let response = self
700 .client
701 .post(&url)
702 .header("Accept", "application/json")
703 .header("User-Agent", "BambooCopilot/1.0")
704 .form(¶ms)
705 .send()
706 .await?;
707
708 if response.status() == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
709 return Err(anyhow!(ProxyAuthRequiredError));
710 }
711
712 let status = response.status();
713 if !status.is_success() {
714 let text = response.text().await.unwrap_or_default();
715 return Err(anyhow!(
716 "Device code request failed: HTTP {} - {}",
717 status,
718 text
719 ));
720 }
721
722 Ok(response.json::<DeviceCodeResponse>().await?)
723 }
724
725 /// Starts the authentication process by getting a device code.
726 ///
727 /// This method initiates the OAuth device flow by requesting a device
728 /// code from GitHub. If `headless_auth` is `false`, it prints user-friendly
729 /// instructions to the console.
730 ///
731 /// # Display Behavior
732 ///
733 /// - **Headless mode (`headless_auth = true`)**: Prints full instructions with ASCII art
734 /// - **GUI mode (`headless_auth = false`)**: Returns device code for custom UI
735 ///
736 /// # Returns
737 ///
738 /// A [`DeviceCodeResponse`] with the device code and verification information.
739 ///
740 /// # Example
741 ///
742 /// ```rust,ignore
743 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
744 /// # async fn example(handler: CopilotAuthHandler) -> anyhow::Result<()> {
745 /// let device_code = handler.start_authentication().await?;
746 /// // In GUI mode, display these values to the user
747 /// println!("URL: {}", device_code.verification_uri);
748 /// println!("Code: {}", device_code.user_code);
749 /// # Ok(())
750 /// # }
751 /// ```
752 pub async fn start_authentication(&self) -> anyhow::Result<DeviceCodeResponse> {
753 let device_code = self.get_device_code().await?;
754
755 if self.headless_auth {
756 // CLI mode: print to console
757 println!("\n╔════════════════════════════════════════════════════════════╗");
758 println!("║ 🔐 GitHub Copilot Authorization Required ║");
759 println!("╚════════════════════════════════════════════════════════════╝");
760 println!();
761 println!(" 1. Open your browser and navigate to:");
762 println!(" {}", device_code.verification_uri);
763 println!();
764 println!(" 2. Enter the following code:");
765 println!();
766 println!(" ┌─────────────────────────┐");
767 println!(" │ {:^23} │", device_code.user_code);
768 println!(" └─────────────────────────┘");
769 println!();
770 println!(" 3. Click 'Authorize' and wait...");
771 println!();
772 println!(
773 " ⏳ Waiting for authorization (expires in {} seconds)...",
774 device_code.expires_in
775 );
776 println!();
777 }
778
779 Ok(device_code)
780 }
781
782 /// Completes authentication by polling for access token and exchanging for Copilot token.
783 ///
784 /// This method completes the OAuth flow by:
785 /// 1. Polling GitHub for the access token (waits for user authorization)
786 /// 2. Exchanging the access token for a Copilot API token
787 /// 3. Caching both tokens to disk for future use
788 ///
789 /// # Arguments
790 ///
791 /// * `device_code` - Device code response from [`start_authentication`](Self::start_authentication)
792 ///
793 /// # Returns
794 ///
795 /// A [`CopilotConfig`] containing the Copilot API token and configuration.
796 ///
797 /// # Side Effects
798 ///
799 /// Writes the following files to `app_data_dir`:
800 /// - `.token`: GitHub OAuth access token
801 /// - `.copilot_token.json`: Copilot API configuration
802 ///
803 /// # Errors
804 ///
805 /// Returns an error if:
806 /// - User denies authorization
807 /// - Device code expires before authorization
808 /// - Token exchange fails
809 /// - File writing fails
810 pub async fn complete_authentication(
811 &self,
812 device_code: &DeviceCodeResponse,
813 ) -> anyhow::Result<CopilotConfig> {
814 let access_token = self.get_access_token(device_code).await?;
815
816 // Extract access token string before passing to get_copilot_token
817 let access_token_str = access_token
818 .access_token
819 .clone()
820 .ok_or_else(|| anyhow!("Access token not found"))?;
821
822 let copilot_config = self.get_copilot_token(access_token).await?;
823
824 // Write tokens to disk
825 let token_path = self.app_data_dir.join(".token");
826 let copilot_token_path = self.app_data_dir.join(".copilot_token.json");
827
828 // Write access token
829 let mut file = File::create(&token_path)?;
830 file.write_all(access_token_str.as_bytes())?;
831
832 // Write copilot config
833 self.write_cached_copilot_config(&copilot_token_path, &copilot_config)?;
834
835 Ok(copilot_config)
836 }
837
838 /// Attempts silent authentication without user interaction.
839 ///
840 /// This method tries multiple authentication strategies in order of preference,
841 /// all of which can succeed without requiring user interaction:
842 ///
843 /// 1. **Cached Copilot Token**: Load from `.copilot_token.json` if still valid
844 /// 2. **Environment Variable**: Check `COPILOT_API_KEY`
845 /// 3. **Cached Access Token**: Use cached GitHub token to fetch new Copilot token
846 ///
847 /// # Returns
848 ///
849 /// - `Ok(Some(token))` if silent authentication succeeded
850 /// - `Ok(None)` if silent authentication is not possible (triggers interactive flow)
851 /// - `Err(...)` if an unexpected error occurred
852 ///
853 /// # Side Effects
854 ///
855 /// If using a cached access token, this method will:
856 /// - Fetch a new Copilot token from GitHub
857 /// - Cache the new Copilot token to `.copilot_token.json`
858 /// - Remove the cached access token if it's invalid
859 ///
860 /// # Example
861 ///
862 /// ```rust,ignore
863 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::CopilotAuthHandler;
864 /// # async fn example(handler: CopilotAuthHandler) -> anyhow::Result<()> {
865 /// match handler.try_get_chat_token_silent().await? {
866 /// Some(token) => println!("Got token silently: {}", token),
867 /// None => println!("Need interactive authentication"),
868 /// }
869 /// # Ok(())
870 /// # }
871 /// ```
872 pub async fn try_get_chat_token_silent(&self) -> anyhow::Result<Option<String>> {
873 let copilot_token_path = self.app_data_dir.join(".copilot_token.json");
874
875 // Check cached copilot token
876 if let Some(cached_config) = self.read_cached_copilot_config(&copilot_token_path) {
877 if self.is_copilot_token_valid(&cached_config) {
878 return Ok(Some(cached_config.token));
879 }
880 }
881
882 // Check env var
883 if let Ok(token) = std::env::var("COPILOT_API_KEY") {
884 let trimmed = token.trim();
885 if !trimmed.is_empty() {
886 return Ok(Some(trimmed.to_string()));
887 }
888 }
889
890 // Check access token file and try to exchange
891 let token_path = self.app_data_dir.join(".token");
892 if let Some(access_token_str) = Self::read_access_token(&token_path) {
893 let access_token = AccessTokenResponse::from_token(access_token_str);
894 match self.get_copilot_token(access_token).await {
895 Ok(copilot_config) => {
896 self.write_cached_copilot_config(&copilot_token_path, &copilot_config)?;
897 return Ok(Some(copilot_config.token));
898 }
899 Err(e) => {
900 // Only discard the cached access token when we are confident it is invalid.
901 // Copilot tokens are short-lived; the GitHub OAuth access token should be
902 // long-lived, so removing it on transient failures causes unnecessary re-auth.
903 if Self::should_discard_access_token(&e) {
904 let _ = std::fs::remove_file(&token_path);
905 }
906 }
907 }
908 }
909
910 Ok(None)
911 }
912
913 /// Force refresh a Copilot token using the cached GitHub OAuth access token.
914 ///
915 /// This bypasses the `.copilot_token.json` cache and is useful when the cached
916 /// Copilot token is rejected early (e.g. revoked) even if it hasn't reached
917 /// `expires_at` yet.
918 ///
919 /// Returns:
920 /// - `Ok(Some(token))` if the refresh succeeded
921 /// - `Ok(None)` if no cached access token exists
922 pub async fn force_refresh_chat_token(&self) -> anyhow::Result<Option<String>> {
923 let token_path = self.app_data_dir.join(".token");
924 let Some(access_token_str) = Self::read_access_token(&token_path) else {
925 return Ok(None);
926 };
927
928 let access_token = AccessTokenResponse::from_token(access_token_str);
929 match self.get_copilot_token(access_token).await {
930 Ok(copilot_config) => {
931 let copilot_token_path = self.app_data_dir.join(".copilot_token.json");
932 self.write_cached_copilot_config(&copilot_token_path, &copilot_config)?;
933 Ok(Some(copilot_config.token))
934 }
935 Err(e) => {
936 if Self::should_discard_access_token(&e) {
937 let _ = std::fs::remove_file(&token_path);
938 }
939 Err(e)
940 }
941 }
942 }
943
944 fn should_discard_access_token_message(msg: &str) -> bool {
945 // get_copilot_token formats errors like:
946 // "Copilot token request failed: HTTP {status} - {text}"
947 msg.contains("HTTP 401") || msg.contains("HTTP 403")
948 }
949
950 fn should_discard_access_token(err: &anyhow::Error) -> bool {
951 if err.downcast_ref::<ProxyAuthRequiredError>().is_some() {
952 return false;
953 }
954 Self::should_discard_access_token_message(&err.to_string())
955 }
956
957 /// Polls GitHub for an access token after user completes device flow.
958 ///
959 /// This method continuously polls GitHub's OAuth endpoint until either:
960 /// - The user authorizes the application (success)
961 /// - The device code expires (error)
962 /// - The user denies authorization (error)
963 ///
964 /// # Polling Behavior
965 ///
966 /// The method polls at the interval specified in the device code response
967 /// (minimum 5 seconds). It handles various OAuth states:
968 ///
969 /// - `authorization_pending`: User hasn't authorized yet, keep polling
970 /// - `slow_down`: Server requested slower polling, increase interval
971 /// - `expired_token`: Device code expired, return error
972 /// - `access_denied`: User denied authorization, return error
973 ///
974 /// # Arguments
975 ///
976 /// * `device_code` - Device code response from [`get_device_code`](Self::get_device_code)
977 ///
978 /// # Returns
979 ///
980 /// An [`AccessTokenResponse`] containing the GitHub OAuth access token.
981 ///
982 /// # Errors
983 ///
984 /// Returns an error if:
985 /// - Device code expires before user authorizes
986 /// - User denies authorization
987 /// - Proxy authentication is required
988 /// - Network errors occur
989 ///
990 /// # Display Output
991 ///
992 /// In headless mode, prints progress dots. In GUI mode, shows polling status.
993 pub(super) async fn get_access_token(
994 &self,
995 device_code: &DeviceCodeResponse,
996 ) -> anyhow::Result<AccessTokenResponse> {
997 let params = [
998 ("client_id", "Iv1.b507a08c87ecfe98"),
999 ("device_code", &device_code.device_code),
1000 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
1001 ];
1002
1003 let poll_interval = Duration::from_secs(device_code.interval.max(5));
1004 let max_duration = Duration::from_secs(device_code.expires_in);
1005 let start = std::time::Instant::now();
1006
1007 if !self.headless_auth {
1008 println!(" 🔄 Polling for authorization...");
1009 }
1010
1011 loop {
1012 if start.elapsed() > max_duration {
1013 return Err(anyhow!("❌ Device code expired. Please try again."));
1014 }
1015
1016 let url = format!("{}/login/oauth/access_token", self.github_login_base_url);
1017 let response = self
1018 .client
1019 .post(&url)
1020 .header("Accept", "application/json")
1021 .header("User-Agent", "BambooCopilot/1.0")
1022 .form(¶ms)
1023 .send()
1024 .await?;
1025
1026 if response.status() == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
1027 return Err(anyhow!(ProxyAuthRequiredError));
1028 }
1029
1030 let response = response.json::<AccessTokenResponse>().await?;
1031
1032 if let Some(token) = response.access_token {
1033 if !self.headless_auth {
1034 println!(" ✅ Access token received!");
1035 }
1036 return Ok(AccessTokenResponse::from_token(token));
1037 }
1038
1039 if let Some(error) = &response.error {
1040 match error.as_str() {
1041 "authorization_pending" => {
1042 if self.headless_auth {
1043 print!(".");
1044 std::io::Write::flush(&mut std::io::stdout()).ok();
1045 }
1046 }
1047 "slow_down" => {
1048 if !self.headless_auth {
1049 println!("\n ⚠️ Server requested slower polling...");
1050 }
1051 sleep(Duration::from_secs(device_code.interval + 5)).await;
1052 continue;
1053 }
1054 "expired_token" => {
1055 return Err(anyhow!("❌ Device code expired. Please try again."));
1056 }
1057 "access_denied" => {
1058 return Err(anyhow!("❌ Authorization denied by user."));
1059 }
1060 _ => {
1061 let desc = response.error_description.as_deref().unwrap_or("");
1062 return Err(anyhow!("❌ Auth error: {} - {}", error, desc));
1063 }
1064 }
1065 }
1066
1067 sleep(poll_interval).await;
1068 }
1069 }
1070
1071 /// Exchanges a GitHub access token for a Copilot API token.
1072 ///
1073 /// This method exchanges a GitHub OAuth access token for a Copilot-specific
1074 /// API token by calling GitHub's `/copilot_internal/v2/token` endpoint.
1075 ///
1076 /// # Arguments
1077 ///
1078 /// * `access_token` - GitHub OAuth access token response
1079 ///
1080 /// # Returns
1081 ///
1082 /// A [`CopilotConfig`] containing:
1083 /// - Copilot API token
1084 /// - Token expiration time
1085 /// - Feature flags and settings
1086 /// - API endpoints
1087 ///
1088 /// # Errors
1089 ///
1090 /// Returns an error if:
1091 /// - Access token is invalid or expired
1092 /// - Copilot is not enabled for the GitHub account
1093 /// - Proxy authentication is required
1094 /// - Network errors occur
1095 /// - Response parsing fails
1096 ///
1097 /// # Example
1098 ///
1099 /// ```rust,ignore
1100 /// # use bamboo_agent::agent::llm::providers::copilot::auth::handler::{CopilotAuthHandler, AccessTokenResponse};
1101 /// # async fn example(handler: CopilotAuthHandler, access_token: AccessTokenResponse) -> anyhow::Result<()> {
1102 /// let config = handler.get_copilot_token(access_token).await?;
1103 /// println!("Got Copilot token, expires at: {}", config.expires_at);
1104 /// # Ok(())
1105 /// # }
1106 /// ```
1107 pub(super) async fn get_copilot_token(
1108 &self,
1109 access_token: AccessTokenResponse,
1110 ) -> anyhow::Result<CopilotConfig> {
1111 let url = format!("{}/copilot_internal/v2/token", self.github_api_base_url);
1112 let actual_github_token = access_token
1113 .access_token
1114 .ok_or_else(|| anyhow!("Access token not found"))?;
1115
1116 let response = self
1117 .client
1118 .get(url)
1119 .header("Authorization", format!("token {}", actual_github_token))
1120 .header("Accept", "application/json")
1121 .header("User-Agent", "BambooCopilot/1.0")
1122 .send()
1123 .await?;
1124
1125 if response.status() == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
1126 return Err(anyhow!(ProxyAuthRequiredError));
1127 }
1128
1129 let status = response.status();
1130 if !status.is_success() {
1131 let text = response.text().await.unwrap_or_default();
1132 return Err(anyhow!(
1133 "Copilot token request failed: HTTP {} - {}",
1134 status,
1135 text
1136 ));
1137 }
1138
1139 let body = response.bytes().await?;
1140 match serde_json::from_slice::<CopilotConfig>(&body) {
1141 Ok(copilot_config) => {
1142 if !copilot_config.chat_enabled {
1143 return Err(anyhow!("❌ Copilot chat is not enabled for this account."));
1144 }
1145 if !self.headless_auth {
1146 println!(" ✅ Copilot token received!");
1147 }
1148 Ok(copilot_config)
1149 }
1150 Err(_) => {
1151 let body_str = String::from_utf8_lossy(&body);
1152 let error_msg = format!("Failed to get copilot config: {body_str}");
1153 error!("{error_msg}");
1154 Err(anyhow!(error_msg))
1155 }
1156 }
1157 }
1158}
1159
1160/// Integration tests for authentication retry logic.
1161///
1162/// These tests verify that authentication requests properly retry
1163/// on transient failures (e.g., 503 errors) while failing fast
1164/// on authentication errors (e.g., 401 unauthorized).
1165#[cfg(test)]
1166mod retry_tests {
1167 use super::*;
1168 use std::sync::atomic::{AtomicUsize, Ordering};
1169 use std::sync::Mutex as StdMutex;
1170
1171 // use http; // TODO: add http crate if needed
1172 use reqwest::Method;
1173 use reqwest_middleware::{ClientBuilder, Middleware, Next, Result as MiddlewareResult};
1174 use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
1175
1176 /// Mock HTTP response for testing.
1177 #[derive(Clone)]
1178 struct MockReply {
1179 /// HTTP status code
1180 status: u16,
1181 /// Response body
1182 body: String,
1183 /// Content-Type header value
1184 content_type: Option<&'static str>,
1185 }
1186
1187 impl MockReply {
1188 /// Creates a text response with the given status and body.
1189 fn text(status: u16, body: impl Into<String>) -> Self {
1190 Self {
1191 status,
1192 body: body.into(),
1193 content_type: Some("application/json"),
1194 }
1195 }
1196
1197 /// Creates a JSON response with the given status and JSON value.
1198 fn json(status: u16, value: serde_json::Value) -> Self {
1199 Self {
1200 status,
1201 body: value.to_string(),
1202 content_type: Some("application/json"),
1203 }
1204 }
1205 }
1206
1207 /// Middleware that mocks HTTP responses for testing.
1208 ///
1209 /// Returns responses in sequence, allowing tests to simulate
1210 /// retry scenarios (e.g., return 503 twice, then 200).
1211 #[derive(Clone)]
1212 struct MockResponder {
1213 /// Expected HTTP method
1214 expected_method: Method,
1215 /// Expected URL path
1216 expected_path: String,
1217 /// Counter for number of calls
1218 call_count: Arc<AtomicUsize>,
1219 /// Queue of responses to return
1220 replies: Arc<StdMutex<Vec<MockReply>>>,
1221 }
1222
1223 impl MockResponder {
1224 /// Creates a new mock responder.
1225 ///
1226 /// # Arguments
1227 ///
1228 /// * `expected_method` - HTTP method to expect
1229 /// * `expected_path` - URL path to expect
1230 /// * `call_count` - Counter to track number of calls
1231 /// * `replies` - Queue of responses to return in sequence
1232 fn new(
1233 expected_method: Method,
1234 expected_path: impl Into<String>,
1235 call_count: Arc<AtomicUsize>,
1236 replies: Vec<MockReply>,
1237 ) -> Self {
1238 Self {
1239 expected_method,
1240 expected_path: expected_path.into(),
1241 call_count,
1242 replies: Arc::new(StdMutex::new(replies)),
1243 }
1244 }
1245 }
1246
1247 #[async_trait::async_trait]
1248 impl Middleware for MockResponder {
1249 async fn handle(
1250 &self,
1251 req: reqwest::Request,
1252 _extensions: &mut http::Extensions,
1253 _next: Next<'_>,
1254 ) -> MiddlewareResult<reqwest::Response> {
1255 assert_eq!(
1256 req.method(),
1257 &self.expected_method,
1258 "unexpected method for {}",
1259 req.url()
1260 );
1261 assert_eq!(
1262 req.url().path(),
1263 self.expected_path.as_str(),
1264 "unexpected path for {}",
1265 req.url()
1266 );
1267
1268 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
1269 let reply = {
1270 let mut guard = self.replies.lock().expect("lock");
1271 if guard.is_empty() {
1272 panic!("no mock reply left for call #{idx}");
1273 }
1274 guard.remove(0)
1275 };
1276
1277 let mut builder = http::Response::builder().status(reply.status);
1278 if let Some(ct) = reply.content_type {
1279 builder = builder.header("content-type", ct);
1280 }
1281
1282 let http_response = builder.body(reply.body).expect("http response");
1283 Ok(reqwest::Response::from(http_response))
1284 }
1285 }
1286
1287 /// Creates a test HTTP client with retry middleware and mock responder.
1288 fn create_test_client_with_retry(mock: MockResponder) -> Arc<ClientWithMiddleware> {
1289 use reqwest::Client as ReqwestClient;
1290
1291 // Use a zero-delay retry policy to keep tests fast and deterministic.
1292 let retry_policy = ExponentialBackoff::builder()
1293 .retry_bounds(Duration::from_millis(0), Duration::from_millis(0))
1294 .build_with_max_retries(3);
1295
1296 let client = ReqwestClient::builder().build().expect("client");
1297
1298 Arc::new(
1299 ClientBuilder::new(client)
1300 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
1301 .with(mock)
1302 .build(),
1303 )
1304 }
1305
1306 /// Creates a sample CopilotConfig for testing with specified expiration time.
1307 fn sample_config(expires_at: u64) -> CopilotConfig {
1308 CopilotConfig {
1309 token: "cached-token".to_string(),
1310 annotations_enabled: false,
1311 chat_enabled: true,
1312 chat_jetbrains_enabled: false,
1313 code_quote_enabled: false,
1314 code_review_enabled: false,
1315 codesearch: false,
1316 copilotignore_enabled: false,
1317 endpoints: Endpoints {
1318 api: Some("https://api.example.com".to_string()),
1319 origin_tracker: None,
1320 proxy: None,
1321 telemetry: None,
1322 },
1323 expires_at,
1324 individual: true,
1325 limited_user_quotas: None,
1326 limited_user_reset_date: None,
1327 prompt_8k: false,
1328 public_suggestions: "disabled".to_string(),
1329 refresh_in: 300,
1330 sku: "test".to_string(),
1331 snippy_load_test_enabled: false,
1332 telemetry: "disabled".to_string(),
1333 tracking_id: "test".to_string(),
1334 vsc_electron_fetcher_v2: false,
1335 xcode: false,
1336 xcode_chat: false,
1337 }
1338 }
1339
1340 /// Test that auth requests are retried on transient failures.
1341 ///
1342 /// Simulates a scenario where the Copilot token endpoint returns
1343 /// 503 (Service Unavailable) twice before succeeding. Verifies that:
1344 /// - The request is retried automatically
1345 /// - Eventually succeeds after retries
1346 /// - Total call count is 3 (2 failures + 1 success)
1347 #[tokio::test]
1348 async fn test_auth_retry_on_server_error() {
1349 let request_count = Arc::new(AtomicUsize::new(0));
1350
1351 let mock = MockResponder::new(
1352 Method::GET,
1353 "/copilot_internal/v2/token",
1354 request_count.clone(),
1355 vec![
1356 MockReply::text(503, r#"{"error":"Service Unavailable"}"#),
1357 MockReply::text(503, r#"{"error":"Service Unavailable"}"#),
1358 MockReply::json(
1359 200,
1360 serde_json::json!({
1361 "token": "test-copilot-token",
1362 "expires_at": (SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() + 3600),
1363 "annotations_enabled": true,
1364 "chat_enabled": true,
1365 "chat_jetbrains_enabled": false,
1366 "code_quote_enabled": true,
1367 "code_review_enabled": false,
1368 "codesearch": false,
1369 "copilotignore_enabled": true,
1370 "endpoints": {
1371 "api": "https://api.githubcopilot.com"
1372 },
1373 "individual": true,
1374 "prompt_8k": true,
1375 "public_suggestions": "disabled",
1376 "refresh_in": 300,
1377 "sku": "copilot_individual",
1378 "snippy_load_test_enabled": false,
1379 "telemetry": "disabled",
1380 "tracking_id": "test-tracking-id",
1381 "vsc_electron_fetcher_v2": true,
1382 "xcode": false,
1383 "xcode_chat": false
1384 }),
1385 ),
1386 ],
1387 );
1388
1389 let client = create_test_client_with_retry(mock);
1390 let temp_dir = tempfile::tempdir().expect("tempdir");
1391 let handler = CopilotAuthHandler::new(client, temp_dir.path().to_path_buf(), true)
1392 .with_github_api_base_url("http://mock.local");
1393
1394 // Create a valid access token
1395 let access_token = AccessTokenResponse {
1396 access_token: Some("test-github-token".to_string()),
1397 token_type: Some("bearer".to_string()),
1398 scope: Some("read:user".to_string()),
1399 error: None,
1400 error_description: None,
1401 };
1402
1403 // This should retry and eventually succeed
1404 let result = handler.get_copilot_token(access_token).await;
1405 assert!(
1406 result.is_ok(),
1407 "Should succeed after retries: {:?}",
1408 result.err()
1409 );
1410 assert_eq!(request_count.load(Ordering::SeqCst), 3);
1411
1412 let config = result.unwrap();
1413 assert_eq!(config.token, "test-copilot-token");
1414 }
1415
1416 /// Test that auth requests fail fast on 401 (no retry).
1417 ///
1418 /// Verifies that authentication errors (401 Unauthorized) are not
1419 /// retried, as retrying would not fix the underlying issue.
1420 #[tokio::test]
1421 async fn test_auth_no_retry_on_unauthorized() {
1422 let request_count = Arc::new(AtomicUsize::new(0));
1423
1424 let mock = MockResponder::new(
1425 Method::GET,
1426 "/copilot_internal/v2/token",
1427 request_count.clone(),
1428 vec![MockReply::text(401, r#"{"error":"Unauthorized"}"#)],
1429 );
1430
1431 let client = create_test_client_with_retry(mock);
1432 let temp_dir = tempfile::tempdir().expect("tempdir");
1433 let handler = CopilotAuthHandler::new(client, temp_dir.path().to_path_buf(), true)
1434 .with_github_api_base_url("http://mock.local");
1435
1436 let access_token = AccessTokenResponse {
1437 access_token: Some("invalid-token".to_string()),
1438 token_type: Some("bearer".to_string()),
1439 scope: Some("read:user".to_string()),
1440 error: None,
1441 error_description: None,
1442 };
1443
1444 let result = handler.get_copilot_token(access_token).await;
1445 assert!(result.is_err());
1446 assert_eq!(request_count.load(Ordering::SeqCst), 1);
1447 }
1448
1449 /// Test device code endpoint retry.
1450 ///
1451 /// Simulates transient failures when requesting a device code
1452 /// and verifies that the request is retried until success.
1453 #[tokio::test]
1454 async fn test_device_code_retry() {
1455 let request_count = Arc::new(AtomicUsize::new(0));
1456
1457 let mock = MockResponder::new(
1458 Method::POST,
1459 "/login/device/code",
1460 request_count.clone(),
1461 vec![
1462 MockReply::text(503, ""),
1463 MockReply::text(503, ""),
1464 MockReply::json(
1465 200,
1466 serde_json::json!({
1467 "device_code": "test-device-code",
1468 "user_code": "ABCD-EFGH",
1469 "verification_uri": "https://github.com/login/device",
1470 "expires_in": 900,
1471 "interval": 5
1472 }),
1473 ),
1474 ],
1475 );
1476
1477 let client = create_test_client_with_retry(mock);
1478 let temp_dir = tempfile::tempdir().expect("tempdir");
1479 let handler = CopilotAuthHandler::new(client, temp_dir.path().to_path_buf(), true)
1480 .with_github_login_base_url("http://mock.local");
1481
1482 // Call the actual method - it should retry and eventually succeed
1483 let result = handler.get_device_code().await;
1484
1485 assert!(
1486 result.is_ok(),
1487 "Should succeed after retries: {:?}",
1488 result.err()
1489 );
1490 assert_eq!(request_count.load(Ordering::SeqCst), 3);
1491
1492 let device_code = result.unwrap();
1493 assert_eq!(device_code.device_code, "test-device-code");
1494 assert_eq!(device_code.user_code, "ABCD-EFGH");
1495 }
1496
1497 /// Test token cache validation.
1498 ///
1499 /// Verifies that the 60-second buffer for token validation works correctly:
1500 /// - Tokens valid for > 60 seconds are considered valid
1501 /// - Tokens expired or expiring within 60 seconds are considered invalid
1502 #[test]
1503 fn test_token_cache_validation() {
1504 let temp_dir = tempfile::tempdir().expect("tempdir");
1505 let client = create_test_client_with_retry(MockResponder::new(
1506 Method::GET,
1507 "/__unused__",
1508 Arc::new(AtomicUsize::new(0)),
1509 vec![],
1510 ));
1511 let handler = CopilotAuthHandler::new(client, temp_dir.path().to_path_buf(), true);
1512
1513 // Valid token (expires in 1 hour)
1514 let valid_config = sample_config(
1515 SystemTime::now()
1516 .duration_since(UNIX_EPOCH)
1517 .unwrap()
1518 .as_secs()
1519 + 3600,
1520 );
1521 assert!(handler.is_copilot_token_valid(&valid_config));
1522
1523 // Expired token (expired 1 hour ago)
1524 let expired_config = sample_config(
1525 SystemTime::now()
1526 .duration_since(UNIX_EPOCH)
1527 .unwrap()
1528 .as_secs()
1529 - 3600,
1530 );
1531 assert!(!handler.is_copilot_token_valid(&expired_config));
1532
1533 // Token expiring soon (30 seconds left, but we use 60s buffer)
1534 let expiring_soon_config = sample_config(
1535 SystemTime::now()
1536 .duration_since(UNIX_EPOCH)
1537 .unwrap()
1538 .as_secs()
1539 + 30,
1540 );
1541 assert!(!handler.is_copilot_token_valid(&expiring_soon_config));
1542 }
1543
1544 /// Test cached config round-trip with retry client.
1545 ///
1546 /// Verifies that CopilotConfig can be written to disk and read back
1547 /// correctly when using an HTTP client with retry middleware.
1548 #[test]
1549 fn test_cached_copilot_config_with_retry_client() {
1550 let dir = tempfile::tempdir().expect("tempdir");
1551 let client = create_test_client_with_retry(MockResponder::new(
1552 Method::GET,
1553 "/__unused__",
1554 Arc::new(AtomicUsize::new(0)),
1555 vec![],
1556 ));
1557 let handler = CopilotAuthHandler::new(client, dir.path().to_path_buf(), false);
1558 let token_path = dir.path().join(".copilot_token.json");
1559
1560 let expires_at = SystemTime::now()
1561 .duration_since(UNIX_EPOCH)
1562 .unwrap()
1563 .as_secs()
1564 + 3600;
1565 let config = sample_config(expires_at);
1566
1567 handler
1568 .write_cached_copilot_config(&token_path, &config)
1569 .expect("write cache");
1570 let loaded = handler
1571 .read_cached_copilot_config(&token_path)
1572 .expect("read cache");
1573
1574 assert_eq!(loaded.token, config.token);
1575 assert_eq!(loaded.expires_at, config.expires_at);
1576 }
1577}