Skip to main content

codetether_agent/cli/
auth.rs

1//! Provider authentication commands.
2
3use super::{AuthArgs, AuthCommand, CopilotAuthArgs, LoginAuthArgs, RegisterAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::{self, Write};
12use std::path::PathBuf;
13use tokio::time::{Duration, sleep};
14
15const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
16const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
17
18#[derive(Debug, Deserialize)]
19struct DeviceCodeResponse {
20    device_code: String,
21    user_code: String,
22    verification_uri: String,
23    #[serde(default)]
24    verification_uri_complete: Option<String>,
25    #[serde(default)]
26    interval: Option<u64>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AccessTokenResponse {
31    #[serde(default)]
32    access_token: Option<String>,
33    #[serde(default)]
34    error: Option<String>,
35    #[serde(default)]
36    error_description: Option<String>,
37    #[serde(default)]
38    interval: Option<u64>,
39}
40
41pub async fn execute(args: AuthArgs) -> Result<()> {
42    match args.command {
43        AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
44        AuthCommand::Register(register_args) => authenticate_register(register_args).await,
45        AuthCommand::Login(login_args) => authenticate_login(login_args).await,
46    }
47}
48
49#[derive(Debug, Deserialize)]
50struct LoginResponsePayload {
51    access_token: String,
52    expires_at: String,
53    user: serde_json::Value,
54}
55
56async fn login_with_password(
57    client: &Client,
58    server_url: &str,
59    email: &str,
60    password: &str,
61) -> Result<LoginResponsePayload> {
62    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
63
64    let resp = client
65        .post(format!("{}/v1/users/login", server_url))
66        .header("User-Agent", &user_agent)
67        .header("Content-Type", "application/json")
68        .json(&json!({
69            "email": email,
70            "password": password,
71        }))
72        .send()
73        .await
74        .context("Failed to connect to CodeTether server")?;
75
76    if !resp.status().is_success() {
77        let status = resp.status();
78        let body: serde_json::Value = resp.json().await.unwrap_or_default();
79        let detail = body
80            .get("detail")
81            .and_then(|v| v.as_str())
82            .unwrap_or("Authentication failed");
83        anyhow::bail!("Login failed ({}): {}", status, detail);
84    }
85
86    let login: LoginResponsePayload = resp
87        .json()
88        .await
89        .context("Failed to parse login response")?;
90
91    Ok(login)
92}
93
94fn write_saved_credentials(
95    server_url: &str,
96    email: &str,
97    login: &LoginResponsePayload,
98) -> Result<PathBuf> {
99    // Store token to ~/.config/codetether-agent/credentials.json
100    let cred_path = credential_file_path()?;
101    if let Some(parent) = cred_path.parent() {
102        std::fs::create_dir_all(parent)
103            .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
104    }
105
106    let creds = json!({
107        "server": server_url,
108        "access_token": login.access_token,
109        "expires_at": login.expires_at,
110        "email": email,
111    });
112
113    // Write with restrictive permissions (owner-only read/write)
114    #[cfg(unix)]
115    {
116        use std::os::unix::fs::OpenOptionsExt;
117        let file = std::fs::OpenOptions::new()
118            .write(true)
119            .create(true)
120            .truncate(true)
121            .mode(0o600)
122            .open(&cred_path)
123            .with_context(|| {
124                format!("Failed to write credentials to {}", cred_path.display())
125            })?;
126        serde_json::to_writer_pretty(file, &creds)?;
127    }
128    #[cfg(not(unix))]
129    {
130        let file = std::fs::File::create(&cred_path)
131            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
132        serde_json::to_writer_pretty(file, &creds)?;
133    }
134
135    Ok(cred_path)
136}
137
138async fn authenticate_register(args: RegisterAuthArgs) -> Result<()> {
139    #[derive(Debug, Deserialize)]
140    struct RegisterResponse {
141        user_id: String,
142        email: String,
143        message: String,
144        #[serde(default)]
145        instance_url: Option<String>,
146        #[serde(default)]
147        instance_namespace: Option<String>,
148        #[serde(default)]
149        provisioning_status: Option<String>,
150    }
151
152    let server_url = args.server.trim_end_matches('/').to_string();
153
154    let email = match args.email {
155        Some(e) => e,
156        None => {
157            print!("Email: ");
158            io::stdout().flush()?;
159            let mut email = String::new();
160            io::stdin().read_line(&mut email)?;
161            email.trim().to_string()
162        }
163    };
164
165    if email.is_empty() {
166        anyhow::bail!("Email is required");
167    }
168
169    let password = rpassword_prompt("Password (min 8 chars): ")?;
170    if password.trim().len() < 8 {
171        anyhow::bail!("Password must be at least 8 characters");
172    }
173    let confirm = rpassword_prompt("Confirm password: ")?;
174    if password != confirm {
175        anyhow::bail!("Passwords do not match");
176    }
177
178    println!("Registering with {}...", server_url);
179
180    let client = Client::new();
181    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
182
183    let resp = client
184        .post(format!("{}/v1/users/register", server_url))
185        .header("User-Agent", &user_agent)
186        .header("Content-Type", "application/json")
187        .json(&json!({
188            "email": email,
189            "password": password,
190            "first_name": args.first_name,
191            "last_name": args.last_name,
192            "referral_source": args.referral_source,
193        }))
194        .send()
195        .await
196        .context("Failed to connect to CodeTether server")?;
197
198    if !resp.status().is_success() {
199        let status = resp.status();
200        let body: serde_json::Value = resp.json().await.unwrap_or_default();
201        let detail = body
202            .get("detail")
203            .and_then(|v| v.as_str())
204            .unwrap_or("Registration failed");
205        anyhow::bail!("Registration failed ({}): {}", status, detail);
206    }
207
208    let reg: RegisterResponse = resp
209        .json()
210        .await
211        .context("Failed to parse registration response")?;
212
213    println!("Account created for {} (user_id={})", reg.email, reg.user_id);
214    println!("{}", reg.message);
215    if let Some(status) = reg.provisioning_status.as_deref() {
216        println!("Provisioning status: {}", status);
217    }
218    if let Some(url) = reg.instance_url.as_deref() {
219        println!("Instance URL: {}", url);
220    }
221    if let Some(ns) = reg.instance_namespace.as_deref() {
222        println!("Instance namespace: {}", ns);
223    }
224
225    // Auto-login and save credentials for the worker.
226    println!("Logging in...");
227    let login = login_with_password(&client, &server_url, &reg.email, &password).await?;
228    let cred_path = write_saved_credentials(&server_url, &reg.email, &login)?;
229
230    let user_email = login
231        .user
232        .get("email")
233        .and_then(|v| v.as_str())
234        .unwrap_or(&reg.email);
235
236    println!("Logged in as {} (expires {})", user_email, login.expires_at);
237    println!("Credentials saved to {}", cred_path.display());
238    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
239
240    Ok(())
241}
242
243async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
244    let server_url = args.server.trim_end_matches('/').to_string();
245
246    // Prompt for email if not provided
247    let email = match args.email {
248        Some(e) => e,
249        None => {
250            print!("Email: ");
251            io::stdout().flush()?;
252            let mut email = String::new();
253            io::stdin().read_line(&mut email)?;
254            email.trim().to_string()
255        }
256    };
257
258    if email.is_empty() {
259        anyhow::bail!("Email is required");
260    }
261
262    // Prompt for password (no echo)
263    let password = rpassword_prompt("Password: ")?;
264    if password.is_empty() {
265        anyhow::bail!("Password is required");
266    }
267
268    println!("Authenticating with {}...", server_url);
269
270    let client = Client::new();
271
272    let login = login_with_password(&client, &server_url, &email, &password).await?;
273    let cred_path = write_saved_credentials(&server_url, &email, &login)?;
274
275    let user_email = login
276        .user
277        .get("email")
278        .and_then(|v| v.as_str())
279        .unwrap_or(&email);
280
281    println!("Logged in as {} (expires {})", user_email, login.expires_at);
282    println!("Credentials saved to {}", cred_path.display());
283    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
284
285    Ok(())
286}
287
288/// Read password from terminal without echo.
289fn rpassword_prompt(prompt: &str) -> Result<String> {
290    print!("{}", prompt);
291    io::stdout().flush()?;
292
293    // Disable echo on Unix
294    #[cfg(unix)]
295    {
296        use std::io::BufRead;
297        // Save terminal state
298        let fd = 0; // stdin
299        let orig = unsafe {
300            let mut termios = std::mem::zeroed::<libc::termios>();
301            libc::tcgetattr(fd, &mut termios);
302            termios
303        };
304
305        // Disable echo
306        unsafe {
307            let mut termios = orig;
308            termios.c_lflag &= !libc::ECHO;
309            libc::tcsetattr(fd, libc::TCSANOW, &termios);
310        }
311
312        let mut password = String::new();
313        let result = io::stdin().lock().read_line(&mut password);
314
315        // Restore terminal state
316        unsafe {
317            libc::tcsetattr(fd, libc::TCSANOW, &orig);
318        }
319        println!(); // newline after password entry
320
321        result?;
322        Ok(password.trim().to_string())
323    }
324
325    #[cfg(not(unix))]
326    {
327        let mut password = String::new();
328        io::stdin().read_line(&mut password)?;
329        Ok(password.trim().to_string())
330    }
331}
332
333/// Get the path to the credential storage file.
334fn credential_file_path() -> Result<std::path::PathBuf> {
335    use directories::ProjectDirs;
336    let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
337        .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
338    Ok(dirs.config_dir().join("credentials.json"))
339}
340
341/// Stored credentials from `codetether auth login`.
342#[derive(Debug, Deserialize)]
343pub struct SavedCredentials {
344    pub server: String,
345    pub access_token: String,
346    pub expires_at: String,
347    #[serde(default)]
348    pub email: String,
349}
350
351/// Load saved credentials from disk, returning `None` if the file doesn't exist,
352/// is malformed, or the token has expired.
353pub fn load_saved_credentials() -> Option<SavedCredentials> {
354    let path = credential_file_path().ok()?;
355    let data = std::fs::read_to_string(&path).ok()?;
356    let creds: SavedCredentials = serde_json::from_str(&data).ok()?;
357
358    // Check expiry if parseable
359    if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(&creds.expires_at) {
360        if expires < chrono::Utc::now() {
361            tracing::warn!("Saved credentials have expired — run `codetether auth login` to refresh");
362            return None;
363        }
364    }
365
366    Some(creds)
367}
368
369async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
370    if secrets::secrets_manager().is_none() {
371        anyhow::bail!(
372            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
373        );
374    }
375
376    let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
377        Some(raw) => {
378            let domain = normalize_enterprise_domain(&raw);
379            if domain.is_empty() {
380                anyhow::bail!("--enterprise-url cannot be empty");
381            }
382            ("github-copilot-enterprise", domain.clone(), Some(domain))
383        }
384        None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
385    };
386
387    let client = Client::new();
388    let client_id = resolve_client_id(args.client_id)?;
389    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
390    let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
391
392    println!("GitHub Copilot device authentication");
393    println!(
394        "Open this URL: {}",
395        device
396            .verification_uri_complete
397            .as_deref()
398            .unwrap_or(&device.verification_uri)
399    );
400    println!("Enter code: {}", device.user_code);
401    println!("Waiting for authorization...");
402
403    let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
404
405    let mut extra = HashMap::new();
406    if let Some(enterprise_url) = enterprise_domain {
407        extra.insert(
408            "enterpriseUrl".to_string(),
409            serde_json::Value::String(enterprise_url),
410        );
411    }
412
413    let provider_secrets = ProviderSecrets {
414        api_key: Some(token),
415        base_url: None,
416        organization: None,
417        headers: None,
418        extra,
419    };
420
421    secrets::set_provider_secrets(provider_id, &provider_secrets)
422        .await
423        .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
424
425    println!("Saved {} credentials to HashiCorp Vault.", provider_id);
426    Ok(())
427}
428
429async fn request_device_code(
430    client: &Client,
431    domain: &str,
432    user_agent: &str,
433    client_id: &str,
434) -> Result<DeviceCodeResponse> {
435    let url = format!("https://{domain}/login/device/code");
436    let response = client
437        .post(&url)
438        .header("Accept", "application/json")
439        .header("Content-Type", "application/json")
440        .header("User-Agent", user_agent)
441        .json(&json!({
442            "client_id": client_id,
443            "scope": "read:user",
444        }))
445        .send()
446        .await
447        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
448
449    let status = response.status();
450    if !status.is_success() {
451        let body = response.text().await.unwrap_or_default();
452        anyhow::bail!(
453            "Failed to initiate device authorization ({}): {}",
454            status,
455            truncate_body(&body)
456        );
457    }
458
459    let mut device: DeviceCodeResponse = response
460        .json()
461        .await
462        .context("Failed to parse device authorization response")?;
463    if device.interval.unwrap_or(0) == 0 {
464        device.interval = Some(5);
465    }
466    Ok(device)
467}
468
469async fn poll_for_access_token(
470    client: &Client,
471    domain: &str,
472    user_agent: &str,
473    client_id: &str,
474    device: &DeviceCodeResponse,
475) -> Result<String> {
476    let url = format!("https://{domain}/login/oauth/access_token");
477    let mut interval_secs = device.interval.unwrap_or(5).max(1);
478
479    loop {
480        let response = client
481            .post(&url)
482            .header("Accept", "application/json")
483            .header("Content-Type", "application/json")
484            .header("User-Agent", user_agent)
485            .json(&json!({
486                "client_id": client_id,
487                "device_code": device.device_code,
488                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
489            }))
490            .send()
491            .await
492            .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
493
494        let status = response.status();
495        if !status.is_success() {
496            let body = response.text().await.unwrap_or_default();
497            anyhow::bail!(
498                "Failed to exchange device code for access token ({}): {}",
499                status,
500                truncate_body(&body)
501            );
502        }
503
504        let payload: AccessTokenResponse = response
505            .json()
506            .await
507            .context("Failed to parse OAuth token response")?;
508
509        if let Some(token) = payload.access_token {
510            if !token.trim().is_empty() {
511                return Ok(token);
512            }
513        }
514
515        match payload.error.as_deref() {
516            Some("authorization_pending") => sleep_with_margin(interval_secs).await,
517            Some("slow_down") => {
518                interval_secs = payload
519                    .interval
520                    .filter(|value| *value > 0)
521                    .unwrap_or(interval_secs + 5);
522                sleep_with_margin(interval_secs).await;
523            }
524            Some(error) => {
525                let description = payload
526                    .error_description
527                    .unwrap_or_else(|| "No error description provided".to_string());
528                anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
529            }
530            None => sleep_with_margin(interval_secs).await,
531        }
532    }
533}
534
535fn resolve_client_id(client_id: Option<String>) -> Result<String> {
536    let id = client_id
537        .map(|value| value.trim().to_string())
538        .filter(|value| !value.is_empty())
539        .ok_or_else(|| {
540            anyhow::anyhow!(
541                "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
542            )
543        })?;
544
545    Ok(id)
546}
547
548async fn sleep_with_margin(interval_secs: u64) {
549    sleep(Duration::from_millis(
550        interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
551    ))
552    .await;
553}
554
555fn truncate_body(body: &str) -> String {
556    const MAX_LEN: usize = 300;
557    if body.len() <= MAX_LEN {
558        body.to_string()
559    } else {
560        format!("{}...", &body[..MAX_LEN])
561    }
562}