codetether_agent/cli/
auth.rs1use super::{AuthArgs, AuthCommand, CopilotAuthArgs};
4use crate::provider::copilot::normalize_enterprise_domain;
5use crate::secrets::{self, ProviderSecrets};
6use anyhow::{Context, Result};
7use reqwest::Client;
8use serde::Deserialize;
9use serde_json::json;
10use std::collections::HashMap;
11use tokio::time::{Duration, sleep};
12
13const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
14const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
15
16#[derive(Debug, Deserialize)]
17struct DeviceCodeResponse {
18 device_code: String,
19 user_code: String,
20 verification_uri: String,
21 #[serde(default)]
22 verification_uri_complete: Option<String>,
23 #[serde(default)]
24 interval: Option<u64>,
25}
26
27#[derive(Debug, Deserialize)]
28struct AccessTokenResponse {
29 #[serde(default)]
30 access_token: Option<String>,
31 #[serde(default)]
32 error: Option<String>,
33 #[serde(default)]
34 error_description: Option<String>,
35 #[serde(default)]
36 interval: Option<u64>,
37}
38
39pub async fn execute(args: AuthArgs) -> Result<()> {
40 match args.command {
41 AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
42 }
43}
44
45async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
46 if secrets::secrets_manager().is_none() {
47 anyhow::bail!(
48 "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
49 );
50 }
51
52 let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
53 Some(raw) => {
54 let domain = normalize_enterprise_domain(&raw);
55 if domain.is_empty() {
56 anyhow::bail!("--enterprise-url cannot be empty");
57 }
58 ("github-copilot-enterprise", domain.clone(), Some(domain))
59 }
60 None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
61 };
62
63 let client = Client::new();
64 let client_id = resolve_client_id(args.client_id)?;
65 let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
66 let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
67
68 println!("GitHub Copilot device authentication");
69 println!(
70 "Open this URL: {}",
71 device
72 .verification_uri_complete
73 .as_deref()
74 .unwrap_or(&device.verification_uri)
75 );
76 println!("Enter code: {}", device.user_code);
77 println!("Waiting for authorization...");
78
79 let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
80
81 let mut extra = HashMap::new();
82 if let Some(enterprise_url) = enterprise_domain {
83 extra.insert(
84 "enterpriseUrl".to_string(),
85 serde_json::Value::String(enterprise_url),
86 );
87 }
88
89 let provider_secrets = ProviderSecrets {
90 api_key: Some(token),
91 base_url: None,
92 organization: None,
93 headers: None,
94 extra,
95 };
96
97 secrets::set_provider_secrets(provider_id, &provider_secrets)
98 .await
99 .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
100
101 println!("Saved {} credentials to HashiCorp Vault.", provider_id);
102 Ok(())
103}
104
105async fn request_device_code(
106 client: &Client,
107 domain: &str,
108 user_agent: &str,
109 client_id: &str,
110) -> Result<DeviceCodeResponse> {
111 let url = format!("https://{domain}/login/device/code");
112 let response = client
113 .post(&url)
114 .header("Accept", "application/json")
115 .header("Content-Type", "application/json")
116 .header("User-Agent", user_agent)
117 .json(&json!({
118 "client_id": client_id,
119 "scope": "read:user",
120 }))
121 .send()
122 .await
123 .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
124
125 let status = response.status();
126 if !status.is_success() {
127 let body = response.text().await.unwrap_or_default();
128 anyhow::bail!(
129 "Failed to initiate device authorization ({}): {}",
130 status,
131 truncate_body(&body)
132 );
133 }
134
135 let mut device: DeviceCodeResponse = response
136 .json()
137 .await
138 .context("Failed to parse device authorization response")?;
139 if device.interval.unwrap_or(0) == 0 {
140 device.interval = Some(5);
141 }
142 Ok(device)
143}
144
145async fn poll_for_access_token(
146 client: &Client,
147 domain: &str,
148 user_agent: &str,
149 client_id: &str,
150 device: &DeviceCodeResponse,
151) -> Result<String> {
152 let url = format!("https://{domain}/login/oauth/access_token");
153 let mut interval_secs = device.interval.unwrap_or(5).max(1);
154
155 loop {
156 let response = client
157 .post(&url)
158 .header("Accept", "application/json")
159 .header("Content-Type", "application/json")
160 .header("User-Agent", user_agent)
161 .json(&json!({
162 "client_id": client_id,
163 "device_code": device.device_code,
164 "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
165 }))
166 .send()
167 .await
168 .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
169
170 let status = response.status();
171 if !status.is_success() {
172 let body = response.text().await.unwrap_or_default();
173 anyhow::bail!(
174 "Failed to exchange device code for access token ({}): {}",
175 status,
176 truncate_body(&body)
177 );
178 }
179
180 let payload: AccessTokenResponse = response
181 .json()
182 .await
183 .context("Failed to parse OAuth token response")?;
184
185 if let Some(token) = payload.access_token {
186 if !token.trim().is_empty() {
187 return Ok(token);
188 }
189 }
190
191 match payload.error.as_deref() {
192 Some("authorization_pending") => sleep_with_margin(interval_secs).await,
193 Some("slow_down") => {
194 interval_secs = payload
195 .interval
196 .filter(|value| *value > 0)
197 .unwrap_or(interval_secs + 5);
198 sleep_with_margin(interval_secs).await;
199 }
200 Some(error) => {
201 let description = payload
202 .error_description
203 .unwrap_or_else(|| "No error description provided".to_string());
204 anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
205 }
206 None => sleep_with_margin(interval_secs).await,
207 }
208 }
209}
210
211fn resolve_client_id(client_id: Option<String>) -> Result<String> {
212 let id = client_id
213 .map(|value| value.trim().to_string())
214 .filter(|value| !value.is_empty())
215 .ok_or_else(|| {
216 anyhow::anyhow!(
217 "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
218 )
219 })?;
220
221 Ok(id)
222}
223
224async fn sleep_with_margin(interval_secs: u64) {
225 sleep(Duration::from_millis(
226 interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
227 ))
228 .await;
229}
230
231fn truncate_body(body: &str) -> String {
232 const MAX_LEN: usize = 300;
233 if body.len() <= MAX_LEN {
234 body.to_string()
235 } else {
236 format!("{}...", &body[..MAX_LEN])
237 }
238}