mecha10_cli/services/
auth.rs

1//! Authentication service for device code flow
2//!
3//! Handles the OAuth 2.0 Device Authorization Grant (RFC 8628) flow
4//! for authenticating users via browser.
5
6use crate::services::credentials::DEFAULT_AUTH_URL;
7use crate::types::credentials::{AuthError, Credentials, DeviceCodeResponse, DeviceCodeStatus};
8use anyhow::Result;
9use chrono::Utc;
10use std::time::Duration;
11
12/// Service for handling authentication flows
13pub struct AuthService {
14    /// Base URL for auth API
15    auth_url: String,
16    /// HTTP client
17    client: reqwest::Client,
18}
19
20impl AuthService {
21    /// Create a new AuthService with default auth URL
22    pub fn new() -> Self {
23        Self::with_auth_url(DEFAULT_AUTH_URL.to_string())
24    }
25
26    /// Create a new AuthService with custom auth URL
27    pub fn with_auth_url(auth_url: String) -> Self {
28        let client = reqwest::Client::builder()
29            .timeout(Duration::from_secs(30))
30            .build()
31            .expect("Failed to create HTTP client");
32
33        Self { auth_url, client }
34    }
35
36    /// Get the auth URL
37    #[allow(dead_code)]
38    pub fn auth_url(&self) -> &str {
39        &self.auth_url
40    }
41
42    /// Request a new device code for authentication
43    ///
44    /// Returns device code info including user_code and verification_uri
45    pub async fn request_device_code(&self) -> Result<DeviceCodeResponse, AuthError> {
46        let url = format!("{}/device/code", self.auth_url);
47
48        let response = self
49            .client
50            .post(&url)
51            .header("Content-Type", "application/json")
52            .send()
53            .await
54            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
55
56        let status = response.status();
57
58        if status.is_success() {
59            response.json().await.map_err(|e| AuthError::ServerError {
60                message: format!("Failed to parse response: {}", e),
61                status_code: Some(status.as_u16()),
62            })
63        } else if status.as_u16() == 429 {
64            let retry_after = response
65                .headers()
66                .get("retry-after")
67                .and_then(|v| v.to_str().ok())
68                .and_then(|s| s.parse().ok());
69
70            Err(AuthError::RateLimited { retry_after })
71        } else {
72            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
73
74            Err(AuthError::ServerError {
75                message,
76                status_code: Some(status.as_u16()),
77            })
78        }
79    }
80
81    /// Poll for device code status
82    ///
83    /// Should be called at the interval specified in DeviceCodeResponse
84    pub async fn poll_device_code(&self, device_code: &str) -> Result<DeviceCodeStatus, AuthError> {
85        let url = format!("{}/device/code/{}/status", self.auth_url, device_code);
86
87        let response = self
88            .client
89            .get(&url)
90            .send()
91            .await
92            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
93
94        let status = response.status();
95
96        if status.is_success() {
97            response.json().await.map_err(|e| AuthError::ServerError {
98                message: format!("Failed to parse status: {}", e),
99                status_code: Some(status.as_u16()),
100            })
101        } else if status.as_u16() == 429 {
102            let retry_after = response
103                .headers()
104                .get("retry-after")
105                .and_then(|v| v.to_str().ok())
106                .and_then(|s| s.parse().ok());
107
108            Err(AuthError::RateLimited { retry_after })
109        } else if status.as_u16() == 404 {
110            // Device code not found or expired
111            Err(AuthError::ExpiredCode)
112        } else {
113            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
114
115            Err(AuthError::ServerError {
116                message,
117                status_code: Some(status.as_u16()),
118            })
119        }
120    }
121
122    /// Run the complete device code flow
123    ///
124    /// This method:
125    /// 1. Requests a device code
126    /// 2. Returns immediately with device code info (caller should display to user)
127    /// 3. Polls until authorized, denied, or expired
128    ///
129    /// Returns credentials on success
130    pub async fn run_device_code_flow<F>(&self, on_device_code: F) -> Result<Credentials, AuthError>
131    where
132        F: FnOnce(&DeviceCodeResponse),
133    {
134        // Step 1: Request device code
135        let device_code_response = self.request_device_code().await?;
136        let poll_interval = Duration::from_secs(device_code_response.interval as u64);
137        let expires_at = std::time::Instant::now() + Duration::from_secs(device_code_response.expires_in as u64);
138
139        // Step 2: Notify caller with device code info
140        on_device_code(&device_code_response);
141
142        // Step 3: Poll until terminal state
143        loop {
144            // Check if expired
145            if std::time::Instant::now() > expires_at {
146                return Err(AuthError::ExpiredCode);
147            }
148
149            // Wait before polling
150            tokio::time::sleep(poll_interval).await;
151
152            // Poll for status
153            match self.poll_device_code(&device_code_response.device_code).await {
154                Ok(DeviceCodeStatus::Pending) => {
155                    // Continue polling
156                    continue;
157                }
158                Ok(DeviceCodeStatus::Authorized {
159                    api_key,
160                    user_id,
161                    email,
162                    name,
163                }) => {
164                    return Ok(Credentials {
165                        api_key,
166                        user_id,
167                        email,
168                        name,
169                        authenticated_at: Utc::now(),
170                        auth_url: self.auth_url.clone(),
171                    });
172                }
173                Ok(DeviceCodeStatus::Denied) => {
174                    return Err(AuthError::AccessDenied);
175                }
176                Ok(DeviceCodeStatus::Expired) => {
177                    return Err(AuthError::ExpiredCode);
178                }
179                Err(AuthError::RateLimited { retry_after }) => {
180                    // Back off and retry
181                    let wait_time = retry_after.map(|s| s as u64).unwrap_or(10);
182                    tokio::time::sleep(Duration::from_secs(wait_time)).await;
183                    continue;
184                }
185                Err(e) => {
186                    return Err(e);
187                }
188            }
189        }
190    }
191}
192
193impl Default for AuthService {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_auth_service_creation() {
205        let service = AuthService::new();
206        assert_eq!(service.auth_url(), DEFAULT_AUTH_URL);
207    }
208
209    #[test]
210    fn test_custom_auth_url() {
211        let custom_url = "https://custom.auth.example.com";
212        let service = AuthService::with_auth_url(custom_url.to_string());
213        assert_eq!(service.auth_url(), custom_url);
214    }
215}