Skip to main content

codetether_agent/cli/
auth.rs

1//! Provider authentication commands.
2
3use super::{AuthArgs, AuthCommand, CopilotAuthArgs, LoginAuthArgs, RegisterAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::{self, Write};
12use std::path::PathBuf;
13use tokio::time::{Duration, sleep};
14
15const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
16const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
17
18#[derive(Debug, Deserialize)]
19struct DeviceCodeResponse {
20    device_code: String,
21    user_code: String,
22    verification_uri: String,
23    #[serde(default)]
24    verification_uri_complete: Option<String>,
25    #[serde(default)]
26    interval: Option<u64>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AccessTokenResponse {
31    #[serde(default)]
32    access_token: Option<String>,
33    #[serde(default)]
34    error: Option<String>,
35    #[serde(default)]
36    error_description: Option<String>,
37    #[serde(default)]
38    interval: Option<u64>,
39}
40
41pub async fn execute(args: AuthArgs) -> Result<()> {
42    match args.command {
43        AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
44        AuthCommand::Register(register_args) => authenticate_register(register_args).await,
45        AuthCommand::Login(login_args) => authenticate_login(login_args).await,
46    }
47}
48
49#[derive(Debug, Deserialize)]
50struct LoginResponsePayload {
51    access_token: String,
52    expires_at: String,
53    user: serde_json::Value,
54}
55
56async fn login_with_password(
57    client: &Client,
58    server_url: &str,
59    email: &str,
60    password: &str,
61) -> Result<LoginResponsePayload> {
62    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
63
64    let resp = client
65        .post(format!("{}/v1/users/login", server_url))
66        .header("User-Agent", &user_agent)
67        .header("Content-Type", "application/json")
68        .json(&json!({
69            "email": email,
70            "password": password,
71        }))
72        .send()
73        .await
74        .context("Failed to connect to CodeTether server")?;
75
76    if !resp.status().is_success() {
77        let status = resp.status();
78        let body: serde_json::Value = resp.json().await.unwrap_or_default();
79        let detail = body
80            .get("detail")
81            .and_then(|v| v.as_str())
82            .unwrap_or("Authentication failed");
83        anyhow::bail!("Login failed ({}): {}", status, detail);
84    }
85
86    let login: LoginResponsePayload = resp
87        .json()
88        .await
89        .context("Failed to parse login response")?;
90
91    Ok(login)
92}
93
94fn write_saved_credentials(
95    server_url: &str,
96    email: &str,
97    login: &LoginResponsePayload,
98) -> Result<PathBuf> {
99    // Store token to ~/.config/codetether-agent/credentials.json
100    let cred_path = credential_file_path()?;
101    if let Some(parent) = cred_path.parent() {
102        std::fs::create_dir_all(parent)
103            .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
104    }
105
106    let creds = json!({
107        "server": server_url,
108        "access_token": login.access_token,
109        "expires_at": login.expires_at,
110        "email": email,
111    });
112
113    // Write with restrictive permissions (owner-only read/write)
114    #[cfg(unix)]
115    {
116        use std::os::unix::fs::OpenOptionsExt;
117        let file = std::fs::OpenOptions::new()
118            .write(true)
119            .create(true)
120            .truncate(true)
121            .mode(0o600)
122            .open(&cred_path)
123            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
124        serde_json::to_writer_pretty(file, &creds)?;
125    }
126    #[cfg(not(unix))]
127    {
128        let file = std::fs::File::create(&cred_path)
129            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
130        serde_json::to_writer_pretty(file, &creds)?;
131    }
132
133    Ok(cred_path)
134}
135
136async fn authenticate_register(args: RegisterAuthArgs) -> Result<()> {
137    #[derive(Debug, Deserialize)]
138    struct RegisterResponse {
139        user_id: String,
140        email: String,
141        message: String,
142        #[serde(default)]
143        instance_url: Option<String>,
144        #[serde(default)]
145        instance_namespace: Option<String>,
146        #[serde(default)]
147        provisioning_status: Option<String>,
148    }
149
150    let server_url = args.server.trim_end_matches('/').to_string();
151
152    let email = match args.email {
153        Some(e) => e,
154        None => {
155            print!("Email: ");
156            io::stdout().flush()?;
157            let mut email = String::new();
158            io::stdin().read_line(&mut email)?;
159            email.trim().to_string()
160        }
161    };
162
163    if email.is_empty() {
164        anyhow::bail!("Email is required");
165    }
166
167    let password = rpassword_prompt("Password (min 8 chars): ")?;
168    if password.trim().len() < 8 {
169        anyhow::bail!("Password must be at least 8 characters");
170    }
171    let confirm = rpassword_prompt("Confirm password: ")?;
172    if password != confirm {
173        anyhow::bail!("Passwords do not match");
174    }
175
176    println!("Registering with {}...", server_url);
177
178    let client = Client::new();
179    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
180
181    let resp = client
182        .post(format!("{}/v1/users/register", server_url))
183        .header("User-Agent", &user_agent)
184        .header("Content-Type", "application/json")
185        .json(&json!({
186            "email": email,
187            "password": password,
188            "first_name": args.first_name,
189            "last_name": args.last_name,
190            "referral_source": args.referral_source,
191        }))
192        .send()
193        .await
194        .context("Failed to connect to CodeTether server")?;
195
196    if !resp.status().is_success() {
197        let status = resp.status();
198        let body: serde_json::Value = resp.json().await.unwrap_or_default();
199        let detail = body
200            .get("detail")
201            .and_then(|v| v.as_str())
202            .unwrap_or("Registration failed");
203        anyhow::bail!("Registration failed ({}): {}", status, detail);
204    }
205
206    let reg: RegisterResponse = resp
207        .json()
208        .await
209        .context("Failed to parse registration response")?;
210
211    println!(
212        "Account created for {} (user_id={})",
213        reg.email, reg.user_id
214    );
215    println!("{}", reg.message);
216    if let Some(status) = reg.provisioning_status.as_deref() {
217        println!("Provisioning status: {}", status);
218    }
219    if let Some(url) = reg.instance_url.as_deref() {
220        println!("Instance URL: {}", url);
221    }
222    if let Some(ns) = reg.instance_namespace.as_deref() {
223        println!("Instance namespace: {}", ns);
224    }
225
226    // Auto-login and save credentials for the worker.
227    println!("Logging in...");
228    let login = login_with_password(&client, &server_url, &reg.email, &password).await?;
229    let cred_path = write_saved_credentials(&server_url, &reg.email, &login)?;
230
231    let user_email = login
232        .user
233        .get("email")
234        .and_then(|v| v.as_str())
235        .unwrap_or(&reg.email);
236
237    println!("Logged in as {} (expires {})", user_email, login.expires_at);
238    println!("Credentials saved to {}", cred_path.display());
239    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
240
241    Ok(())
242}
243
244async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
245    let server_url = args.server.trim_end_matches('/').to_string();
246
247    // Prompt for email if not provided
248    let email = match args.email {
249        Some(e) => e,
250        None => {
251            print!("Email: ");
252            io::stdout().flush()?;
253            let mut email = String::new();
254            io::stdin().read_line(&mut email)?;
255            email.trim().to_string()
256        }
257    };
258
259    if email.is_empty() {
260        anyhow::bail!("Email is required");
261    }
262
263    // Prompt for password (no echo)
264    let password = rpassword_prompt("Password: ")?;
265    if password.is_empty() {
266        anyhow::bail!("Password is required");
267    }
268
269    println!("Authenticating with {}...", server_url);
270
271    let client = Client::new();
272
273    let login = login_with_password(&client, &server_url, &email, &password).await?;
274    let cred_path = write_saved_credentials(&server_url, &email, &login)?;
275
276    let user_email = login
277        .user
278        .get("email")
279        .and_then(|v| v.as_str())
280        .unwrap_or(&email);
281
282    println!("Logged in as {} (expires {})", user_email, login.expires_at);
283    println!("Credentials saved to {}", cred_path.display());
284    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
285
286    Ok(())
287}
288
289/// Read password from terminal without echo.
290fn rpassword_prompt(prompt: &str) -> Result<String> {
291    print!("{}", prompt);
292    io::stdout().flush()?;
293
294    // Disable echo on Unix
295    #[cfg(unix)]
296    {
297        use std::io::BufRead;
298        // Save terminal state
299        let fd = 0; // stdin
300        let orig = unsafe {
301            let mut termios = std::mem::zeroed::<libc::termios>();
302            libc::tcgetattr(fd, &mut termios);
303            termios
304        };
305
306        // Disable echo
307        unsafe {
308            let mut termios = orig;
309            termios.c_lflag &= !libc::ECHO;
310            libc::tcsetattr(fd, libc::TCSANOW, &termios);
311        }
312
313        let mut password = String::new();
314        let result = io::stdin().lock().read_line(&mut password);
315
316        // Restore terminal state
317        unsafe {
318            libc::tcsetattr(fd, libc::TCSANOW, &orig);
319        }
320        println!(); // newline after password entry
321
322        result?;
323        Ok(password.trim().to_string())
324    }
325
326    #[cfg(not(unix))]
327    {
328        let mut password = String::new();
329        io::stdin().read_line(&mut password)?;
330        Ok(password.trim().to_string())
331    }
332}
333
334/// Get the path to the credential storage file.
335fn credential_file_path() -> Result<std::path::PathBuf> {
336    use directories::ProjectDirs;
337    let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
338        .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
339    Ok(dirs.config_dir().join("credentials.json"))
340}
341
342/// Stored credentials from `codetether auth login`.
343#[derive(Debug, Deserialize)]
344pub struct SavedCredentials {
345    pub server: String,
346    pub access_token: String,
347    pub expires_at: String,
348    #[serde(default)]
349    pub email: String,
350}
351
352/// Load saved credentials from disk, returning `None` if the file doesn't exist,
353/// is malformed, or the token has expired.
354pub fn load_saved_credentials() -> Option<SavedCredentials> {
355    let path = credential_file_path().ok()?;
356    let data = std::fs::read_to_string(&path).ok()?;
357    let creds: SavedCredentials = serde_json::from_str(&data).ok()?;
358
359    // Check expiry if parseable
360    if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(&creds.expires_at) {
361        if expires < chrono::Utc::now() {
362            tracing::warn!(
363                "Saved credentials have expired — run `codetether auth login` to refresh"
364            );
365            return None;
366        }
367    }
368
369    Some(creds)
370}
371
372async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
373    if secrets::secrets_manager().is_none() {
374        anyhow::bail!(
375            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
376        );
377    }
378
379    let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
380        Some(raw) => {
381            let domain = normalize_enterprise_domain(&raw);
382            if domain.is_empty() {
383                anyhow::bail!("--enterprise-url cannot be empty");
384            }
385            ("github-copilot-enterprise", domain.clone(), Some(domain))
386        }
387        None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
388    };
389
390    let client = Client::new();
391    let client_id = resolve_client_id(args.client_id)?;
392    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
393    let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
394
395    println!("GitHub Copilot device authentication");
396    println!(
397        "Open this URL: {}",
398        device
399            .verification_uri_complete
400            .as_deref()
401            .unwrap_or(&device.verification_uri)
402    );
403    println!("Enter code: {}", device.user_code);
404    println!("Waiting for authorization...");
405
406    let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
407
408    let mut extra = HashMap::new();
409    if let Some(enterprise_url) = enterprise_domain {
410        extra.insert(
411            "enterpriseUrl".to_string(),
412            serde_json::Value::String(enterprise_url),
413        );
414    }
415
416    let provider_secrets = ProviderSecrets {
417        api_key: Some(token),
418        base_url: None,
419        organization: None,
420        headers: None,
421        extra,
422    };
423
424    secrets::set_provider_secrets(provider_id, &provider_secrets)
425        .await
426        .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
427
428    println!("Saved {} credentials to HashiCorp Vault.", provider_id);
429    Ok(())
430}
431
432async fn request_device_code(
433    client: &Client,
434    domain: &str,
435    user_agent: &str,
436    client_id: &str,
437) -> Result<DeviceCodeResponse> {
438    let url = format!("https://{domain}/login/device/code");
439    let response = client
440        .post(&url)
441        .header("Accept", "application/json")
442        .header("Content-Type", "application/json")
443        .header("User-Agent", user_agent)
444        .json(&json!({
445            "client_id": client_id,
446            "scope": "read:user",
447        }))
448        .send()
449        .await
450        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
451
452    let status = response.status();
453    if !status.is_success() {
454        let body = response.text().await.unwrap_or_default();
455        anyhow::bail!(
456            "Failed to initiate device authorization ({}): {}",
457            status,
458            truncate_body(&body)
459        );
460    }
461
462    let mut device: DeviceCodeResponse = response
463        .json()
464        .await
465        .context("Failed to parse device authorization response")?;
466    if device.interval.unwrap_or(0) == 0 {
467        device.interval = Some(5);
468    }
469    Ok(device)
470}
471
472async fn poll_for_access_token(
473    client: &Client,
474    domain: &str,
475    user_agent: &str,
476    client_id: &str,
477    device: &DeviceCodeResponse,
478) -> Result<String> {
479    let url = format!("https://{domain}/login/oauth/access_token");
480    let mut interval_secs = device.interval.unwrap_or(5).max(1);
481
482    loop {
483        let response = client
484            .post(&url)
485            .header("Accept", "application/json")
486            .header("Content-Type", "application/json")
487            .header("User-Agent", user_agent)
488            .json(&json!({
489                "client_id": client_id,
490                "device_code": device.device_code,
491                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
492            }))
493            .send()
494            .await
495            .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
496
497        let status = response.status();
498        if !status.is_success() {
499            let body = response.text().await.unwrap_or_default();
500            anyhow::bail!(
501                "Failed to exchange device code for access token ({}): {}",
502                status,
503                truncate_body(&body)
504            );
505        }
506
507        let payload: AccessTokenResponse = response
508            .json()
509            .await
510            .context("Failed to parse OAuth token response")?;
511
512        if let Some(token) = payload.access_token {
513            if !token.trim().is_empty() {
514                return Ok(token);
515            }
516        }
517
518        match payload.error.as_deref() {
519            Some("authorization_pending") => sleep_with_margin(interval_secs).await,
520            Some("slow_down") => {
521                interval_secs = payload
522                    .interval
523                    .filter(|value| *value > 0)
524                    .unwrap_or(interval_secs + 5);
525                sleep_with_margin(interval_secs).await;
526            }
527            Some(error) => {
528                let description = payload
529                    .error_description
530                    .unwrap_or_else(|| "No error description provided".to_string());
531                anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
532            }
533            None => sleep_with_margin(interval_secs).await,
534        }
535    }
536}
537
538fn resolve_client_id(client_id: Option<String>) -> Result<String> {
539    let id = client_id
540        .map(|value| value.trim().to_string())
541        .filter(|value| !value.is_empty())
542        .ok_or_else(|| {
543            anyhow::anyhow!(
544                "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
545            )
546        })?;
547
548    Ok(id)
549}
550
551async fn sleep_with_margin(interval_secs: u64) {
552    sleep(Duration::from_millis(
553        interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
554    ))
555    .await;
556}
557
558fn truncate_body(body: &str) -> String {
559    const MAX_LEN: usize = 300;
560    if body.len() <= MAX_LEN {
561        body.to_string()
562    } else {
563        format!("{}...", &body[..MAX_LEN])
564    }
565}