1use super::{AuthArgs, AuthCommand, CopilotAuthArgs, LoginAuthArgs, RegisterAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::{self, Write};
12use std::path::PathBuf;
13use tokio::time::{Duration, sleep};
14
15const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
16const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
17
18#[derive(Debug, Deserialize)]
19struct DeviceCodeResponse {
20 device_code: String,
21 user_code: String,
22 verification_uri: String,
23 #[serde(default)]
24 verification_uri_complete: Option<String>,
25 #[serde(default)]
26 interval: Option<u64>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AccessTokenResponse {
31 #[serde(default)]
32 access_token: Option<String>,
33 #[serde(default)]
34 error: Option<String>,
35 #[serde(default)]
36 error_description: Option<String>,
37 #[serde(default)]
38 interval: Option<u64>,
39}
40
41pub async fn execute(args: AuthArgs) -> Result<()> {
42 match args.command {
43 AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
44 AuthCommand::Register(register_args) => authenticate_register(register_args).await,
45 AuthCommand::Login(login_args) => authenticate_login(login_args).await,
46 }
47}
48
49#[derive(Debug, Deserialize)]
50struct LoginResponsePayload {
51 access_token: String,
52 expires_at: String,
53 user: serde_json::Value,
54}
55
56async fn login_with_password(
57 client: &Client,
58 server_url: &str,
59 email: &str,
60 password: &str,
61) -> Result<LoginResponsePayload> {
62 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
63
64 let resp = client
65 .post(format!("{}/v1/users/login", server_url))
66 .header("User-Agent", &user_agent)
67 .header("Content-Type", "application/json")
68 .json(&json!({
69 "email": email,
70 "password": password,
71 }))
72 .send()
73 .await
74 .context("Failed to connect to CodeTether server")?;
75
76 if !resp.status().is_success() {
77 let status = resp.status();
78 let body: serde_json::Value = resp.json().await.unwrap_or_default();
79 let detail = body
80 .get("detail")
81 .and_then(|v| v.as_str())
82 .unwrap_or("Authentication failed");
83 anyhow::bail!("Login failed ({}): {}", status, detail);
84 }
85
86 let login: LoginResponsePayload = resp
87 .json()
88 .await
89 .context("Failed to parse login response")?;
90
91 Ok(login)
92}
93
94fn write_saved_credentials(
95 server_url: &str,
96 email: &str,
97 login: &LoginResponsePayload,
98) -> Result<PathBuf> {
99 let cred_path = credential_file_path()?;
101 if let Some(parent) = cred_path.parent() {
102 std::fs::create_dir_all(parent)
103 .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
104 }
105
106 let creds = json!({
107 "server": server_url,
108 "access_token": login.access_token,
109 "expires_at": login.expires_at,
110 "email": email,
111 });
112
113 #[cfg(unix)]
115 {
116 use std::os::unix::fs::OpenOptionsExt;
117 let file = std::fs::OpenOptions::new()
118 .write(true)
119 .create(true)
120 .truncate(true)
121 .mode(0o600)
122 .open(&cred_path)
123 .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
124 serde_json::to_writer_pretty(file, &creds)?;
125 }
126 #[cfg(not(unix))]
127 {
128 let file = std::fs::File::create(&cred_path)
129 .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
130 serde_json::to_writer_pretty(file, &creds)?;
131 }
132
133 Ok(cred_path)
134}
135
136async fn authenticate_register(args: RegisterAuthArgs) -> Result<()> {
137 #[derive(Debug, Deserialize)]
138 struct RegisterResponse {
139 user_id: String,
140 email: String,
141 message: String,
142 #[serde(default)]
143 instance_url: Option<String>,
144 #[serde(default)]
145 instance_namespace: Option<String>,
146 #[serde(default)]
147 provisioning_status: Option<String>,
148 }
149
150 let server_url = args.server.trim_end_matches('/').to_string();
151
152 let email = match args.email {
153 Some(e) => e,
154 None => {
155 print!("Email: ");
156 io::stdout().flush()?;
157 let mut email = String::new();
158 io::stdin().read_line(&mut email)?;
159 email.trim().to_string()
160 }
161 };
162
163 if email.is_empty() {
164 anyhow::bail!("Email is required");
165 }
166
167 let password = rpassword_prompt("Password (min 8 chars): ")?;
168 if password.trim().len() < 8 {
169 anyhow::bail!("Password must be at least 8 characters");
170 }
171 let confirm = rpassword_prompt("Confirm password: ")?;
172 if password != confirm {
173 anyhow::bail!("Passwords do not match");
174 }
175
176 println!("Registering with {}...", server_url);
177
178 let client = Client::new();
179 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
180
181 let resp = client
182 .post(format!("{}/v1/users/register", server_url))
183 .header("User-Agent", &user_agent)
184 .header("Content-Type", "application/json")
185 .json(&json!({
186 "email": email,
187 "password": password,
188 "first_name": args.first_name,
189 "last_name": args.last_name,
190 "referral_source": args.referral_source,
191 }))
192 .send()
193 .await
194 .context("Failed to connect to CodeTether server")?;
195
196 if !resp.status().is_success() {
197 let status = resp.status();
198 let body: serde_json::Value = resp.json().await.unwrap_or_default();
199 let detail = body
200 .get("detail")
201 .and_then(|v| v.as_str())
202 .unwrap_or("Registration failed");
203 anyhow::bail!("Registration failed ({}): {}", status, detail);
204 }
205
206 let reg: RegisterResponse = resp
207 .json()
208 .await
209 .context("Failed to parse registration response")?;
210
211 println!(
212 "Account created for {} (user_id={})",
213 reg.email, reg.user_id
214 );
215 println!("{}", reg.message);
216 if let Some(status) = reg.provisioning_status.as_deref() {
217 println!("Provisioning status: {}", status);
218 }
219 if let Some(url) = reg.instance_url.as_deref() {
220 println!("Instance URL: {}", url);
221 }
222 if let Some(ns) = reg.instance_namespace.as_deref() {
223 println!("Instance namespace: {}", ns);
224 }
225
226 println!("Logging in...");
228 let login = login_with_password(&client, &server_url, ®.email, &password).await?;
229 let cred_path = write_saved_credentials(&server_url, ®.email, &login)?;
230
231 let user_email = login
232 .user
233 .get("email")
234 .and_then(|v| v.as_str())
235 .unwrap_or(®.email);
236
237 println!("Logged in as {} (expires {})", user_email, login.expires_at);
238 println!("Credentials saved to {}", cred_path.display());
239 println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
240
241 Ok(())
242}
243
244async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
245 let server_url = args.server.trim_end_matches('/').to_string();
246
247 let email = match args.email {
249 Some(e) => e,
250 None => {
251 print!("Email: ");
252 io::stdout().flush()?;
253 let mut email = String::new();
254 io::stdin().read_line(&mut email)?;
255 email.trim().to_string()
256 }
257 };
258
259 if email.is_empty() {
260 anyhow::bail!("Email is required");
261 }
262
263 let password = rpassword_prompt("Password: ")?;
265 if password.is_empty() {
266 anyhow::bail!("Password is required");
267 }
268
269 println!("Authenticating with {}...", server_url);
270
271 let client = Client::new();
272
273 let login = login_with_password(&client, &server_url, &email, &password).await?;
274 let cred_path = write_saved_credentials(&server_url, &email, &login)?;
275
276 let user_email = login
277 .user
278 .get("email")
279 .and_then(|v| v.as_str())
280 .unwrap_or(&email);
281
282 println!("Logged in as {} (expires {})", user_email, login.expires_at);
283 println!("Credentials saved to {}", cred_path.display());
284 println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
285
286 Ok(())
287}
288
289fn rpassword_prompt(prompt: &str) -> Result<String> {
291 print!("{}", prompt);
292 io::stdout().flush()?;
293
294 #[cfg(unix)]
296 {
297 use std::io::BufRead;
298 let fd = 0; let orig = unsafe {
301 let mut termios = std::mem::zeroed::<libc::termios>();
302 libc::tcgetattr(fd, &mut termios);
303 termios
304 };
305
306 unsafe {
308 let mut termios = orig;
309 termios.c_lflag &= !libc::ECHO;
310 libc::tcsetattr(fd, libc::TCSANOW, &termios);
311 }
312
313 let mut password = String::new();
314 let result = io::stdin().lock().read_line(&mut password);
315
316 unsafe {
318 libc::tcsetattr(fd, libc::TCSANOW, &orig);
319 }
320 println!(); result?;
323 Ok(password.trim().to_string())
324 }
325
326 #[cfg(not(unix))]
327 {
328 let mut password = String::new();
329 io::stdin().read_line(&mut password)?;
330 Ok(password.trim().to_string())
331 }
332}
333
334fn credential_file_path() -> Result<std::path::PathBuf> {
336 use directories::ProjectDirs;
337 let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
338 .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
339 Ok(dirs.config_dir().join("credentials.json"))
340}
341
342#[derive(Debug, Deserialize)]
344pub struct SavedCredentials {
345 pub server: String,
346 pub access_token: String,
347 pub expires_at: String,
348 #[serde(default)]
349 pub email: String,
350}
351
352pub fn load_saved_credentials() -> Option<SavedCredentials> {
355 let path = credential_file_path().ok()?;
356 let data = std::fs::read_to_string(&path).ok()?;
357 let creds: SavedCredentials = serde_json::from_str(&data).ok()?;
358
359 if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(&creds.expires_at) {
361 if expires < chrono::Utc::now() {
362 tracing::warn!(
363 "Saved credentials have expired — run `codetether auth login` to refresh"
364 );
365 return None;
366 }
367 }
368
369 Some(creds)
370}
371
372async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
373 if secrets::secrets_manager().is_none() {
374 anyhow::bail!(
375 "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
376 );
377 }
378
379 let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
380 Some(raw) => {
381 let domain = normalize_enterprise_domain(&raw);
382 if domain.is_empty() {
383 anyhow::bail!("--enterprise-url cannot be empty");
384 }
385 ("github-copilot-enterprise", domain.clone(), Some(domain))
386 }
387 None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
388 };
389
390 let client = Client::new();
391 let client_id = resolve_client_id(args.client_id)?;
392 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
393 let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
394
395 println!("GitHub Copilot device authentication");
396 println!(
397 "Open this URL: {}",
398 device
399 .verification_uri_complete
400 .as_deref()
401 .unwrap_or(&device.verification_uri)
402 );
403 println!("Enter code: {}", device.user_code);
404 println!("Waiting for authorization...");
405
406 let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
407
408 let mut extra = HashMap::new();
409 if let Some(enterprise_url) = enterprise_domain {
410 extra.insert(
411 "enterpriseUrl".to_string(),
412 serde_json::Value::String(enterprise_url),
413 );
414 }
415
416 let provider_secrets = ProviderSecrets {
417 api_key: Some(token),
418 base_url: None,
419 organization: None,
420 headers: None,
421 extra,
422 };
423
424 secrets::set_provider_secrets(provider_id, &provider_secrets)
425 .await
426 .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
427
428 println!("Saved {} credentials to HashiCorp Vault.", provider_id);
429 Ok(())
430}
431
432async fn request_device_code(
433 client: &Client,
434 domain: &str,
435 user_agent: &str,
436 client_id: &str,
437) -> Result<DeviceCodeResponse> {
438 let url = format!("https://{domain}/login/device/code");
439 let response = client
440 .post(&url)
441 .header("Accept", "application/json")
442 .header("Content-Type", "application/json")
443 .header("User-Agent", user_agent)
444 .json(&json!({
445 "client_id": client_id,
446 "scope": "read:user",
447 }))
448 .send()
449 .await
450 .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
451
452 let status = response.status();
453 if !status.is_success() {
454 let body = response.text().await.unwrap_or_default();
455 anyhow::bail!(
456 "Failed to initiate device authorization ({}): {}",
457 status,
458 truncate_body(&body)
459 );
460 }
461
462 let mut device: DeviceCodeResponse = response
463 .json()
464 .await
465 .context("Failed to parse device authorization response")?;
466 if device.interval.unwrap_or(0) == 0 {
467 device.interval = Some(5);
468 }
469 Ok(device)
470}
471
472async fn poll_for_access_token(
473 client: &Client,
474 domain: &str,
475 user_agent: &str,
476 client_id: &str,
477 device: &DeviceCodeResponse,
478) -> Result<String> {
479 let url = format!("https://{domain}/login/oauth/access_token");
480 let mut interval_secs = device.interval.unwrap_or(5).max(1);
481
482 loop {
483 let response = client
484 .post(&url)
485 .header("Accept", "application/json")
486 .header("Content-Type", "application/json")
487 .header("User-Agent", user_agent)
488 .json(&json!({
489 "client_id": client_id,
490 "device_code": device.device_code,
491 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
492 }))
493 .send()
494 .await
495 .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
496
497 let status = response.status();
498 if !status.is_success() {
499 let body = response.text().await.unwrap_or_default();
500 anyhow::bail!(
501 "Failed to exchange device code for access token ({}): {}",
502 status,
503 truncate_body(&body)
504 );
505 }
506
507 let payload: AccessTokenResponse = response
508 .json()
509 .await
510 .context("Failed to parse OAuth token response")?;
511
512 if let Some(token) = payload.access_token {
513 if !token.trim().is_empty() {
514 return Ok(token);
515 }
516 }
517
518 match payload.error.as_deref() {
519 Some("authorization_pending") => sleep_with_margin(interval_secs).await,
520 Some("slow_down") => {
521 interval_secs = payload
522 .interval
523 .filter(|value| *value > 0)
524 .unwrap_or(interval_secs + 5);
525 sleep_with_margin(interval_secs).await;
526 }
527 Some(error) => {
528 let description = payload
529 .error_description
530 .unwrap_or_else(|| "No error description provided".to_string());
531 anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
532 }
533 None => sleep_with_margin(interval_secs).await,
534 }
535 }
536}
537
538fn resolve_client_id(client_id: Option<String>) -> Result<String> {
539 let id = client_id
540 .map(|value| value.trim().to_string())
541 .filter(|value| !value.is_empty())
542 .ok_or_else(|| {
543 anyhow::anyhow!(
544 "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
545 )
546 })?;
547
548 Ok(id)
549}
550
551async fn sleep_with_margin(interval_secs: u64) {
552 sleep(Duration::from_millis(
553 interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
554 ))
555 .await;
556}
557
558fn truncate_body(body: &str) -> String {
559 const MAX_LEN: usize = 300;
560 if body.len() <= MAX_LEN {
561 body.to_string()
562 } else {
563 format!("{}...", &body[..MAX_LEN])
564 }
565}