1use crate::types::{AuthError, Credentials, DeviceCodeResponse, DeviceCodeStatus};
7use crate::DEFAULT_AUTH_URL;
8use chrono::Utc;
9use std::time::Duration;
10
11pub struct AuthService {
13 auth_url: String,
15 client: reqwest::Client,
17}
18
19impl AuthService {
20 pub fn new() -> Self {
22 Self::with_auth_url(DEFAULT_AUTH_URL.to_string())
23 }
24
25 pub fn with_auth_url(auth_url: String) -> Self {
27 let client = reqwest::Client::builder()
28 .timeout(Duration::from_secs(30))
29 .build()
30 .expect("Failed to create HTTP client");
31
32 Self { auth_url, client }
33 }
34
35 pub fn auth_url(&self) -> &str {
37 &self.auth_url
38 }
39
40 pub async fn request_device_code(&self) -> Result<DeviceCodeResponse, AuthError> {
44 let url = format!("{}/device/code", self.auth_url);
45
46 let response = self
47 .client
48 .post(&url)
49 .header("Content-Type", "application/json")
50 .send()
51 .await
52 .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
53
54 let status = response.status();
55
56 if status.is_success() {
57 response.json().await.map_err(|e| AuthError::ServerError {
58 message: format!("Failed to parse response: {}", e),
59 status_code: Some(status.as_u16()),
60 })
61 } else if status.as_u16() == 429 {
62 let retry_after = response
63 .headers()
64 .get("retry-after")
65 .and_then(|v| v.to_str().ok())
66 .and_then(|s| s.parse().ok());
67
68 Err(AuthError::RateLimited { retry_after })
69 } else {
70 let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
71
72 Err(AuthError::ServerError {
73 message,
74 status_code: Some(status.as_u16()),
75 })
76 }
77 }
78
79 pub async fn poll_device_code(&self, device_code: &str) -> Result<DeviceCodeStatus, AuthError> {
83 let url = format!("{}/device/code/{}/status", self.auth_url, device_code);
84
85 let response = self
86 .client
87 .get(&url)
88 .send()
89 .await
90 .map_err(|e| AuthError::NetworkError { message: e.to_string() })?;
91
92 let status = response.status();
93
94 if status.is_success() {
95 response.json().await.map_err(|e| AuthError::ServerError {
96 message: format!("Failed to parse status: {}", e),
97 status_code: Some(status.as_u16()),
98 })
99 } else if status.as_u16() == 429 {
100 let retry_after = response
101 .headers()
102 .get("retry-after")
103 .and_then(|v| v.to_str().ok())
104 .and_then(|s| s.parse().ok());
105
106 Err(AuthError::RateLimited { retry_after })
107 } else if status.as_u16() == 404 {
108 Err(AuthError::ExpiredCode)
110 } else {
111 let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
112
113 Err(AuthError::ServerError {
114 message,
115 status_code: Some(status.as_u16()),
116 })
117 }
118 }
119
120 pub async fn run_device_code_flow<F>(&self, on_device_code: F) -> Result<Credentials, AuthError>
129 where
130 F: FnOnce(&DeviceCodeResponse),
131 {
132 let device_code_response = self.request_device_code().await?;
134 let poll_interval = Duration::from_secs(device_code_response.interval as u64);
135 let expires_at = std::time::Instant::now() + Duration::from_secs(device_code_response.expires_in as u64);
136
137 on_device_code(&device_code_response);
139
140 loop {
142 if std::time::Instant::now() > expires_at {
144 return Err(AuthError::ExpiredCode);
145 }
146
147 tokio::time::sleep(poll_interval).await;
149
150 match self.poll_device_code(&device_code_response.device_code).await {
152 Ok(DeviceCodeStatus::Pending) => {
153 continue;
155 }
156 Ok(DeviceCodeStatus::Authorized {
157 api_key,
158 user_id,
159 email,
160 name,
161 }) => {
162 return Ok(Credentials {
163 api_key,
164 user_id,
165 email,
166 name,
167 authenticated_at: Utc::now(),
168 auth_url: self.auth_url.clone(),
169 });
170 }
171 Ok(DeviceCodeStatus::Denied) => {
172 return Err(AuthError::AccessDenied);
173 }
174 Ok(DeviceCodeStatus::Expired) => {
175 return Err(AuthError::ExpiredCode);
176 }
177 Err(AuthError::RateLimited { retry_after }) => {
178 let wait_time = retry_after.map(|s| s as u64).unwrap_or(10);
180 tokio::time::sleep(Duration::from_secs(wait_time)).await;
181 continue;
182 }
183 Err(e) => {
184 return Err(e);
185 }
186 }
187 }
188 }
189}
190
191impl Default for AuthService {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_auth_service_creation() {
203 let service = AuthService::new();
204 assert_eq!(service.auth_url(), DEFAULT_AUTH_URL);
205 }
206
207 #[test]
208 fn test_custom_auth_url() {
209 let custom_url = "https://custom.auth.example.com";
210 let service = AuthService::with_auth_url(custom_url.to_string());
211 assert_eq!(service.auth_url(), custom_url);
212 }
213}