1use secrecy::SecretString;
2use tracing::{debug, info, warn, Instrument};
3
4use crate::error::MemoryError;
5
6pub mod github;
8pub use github::GitHubDeviceFlow;
9
10pub trait DeviceFlowProvider: Send + Sync {
19 fn client_id(&self) -> &str;
21 fn device_code_url(&self) -> &str;
23 fn access_token_url(&self) -> &str;
25 fn scopes(&self) -> &[&str];
27 fn validate(&self) -> Result<(), MemoryError>;
29}
30
31pub(crate) fn validate_endpoint_url(url: &str, field_name: &str) -> Result<(), MemoryError> {
37 let parsed = reqwest::Url::parse(url)
38 .map_err(|e| MemoryError::OAuth(format!("invalid {field_name} URL: {e}")))?;
39 match parsed.scheme() {
40 "https" => Ok(()),
41 "http" if matches!(parsed.host_str(), Some("localhost" | "127.0.0.1" | "[::1]")) => Ok(()),
42 _ => Err(MemoryError::OAuth(format!(
43 "{field_name} must use HTTPS (got {url})"
44 ))),
45 }
46}
47
48#[derive(serde::Deserialize)]
53struct DeviceCodeResponse {
54 device_code: String,
55 user_code: String,
56 verification_uri: String,
57 expires_in: u64,
58 interval: u64,
59}
60
61#[derive(serde::Deserialize)]
62struct AccessTokenResponse {
63 #[serde(default)]
64 access_token: Option<String>,
65 #[serde(default)]
66 error: Option<String>,
67 #[serde(default)]
68 error_description: Option<String>,
69}
70
71pub async fn device_flow_login(
79 provider: &dyn DeviceFlowProvider,
80 store: Option<super::StoreBackend>,
81 #[cfg(feature = "k8s")] k8s_config: Option<super::K8sSecretConfig>,
82) -> Result<(), MemoryError> {
83 use std::time::{Duration, Instant};
84 use tokio::time::sleep;
85
86 let provider_label = reqwest::Url::parse(provider.device_code_url())
89 .ok()
90 .and_then(|u| u.host_str().map(str::to_owned))
91 .unwrap_or_else(|| provider.device_code_url().to_owned());
92
93 let span = tracing::info_span!(
94 "auth.device_flow_login",
95 provider = %provider_label,
96 scopes = %provider.scopes().join(" "),
97 poll_count = tracing::field::Empty,
98 elapsed_ms = tracing::field::Empty,
99 outcome = tracing::field::Empty,
100 );
101 let start = Instant::now();
102
103 let result = async {
104 provider.validate()?;
105
106 let client = reqwest::Client::builder()
107 .connect_timeout(Duration::from_secs(10))
108 .timeout(Duration::from_secs(30))
109 .build()
110 .map_err(|e| MemoryError::OAuth(format!("failed to build HTTP client: {e}")))?;
111
112 let scope = provider.scopes().join(" ");
113
114 debug!(
116 url = provider.device_code_url(),
117 "auth.device_flow: requesting device code"
118 );
119 let device_resp = async {
120 client
121 .post(provider.device_code_url())
122 .header("Accept", "application/json")
123 .form(&[("client_id", provider.client_id()), ("scope", &scope)])
124 .send()
125 .await
126 .map_err(|e| {
127 MemoryError::OAuth(format!("failed to contact device code endpoint: {e}"))
128 })?
129 .error_for_status()
130 .map_err(|e| MemoryError::OAuth(format!("device code request failed: {e}")))?
131 .json::<DeviceCodeResponse>()
132 .await
133 .map_err(|e| {
134 MemoryError::OAuth(format!("failed to parse device code response: {e}"))
135 })
136 }
137 .instrument(tracing::debug_span!("auth.device_flow.request_code"))
138 .await?;
139
140 let expires_in = device_resp.expires_in.min(1800);
142 let deadline = Instant::now() + Duration::from_secs(expires_in);
143
144 debug!(
145 expires_in,
146 verification_uri = %device_resp.verification_uri,
147 "auth.device_flow: device code obtained"
148 );
149
150 eprintln!();
152 eprintln!(" Open this URL in your browser:");
153 eprintln!(" {}", device_resp.verification_uri);
154 eprintln!();
155 eprintln!(" Enter this code when prompted:");
156 eprintln!(" {}", device_resp.user_code);
157 eprintln!();
158 eprintln!(" Waiting for authorization...");
159
160 let mut poll_interval = device_resp.interval.clamp(1, 30);
162 let mut poll_count: u32 = 0;
163 let token = loop {
164 if Instant::now() >= deadline {
165 tracing::Span::current().record("poll_count", poll_count);
166 warn!(
167 poll_count,
168 expires_in, "auth.device_flow: device code expired"
169 );
170 return Err(MemoryError::OAuth(format!(
171 "Device code expired after {expires_in} seconds"
172 )));
173 }
174
175 sleep(Duration::from_secs(poll_interval)).await;
176 poll_count += 1;
177
178 debug!(
179 poll = poll_count,
180 interval_secs = poll_interval,
181 "auth.device_flow: polling token endpoint"
182 );
183
184 let resp = client
185 .post(provider.access_token_url())
186 .header("Accept", "application/json")
187 .form(&[
188 ("client_id", provider.client_id()),
189 ("device_code", device_resp.device_code.as_str()),
190 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
191 ])
192 .send()
193 .await
194 .map_err(|e| MemoryError::OAuth(format!("polling token endpoint failed: {e}")))?
195 .error_for_status()
196 .map_err(|e| {
197 MemoryError::OAuth(format!("token request returned error status: {e}"))
198 })?
199 .json::<AccessTokenResponse>()
200 .await
201 .map_err(|e| MemoryError::OAuth(format!("failed to parse token response: {e}")))?;
202
203 if let Some(tok) = resp.access_token.filter(|t| !t.trim().is_empty()) {
204 break SecretString::from(tok);
205 }
206
207 match resp.error.as_deref() {
208 Some("authorization_pending") => {
209 debug!(poll = poll_count, "auth.device_flow: authorization pending");
210 continue;
211 }
212 Some("slow_down") => {
213 poll_interval = (poll_interval + 5).min(60);
214 debug!(
215 poll = poll_count,
216 new_interval_secs = poll_interval,
217 "auth.device_flow: slow_down received, backing off"
218 );
219 continue;
220 }
221 Some("expired_token") => {
222 tracing::Span::current().record("poll_count", poll_count);
223 warn!(
224 poll_count,
225 "auth.device_flow: device code expired during poll"
226 );
227 return Err(MemoryError::OAuth(
228 "device code expired; please run `memory-mcp auth login` again".to_string(),
229 ));
230 }
231 Some("access_denied") => {
232 tracing::Span::current().record("poll_count", poll_count);
233 warn!(poll_count, "auth.device_flow: access denied by user");
234 return Err(MemoryError::OAuth(
235 "authorization denied by user".to_string(),
236 ));
237 }
238 Some(other) => {
239 let desc = resp
240 .error_description
241 .as_deref()
242 .unwrap_or("no description");
243 tracing::Span::current().record("poll_count", poll_count);
244 warn!(
245 poll_count,
246 error = other,
247 description = desc,
248 "auth.device_flow: unexpected OAuth error"
249 );
250 return Err(MemoryError::OAuth(format!(
251 "unexpected OAuth error '{other}': {desc}"
252 )));
253 }
254 None => {
255 tracing::Span::current().record("poll_count", poll_count);
256 warn!(
257 poll_count,
258 "auth.device_flow: server returned neither access_token nor error"
259 );
260 return Err(MemoryError::OAuth(
261 "server returned neither an access_token nor an error field; \
262 unexpected response"
263 .to_string(),
264 ));
265 }
266 }
267 };
268
269 tracing::Span::current().record("poll_count", poll_count);
270 info!(poll_count, "auth.device_flow: token obtained successfully");
271
272 super::store_token(
274 &token,
275 store,
276 #[cfg(feature = "k8s")]
277 k8s_config,
278 )
279 .await?;
280 eprintln!("Authentication successful.");
281
282 Ok(())
283 }
284 .instrument(span.clone())
285 .await;
286
287 let elapsed_ms = start.elapsed().as_millis() as u64;
288 let outcome = if result.is_ok() { "success" } else { "error" };
289 span.record("elapsed_ms", elapsed_ms);
290 span.record("outcome", outcome);
291
292 result
293}
294
295#[cfg(test)]
300mod tests {
301 use super::*;
302
303 struct MockDeviceFlow {
304 client_id: &'static str,
305 device_code_url: &'static str,
306 access_token_url: &'static str,
307 scopes: &'static [&'static str],
308 }
309
310 impl DeviceFlowProvider for MockDeviceFlow {
311 fn client_id(&self) -> &str {
312 self.client_id
313 }
314 fn device_code_url(&self) -> &str {
315 self.device_code_url
316 }
317 fn access_token_url(&self) -> &str {
318 self.access_token_url
319 }
320 fn scopes(&self) -> &[&str] {
321 self.scopes
322 }
323 fn validate(&self) -> Result<(), MemoryError> {
324 if self.client_id.is_empty() {
325 return Err(MemoryError::OAuth("client ID must not be empty".into()));
326 }
327 if self.client_id.len() < 4 || self.client_id.len() > 64 {
328 return Err(MemoryError::OAuth(format!(
329 "client ID has unexpected length ({})",
330 self.client_id.len()
331 )));
332 }
333 validate_endpoint_url(self.device_code_url, "device_code_url")?;
334 validate_endpoint_url(self.access_token_url, "access_token_url")?;
335 Ok(())
336 }
337 }
338
339 fn valid_mock() -> MockDeviceFlow {
340 MockDeviceFlow {
341 client_id: "test-client-id",
342 device_code_url: "https://example.com/device/code",
343 access_token_url: "https://example.com/oauth/token",
344 scopes: &["repo"],
345 }
346 }
347
348 #[test]
350 fn github_provider_returns_expected_values() {
351 let p = GitHubDeviceFlow;
352 assert_eq!(p.client_id(), "Ov23liWxHYkwXTxCrYHp");
353 assert_eq!(p.device_code_url(), "https://github.com/login/device/code");
354 assert_eq!(
355 p.access_token_url(),
356 "https://github.com/login/oauth/access_token"
357 );
358 assert_eq!(p.scopes(), &["repo"]);
359 }
360
361 #[allow(dead_code)]
363 async fn accepts_trait_object(provider: &dyn DeviceFlowProvider) {
364 let _ = device_flow_login(
365 provider,
366 None,
367 #[cfg(feature = "k8s")]
368 None,
369 )
370 .await;
371 }
372
373 #[test]
375 fn github_provider_validates_ok() {
376 assert!(GitHubDeviceFlow.validate().is_ok());
377 }
378
379 #[test]
381 fn empty_client_id_fails_validation() {
382 let mock = MockDeviceFlow {
383 client_id: "",
384 ..valid_mock()
385 };
386 let err = mock.validate().unwrap_err();
387 assert!(err.to_string().contains("client ID"), "got: {err}");
388 }
389
390 #[test]
392 fn malformed_github_client_id_fails_validation() {
393 assert!(github::validate_github_client_id("").is_err());
394 assert!(github::validate_github_client_id("x").is_err());
395 assert!(github::validate_github_client_id("Ov23liWxHYkwXTxCrYHp").is_ok());
396 }
397
398 #[test]
400 fn http_url_fails_validation() {
401 let mock = MockDeviceFlow {
402 device_code_url: "http://example.com/device/code",
403 ..valid_mock()
404 };
405 assert!(mock.validate().is_err());
406 }
407
408 #[test]
410 fn http_localhost_passes_validation() {
411 let mock = MockDeviceFlow {
412 device_code_url: "http://localhost/device/code",
413 access_token_url: "http://localhost/oauth/token",
414 ..valid_mock()
415 };
416 assert!(mock.validate().is_ok());
417 }
418
419 #[test]
421 fn https_urls_pass_validation() {
422 assert!(valid_mock().validate().is_ok());
423 }
424
425 #[test]
427 fn http_ipv6_localhost_passes_validation() {
428 let mock = MockDeviceFlow {
429 device_code_url: "http://[::1]/device/code",
430 access_token_url: "http://[::1]/oauth/token",
431 ..valid_mock()
432 };
433 assert!(mock.validate().is_ok());
434 }
435
436 #[test]
438 fn http_ipv6_non_loopback_fails_validation() {
439 let mock = MockDeviceFlow {
440 device_code_url: "http://[::2]/device/code",
441 ..valid_mock()
442 };
443 assert!(mock.validate().is_err());
444 }
445
446 #[test]
448 fn http_127_0_0_1_passes_validation() {
449 let mock = MockDeviceFlow {
450 device_code_url: "http://127.0.0.1/device/code",
451 access_token_url: "http://127.0.0.1/oauth/token",
452 ..valid_mock()
453 };
454 assert!(mock.validate().is_ok());
455 }
456
457 #[tokio::test]
459 #[ignore = "requires real OAuth interaction"]
460 async fn device_flow_login_ignored_in_ci() {
461 device_flow_login(
462 &GitHubDeviceFlow,
463 Some(super::super::StoreBackend::Stdout),
464 #[cfg(feature = "k8s")]
465 None,
466 )
467 .await
468 .expect("device flow should succeed");
469 }
470}