mecha10_cli/services/
auth.rs

1//! Authentication service for device code flow
2//!
3//! Handles the OAuth 2.0 Device Authorization Grant (RFC 8628) flow
4//! for authenticating users via browser.
5
6use crate::services::credentials::DEFAULT_AUTH_URL;
7use crate::types::credentials::{AuthError, Credentials, DeviceCodeResponse, DeviceCodeStatus};
8use anyhow::{Context, Result};
9use chrono::Utc;
10use std::time::Duration;
11
12/// Service for handling authentication flows
13pub struct AuthService {
14    /// Base URL for auth API
15    auth_url: String,
16    /// HTTP client
17    client: reqwest::Client,
18}
19
20impl AuthService {
21    /// Create a new AuthService with default auth URL
22    pub fn new() -> Self {
23        Self::with_auth_url(DEFAULT_AUTH_URL.to_string())
24    }
25
26    /// Create a new AuthService with custom auth URL
27    pub fn with_auth_url(auth_url: String) -> Self {
28        let client = reqwest::Client::builder()
29            .timeout(Duration::from_secs(30))
30            .build()
31            .expect("Failed to create HTTP client");
32
33        Self { auth_url, client }
34    }
35
36    /// Get the auth URL
37    pub fn auth_url(&self) -> &str {
38        &self.auth_url
39    }
40
41    /// Request a new device code for authentication
42    ///
43    /// Returns device code info including user_code and verification_uri
44    pub async fn request_device_code(&self) -> Result<DeviceCodeResponse, AuthError> {
45        let url = format!("{}/device/code", self.auth_url);
46
47        let response = self
48            .client
49            .post(&url)
50            .header("Content-Type", "application/json")
51            .send()
52            .await
53            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
54
55        let status = response.status();
56
57        if status.is_success() {
58            response.json().await.map_err(|e| AuthError::ServerError {
59                message: format!("Failed to parse response: {}", e),
60                status_code: Some(status.as_u16()),
61            })
62        } else if status.as_u16() == 429 {
63            let retry_after = response
64                .headers()
65                .get("retry-after")
66                .and_then(|v| v.to_str().ok())
67                .and_then(|s| s.parse().ok());
68
69            Err(AuthError::RateLimited { retry_after })
70        } else {
71            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
72
73            Err(AuthError::ServerError {
74                message,
75                status_code: Some(status.as_u16()),
76            })
77        }
78    }
79
80    /// Poll for device code status
81    ///
82    /// Should be called at the interval specified in DeviceCodeResponse
83    pub async fn poll_device_code(&self, device_code: &str) -> Result<DeviceCodeStatus, AuthError> {
84        let url = format!("{}/device/code/{}/status", self.auth_url, device_code);
85
86        let response = self
87            .client
88            .get(&url)
89            .send()
90            .await
91            .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
92
93        let status = response.status();
94
95        if status.is_success() {
96            response.json().await.map_err(|e| AuthError::ServerError {
97                message: format!("Failed to parse status: {}", e),
98                status_code: Some(status.as_u16()),
99            })
100        } else if status.as_u16() == 429 {
101            let retry_after = response
102                .headers()
103                .get("retry-after")
104                .and_then(|v| v.to_str().ok())
105                .and_then(|s| s.parse().ok());
106
107            Err(AuthError::RateLimited { retry_after })
108        } else if status.as_u16() == 404 {
109            // Device code not found or expired
110            Err(AuthError::ExpiredCode)
111        } else {
112            let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
113
114            Err(AuthError::ServerError {
115                message,
116                status_code: Some(status.as_u16()),
117            })
118        }
119    }
120
121    /// Run the complete device code flow
122    ///
123    /// This method:
124    /// 1. Requests a device code
125    /// 2. Returns immediately with device code info (caller should display to user)
126    /// 3. Polls until authorized, denied, or expired
127    ///
128    /// Returns credentials on success
129    pub async fn run_device_code_flow<F>(&self, on_device_code: F) -> Result<Credentials, AuthError>
130    where
131        F: FnOnce(&DeviceCodeResponse),
132    {
133        // Step 1: Request device code
134        let device_code_response = self.request_device_code().await?;
135        let poll_interval = Duration::from_secs(device_code_response.interval as u64);
136        let expires_at = std::time::Instant::now() + Duration::from_secs(device_code_response.expires_in as u64);
137
138        // Step 2: Notify caller with device code info
139        on_device_code(&device_code_response);
140
141        // Step 3: Poll until terminal state
142        loop {
143            // Check if expired
144            if std::time::Instant::now() > expires_at {
145                return Err(AuthError::ExpiredCode);
146            }
147
148            // Wait before polling
149            tokio::time::sleep(poll_interval).await;
150
151            // Poll for status
152            match self.poll_device_code(&device_code_response.device_code).await {
153                Ok(DeviceCodeStatus::Pending) => {
154                    // Continue polling
155                    continue;
156                }
157                Ok(DeviceCodeStatus::Authorized {
158                    api_key,
159                    user_id,
160                    email,
161                    name,
162                }) => {
163                    return Ok(Credentials {
164                        api_key,
165                        user_id,
166                        email,
167                        name,
168                        authenticated_at: Utc::now(),
169                        auth_url: self.auth_url.clone(),
170                    });
171                }
172                Ok(DeviceCodeStatus::Denied) => {
173                    return Err(AuthError::AccessDenied);
174                }
175                Ok(DeviceCodeStatus::Expired) => {
176                    return Err(AuthError::ExpiredCode);
177                }
178                Err(AuthError::RateLimited { retry_after }) => {
179                    // Back off and retry
180                    let wait_time = retry_after.map(|s| s as u64).unwrap_or(10);
181                    tokio::time::sleep(Duration::from_secs(wait_time)).await;
182                    continue;
183                }
184                Err(e) => {
185                    return Err(e);
186                }
187            }
188        }
189    }
190}
191
192impl Default for AuthService {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_auth_service_creation() {
204        let service = AuthService::new();
205        assert_eq!(service.auth_url(), DEFAULT_AUTH_URL);
206    }
207
208    #[test]
209    fn test_custom_auth_url() {
210        let custom_url = "https://custom.auth.example.com";
211        let service = AuthService::with_auth_url(custom_url.to_string());
212        assert_eq!(service.auth_url(), custom_url);
213    }
214}