1use super::{AuthArgs, AuthCommand, CopilotAuthArgs, LoginAuthArgs, RegisterAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use std::io::{self, Write};
12use std::path::PathBuf;
13use tokio::time::{Duration, sleep};
14
15const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
16const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
17
18#[derive(Debug, Deserialize)]
19struct DeviceCodeResponse {
20 device_code: String,
21 user_code: String,
22 verification_uri: String,
23 #[serde(default)]
24 verification_uri_complete: Option<String>,
25 #[serde(default)]
26 interval: Option<u64>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AccessTokenResponse {
31 #[serde(default)]
32 access_token: Option<String>,
33 #[serde(default)]
34 error: Option<String>,
35 #[serde(default)]
36 error_description: Option<String>,
37 #[serde(default)]
38 interval: Option<u64>,
39}
40
41pub async fn execute(args: AuthArgs) -> Result<()> {
42 match args.command {
43 AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
44 AuthCommand::Register(register_args) => authenticate_register(register_args).await,
45 AuthCommand::Login(login_args) => authenticate_login(login_args).await,
46 }
47}
48
49#[derive(Debug, Deserialize)]
50struct LoginResponsePayload {
51 access_token: String,
52 expires_at: String,
53 user: serde_json::Value,
54}
55
56async fn login_with_password(
57 client: &Client,
58 server_url: &str,
59 email: &str,
60 password: &str,
61) -> Result<LoginResponsePayload> {
62 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
63
64 let resp = client
65 .post(format!("{}/v1/users/login", server_url))
66 .header("User-Agent", &user_agent)
67 .header("Content-Type", "application/json")
68 .json(&json!({
69 "email": email,
70 "password": password,
71 }))
72 .send()
73 .await
74 .context("Failed to connect to CodeTether server")?;
75
76 if !resp.status().is_success() {
77 let status = resp.status();
78 let body: serde_json::Value = resp.json().await.unwrap_or_default();
79 let detail = body
80 .get("detail")
81 .and_then(|v| v.as_str())
82 .unwrap_or("Authentication failed");
83 anyhow::bail!("Login failed ({}): {}", status, detail);
84 }
85
86 let login: LoginResponsePayload = resp
87 .json()
88 .await
89 .context("Failed to parse login response")?;
90
91 Ok(login)
92}
93
94fn write_saved_credentials(
95 server_url: &str,
96 email: &str,
97 login: &LoginResponsePayload,
98) -> Result<PathBuf> {
99 let cred_path = credential_file_path()?;
101 if let Some(parent) = cred_path.parent() {
102 std::fs::create_dir_all(parent)
103 .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
104 }
105
106 let creds = json!({
107 "server": server_url,
108 "access_token": login.access_token,
109 "expires_at": login.expires_at,
110 "email": email,
111 });
112
113 #[cfg(unix)]
115 {
116 use std::os::unix::fs::OpenOptionsExt;
117 let file = std::fs::OpenOptions::new()
118 .write(true)
119 .create(true)
120 .truncate(true)
121 .mode(0o600)
122 .open(&cred_path)
123 .with_context(|| {
124 format!("Failed to write credentials to {}", cred_path.display())
125 })?;
126 serde_json::to_writer_pretty(file, &creds)?;
127 }
128 #[cfg(not(unix))]
129 {
130 let file = std::fs::File::create(&cred_path)
131 .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
132 serde_json::to_writer_pretty(file, &creds)?;
133 }
134
135 Ok(cred_path)
136}
137
138async fn authenticate_register(args: RegisterAuthArgs) -> Result<()> {
139 #[derive(Debug, Deserialize)]
140 struct RegisterResponse {
141 user_id: String,
142 email: String,
143 message: String,
144 #[serde(default)]
145 instance_url: Option<String>,
146 #[serde(default)]
147 instance_namespace: Option<String>,
148 #[serde(default)]
149 provisioning_status: Option<String>,
150 }
151
152 let server_url = args.server.trim_end_matches('/').to_string();
153
154 let email = match args.email {
155 Some(e) => e,
156 None => {
157 print!("Email: ");
158 io::stdout().flush()?;
159 let mut email = String::new();
160 io::stdin().read_line(&mut email)?;
161 email.trim().to_string()
162 }
163 };
164
165 if email.is_empty() {
166 anyhow::bail!("Email is required");
167 }
168
169 let password = rpassword_prompt("Password (min 8 chars): ")?;
170 if password.trim().len() < 8 {
171 anyhow::bail!("Password must be at least 8 characters");
172 }
173 let confirm = rpassword_prompt("Confirm password: ")?;
174 if password != confirm {
175 anyhow::bail!("Passwords do not match");
176 }
177
178 println!("Registering with {}...", server_url);
179
180 let client = Client::new();
181 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
182
183 let resp = client
184 .post(format!("{}/v1/users/register", server_url))
185 .header("User-Agent", &user_agent)
186 .header("Content-Type", "application/json")
187 .json(&json!({
188 "email": email,
189 "password": password,
190 "first_name": args.first_name,
191 "last_name": args.last_name,
192 "referral_source": args.referral_source,
193 }))
194 .send()
195 .await
196 .context("Failed to connect to CodeTether server")?;
197
198 if !resp.status().is_success() {
199 let status = resp.status();
200 let body: serde_json::Value = resp.json().await.unwrap_or_default();
201 let detail = body
202 .get("detail")
203 .and_then(|v| v.as_str())
204 .unwrap_or("Registration failed");
205 anyhow::bail!("Registration failed ({}): {}", status, detail);
206 }
207
208 let reg: RegisterResponse = resp
209 .json()
210 .await
211 .context("Failed to parse registration response")?;
212
213 println!("Account created for {} (user_id={})", reg.email, reg.user_id);
214 println!("{}", reg.message);
215 if let Some(status) = reg.provisioning_status.as_deref() {
216 println!("Provisioning status: {}", status);
217 }
218 if let Some(url) = reg.instance_url.as_deref() {
219 println!("Instance URL: {}", url);
220 }
221 if let Some(ns) = reg.instance_namespace.as_deref() {
222 println!("Instance namespace: {}", ns);
223 }
224
225 println!("Logging in...");
227 let login = login_with_password(&client, &server_url, ®.email, &password).await?;
228 let cred_path = write_saved_credentials(&server_url, ®.email, &login)?;
229
230 let user_email = login
231 .user
232 .get("email")
233 .and_then(|v| v.as_str())
234 .unwrap_or(®.email);
235
236 println!("Logged in as {} (expires {})", user_email, login.expires_at);
237 println!("Credentials saved to {}", cred_path.display());
238 println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
239
240 Ok(())
241}
242
243async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
244 let server_url = args.server.trim_end_matches('/').to_string();
245
246 let email = match args.email {
248 Some(e) => e,
249 None => {
250 print!("Email: ");
251 io::stdout().flush()?;
252 let mut email = String::new();
253 io::stdin().read_line(&mut email)?;
254 email.trim().to_string()
255 }
256 };
257
258 if email.is_empty() {
259 anyhow::bail!("Email is required");
260 }
261
262 let password = rpassword_prompt("Password: ")?;
264 if password.is_empty() {
265 anyhow::bail!("Password is required");
266 }
267
268 println!("Authenticating with {}...", server_url);
269
270 let client = Client::new();
271
272 let login = login_with_password(&client, &server_url, &email, &password).await?;
273 let cred_path = write_saved_credentials(&server_url, &email, &login)?;
274
275 let user_email = login
276 .user
277 .get("email")
278 .and_then(|v| v.as_str())
279 .unwrap_or(&email);
280
281 println!("Logged in as {} (expires {})", user_email, login.expires_at);
282 println!("Credentials saved to {}", cred_path.display());
283 println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
284
285 Ok(())
286}
287
288fn rpassword_prompt(prompt: &str) -> Result<String> {
290 print!("{}", prompt);
291 io::stdout().flush()?;
292
293 #[cfg(unix)]
295 {
296 use std::io::BufRead;
297 let fd = 0; let orig = unsafe {
300 let mut termios = std::mem::zeroed::<libc::termios>();
301 libc::tcgetattr(fd, &mut termios);
302 termios
303 };
304
305 unsafe {
307 let mut termios = orig;
308 termios.c_lflag &= !libc::ECHO;
309 libc::tcsetattr(fd, libc::TCSANOW, &termios);
310 }
311
312 let mut password = String::new();
313 let result = io::stdin().lock().read_line(&mut password);
314
315 unsafe {
317 libc::tcsetattr(fd, libc::TCSANOW, &orig);
318 }
319 println!(); result?;
322 Ok(password.trim().to_string())
323 }
324
325 #[cfg(not(unix))]
326 {
327 let mut password = String::new();
328 io::stdin().read_line(&mut password)?;
329 Ok(password.trim().to_string())
330 }
331}
332
333fn credential_file_path() -> Result<std::path::PathBuf> {
335 use directories::ProjectDirs;
336 let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
337 .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
338 Ok(dirs.config_dir().join("credentials.json"))
339}
340
341#[derive(Debug, Deserialize)]
343pub struct SavedCredentials {
344 pub server: String,
345 pub access_token: String,
346 pub expires_at: String,
347 #[serde(default)]
348 pub email: String,
349}
350
351pub fn load_saved_credentials() -> Option<SavedCredentials> {
354 let path = credential_file_path().ok()?;
355 let data = std::fs::read_to_string(&path).ok()?;
356 let creds: SavedCredentials = serde_json::from_str(&data).ok()?;
357
358 if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(&creds.expires_at) {
360 if expires < chrono::Utc::now() {
361 tracing::warn!("Saved credentials have expired — run `codetether auth login` to refresh");
362 return None;
363 }
364 }
365
366 Some(creds)
367}
368
369async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
370 if secrets::secrets_manager().is_none() {
371 anyhow::bail!(
372 "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
373 );
374 }
375
376 let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
377 Some(raw) => {
378 let domain = normalize_enterprise_domain(&raw);
379 if domain.is_empty() {
380 anyhow::bail!("--enterprise-url cannot be empty");
381 }
382 ("github-copilot-enterprise", domain.clone(), Some(domain))
383 }
384 None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
385 };
386
387 let client = Client::new();
388 let client_id = resolve_client_id(args.client_id)?;
389 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
390 let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
391
392 println!("GitHub Copilot device authentication");
393 println!(
394 "Open this URL: {}",
395 device
396 .verification_uri_complete
397 .as_deref()
398 .unwrap_or(&device.verification_uri)
399 );
400 println!("Enter code: {}", device.user_code);
401 println!("Waiting for authorization...");
402
403 let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
404
405 let mut extra = HashMap::new();
406 if let Some(enterprise_url) = enterprise_domain {
407 extra.insert(
408 "enterpriseUrl".to_string(),
409 serde_json::Value::String(enterprise_url),
410 );
411 }
412
413 let provider_secrets = ProviderSecrets {
414 api_key: Some(token),
415 base_url: None,
416 organization: None,
417 headers: None,
418 extra,
419 };
420
421 secrets::set_provider_secrets(provider_id, &provider_secrets)
422 .await
423 .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
424
425 println!("Saved {} credentials to HashiCorp Vault.", provider_id);
426 Ok(())
427}
428
429async fn request_device_code(
430 client: &Client,
431 domain: &str,
432 user_agent: &str,
433 client_id: &str,
434) -> Result<DeviceCodeResponse> {
435 let url = format!("https://{domain}/login/device/code");
436 let response = client
437 .post(&url)
438 .header("Accept", "application/json")
439 .header("Content-Type", "application/json")
440 .header("User-Agent", user_agent)
441 .json(&json!({
442 "client_id": client_id,
443 "scope": "read:user",
444 }))
445 .send()
446 .await
447 .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
448
449 let status = response.status();
450 if !status.is_success() {
451 let body = response.text().await.unwrap_or_default();
452 anyhow::bail!(
453 "Failed to initiate device authorization ({}): {}",
454 status,
455 truncate_body(&body)
456 );
457 }
458
459 let mut device: DeviceCodeResponse = response
460 .json()
461 .await
462 .context("Failed to parse device authorization response")?;
463 if device.interval.unwrap_or(0) == 0 {
464 device.interval = Some(5);
465 }
466 Ok(device)
467}
468
469async fn poll_for_access_token(
470 client: &Client,
471 domain: &str,
472 user_agent: &str,
473 client_id: &str,
474 device: &DeviceCodeResponse,
475) -> Result<String> {
476 let url = format!("https://{domain}/login/oauth/access_token");
477 let mut interval_secs = device.interval.unwrap_or(5).max(1);
478
479 loop {
480 let response = client
481 .post(&url)
482 .header("Accept", "application/json")
483 .header("Content-Type", "application/json")
484 .header("User-Agent", user_agent)
485 .json(&json!({
486 "client_id": client_id,
487 "device_code": device.device_code,
488 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
489 }))
490 .send()
491 .await
492 .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
493
494 let status = response.status();
495 if !status.is_success() {
496 let body = response.text().await.unwrap_or_default();
497 anyhow::bail!(
498 "Failed to exchange device code for access token ({}): {}",
499 status,
500 truncate_body(&body)
501 );
502 }
503
504 let payload: AccessTokenResponse = response
505 .json()
506 .await
507 .context("Failed to parse OAuth token response")?;
508
509 if let Some(token) = payload.access_token {
510 if !token.trim().is_empty() {
511 return Ok(token);
512 }
513 }
514
515 match payload.error.as_deref() {
516 Some("authorization_pending") => sleep_with_margin(interval_secs).await,
517 Some("slow_down") => {
518 interval_secs = payload
519 .interval
520 .filter(|value| *value > 0)
521 .unwrap_or(interval_secs + 5);
522 sleep_with_margin(interval_secs).await;
523 }
524 Some(error) => {
525 let description = payload
526 .error_description
527 .unwrap_or_else(|| "No error description provided".to_string());
528 anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
529 }
530 None => sleep_with_margin(interval_secs).await,
531 }
532 }
533}
534
535fn resolve_client_id(client_id: Option<String>) -> Result<String> {
536 let id = client_id
537 .map(|value| value.trim().to_string())
538 .filter(|value| !value.is_empty())
539 .ok_or_else(|| {
540 anyhow::anyhow!(
541 "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
542 )
543 })?;
544
545 Ok(id)
546}
547
548async fn sleep_with_margin(interval_secs: u64) {
549 sleep(Duration::from_millis(
550 interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
551 ))
552 .await;
553}
554
555fn truncate_body(body: &str) -> String {
556 const MAX_LEN: usize = 300;
557 if body.len() <= MAX_LEN {
558 body.to_string()
559 } else {
560 format!("{}...", &body[..MAX_LEN])
561 }
562}