Skip to main content

codetether_agent/cli/
auth.rs

1//! Provider authentication commands.
2
3use super::{AuthArgs, AuthCommand, CopilotAuthArgs, LoginAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::{self, Write};
12use tokio::time::{Duration, sleep};
13
14const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
15const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
16
17#[derive(Debug, Deserialize)]
18struct DeviceCodeResponse {
19    device_code: String,
20    user_code: String,
21    verification_uri: String,
22    #[serde(default)]
23    verification_uri_complete: Option<String>,
24    #[serde(default)]
25    interval: Option<u64>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AccessTokenResponse {
30    #[serde(default)]
31    access_token: Option<String>,
32    #[serde(default)]
33    error: Option<String>,
34    #[serde(default)]
35    error_description: Option<String>,
36    #[serde(default)]
37    interval: Option<u64>,
38}
39
40pub async fn execute(args: AuthArgs) -> Result<()> {
41    match args.command {
42        AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
43        AuthCommand::Login(login_args) => authenticate_login(login_args).await,
44    }
45}
46
47async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
48    let server_url = args.server.trim_end_matches('/').to_string();
49
50    // Prompt for email if not provided
51    let email = match args.email {
52        Some(e) => e,
53        None => {
54            print!("Email: ");
55            io::stdout().flush()?;
56            let mut email = String::new();
57            io::stdin().read_line(&mut email)?;
58            email.trim().to_string()
59        }
60    };
61
62    if email.is_empty() {
63        anyhow::bail!("Email is required");
64    }
65
66    // Prompt for password (no echo)
67    let password = rpassword_prompt("Password: ")?;
68    if password.is_empty() {
69        anyhow::bail!("Password is required");
70    }
71
72    println!("Authenticating with {}...", server_url);
73
74    let client = Client::new();
75    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
76
77    let resp = client
78        .post(format!("{}/v1/users/login", server_url))
79        .header("User-Agent", &user_agent)
80        .header("Content-Type", "application/json")
81        .json(&json!({
82            "email": email,
83            "password": password,
84        }))
85        .send()
86        .await
87        .context("Failed to connect to CodeTether server")?;
88
89    if !resp.status().is_success() {
90        let status = resp.status();
91        let body: serde_json::Value = resp.json().await.unwrap_or_default();
92        let detail = body
93            .get("detail")
94            .and_then(|v| v.as_str())
95            .unwrap_or("Authentication failed");
96        anyhow::bail!("Login failed ({}): {}", status, detail);
97    }
98
99    #[derive(Deserialize)]
100    struct LoginResponse {
101        access_token: String,
102        expires_at: String,
103        user: serde_json::Value,
104    }
105
106    let login: LoginResponse = resp
107        .json()
108        .await
109        .context("Failed to parse login response")?;
110
111    // Store token to ~/.config/codetether-agent/credentials.json
112    let cred_path = credential_file_path()?;
113    if let Some(parent) = cred_path.parent() {
114        std::fs::create_dir_all(parent)
115            .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
116    }
117
118    let creds = json!({
119        "server": server_url,
120        "access_token": login.access_token,
121        "expires_at": login.expires_at,
122        "email": email,
123    });
124
125    // Write with restrictive permissions (owner-only read/write)
126    #[cfg(unix)]
127    {
128        use std::os::unix::fs::OpenOptionsExt;
129        let file = std::fs::OpenOptions::new()
130            .write(true)
131            .create(true)
132            .truncate(true)
133            .mode(0o600)
134            .open(&cred_path)
135            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
136        serde_json::to_writer_pretty(file, &creds)?;
137    }
138    #[cfg(not(unix))]
139    {
140        let file = std::fs::File::create(&cred_path)
141            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
142        serde_json::to_writer_pretty(file, &creds)?;
143    }
144
145    let user_email = login
146        .user
147        .get("email")
148        .and_then(|v| v.as_str())
149        .unwrap_or(&email);
150
151    println!("Logged in as {} (expires {})", user_email, login.expires_at);
152    println!("Credentials saved to {}", cred_path.display());
153    println!(
154        "\nUse the token with: export CODETETHER_TOKEN={}",
155        &login.access_token[..login.access_token.len().min(20)]
156    );
157    println!("Or run commands directly — the CLI reads credentials automatically.");
158
159    Ok(())
160}
161
162/// Read password from terminal without echo.
163fn rpassword_prompt(prompt: &str) -> Result<String> {
164    print!("{}", prompt);
165    io::stdout().flush()?;
166
167    // Disable echo on Unix
168    #[cfg(unix)]
169    {
170        use std::io::BufRead;
171        // Save terminal state
172        let fd = 0; // stdin
173        let orig = unsafe {
174            let mut termios = std::mem::zeroed::<libc::termios>();
175            libc::tcgetattr(fd, &mut termios);
176            termios
177        };
178
179        // Disable echo
180        unsafe {
181            let mut termios = orig;
182            termios.c_lflag &= !libc::ECHO;
183            libc::tcsetattr(fd, libc::TCSANOW, &termios);
184        }
185
186        let mut password = String::new();
187        let result = io::stdin().lock().read_line(&mut password);
188
189        // Restore terminal state
190        unsafe {
191            libc::tcsetattr(fd, libc::TCSANOW, &orig);
192        }
193        println!(); // newline after password entry
194
195        result?;
196        Ok(password.trim().to_string())
197    }
198
199    #[cfg(not(unix))]
200    {
201        let mut password = String::new();
202        io::stdin().read_line(&mut password)?;
203        Ok(password.trim().to_string())
204    }
205}
206
207/// Get the path to the credential storage file.
208fn credential_file_path() -> Result<std::path::PathBuf> {
209    use directories::ProjectDirs;
210    let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
211        .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
212    Ok(dirs.config_dir().join("credentials.json"))
213}
214
215async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
216    if secrets::secrets_manager().is_none() {
217        anyhow::bail!(
218            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
219        );
220    }
221
222    let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
223        Some(raw) => {
224            let domain = normalize_enterprise_domain(&raw);
225            if domain.is_empty() {
226                anyhow::bail!("--enterprise-url cannot be empty");
227            }
228            ("github-copilot-enterprise", domain.clone(), Some(domain))
229        }
230        None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
231    };
232
233    let client = Client::new();
234    let client_id = resolve_client_id(args.client_id)?;
235    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
236    let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
237
238    println!("GitHub Copilot device authentication");
239    println!(
240        "Open this URL: {}",
241        device
242            .verification_uri_complete
243            .as_deref()
244            .unwrap_or(&device.verification_uri)
245    );
246    println!("Enter code: {}", device.user_code);
247    println!("Waiting for authorization...");
248
249    let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
250
251    let mut extra = HashMap::new();
252    if let Some(enterprise_url) = enterprise_domain {
253        extra.insert(
254            "enterpriseUrl".to_string(),
255            serde_json::Value::String(enterprise_url),
256        );
257    }
258
259    let provider_secrets = ProviderSecrets {
260        api_key: Some(token),
261        base_url: None,
262        organization: None,
263        headers: None,
264        extra,
265    };
266
267    secrets::set_provider_secrets(provider_id, &provider_secrets)
268        .await
269        .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
270
271    println!("Saved {} credentials to HashiCorp Vault.", provider_id);
272    Ok(())
273}
274
275async fn request_device_code(
276    client: &Client,
277    domain: &str,
278    user_agent: &str,
279    client_id: &str,
280) -> Result<DeviceCodeResponse> {
281    let url = format!("https://{domain}/login/device/code");
282    let response = client
283        .post(&url)
284        .header("Accept", "application/json")
285        .header("Content-Type", "application/json")
286        .header("User-Agent", user_agent)
287        .json(&json!({
288            "client_id": client_id,
289            "scope": "read:user",
290        }))
291        .send()
292        .await
293        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
294
295    let status = response.status();
296    if !status.is_success() {
297        let body = response.text().await.unwrap_or_default();
298        anyhow::bail!(
299            "Failed to initiate device authorization ({}): {}",
300            status,
301            truncate_body(&body)
302        );
303    }
304
305    let mut device: DeviceCodeResponse = response
306        .json()
307        .await
308        .context("Failed to parse device authorization response")?;
309    if device.interval.unwrap_or(0) == 0 {
310        device.interval = Some(5);
311    }
312    Ok(device)
313}
314
315async fn poll_for_access_token(
316    client: &Client,
317    domain: &str,
318    user_agent: &str,
319    client_id: &str,
320    device: &DeviceCodeResponse,
321) -> Result<String> {
322    let url = format!("https://{domain}/login/oauth/access_token");
323    let mut interval_secs = device.interval.unwrap_or(5).max(1);
324
325    loop {
326        let response = client
327            .post(&url)
328            .header("Accept", "application/json")
329            .header("Content-Type", "application/json")
330            .header("User-Agent", user_agent)
331            .json(&json!({
332                "client_id": client_id,
333                "device_code": device.device_code,
334                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
335            }))
336            .send()
337            .await
338            .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
339
340        let status = response.status();
341        if !status.is_success() {
342            let body = response.text().await.unwrap_or_default();
343            anyhow::bail!(
344                "Failed to exchange device code for access token ({}): {}",
345                status,
346                truncate_body(&body)
347            );
348        }
349
350        let payload: AccessTokenResponse = response
351            .json()
352            .await
353            .context("Failed to parse OAuth token response")?;
354
355        if let Some(token) = payload.access_token {
356            if !token.trim().is_empty() {
357                return Ok(token);
358            }
359        }
360
361        match payload.error.as_deref() {
362            Some("authorization_pending") => sleep_with_margin(interval_secs).await,
363            Some("slow_down") => {
364                interval_secs = payload
365                    .interval
366                    .filter(|value| *value > 0)
367                    .unwrap_or(interval_secs + 5);
368                sleep_with_margin(interval_secs).await;
369            }
370            Some(error) => {
371                let description = payload
372                    .error_description
373                    .unwrap_or_else(|| "No error description provided".to_string());
374                anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
375            }
376            None => sleep_with_margin(interval_secs).await,
377        }
378    }
379}
380
381fn resolve_client_id(client_id: Option<String>) -> Result<String> {
382    let id = client_id
383        .map(|value| value.trim().to_string())
384        .filter(|value| !value.is_empty())
385        .ok_or_else(|| {
386            anyhow::anyhow!(
387                "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
388            )
389        })?;
390
391    Ok(id)
392}
393
394async fn sleep_with_margin(interval_secs: u64) {
395    sleep(Duration::from_millis(
396        interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
397    ))
398    .await;
399}
400
401fn truncate_body(body: &str) -> String {
402    const MAX_LEN: usize = 300;
403    if body.len() <= MAX_LEN {
404        body.to_string()
405    } else {
406        format!("{}...", &body[..MAX_LEN])
407    }
408}