Skip to main content

codetether_agent/cli/
auth.rs

1//! Provider authentication commands.
2
3use super::{AuthArgs, AuthCommand, CopilotAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use tokio::time::{Duration, sleep};
12
13const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
14const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
15
16#[derive(Debug, Deserialize)]
17struct DeviceCodeResponse {
18    device_code: String,
19    user_code: String,
20    verification_uri: String,
21    #[serde(default)]
22    verification_uri_complete: Option<String>,
23    #[serde(default)]
24    interval: Option<u64>,
25}
26
27#[derive(Debug, Deserialize)]
28struct AccessTokenResponse {
29    #[serde(default)]
30    access_token: Option<String>,
31    #[serde(default)]
32    error: Option<String>,
33    #[serde(default)]
34    error_description: Option<String>,
35    #[serde(default)]
36    interval: Option<u64>,
37}
38
39pub async fn execute(args: AuthArgs) -> Result<()> {
40    match args.command {
41        AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
42    }
43}
44
45async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
46    if secrets::secrets_manager().is_none() {
47        anyhow::bail!(
48            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
49        );
50    }
51
52    let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
53        Some(raw) => {
54            let domain = normalize_enterprise_domain(&raw);
55            if domain.is_empty() {
56                anyhow::bail!("--enterprise-url cannot be empty");
57            }
58            ("github-copilot-enterprise", domain.clone(), Some(domain))
59        }
60        None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
61    };
62
63    let client = Client::new();
64    let client_id = resolve_client_id(args.client_id)?;
65    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
66    let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
67
68    println!("GitHub Copilot device authentication");
69    println!(
70        "Open this URL: {}",
71        device
72            .verification_uri_complete
73            .as_deref()
74            .unwrap_or(&device.verification_uri)
75    );
76    println!("Enter code: {}", device.user_code);
77    println!("Waiting for authorization...");
78
79    let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
80
81    let mut extra = HashMap::new();
82    if let Some(enterprise_url) = enterprise_domain {
83        extra.insert(
84            "enterpriseUrl".to_string(),
85            serde_json::Value::String(enterprise_url),
86        );
87    }
88
89    let provider_secrets = ProviderSecrets {
90        api_key: Some(token),
91        base_url: None,
92        organization: None,
93        headers: None,
94        extra,
95    };
96
97    secrets::set_provider_secrets(provider_id, &provider_secrets)
98        .await
99        .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
100
101    println!("Saved {} credentials to HashiCorp Vault.", provider_id);
102    Ok(())
103}
104
105async fn request_device_code(
106    client: &Client,
107    domain: &str,
108    user_agent: &str,
109    client_id: &str,
110) -> Result<DeviceCodeResponse> {
111    let url = format!("https://{domain}/login/device/code");
112    let response = client
113        .post(&url)
114        .header("Accept", "application/json")
115        .header("Content-Type", "application/json")
116        .header("User-Agent", user_agent)
117        .json(&json!({
118            "client_id": client_id,
119            "scope": "read:user",
120        }))
121        .send()
122        .await
123        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
124
125    let status = response.status();
126    if !status.is_success() {
127        let body = response.text().await.unwrap_or_default();
128        anyhow::bail!(
129            "Failed to initiate device authorization ({}): {}",
130            status,
131            truncate_body(&body)
132        );
133    }
134
135    let mut device: DeviceCodeResponse = response
136        .json()
137        .await
138        .context("Failed to parse device authorization response")?;
139    if device.interval.unwrap_or(0) == 0 {
140        device.interval = Some(5);
141    }
142    Ok(device)
143}
144
145async fn poll_for_access_token(
146    client: &Client,
147    domain: &str,
148    user_agent: &str,
149    client_id: &str,
150    device: &DeviceCodeResponse,
151) -> Result<String> {
152    let url = format!("https://{domain}/login/oauth/access_token");
153    let mut interval_secs = device.interval.unwrap_or(5).max(1);
154
155    loop {
156        let response = client
157            .post(&url)
158            .header("Accept", "application/json")
159            .header("Content-Type", "application/json")
160            .header("User-Agent", user_agent)
161            .json(&json!({
162                "client_id": client_id,
163                "device_code": device.device_code,
164                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
165            }))
166            .send()
167            .await
168            .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
169
170        let status = response.status();
171        if !status.is_success() {
172            let body = response.text().await.unwrap_or_default();
173            anyhow::bail!(
174                "Failed to exchange device code for access token ({}): {}",
175                status,
176                truncate_body(&body)
177            );
178        }
179
180        let payload: AccessTokenResponse = response
181            .json()
182            .await
183            .context("Failed to parse OAuth token response")?;
184
185        if let Some(token) = payload.access_token {
186            if !token.trim().is_empty() {
187                return Ok(token);
188            }
189        }
190
191        match payload.error.as_deref() {
192            Some("authorization_pending") => sleep_with_margin(interval_secs).await,
193            Some("slow_down") => {
194                interval_secs = payload
195                    .interval
196                    .filter(|value| *value > 0)
197                    .unwrap_or(interval_secs + 5);
198                sleep_with_margin(interval_secs).await;
199            }
200            Some(error) => {
201                let description = payload
202                    .error_description
203                    .unwrap_or_else(|| "No error description provided".to_string());
204                anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
205            }
206            None => sleep_with_margin(interval_secs).await,
207        }
208    }
209}
210
211fn resolve_client_id(client_id: Option<String>) -> Result<String> {
212    let id = client_id
213        .map(|value| value.trim().to_string())
214        .filter(|value| !value.is_empty())
215        .ok_or_else(|| {
216            anyhow::anyhow!(
217                "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
218            )
219        })?;
220
221    Ok(id)
222}
223
224async fn sleep_with_margin(interval_secs: u64) {
225    sleep(Duration::from_millis(
226        interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
227    ))
228    .await;
229}
230
231fn truncate_body(body: &str) -> String {
232    const MAX_LEN: usize = 300;
233    if body.len() <= MAX_LEN {
234        body.to_string()
235    } else {
236        format!("{}...", &body[..MAX_LEN])
237    }
238}