1use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
4use chrono::{DateTime, Duration, Utc};
5use rand::Rng;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11use crate::error::{PolestarError, Result};
12
13pub const OIDC_PROVIDER_BASE_URL: &str = "https://polestarid.eu.polestar.com";
15pub const OIDC_CLIENT_ID: &str = "l3oopkc_10";
16pub const OIDC_REDIRECT_URI: &str = "https://www.polestar.com/sign-in-callback";
17pub const OIDC_SCOPE: &str = "openid profile email customer:attributes";
18pub const TOKEN_REFRESH_WINDOW_SECS: i64 = 300;
19
20#[derive(Debug, Clone, Deserialize)]
22pub struct OidcConfig {
23 pub issuer: String,
24 pub token_endpoint: String,
25 pub authorization_endpoint: String,
26}
27
28#[derive(Debug, Clone, Deserialize)]
30pub struct TokenResponse {
31 pub access_token: String,
32 pub refresh_token: String,
33 pub expires_in: i64,
34}
35
36#[derive(Debug, Clone)]
38pub struct TokenState {
39 pub access_token: String,
40 pub refresh_token: String,
41 pub expires_at: DateTime<Utc>,
42}
43
44impl TokenState {
45 pub fn from_response(response: TokenResponse) -> Self {
46 Self {
47 access_token: response.access_token,
48 refresh_token: response.refresh_token,
49 expires_at: Utc::now() + Duration::seconds(response.expires_in),
50 }
51 }
52
53 pub fn is_valid(&self) -> bool {
54 Utc::now() < self.expires_at
55 }
56
57 pub fn needs_refresh(&self, token_lifetime_secs: i64) -> bool {
58 let refresh_window = std::cmp::min(token_lifetime_secs / 2, TOKEN_REFRESH_WINDOW_SECS);
59 let expires_in = (self.expires_at - Utc::now()).num_seconds();
60 expires_in < refresh_window
61 }
62}
63
64pub struct AuthState {
66 pub username: String,
68 pub password: String,
70 pub token: Arc<RwLock<Option<TokenState>>>,
72 pub oidc_config: Arc<RwLock<Option<OidcConfig>>>,
74}
75
76impl AuthState {
77 pub fn new(username: String, password: String) -> Self {
79 Self {
80 username,
81 password,
82 token: Arc::new(RwLock::new(None)),
83 oidc_config: Arc::new(RwLock::new(None)),
84 }
85 }
86
87 pub async fn get_oidc_config(&self, client: &reqwest::Client) -> Result<OidcConfig> {
89 {
91 let config = self.oidc_config.read().await;
92 if let Some(cfg) = config.as_ref() {
93 return Ok(cfg.clone());
94 }
95 }
96
97 let url = format!("{}/.well-known/openid-configuration", OIDC_PROVIDER_BASE_URL);
99 let response = client
100 .get(&url)
101 .send()
102 .await?;
103
104 if !response.status().is_success() {
105 return Err(PolestarError::OidcConfigError(format!(
106 "HTTP {}",
107 response.status()
108 )));
109 }
110
111 let config: OidcConfig = response.json().await?;
112
113 {
115 let mut cached = self.oidc_config.write().await;
116 *cached = Some(config.clone());
117 }
118
119 Ok(config)
120 }
121
122 async fn get_resume_path(
124 &self,
125 client: &reqwest::Client,
126 config: &OidcConfig,
127 state: &str,
128 code_challenge: &str,
129 ) -> Result<String> {
130 let params = [
131 ("client_id", OIDC_CLIENT_ID),
132 ("redirect_uri", OIDC_REDIRECT_URI),
133 ("response_type", "code"),
134 ("scope", OIDC_SCOPE),
135 ("state", state),
136 ("code_challenge", code_challenge),
137 ("code_challenge_method", "S256"),
138 ("response_mode", "query"),
139 ];
140
141 let response = client
142 .get(&config.authorization_endpoint)
143 .query(¶ms)
144 .send()
145 .await?;
146
147 if !response.status().is_success() {
148 return Err(PolestarError::AuthError(format!(
149 "Failed to get resume path: {}",
150 response.status()
151 )));
152 }
153
154 let text = response.text().await?;
155
156 let re = regex::Regex::new(r#"(?:url|action):\s*"(.+?)""#)
158 .map_err(|e| PolestarError::ApiError(format!("Regex error: {}", e)))?;
159
160 if let Some(caps) = re.captures(&text) {
161 if let Some(path) = caps.get(1) {
162 return Ok(path.as_str().to_string());
163 }
164 }
165
166 Err(PolestarError::AuthError("Resume path not found in response".to_string()))
167 }
168
169 pub async fn get_authorization_code(
172 &self,
173 client: &reqwest::Client,
174 ) -> Result<(String, String)> {
175 let config = self.get_oidc_config(client).await?;
176 let state = generate_state();
177 let code_verifier = generate_code_verifier();
178 let code_challenge = generate_code_challenge(&code_verifier);
179
180 let resume_path = self.get_resume_path(client, &config, &state, &code_challenge).await?;
181 let resume_url = format!("{}{}", OIDC_PROVIDER_BASE_URL, resume_path);
182
183 let params = [
185 ("client_id", OIDC_CLIENT_ID),
186 ("redirect_uri", OIDC_REDIRECT_URI),
187 ("response_type", "code"),
188 ("scope", OIDC_SCOPE),
189 ("state", state.as_str()),
190 ("code_challenge", code_challenge.as_str()),
191 ("code_challenge_method", "S256"),
192 ("response_mode", "query"),
193 ];
194
195 let form = [
197 ("pf.username", self.username.as_str()),
198 ("pf.pass", self.password.as_str()),
199 ];
200
201 let response = client
202 .post(&resume_url)
203 .query(¶ms)
204 .form(&form)
205 .send()
206 .await?;
207
208 let status = response.status();
209
210 if status.is_client_error() && !status.is_redirection() {
212 let text = response.text().await?;
213 if text.contains(r#"authMessage: "ERR001""#) {
214 return Err(PolestarError::InvalidCredentials);
215 }
216 return Err(PolestarError::AuthError(format!("Authentication failed: {}", status)));
217 }
218
219 let final_url = response.url().clone();
222
223 if let Some(code) = final_url.query_pairs().find(|(k, _)| k == "code").map(|(_, v)| v.to_string()) {
225 return Ok((code, code_verifier));
226 }
227
228 if let Some(uid) = final_url.query_pairs().find(|(k, _)| k == "uid").map(|(_, v)| v.to_string()) {
230 let form = [
232 ("pf.submit", "true"),
233 ("subject", &uid),
234 ];
235
236 let response = client
237 .post(&resume_url)
238 .query(¶ms)
239 .form(&form)
240 .send()
241 .await?;
242
243 let final_url = response.url().clone();
244
245 if let Some(code) = final_url.query_pairs().find(|(k, _)| k == "code").map(|(_, v)| v.to_string()) {
246 return Ok((code, code_verifier));
247 }
248 }
249
250 Err(PolestarError::AuthError("No authorization code found".to_string()))
251 }
252
253 pub async fn exchange_code_for_token(
255 &self,
256 client: &reqwest::Client,
257 code: &str,
258 code_verifier: &str,
259 ) -> Result<TokenState> {
260 let config = self.get_oidc_config(client).await?;
261
262 let form = [
263 ("grant_type", "authorization_code"),
264 ("client_id", OIDC_CLIENT_ID),
265 ("code", code),
266 ("redirect_uri", OIDC_REDIRECT_URI),
267 ("code_verifier", code_verifier),
268 ];
269
270 let response = client
271 .post(&config.token_endpoint)
272 .form(&form)
273 .send()
274 .await?;
275
276 if !response.status().is_success() {
277 let text = response.text().await?;
278 return Err(PolestarError::AuthError(format!(
279 "Token exchange failed: {}",
280 text
281 )));
282 }
283
284 let token_response: TokenResponse = response.json().await?;
285 let token_state = TokenState::from_response(token_response);
286
287 {
289 let mut token = self.token.write().await;
290 *token = Some(token_state.clone());
291 }
292
293 Ok(token_state)
294 }
295
296 pub async fn refresh_token(&self, client: &reqwest::Client) -> Result<TokenState> {
298 let config = self.get_oidc_config(client).await?;
299
300 let refresh_token = {
302 let token = self.token.read().await;
303 token
304 .as_ref()
305 .map(|t| t.refresh_token.clone())
306 .ok_or_else(|| PolestarError::AuthError("No refresh token available".to_string()))?
307 };
308
309 let form = [
310 ("grant_type", "refresh_token"),
311 ("client_id", OIDC_CLIENT_ID),
312 ("refresh_token", &refresh_token),
313 ];
314
315 let response = client
316 .post(&config.token_endpoint)
317 .form(&form)
318 .send()
319 .await?;
320
321 if !response.status().is_success() {
322 let status = response.status();
323 let text = response.text().await?;
324
325 if status == reqwest::StatusCode::UNAUTHORIZED || text.contains("invalid_grant") {
327 return Err(PolestarError::TokenExpired);
328 }
329
330 return Err(PolestarError::AuthError(format!(
331 "Token refresh failed: {}",
332 text
333 )));
334 }
335
336 let token_response: TokenResponse = response.json().await?;
337 let token_state = TokenState::from_response(token_response);
338
339 {
341 let mut token = self.token.write().await;
342 *token = Some(token_state.clone());
343 }
344
345 Ok(token_state)
346 }
347
348 pub async fn is_token_valid(&self) -> bool {
350 let token = self.token.read().await;
351 token.as_ref().map(|t| t.is_valid()).unwrap_or(false)
352 }
353
354 pub async fn needs_token_refresh(&self) -> bool {
356 let token = self.token.read().await;
357 if let Some(t) = token.as_ref() {
358 let lifetime = (t.expires_at - Utc::now()).num_seconds() + 3600; t.needs_refresh(lifetime)
361 } else {
362 true }
364 }
365
366 pub async fn get_valid_token(&self, client: &reqwest::Client) -> Result<String> {
368 if self.is_token_valid().await && !self.needs_token_refresh().await {
370 let token = self.token.read().await;
371 return Ok(token.as_ref().unwrap().access_token.clone());
372 }
373
374 {
376 let token = self.token.read().await;
377 if token.is_some() {
378 drop(token); match self.refresh_token(client).await {
380 Ok(state) => return Ok(state.access_token),
381 Err(PolestarError::TokenExpired) => {
382 let mut token = self.token.write().await;
384 *token = None;
385 }
386 Err(e) => {
387 eprintln!("Token refresh failed: {}, attempting full auth", e);
389 }
390 }
391 }
392 }
393
394 let max_retries = 2;
396 let mut last_error = None;
397
398 for attempt in 0..max_retries {
399 match self.get_authorization_code(client).await {
400 Ok((code, verifier)) => {
401 match self.exchange_code_for_token(client, &code, &verifier).await {
402 Ok(token_state) => return Ok(token_state.access_token),
403 Err(e) => {
404 last_error = Some(e);
405 if attempt < max_retries - 1 {
406 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
407 }
408 }
409 }
410 }
411 Err(e) => {
412 last_error = Some(e);
413 if attempt < max_retries - 1 {
414 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
415 }
416 }
417 }
418 }
419
420 Err(last_error.unwrap_or_else(|| PolestarError::AuthError("Authentication failed".to_string())))
421 }
422}
423
424pub fn generate_code_verifier() -> String {
428 let random_bytes: [u8; 32] = rand::thread_rng().gen();
429 URL_SAFE_NO_PAD.encode(random_bytes)
430}
431
432pub fn generate_code_challenge(verifier: &str) -> String {
434 let mut hasher = Sha256::new();
435 hasher.update(verifier.as_bytes());
436 let hash = hasher.finalize();
437 URL_SAFE_NO_PAD.encode(hash)
438}
439
440pub fn generate_state() -> String {
442 let random_bytes: [u8; 32] = rand::thread_rng().gen();
443 URL_SAFE_NO_PAD.encode(random_bytes)
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_generate_code_verifier() {
452 let verifier = generate_code_verifier();
453 assert!(!verifier.is_empty());
454 assert_eq!(verifier.len(), 43); }
456
457 #[test]
458 fn test_generate_code_challenge() {
459 let verifier = "test_verifier";
460 let challenge = generate_code_challenge(verifier);
461 assert!(!challenge.is_empty());
462 assert_eq!(challenge.len(), 43); }
464
465 #[test]
466 fn test_generate_state() {
467 let state = generate_state();
468 assert!(!state.is_empty());
469 assert_eq!(state.len(), 43);
470 }
471
472 #[test]
473 fn test_token_state_is_valid() {
474 let response = TokenResponse {
475 access_token: "test_token".to_string(),
476 refresh_token: "test_refresh".to_string(),
477 expires_in: 3600,
478 };
479 let state = TokenState::from_response(response);
480 assert!(state.is_valid());
481 }
482
483 #[test]
484 fn test_token_state_needs_refresh() {
485 let mut state = TokenState {
487 access_token: "test".to_string(),
488 refresh_token: "test".to_string(),
489 expires_at: Utc::now() + Duration::seconds(10),
490 };
491 assert!(state.needs_refresh(600)); state.expires_at = Utc::now() + Duration::seconds(3600);
496 assert!(!state.needs_refresh(3600)); }
498
499 #[tokio::test]
500 async fn test_get_oidc_config() {
501 use wiremock::{MockServer, Mock, ResponseTemplate};
502 use wiremock::matchers::{method, path};
503
504 let mock_server = MockServer::start().await;
505
506 let config_json = serde_json::json!({
507 "issuer": "https://test.polestar.com",
508 "token_endpoint": "https://test.polestar.com/token",
509 "authorization_endpoint": "https://test.polestar.com/authorize"
510 });
511
512 Mock::given(method("GET"))
513 .and(path("/.well-known/openid-configuration"))
514 .respond_with(ResponseTemplate::new(200).set_body_json(&config_json))
515 .mount(&mock_server)
516 .await;
517
518 let auth = AuthState::new("user".to_string(), "pass".to_string());
520 let client = reqwest::Client::new();
521
522 assert_eq!(auth.username, "user");
525 }
526}