Skip to main content

memory_mcp/auth/oauth/
mod.rs

1use secrecy::SecretString;
2use tracing::{debug, info, warn, Instrument};
3
4use crate::error::MemoryError;
5
6/// GitHub device flow provider implementation.
7pub mod github;
8pub use github::GitHubDeviceFlow;
9
10// ---------------------------------------------------------------------------
11// DeviceFlowProvider trait (RFC 8628)
12// ---------------------------------------------------------------------------
13
14/// Abstraction over OAuth device authorization grant (RFC 8628) parameters.
15///
16/// Each provider supplies its own client ID, endpoint URLs, and scopes.
17/// Implementations must validate their own parameters via [`validate`](Self::validate).
18pub trait DeviceFlowProvider: Send + Sync {
19    /// Returns the OAuth client ID for this provider.
20    fn client_id(&self) -> &str;
21    /// Returns the device code endpoint URL.
22    fn device_code_url(&self) -> &str;
23    /// Returns the access token endpoint URL.
24    fn access_token_url(&self) -> &str;
25    /// Returns the list of OAuth scopes to request.
26    fn scopes(&self) -> &[&str];
27    /// Validates provider configuration; returns an error if any value is invalid.
28    fn validate(&self) -> Result<(), MemoryError>;
29}
30
31// ---------------------------------------------------------------------------
32// URL validation helper
33// ---------------------------------------------------------------------------
34
35/// Validates that a URL uses HTTPS, with an exception for localhost (dev/testing).
36pub(crate) fn validate_endpoint_url(url: &str, field_name: &str) -> Result<(), MemoryError> {
37    let parsed = reqwest::Url::parse(url)
38        .map_err(|e| MemoryError::OAuth(format!("invalid {field_name} URL: {e}")))?;
39    match parsed.scheme() {
40        "https" => Ok(()),
41        "http" if matches!(parsed.host_str(), Some("localhost" | "127.0.0.1" | "[::1]")) => Ok(()),
42        _ => Err(MemoryError::OAuth(format!(
43            "{field_name} must use HTTPS (got {url})"
44        ))),
45    }
46}
47
48// ---------------------------------------------------------------------------
49// OAuth response types
50// ---------------------------------------------------------------------------
51
52#[derive(serde::Deserialize)]
53struct DeviceCodeResponse {
54    device_code: String,
55    user_code: String,
56    verification_uri: String,
57    expires_in: u64,
58    interval: u64,
59}
60
61#[derive(serde::Deserialize)]
62struct AccessTokenResponse {
63    #[serde(default)]
64    access_token: Option<String>,
65    #[serde(default)]
66    error: Option<String>,
67    #[serde(default)]
68    error_description: Option<String>,
69}
70
71// ---------------------------------------------------------------------------
72// Device flow login
73// ---------------------------------------------------------------------------
74
75/// Authenticate via the OAuth device flow and persist the token.
76///
77/// Prints user-facing prompts to stderr. Never logs the token value.
78pub async fn device_flow_login(
79    provider: &dyn DeviceFlowProvider,
80    store: Option<super::StoreBackend>,
81    #[cfg(feature = "k8s")] k8s_config: Option<super::K8sSecretConfig>,
82) -> Result<(), MemoryError> {
83    use std::time::{Duration, Instant};
84    use tokio::time::sleep;
85
86    // Derive a safe provider label from the host of the device code URL, falling
87    // back to the literal URL string. Never records client_id (could be secret).
88    let provider_label = reqwest::Url::parse(provider.device_code_url())
89        .ok()
90        .and_then(|u| u.host_str().map(str::to_owned))
91        .unwrap_or_else(|| provider.device_code_url().to_owned());
92
93    let span = tracing::info_span!(
94        "auth.device_flow_login",
95        provider = %provider_label,
96        scopes = %provider.scopes().join(" "),
97        poll_count = tracing::field::Empty,
98        elapsed_ms = tracing::field::Empty,
99        outcome = tracing::field::Empty,
100    );
101    let start = Instant::now();
102
103    let result = async {
104        provider.validate()?;
105
106        let client = reqwest::Client::builder()
107            .connect_timeout(Duration::from_secs(10))
108            .timeout(Duration::from_secs(30))
109            .build()
110            .map_err(|e| MemoryError::OAuth(format!("failed to build HTTP client: {e}")))?;
111
112        let scope = provider.scopes().join(" ");
113
114        // Step 1: Request a device code.
115        debug!(
116            url = provider.device_code_url(),
117            "auth.device_flow: requesting device code"
118        );
119        let device_resp = async {
120            client
121                .post(provider.device_code_url())
122                .header("Accept", "application/json")
123                .form(&[("client_id", provider.client_id()), ("scope", &scope)])
124                .send()
125                .await
126                .map_err(|e| {
127                    MemoryError::OAuth(format!("failed to contact device code endpoint: {e}"))
128                })?
129                .error_for_status()
130                .map_err(|e| MemoryError::OAuth(format!("device code request failed: {e}")))?
131                .json::<DeviceCodeResponse>()
132                .await
133                .map_err(|e| {
134                    MemoryError::OAuth(format!("failed to parse device code response: {e}"))
135                })
136        }
137        .instrument(tracing::debug_span!("auth.device_flow.request_code"))
138        .await?;
139
140        // Compute overall deadline from expires_in, capped at 30 minutes.
141        let expires_in = device_resp.expires_in.min(1800);
142        let deadline = Instant::now() + Duration::from_secs(expires_in);
143
144        debug!(
145            expires_in,
146            verification_uri = %device_resp.verification_uri,
147            "auth.device_flow: device code obtained"
148        );
149
150        // Step 2: Display instructions to the user.
151        eprintln!();
152        eprintln!("  Open this URL in your browser:");
153        eprintln!("    {}", device_resp.verification_uri);
154        eprintln!();
155        eprintln!("  Enter this code when prompted:");
156        eprintln!("    {}", device_resp.user_code);
157        eprintln!();
158        eprintln!("  Waiting for authorization...");
159
160        // Step 3: Poll for the access token.
161        let mut poll_interval = device_resp.interval.clamp(1, 30);
162        let mut poll_count: u32 = 0;
163        let token = loop {
164            if Instant::now() >= deadline {
165                tracing::Span::current().record("poll_count", poll_count);
166                warn!(
167                    poll_count,
168                    expires_in, "auth.device_flow: device code expired"
169                );
170                return Err(MemoryError::OAuth(format!(
171                    "Device code expired after {expires_in} seconds"
172                )));
173            }
174
175            sleep(Duration::from_secs(poll_interval)).await;
176            poll_count += 1;
177
178            debug!(
179                poll = poll_count,
180                interval_secs = poll_interval,
181                "auth.device_flow: polling token endpoint"
182            );
183
184            let resp = client
185                .post(provider.access_token_url())
186                .header("Accept", "application/json")
187                .form(&[
188                    ("client_id", provider.client_id()),
189                    ("device_code", device_resp.device_code.as_str()),
190                    ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
191                ])
192                .send()
193                .await
194                .map_err(|e| MemoryError::OAuth(format!("polling token endpoint failed: {e}")))?
195                .error_for_status()
196                .map_err(|e| {
197                    MemoryError::OAuth(format!("token request returned error status: {e}"))
198                })?
199                .json::<AccessTokenResponse>()
200                .await
201                .map_err(|e| MemoryError::OAuth(format!("failed to parse token response: {e}")))?;
202
203            if let Some(tok) = resp.access_token.filter(|t| !t.trim().is_empty()) {
204                break SecretString::from(tok);
205            }
206
207            match resp.error.as_deref() {
208                Some("authorization_pending") => {
209                    debug!(poll = poll_count, "auth.device_flow: authorization pending");
210                    continue;
211                }
212                Some("slow_down") => {
213                    poll_interval = (poll_interval + 5).min(60);
214                    debug!(
215                        poll = poll_count,
216                        new_interval_secs = poll_interval,
217                        "auth.device_flow: slow_down received, backing off"
218                    );
219                    continue;
220                }
221                Some("expired_token") => {
222                    tracing::Span::current().record("poll_count", poll_count);
223                    warn!(
224                        poll_count,
225                        "auth.device_flow: device code expired during poll"
226                    );
227                    return Err(MemoryError::OAuth(
228                        "device code expired; please run `memory-mcp auth login` again".to_string(),
229                    ));
230                }
231                Some("access_denied") => {
232                    tracing::Span::current().record("poll_count", poll_count);
233                    warn!(poll_count, "auth.device_flow: access denied by user");
234                    return Err(MemoryError::OAuth(
235                        "authorization denied by user".to_string(),
236                    ));
237                }
238                Some(other) => {
239                    let desc = resp
240                        .error_description
241                        .as_deref()
242                        .unwrap_or("no description");
243                    tracing::Span::current().record("poll_count", poll_count);
244                    warn!(
245                        poll_count,
246                        error = other,
247                        description = desc,
248                        "auth.device_flow: unexpected OAuth error"
249                    );
250                    return Err(MemoryError::OAuth(format!(
251                        "unexpected OAuth error '{other}': {desc}"
252                    )));
253                }
254                None => {
255                    tracing::Span::current().record("poll_count", poll_count);
256                    warn!(
257                        poll_count,
258                        "auth.device_flow: server returned neither access_token nor error"
259                    );
260                    return Err(MemoryError::OAuth(
261                        "server returned neither an access_token nor an error field; \
262                         unexpected response"
263                            .to_string(),
264                    ));
265                }
266            }
267        };
268
269        tracing::Span::current().record("poll_count", poll_count);
270        info!(poll_count, "auth.device_flow: token obtained successfully");
271
272        // Step 4: Store the token.
273        super::store_token(
274            &token,
275            store,
276            #[cfg(feature = "k8s")]
277            k8s_config,
278        )
279        .await?;
280        eprintln!("Authentication successful.");
281
282        Ok(())
283    }
284    .instrument(span.clone())
285    .await;
286
287    let elapsed_ms = start.elapsed().as_millis() as u64;
288    let outcome = if result.is_ok() { "success" } else { "error" };
289    span.record("elapsed_ms", elapsed_ms);
290    span.record("outcome", outcome);
291
292    result
293}
294
295// ---------------------------------------------------------------------------
296// Tests
297// ---------------------------------------------------------------------------
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    struct MockDeviceFlow {
304        client_id: &'static str,
305        device_code_url: &'static str,
306        access_token_url: &'static str,
307        scopes: &'static [&'static str],
308    }
309
310    impl DeviceFlowProvider for MockDeviceFlow {
311        fn client_id(&self) -> &str {
312            self.client_id
313        }
314        fn device_code_url(&self) -> &str {
315            self.device_code_url
316        }
317        fn access_token_url(&self) -> &str {
318            self.access_token_url
319        }
320        fn scopes(&self) -> &[&str] {
321            self.scopes
322        }
323        fn validate(&self) -> Result<(), MemoryError> {
324            if self.client_id.is_empty() {
325                return Err(MemoryError::OAuth("client ID must not be empty".into()));
326            }
327            if self.client_id.len() < 4 || self.client_id.len() > 64 {
328                return Err(MemoryError::OAuth(format!(
329                    "client ID has unexpected length ({})",
330                    self.client_id.len()
331                )));
332            }
333            validate_endpoint_url(self.device_code_url, "device_code_url")?;
334            validate_endpoint_url(self.access_token_url, "access_token_url")?;
335            Ok(())
336        }
337    }
338
339    fn valid_mock() -> MockDeviceFlow {
340        MockDeviceFlow {
341            client_id: "test-client-id",
342            device_code_url: "https://example.com/device/code",
343            access_token_url: "https://example.com/oauth/token",
344            scopes: &["repo"],
345        }
346    }
347
348    // TC-08a: GitHubDeviceFlow returns expected values
349    #[test]
350    fn github_provider_returns_expected_values() {
351        let p = GitHubDeviceFlow;
352        assert_eq!(p.client_id(), "Ov23liWxHYkwXTxCrYHp");
353        assert_eq!(p.device_code_url(), "https://github.com/login/device/code");
354        assert_eq!(
355            p.access_token_url(),
356            "https://github.com/login/oauth/access_token"
357        );
358        assert_eq!(p.scopes(), &["repo"]);
359    }
360
361    // TC-08b: device_flow_login accepts &dyn DeviceFlowProvider (compile-time check)
362    #[allow(dead_code)]
363    async fn accepts_trait_object(provider: &dyn DeviceFlowProvider) {
364        let _ = device_flow_login(
365            provider,
366            None,
367            #[cfg(feature = "k8s")]
368            None,
369        )
370        .await;
371    }
372
373    // TC-09a: GitHubDeviceFlow validates OK
374    #[test]
375    fn github_provider_validates_ok() {
376        assert!(GitHubDeviceFlow.validate().is_ok());
377    }
378
379    // TC-09b: Empty client ID fails validation
380    #[test]
381    fn empty_client_id_fails_validation() {
382        let mock = MockDeviceFlow {
383            client_id: "",
384            ..valid_mock()
385        };
386        let err = mock.validate().unwrap_err();
387        assert!(err.to_string().contains("client ID"), "got: {err}");
388    }
389
390    // TC-09c: Malformed client ID fails (GitHub-specific format check)
391    #[test]
392    fn malformed_github_client_id_fails_validation() {
393        assert!(github::validate_github_client_id("").is_err());
394        assert!(github::validate_github_client_id("x").is_err());
395        assert!(github::validate_github_client_id("Ov23liWxHYkwXTxCrYHp").is_ok());
396    }
397
398    // TC-10a: HTTP URL fails validation
399    #[test]
400    fn http_url_fails_validation() {
401        let mock = MockDeviceFlow {
402            device_code_url: "http://example.com/device/code",
403            ..valid_mock()
404        };
405        assert!(mock.validate().is_err());
406    }
407
408    // TC-10b: HTTP localhost passes validation
409    #[test]
410    fn http_localhost_passes_validation() {
411        let mock = MockDeviceFlow {
412            device_code_url: "http://localhost/device/code",
413            access_token_url: "http://localhost/oauth/token",
414            ..valid_mock()
415        };
416        assert!(mock.validate().is_ok());
417    }
418
419    // TC-10c: HTTPS URLs pass validation
420    #[test]
421    fn https_urls_pass_validation() {
422        assert!(valid_mock().validate().is_ok());
423    }
424
425    // IPv6 loopback passes validation (host_str returns "[::1]" with brackets)
426    #[test]
427    fn http_ipv6_localhost_passes_validation() {
428        let mock = MockDeviceFlow {
429            device_code_url: "http://[::1]/device/code",
430            access_token_url: "http://[::1]/oauth/token",
431            ..valid_mock()
432        };
433        assert!(mock.validate().is_ok());
434    }
435
436    // Non-loopback IPv6 HTTP URL is rejected
437    #[test]
438    fn http_ipv6_non_loopback_fails_validation() {
439        let mock = MockDeviceFlow {
440            device_code_url: "http://[::2]/device/code",
441            ..valid_mock()
442        };
443        assert!(mock.validate().is_err());
444    }
445
446    // 127.0.0.1 passes validation
447    #[test]
448    fn http_127_0_0_1_passes_validation() {
449        let mock = MockDeviceFlow {
450            device_code_url: "http://127.0.0.1/device/code",
451            access_token_url: "http://127.0.0.1/oauth/token",
452            ..valid_mock()
453        };
454        assert!(mock.validate().is_ok());
455    }
456
457    /// Device flow requires real OAuth — skip in CI.
458    #[tokio::test]
459    #[ignore = "requires real OAuth interaction"]
460    async fn device_flow_login_ignored_in_ci() {
461        device_flow_login(
462            &GitHubDeviceFlow,
463            Some(super::super::StoreBackend::Stdout),
464            #[cfg(feature = "k8s")]
465            None,
466        )
467        .await
468        .expect("device flow should succeed");
469    }
470}