Skip to main content

bamboo_agent/server/handlers/
copilot_auth.rs

1use crate::server::{app_state::AppState, error::AppError};
2use actix_web::{web, HttpResponse};
3use serde::{Deserialize, Serialize};
4
5#[derive(Serialize)]
6pub struct AuthStatus {
7    authenticated: bool,
8    message: Option<String>,
9}
10
11#[derive(Serialize)]
12pub struct DeviceCodeInfo {
13    device_code: String, // The actual device code for polling
14    user_code: String,   // The code user enters in browser
15    verification_uri: String,
16    expires_in: u64,
17    interval: u64, // Polling interval in seconds
18}
19
20#[derive(Deserialize)]
21pub struct CompleteAuthRequest {
22    device_code: String,
23    interval: u64,
24    expires_in: u64,
25}
26
27/// Start Copilot authentication - returns device code info
28pub async fn start_copilot_auth(app_state: web::Data<AppState>) -> Result<HttpResponse, AppError> {
29    use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
30    use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
31    use std::sync::Arc;
32    use std::time::Duration;
33
34    // Get config
35    let config = app_state.config.read().await.clone();
36    let app_data_dir = app_state.app_data_dir.clone();
37
38    // Resolve headless_auth from providers.copilot, with fallback to deprecated root field.
39    let headless_auth = config
40        .providers
41        .copilot
42        .as_ref()
43        .map(|c| c.headless_auth)
44        .unwrap_or(config.headless_auth);
45
46    // Build retry client
47    let retry_policy = ExponentialBackoff::builder()
48        .retry_bounds(Duration::from_millis(100), Duration::from_secs(5))
49        .build_with_max_retries(3);
50
51    let client = match crate::agent::llm::http_client::build_http_client(&config) {
52        Ok(client) => client,
53        Err(e) => {
54            log::error!("Failed to build Copilot auth HTTP client (proxy?): {}", e);
55            return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
56                "success": false,
57                "error": format!("Failed to build HTTP client: {}", e),
58            })));
59        }
60    };
61    let client_with_middleware: Arc<ClientWithMiddleware> = Arc::new(
62        ClientBuilder::new(client.clone())
63            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
64            .build(),
65    );
66
67    // Create auth handler
68    let handler = crate::agent::llm::providers::copilot::auth::CopilotAuthHandler::new(
69        client_with_middleware,
70        app_data_dir,
71        headless_auth,
72    );
73
74    match handler.start_authentication().await {
75        Ok(device_code) => {
76            log::info!("Device code obtained: {}", device_code.user_code);
77            Ok(HttpResponse::Ok().json(DeviceCodeInfo {
78                device_code: device_code.device_code,
79                user_code: device_code.user_code,
80                verification_uri: device_code.verification_uri,
81                expires_in: device_code.expires_in,
82                interval: device_code.interval,
83            }))
84        }
85        Err(e) => {
86            log::error!("Failed to get device code: {}", e);
87            Ok(HttpResponse::InternalServerError().json(serde_json::json!({
88                "success": false,
89                "error": format!("Failed to get device code: {}", e)
90            })))
91        }
92    }
93}
94
95/// Complete Copilot authentication after user enters device code
96pub async fn complete_copilot_auth(
97    app_state: web::Data<AppState>,
98    payload: web::Json<CompleteAuthRequest>,
99) -> Result<HttpResponse, AppError> {
100    use crate::agent::llm::providers::copilot::auth::{CopilotAuthHandler, DeviceCodeResponse};
101    use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
102    use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
103    use std::sync::Arc;
104    use std::time::Duration;
105
106    // Get config
107    let config = app_state.config.read().await.clone();
108    let app_data_dir = app_state.app_data_dir.clone();
109
110    // Resolve headless_auth from providers.copilot, with fallback to deprecated root field.
111    let headless_auth = config
112        .providers
113        .copilot
114        .as_ref()
115        .map(|c| c.headless_auth)
116        .unwrap_or(config.headless_auth);
117
118    // Build retry client
119    let retry_policy = ExponentialBackoff::builder()
120        .retry_bounds(Duration::from_millis(100), Duration::from_secs(5))
121        .build_with_max_retries(3);
122
123    let client = match crate::agent::llm::http_client::build_http_client(&config) {
124        Ok(client) => client,
125        Err(e) => {
126            log::error!("Failed to build Copilot auth HTTP client (proxy?): {}", e);
127            return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
128                "success": false,
129                "error": format!("Failed to build HTTP client: {}", e),
130            })));
131        }
132    };
133    let client_with_middleware: Arc<ClientWithMiddleware> = Arc::new(
134        ClientBuilder::new(client.clone())
135            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
136            .build(),
137    );
138
139    // Create auth handler
140    let handler = CopilotAuthHandler::new(client_with_middleware, app_data_dir, headless_auth);
141
142    // Create device code response from request
143    let device_code = DeviceCodeResponse {
144        device_code: payload.device_code.clone(),
145        user_code: String::new(), // Not needed for completion
146        verification_uri: String::new(),
147        expires_in: payload.expires_in,
148        interval: payload.interval,
149    };
150
151    match handler.complete_authentication(&device_code).await {
152        Ok(_) => {
153            log::info!("Copilot authentication completed successfully");
154
155            // Reload the provider in AppState with the authenticated provider
156            app_state.reload_provider().await.map_err(|e| {
157                AppError::InternalError(anyhow::anyhow!(
158                    "Failed to reload provider after authentication: {}",
159                    e
160                ))
161            })?;
162
163            Ok(HttpResponse::Ok().json(serde_json::json!({
164                "success": true,
165                "message": "Copilot authenticated successfully"
166            })))
167        }
168        Err(e) => {
169            log::error!("Copilot authentication completion failed: {}", e);
170            Ok(HttpResponse::InternalServerError().json(serde_json::json!({
171                "success": false,
172                "error": format!("Authentication failed: {}", e)
173            })))
174        }
175    }
176}
177
178/// Trigger Copilot authentication flow (legacy, for backward compatibility)
179pub async fn authenticate_copilot(
180    app_state: web::Data<AppState>,
181) -> Result<HttpResponse, AppError> {
182    // Get the current config
183    let config = app_state.config.read().await.clone();
184    let app_data_dir = app_state.app_data_dir.clone();
185
186    // Resolve headless_auth from providers.copilot, with fallback to deprecated root field.
187    let headless_auth = config
188        .providers
189        .copilot
190        .as_ref()
191        .map(|c| c.headless_auth)
192        .unwrap_or(config.headless_auth);
193
194    // Check if provider is copilot
195    if config.provider != "copilot" {
196        return Ok(HttpResponse::BadRequest().json(serde_json::json!({
197            "success": false,
198            "error": "Current provider is not Copilot"
199        })));
200    }
201
202    // Create a Copilot provider that respects configured proxy settings.
203    let http_client = match crate::agent::llm::http_client::build_http_client(&config) {
204        Ok(client) => client,
205        Err(e) => {
206            log::error!("Failed to build Copilot HTTP client (proxy?): {}", e);
207            return Ok(HttpResponse::InternalServerError().json(serde_json::json!({
208                "success": false,
209                "error": format!("Failed to build HTTP client: {}", e),
210            })));
211        }
212    };
213    let mut provider = crate::agent::llm::providers::CopilotProvider::with_auth_handler(
214        http_client,
215        app_data_dir,
216        headless_auth,
217    );
218
219    match provider.authenticate().await {
220        Ok(_) => {
221            log::info!("Copilot authentication successful");
222
223            // Reload the provider in AppState with the authenticated provider
224            app_state.reload_provider().await.map_err(|e| {
225                AppError::InternalError(anyhow::anyhow!(
226                    "Failed to reload provider after authentication: {}",
227                    e
228                ))
229            })?;
230
231            Ok(HttpResponse::Ok().json(serde_json::json!({
232                "success": true,
233                "message": "Copilot authenticated successfully"
234            })))
235        }
236        Err(e) => {
237            log::error!("Copilot authentication failed: {}", e);
238            Ok(HttpResponse::InternalServerError().json(serde_json::json!({
239                "success": false,
240                "error": format!("Authentication failed: {}", e)
241            })))
242        }
243    }
244}
245
246/// Check Copilot authentication status
247pub async fn get_copilot_auth_status(
248    app_state: web::Data<AppState>,
249) -> Result<HttpResponse, AppError> {
250    use std::fs;
251
252    let app_data_dir = app_state.app_data_dir.clone();
253    let copilot_token_path = app_data_dir.join(".copilot_token.json");
254
255    // Try to load cached token
256    if copilot_token_path.exists() {
257        if let Ok(content) = fs::read_to_string(&copilot_token_path) {
258            if let Ok(token_data) = serde_json::from_str::<serde_json::Value>(&content) {
259                if let Some(expires_at) = token_data.get("expires_at").and_then(|v| v.as_u64()) {
260                    let now = std::time::SystemTime::now()
261                        .duration_since(std::time::UNIX_EPOCH)
262                        .unwrap_or_default()
263                        .as_secs();
264
265                    if expires_at.saturating_sub(60) > now {
266                        let remaining = expires_at.saturating_sub(now);
267                        return Ok(HttpResponse::Ok().json(AuthStatus {
268                            authenticated: true,
269                            message: Some(format!("Token expires in {} minutes", remaining / 60)),
270                        }));
271                    } else {
272                        return Ok(HttpResponse::Ok().json(AuthStatus {
273                            authenticated: false,
274                            message: Some("Token expired".to_string()),
275                        }));
276                    }
277                }
278            }
279        }
280    }
281
282    Ok(HttpResponse::Ok().json(AuthStatus {
283        authenticated: false,
284        message: Some("No cached token found".to_string()),
285    }))
286}
287
288/// Logout from Copilot (delete cached token)
289pub async fn logout_copilot(app_state: web::Data<AppState>) -> Result<HttpResponse, AppError> {
290    use std::fs;
291
292    let app_data_dir = app_state.app_data_dir.clone();
293
294    let token_path = app_data_dir.join(".token");
295    let copilot_token_path = app_data_dir.join(".copilot_token.json");
296
297    let mut success = true;
298    let mut messages = Vec::new();
299
300    if token_path.exists() {
301        match fs::remove_file(&token_path) {
302            Ok(_) => messages.push("Deleted .token".to_string()),
303            Err(e) => {
304                success = false;
305                messages.push(format!("Failed to delete .token: {}", e));
306            }
307        }
308    }
309
310    if copilot_token_path.exists() {
311        match fs::remove_file(&copilot_token_path) {
312            Ok(_) => messages.push("Deleted .copilot_token.json".to_string()),
313            Err(e) => {
314                success = false;
315                messages.push(format!("Failed to delete .copilot_token.json: {}", e));
316            }
317        }
318    }
319
320    if success {
321        log::info!("Copilot logged out successfully");
322        Ok(HttpResponse::Ok().json(serde_json::json!({
323            "success": true,
324            "message": "Logged out successfully"
325        })))
326    } else {
327        log::error!("Failed to logout: {}", messages.join(", "));
328        Ok(HttpResponse::InternalServerError().json(serde_json::json!({
329            "success": false,
330            "error": messages.join(", ")
331        })))
332    }
333}
334
335pub fn config(cfg: &mut web::ServiceConfig) {
336    cfg.route(
337        "/bamboo/copilot/auth/start",
338        web::post().to(start_copilot_auth),
339    )
340    .route(
341        "/bamboo/copilot/auth/complete",
342        web::post().to(complete_copilot_auth),
343    )
344    .route(
345        "/bamboo/copilot/authenticate",
346        web::post().to(authenticate_copilot),
347    )
348    .route(
349        "/bamboo/copilot/auth/status",
350        web::post().to(get_copilot_auth_status),
351    )
352    .route("/bamboo/copilot/logout", web::post().to(logout_copilot));
353}