mecha10_auth/
service.rs

1//! Authentication service for device code flow
2//!
3//! Handles the OAuth 2.0 Device Authorization Grant (RFC 8628) flow
4//! for authenticating users via browser.
5
6use crate::types::{AuthError, Credentials, DeviceCodeResponse, DeviceCodeStatus};
7use crate::DEFAULT_AUTH_URL;
8use chrono::Utc;
9use std::time::Duration;
10
11/// Service for handling authentication flows
12pub struct AuthService {
13    /// Base URL for auth API
14    auth_url: String,
15    /// HTTP client
16    client: reqwest::Client,
17}
18
19impl AuthService {
20    /// Create a new AuthService with default auth URL
21    pub fn new() -> Self {
22        Self::with_auth_url(DEFAULT_AUTH_URL.to_string())
23    }
24
25    /// Create a new AuthService with custom auth URL
26    pub fn with_auth_url(auth_url: String) -> Self {
27        let client = reqwest::Client::builder()
28            .timeout(Duration::from_secs(30))
29            .build()
30            .expect("Failed to create HTTP client");
31
32        Self { auth_url, client }
33    }
34
35    /// Get the auth URL
36    pub fn auth_url(&self) -> &str {
37        &self.auth_url
38    }
39
40    /// Request a new device code for authentication
41    ///
42    /// Returns device code info including user_code and verification_uri
43    pub async fn request_device_code(&self) -> Result<DeviceCodeResponse, AuthError> {
44        let url = format!("{}/device/code", self.auth_url);
45
46        let response = self
47            .client
48            .post(&url)
49            .header("Content-Type", "application/json")
50            .send()
51            .await
52            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
53
54        let status = response.status();
55
56        if status.is_success() {
57            response.json().await.map_err(|e| AuthError::ServerError {
58                message: format!("Failed to parse response: {}", e),
59                status_code: Some(status.as_u16()),
60            })
61        } else if status.as_u16() == 429 {
62            let retry_after = response
63                .headers()
64                .get("retry-after")
65                .and_then(|v| v.to_str().ok())
66                .and_then(|s| s.parse().ok());
67
68            Err(AuthError::RateLimited { retry_after })
69        } else {
70            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
71
72            Err(AuthError::ServerError {
73                message,
74                status_code: Some(status.as_u16()),
75            })
76        }
77    }
78
79    /// Poll for device code status
80    ///
81    /// Should be called at the interval specified in DeviceCodeResponse
82    pub async fn poll_device_code(&self, device_code: &str) -> Result<DeviceCodeStatus, AuthError> {
83        let url = format!("{}/device/code/{}/status", self.auth_url, device_code);
84
85        let response = self
86            .client
87            .get(&url)
88            .send()
89            .await
90            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
91
92        let status = response.status();
93
94        if status.is_success() {
95            response.json().await.map_err(|e| AuthError::ServerError {
96                message: format!("Failed to parse status: {}", e),
97                status_code: Some(status.as_u16()),
98            })
99        } else if status.as_u16() == 429 {
100            let retry_after = response
101                .headers()
102                .get("retry-after")
103                .and_then(|v| v.to_str().ok())
104                .and_then(|s| s.parse().ok());
105
106            Err(AuthError::RateLimited { retry_after })
107        } else if status.as_u16() == 404 {
108            // Device code not found or expired
109            Err(AuthError::ExpiredCode)
110        } else {
111            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
112
113            Err(AuthError::ServerError {
114                message,
115                status_code: Some(status.as_u16()),
116            })
117        }
118    }
119
120    /// Run the complete device code flow
121    ///
122    /// This method:
123    /// 1. Requests a device code
124    /// 2. Returns immediately with device code info (caller should display to user)
125    /// 3. Polls until authorized, denied, or expired
126    ///
127    /// Returns credentials on success
128    pub async fn run_device_code_flow<F>(&self, on_device_code: F) -> Result<Credentials, AuthError>
129    where
130        F: FnOnce(&DeviceCodeResponse),
131    {
132        // Step 1: Request device code
133        let device_code_response = self.request_device_code().await?;
134        let poll_interval = Duration::from_secs(device_code_response.interval as u64);
135        let expires_at = std::time::Instant::now() + Duration::from_secs(device_code_response.expires_in as u64);
136
137        // Step 2: Notify caller with device code info
138        on_device_code(&device_code_response);
139
140        // Step 3: Poll until terminal state
141        loop {
142            // Check if expired
143            if std::time::Instant::now() > expires_at {
144                return Err(AuthError::ExpiredCode);
145            }
146
147            // Wait before polling
148            tokio::time::sleep(poll_interval).await;
149
150            // Poll for status
151            match self.poll_device_code(&device_code_response.device_code).await {
152                Ok(DeviceCodeStatus::Pending) => {
153                    // Continue polling
154                    continue;
155                }
156                Ok(DeviceCodeStatus::Authorized {
157                    api_key,
158                    user_id,
159                    email,
160                    name,
161                }) => {
162                    return Ok(Credentials {
163                        api_key,
164                        user_id,
165                        email,
166                        name,
167                        authenticated_at: Utc::now(),
168                        auth_url: self.auth_url.clone(),
169                    });
170                }
171                Ok(DeviceCodeStatus::Denied) => {
172                    return Err(AuthError::AccessDenied);
173                }
174                Ok(DeviceCodeStatus::Expired) => {
175                    return Err(AuthError::ExpiredCode);
176                }
177                Err(AuthError::RateLimited { retry_after }) => {
178                    // Back off and retry
179                    let wait_time = retry_after.map(|s| s as u64).unwrap_or(10);
180                    tokio::time::sleep(Duration::from_secs(wait_time)).await;
181                    continue;
182                }
183                Err(e) => {
184                    return Err(e);
185                }
186            }
187        }
188    }
189}
190
191impl Default for AuthService {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_auth_service_creation() {
203        let service = AuthService::new();
204        assert_eq!(service.auth_url(), DEFAULT_AUTH_URL);
205    }
206
207    #[test]
208    fn test_custom_auth_url() {
209        let custom_url = "https://custom.auth.example.com";
210        let service = AuthService::with_auth_url(custom_url.to_string());
211        assert_eq!(service.auth_url(), custom_url);
212    }
213}