1use base64::Engine;
11use serde::Deserialize;
12use sha2::{Digest, Sha256};
13
14use car_secrets::{SecretRef, SecretStore};
15
16pub const PARSLEE_ACCESS_TOKEN_KEY: &str = "PARSLEE_ACCESS_TOKEN";
17pub const PARSLEE_REFRESH_TOKEN_KEY: &str = "PARSLEE_REFRESH_TOKEN";
18pub const PARSLEE_EXPIRES_AT_KEY: &str = "PARSLEE_ACCESS_TOKEN_EXPIRES_AT";
19pub const PARSLEE_API_BASE_KEY: &str = "PARSLEE_API_BASE";
20pub const DEFAULT_API_BASE: &str = "https://api.parslee.ai";
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct TokenSet {
25 pub access_token: String,
26 pub refresh_token: String,
27 pub expires_in: u64,
28 pub token_type: String,
29}
30
31fn epoch_seconds() -> u64 {
32 std::time::SystemTime::now()
33 .duration_since(std::time::UNIX_EPOCH)
34 .map(|d| d.as_secs())
35 .unwrap_or(0)
36}
37
38pub fn pkce_verifier() -> String {
40 let raw = format!(
41 "{}{}",
42 uuid::Uuid::new_v4().simple(),
43 uuid::Uuid::new_v4().simple()
44 );
45 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw.as_bytes())
46}
47
48pub fn new_state() -> String {
50 uuid::Uuid::new_v4().simple().to_string()
51}
52
53pub fn pkce_challenge(verifier: &str) -> String {
55 let digest = Sha256::digest(verifier.as_bytes());
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
57}
58
59pub fn authorize_url(
61 api_base: &str,
62 client_id: &str,
63 redirect_uri: &str,
64 state: &str,
65 challenge: &str,
66 provider: Option<&str>,
67) -> Result<String, String> {
68 let mut url = reqwest::Url::parse(&format!(
69 "{}/connect/authorize",
70 api_base.trim_end_matches('/')
71 ))
72 .map_err(|e| format!("build authorize URL: {e}"))?;
73 url.query_pairs_mut()
74 .append_pair("client_id", client_id)
75 .append_pair("redirect_uri", redirect_uri)
76 .append_pair("response_type", "code")
77 .append_pair("scope", "openid profile email")
78 .append_pair("state", state)
79 .append_pair("code_challenge", challenge)
80 .append_pair("code_challenge_method", "S256");
81 if let Some(provider) = provider {
82 url.query_pairs_mut().append_pair("provider", provider);
83 }
84 Ok(url.to_string())
85}
86
87fn form_body(pairs: &[(&str, &str)]) -> String {
88 let mut s = String::new();
89 for (i, (k, v)) in pairs.iter().enumerate() {
90 if i > 0 {
91 s.push('&');
92 }
93 s.push_str(&urlencode(k));
94 s.push('=');
95 s.push_str(&urlencode(v));
96 }
97 s
98}
99
100fn urlencode(s: &str) -> String {
101 let mut out = String::with_capacity(s.len());
102 for b in s.bytes() {
103 match b {
104 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
105 out.push(b as char)
106 }
107 _ => out.push_str(&format!("%{b:02X}")),
108 }
109 }
110 out
111}
112
113pub async fn exchange_code(
115 api_base: &str,
116 client_id: &str,
117 redirect_uri: &str,
118 code: &str,
119 verifier: &str,
120) -> Result<TokenSet, String> {
121 let body = form_body(&[
122 ("grant_type", "authorization_code"),
123 ("client_id", client_id),
124 ("redirect_uri", redirect_uri),
125 ("code", code),
126 ("code_verifier", verifier),
127 ]);
128 let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
129 let response = reqwest::Client::new()
130 .post(token_url)
131 .header("content-type", "application/x-www-form-urlencoded")
132 .body(body)
133 .send()
134 .await
135 .map_err(|e| format!("exchange Parslee authorization code: {e}"))?;
136 let status = response.status();
137 let text = response
138 .text()
139 .await
140 .map_err(|e| format!("read token response: {e}"))?;
141 if !status.is_success() {
142 return Err(format!(
143 "Parslee token exchange failed: HTTP {status}: {text}"
144 ));
145 }
146 let token: TokenSet =
147 serde_json::from_str(&text).map_err(|e| format!("parse token response: {e}"))?;
148 if !token.token_type.eq_ignore_ascii_case("bearer") {
149 return Err(format!(
150 "unexpected Parslee token_type `{}`",
151 token.token_type
152 ));
153 }
154 Ok(token)
155}
156
157fn put(key: &str, value: &str) -> Result<(), String> {
158 SecretStore::new()
159 .put(&SecretRef::with_default_service(key), value)
160 .map_err(|e| format!("store {key}: {e}"))
161}
162
163pub fn store_tokens(api_base: &str, token: &TokenSet) -> Result<(), String> {
166 put(PARSLEE_ACCESS_TOKEN_KEY, &token.access_token)?;
167 put(PARSLEE_REFRESH_TOKEN_KEY, &token.refresh_token)?;
168 put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'))?;
169 put(
170 PARSLEE_EXPIRES_AT_KEY,
171 &(epoch_seconds() + token.expires_in).to_string(),
172 )?;
173 Ok(())
174}
175
176pub fn clear_tokens() -> Result<(), String> {
178 let store = SecretStore::new();
179 for key in [
180 PARSLEE_ACCESS_TOKEN_KEY,
181 PARSLEE_REFRESH_TOKEN_KEY,
182 PARSLEE_EXPIRES_AT_KEY,
183 PARSLEE_API_BASE_KEY,
184 ] {
185 let _ = store.delete(&SecretRef::with_default_service(key));
186 }
187 Ok(())
188}
189
190pub fn access_token() -> Option<String> {
192 car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)
193}
194
195pub const REFRESH_SKEW_SECS: u64 = 120;
199
200#[derive(Debug, Clone)]
203pub struct RefreshedTokens {
204 pub access_token: String,
205 pub refresh_token: Option<String>,
206 pub expires_in: Option<u64>,
207}
208
209pub async fn refresh_grant(api_base: &str, refresh_token: &str) -> Result<RefreshedTokens, String> {
217 #[derive(Deserialize)]
218 struct Resp {
219 access_token: String,
220 #[serde(default)]
221 refresh_token: Option<String>,
222 #[serde(default)]
223 expires_in: Option<u64>,
224 }
225 let body = form_body(&[
226 ("grant_type", "refresh_token"),
227 ("refresh_token", refresh_token),
228 ]);
229 let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
230 let response = reqwest::Client::new()
231 .post(token_url)
232 .header("content-type", "application/x-www-form-urlencoded")
233 .body(body)
234 .send()
235 .await
236 .map_err(|e| format!("refresh Parslee token: {e}"))?;
237 let status = response.status();
238 let text = response
239 .text()
240 .await
241 .map_err(|e| format!("read Parslee token response: {e}"))?;
242 if !status.is_success() {
243 return Err(format!("refresh Parslee token: HTTP {status}: {text}"));
244 }
245 let r: Resp =
246 serde_json::from_str(&text).map_err(|e| format!("parse Parslee token response: {e}"))?;
247 Ok(RefreshedTokens {
248 access_token: r.access_token,
249 refresh_token: r.refresh_token,
250 expires_in: r.expires_in,
251 })
252}
253
254fn persist_refreshed(api_base: &str, t: &RefreshedTokens) {
258 let _ = put(PARSLEE_ACCESS_TOKEN_KEY, &t.access_token);
259 if let Some(refresh) = &t.refresh_token {
260 let _ = put(PARSLEE_REFRESH_TOKEN_KEY, refresh);
261 }
262 if let Some(expires_in) = t.expires_in {
263 let _ = put(
264 PARSLEE_EXPIRES_AT_KEY,
265 &(epoch_seconds() + expires_in).to_string(),
266 );
267 }
268 let _ = put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'));
269}
270
271pub async fn access_token_refreshing() -> Option<String> {
282 if let Ok(tok) = std::env::var(PARSLEE_ACCESS_TOKEN_KEY) {
284 if !tok.is_empty() {
285 return Some(tok);
286 }
287 }
288 let current = car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)?;
289 let expiring = car_secrets::resolve_env_or_keychain(PARSLEE_EXPIRES_AT_KEY)
292 .and_then(|s| s.trim().parse::<u64>().ok())
293 .map(|exp| epoch_seconds() + REFRESH_SKEW_SECS >= exp)
294 .unwrap_or(false);
295 if !expiring {
296 return Some(current);
297 }
298 let Some(refresh) = car_secrets::resolve_env_or_keychain(PARSLEE_REFRESH_TOKEN_KEY) else {
299 return Some(current);
300 };
301 let base = api_base(None);
302 match refresh_grant(&base, &refresh).await {
303 Ok(tokens) => {
304 let access = tokens.access_token.clone();
305 persist_refreshed(&base, &tokens);
306 Some(access)
307 }
308 Err(_) => Some(current),
312 }
313}
314
315pub async fn force_refresh() -> Option<String> {
329 if let Ok(tok) = std::env::var(PARSLEE_ACCESS_TOKEN_KEY) {
330 if !tok.is_empty() {
331 return None;
332 }
333 }
334 let refresh = car_secrets::resolve_env_or_keychain(PARSLEE_REFRESH_TOKEN_KEY)?;
335 let base = api_base(None);
336 match refresh_grant(&base, &refresh).await {
337 Ok(tokens) => {
338 let access = tokens.access_token.clone();
339 persist_refreshed(&base, &tokens);
340 Some(access)
341 }
342 Err(_) => None,
343 }
344}
345
346pub fn api_base(override_: Option<&str>) -> String {
348 override_
349 .map(|s| s.trim_end_matches('/').to_string())
350 .or_else(|| car_secrets::resolve_env_or_keychain(PARSLEE_API_BASE_KEY))
351 .unwrap_or_else(|| DEFAULT_API_BASE.to_string())
352}
353
354pub async fn fetch_status(api_base_override: Option<&str>) -> Result<Option<String>, String> {
357 let Some(access) = access_token() else {
358 return Ok(None);
359 };
360 let base = api_base(api_base_override);
361 let response = reqwest::Client::new()
362 .get(format!("{}/connect/session", base.trim_end_matches('/')))
363 .bearer_auth(access)
364 .send()
365 .await
366 .map_err(|e| format!("fetch Parslee session: {e}"))?;
367 let status = response.status();
368 let text = response
369 .text()
370 .await
371 .map_err(|e| format!("read Parslee session response: {e}"))?;
372 if !status.is_success() {
373 return Err(format!(
374 "Parslee session check failed: HTTP {status}: {text}"
375 ));
376 }
377 Ok(Some(text))
378}
379
380#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn pkce_challenge_is_s256_urlsafe_nopad() {
393 let v = pkce_verifier();
394 let c = pkce_challenge(&v);
395 assert!(!c.contains('=') && !c.contains('+') && !c.contains('/'));
396 assert_eq!(c, pkce_challenge(&v)); }
398
399 #[test]
400 fn authorize_url_has_pkce_and_provider() {
401 let u = authorize_url(
402 "https://api.parslee.ai/",
403 "parslee-car",
404 "http://localhost:8765/auth/callback",
405 "st8",
406 "chal",
407 Some("microsoft"),
408 )
409 .unwrap();
410 assert!(u.starts_with("https://api.parslee.ai/connect/authorize?"));
411 assert!(u.contains("code_challenge=chal"));
412 assert!(u.contains("code_challenge_method=S256"));
413 assert!(u.contains("client_id=parslee-car"));
414 assert!(u.contains("provider=microsoft"));
415 }
416
417 #[test]
418 fn api_base_precedence() {
419 assert_eq!(api_base(Some("https://x.test/")), "https://x.test");
420 }
421
422 mod mock {
429 use std::io::{Read, Write};
430 use std::net::TcpListener;
431 use std::sync::{Arc, Mutex};
432 use std::thread;
433
434 pub struct Recorded {
435 pub method: String,
436 pub path: String,
437 pub authorization: Option<String>,
438 pub content_type: Option<String>,
439 pub body: String,
440 }
441
442 pub struct Mock {
443 pub base: String,
444 pub recorded: Arc<Mutex<Vec<Recorded>>>,
445 handle: Option<thread::JoinHandle<()>>,
446 }
447
448 impl Drop for Mock {
449 fn drop(&mut self) {
450 if let Some(h) = self.handle.take() {
451 let _ = h.join();
452 }
453 }
454 }
455
456 fn find(hay: &[u8], needle: &[u8]) -> Option<usize> {
457 hay.windows(needle.len()).position(|w| w == needle)
458 }
459
460 pub fn start(
461 expected: usize,
462 respond: impl Fn(&Recorded) -> (u16, String) + Send + 'static,
463 ) -> Mock {
464 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
465 let port = listener.local_addr().unwrap().port();
466 let recorded = Arc::new(Mutex::new(Vec::new()));
467 let rec = recorded.clone();
468 let handle = thread::spawn(move || {
469 for _ in 0..expected {
470 let (mut stream, _) = listener.accept().unwrap();
471 let mut buf = Vec::new();
472 let mut tmp = [0u8; 1024];
473 loop {
474 let n = stream.read(&mut tmp).unwrap();
475 if n == 0 {
476 break;
477 }
478 buf.extend_from_slice(&tmp[..n]);
479 let Some(hdr_end) = find(&buf, b"\r\n\r\n") else {
480 continue;
481 };
482 let headers = String::from_utf8_lossy(&buf[..hdr_end]).into_owned();
483 let content_length = headers
484 .lines()
485 .find_map(|l| {
486 let (k, v) = l.split_once(':')?;
487 if k.eq_ignore_ascii_case("content-length") {
488 v.trim().parse::<usize>().ok()
489 } else {
490 None
491 }
492 })
493 .unwrap_or(0);
494 let body_start = hdr_end + 4;
495 while buf.len() < body_start + content_length {
496 let n = stream.read(&mut tmp).unwrap();
497 if n == 0 {
498 break;
499 }
500 buf.extend_from_slice(&tmp[..n]);
501 }
502 let mut header_lines = headers.lines();
503 let req_line = header_lines.next().unwrap_or("");
504 let mut rl = req_line.split_whitespace();
505 let method = rl.next().unwrap_or("").to_string();
506 let path = rl.next().unwrap_or("").to_string();
507 let mut authorization = None;
508 let mut content_type = None;
509 for l in header_lines {
510 if let Some((k, v)) = l.split_once(':') {
511 if k.eq_ignore_ascii_case("authorization") {
512 authorization = Some(v.trim().to_string());
513 } else if k.eq_ignore_ascii_case("content-type") {
514 content_type = Some(v.trim().to_string());
515 }
516 }
517 }
518 let body = String::from_utf8_lossy(
519 &buf[body_start..(body_start + content_length).min(buf.len())],
520 )
521 .into_owned();
522 let r = Recorded {
523 method,
524 path,
525 authorization,
526 content_type,
527 body,
528 };
529 let (code, resp_body) = respond(&r);
530 rec.lock().unwrap().push(r);
531 let resp = format!(
532 "HTTP/1.1 {code} OK\r\ncontent-type: application/json\r\n\
533 content-length: {}\r\nconnection: close\r\n\r\n{}",
534 resp_body.len(),
535 resp_body
536 );
537 stream.write_all(resp.as_bytes()).unwrap();
538 let _ = stream.flush();
539 break;
540 }
541 }
542 });
543 Mock {
544 base: format!("http://127.0.0.1:{port}"),
545 recorded,
546 handle: Some(handle),
547 }
548 }
549 }
550
551 #[tokio::test]
552 async fn exchange_code_round_trips_token() {
553 let mock = mock::start(1, |_r| {
554 (
555 200,
556 r#"{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"Bearer"}"#
557 .to_string(),
558 )
559 });
560 let token = exchange_code(
561 &mock.base,
562 "parslee-car",
563 "http://localhost:1/cb",
564 "thecode",
565 "theverifier",
566 )
567 .await
568 .unwrap();
569 assert_eq!(token.access_token, "a");
570 assert_eq!(token.refresh_token, "r");
571 assert_eq!(token.expires_in, 3600);
572
573 let reqs = mock.recorded.lock().unwrap();
574 assert_eq!(reqs.len(), 1);
575 assert_eq!(reqs[0].method, "POST");
576 assert_eq!(reqs[0].path, "/connect/token");
577 assert!(reqs[0].body.contains("grant_type=authorization_code"));
578 assert!(reqs[0].body.contains("code=thecode"));
579 assert!(reqs[0].body.contains("code_verifier=theverifier"));
580 }
581
582 #[tokio::test]
583 async fn refresh_grant_round_trips_token() {
584 let mock = mock::start(1, |_r| {
587 (
588 200,
589 r#"{"access_token":"a2","expires_in":3600,"token_type":"Bearer"}"#.to_string(),
590 )
591 });
592 let tokens = refresh_grant(&mock.base, "the-refresh-token")
593 .await
594 .unwrap();
595 assert_eq!(tokens.access_token, "a2");
596 assert_eq!(tokens.refresh_token, None);
597 assert_eq!(tokens.expires_in, Some(3600));
598
599 let reqs = mock.recorded.lock().unwrap();
600 assert_eq!(reqs.len(), 1);
601 assert_eq!(reqs[0].method, "POST");
602 assert_eq!(reqs[0].path, "/connect/token");
603 assert!(reqs[0].body.contains("grant_type=refresh_token"));
604 assert!(reqs[0].body.contains("refresh_token=the-refresh-token"));
605 assert!(!reqs[0].body.contains("client_id"));
607 }
608
609 #[tokio::test]
610 async fn fetch_status_sends_bearer() {
611 std::env::set_var(PARSLEE_ACCESS_TOKEN_KEY, "test-token");
614
615 let mock = mock::start(1, |_r| (200, r#"{"authenticated":true}"#.to_string()));
616
617 let session = fetch_status(Some(&mock.base)).await.unwrap();
618 assert_eq!(session.as_deref(), Some(r#"{"authenticated":true}"#));
619
620 let reqs = mock.recorded.lock().unwrap();
621 assert_eq!(reqs.len(), 1);
622 let sess = &reqs[0];
623 assert_eq!(sess.method, "GET");
624 assert_eq!(sess.path, "/connect/session");
625 assert_eq!(sess.authorization.as_deref(), Some("Bearer test-token"));
626
627 std::env::remove_var(PARSLEE_ACCESS_TOKEN_KEY);
628 }
629}