1use base64::Engine;
11use serde::Deserialize;
12use sha2::{Digest, Sha256};
13
14use car_secrets::{SecretRef, SecretStore};
15
16pub const PARSLEE_ACCESS_TOKEN_KEY: &str = "PARSLEE_ACCESS_TOKEN";
17pub const PARSLEE_REFRESH_TOKEN_KEY: &str = "PARSLEE_REFRESH_TOKEN";
18pub const PARSLEE_EXPIRES_AT_KEY: &str = "PARSLEE_ACCESS_TOKEN_EXPIRES_AT";
19pub const PARSLEE_API_BASE_KEY: &str = "PARSLEE_API_BASE";
20pub const DEFAULT_API_BASE: &str = "https://api.parslee.ai";
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct TokenSet {
25 pub access_token: String,
26 pub refresh_token: String,
27 pub expires_in: u64,
28 pub token_type: String,
29}
30
31fn epoch_seconds() -> u64 {
32 std::time::SystemTime::now()
33 .duration_since(std::time::UNIX_EPOCH)
34 .map(|d| d.as_secs())
35 .unwrap_or(0)
36}
37
38pub fn pkce_verifier() -> String {
40 let raw = format!(
41 "{}{}",
42 uuid::Uuid::new_v4().simple(),
43 uuid::Uuid::new_v4().simple()
44 );
45 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw.as_bytes())
46}
47
48pub fn new_state() -> String {
50 uuid::Uuid::new_v4().simple().to_string()
51}
52
53pub fn pkce_challenge(verifier: &str) -> String {
55 let digest = Sha256::digest(verifier.as_bytes());
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
57}
58
59pub fn authorize_url(
61 api_base: &str,
62 client_id: &str,
63 redirect_uri: &str,
64 state: &str,
65 challenge: &str,
66 provider: Option<&str>,
67) -> Result<String, String> {
68 let mut url = reqwest::Url::parse(&format!(
69 "{}/connect/authorize",
70 api_base.trim_end_matches('/')
71 ))
72 .map_err(|e| format!("build authorize URL: {e}"))?;
73 url.query_pairs_mut()
74 .append_pair("client_id", client_id)
75 .append_pair("redirect_uri", redirect_uri)
76 .append_pair("response_type", "code")
77 .append_pair("scope", "openid profile email")
78 .append_pair("state", state)
79 .append_pair("code_challenge", challenge)
80 .append_pair("code_challenge_method", "S256");
81 if let Some(provider) = provider {
82 url.query_pairs_mut().append_pair("provider", provider);
83 }
84 Ok(url.to_string())
85}
86
87fn form_body(pairs: &[(&str, &str)]) -> String {
88 let mut s = String::new();
89 for (i, (k, v)) in pairs.iter().enumerate() {
90 if i > 0 {
91 s.push('&');
92 }
93 s.push_str(&urlencode(k));
94 s.push('=');
95 s.push_str(&urlencode(v));
96 }
97 s
98}
99
100fn urlencode(s: &str) -> String {
101 let mut out = String::with_capacity(s.len());
102 for b in s.bytes() {
103 match b {
104 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
105 out.push(b as char)
106 }
107 _ => out.push_str(&format!("%{b:02X}")),
108 }
109 }
110 out
111}
112
113pub async fn exchange_code(
115 api_base: &str,
116 client_id: &str,
117 redirect_uri: &str,
118 code: &str,
119 verifier: &str,
120) -> Result<TokenSet, String> {
121 let body = form_body(&[
122 ("grant_type", "authorization_code"),
123 ("client_id", client_id),
124 ("redirect_uri", redirect_uri),
125 ("code", code),
126 ("code_verifier", verifier),
127 ]);
128 let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
129 let response = reqwest::Client::new()
130 .post(token_url)
131 .header("content-type", "application/x-www-form-urlencoded")
132 .body(body)
133 .send()
134 .await
135 .map_err(|e| format!("exchange Parslee authorization code: {e}"))?;
136 let status = response.status();
137 let text = response
138 .text()
139 .await
140 .map_err(|e| format!("read token response: {e}"))?;
141 if !status.is_success() {
142 return Err(format!("Parslee token exchange failed: HTTP {status}: {text}"));
143 }
144 let token: TokenSet =
145 serde_json::from_str(&text).map_err(|e| format!("parse token response: {e}"))?;
146 if !token.token_type.eq_ignore_ascii_case("bearer") {
147 return Err(format!("unexpected Parslee token_type `{}`", token.token_type));
148 }
149 Ok(token)
150}
151
152fn put(key: &str, value: &str) -> Result<(), String> {
153 SecretStore::new()
154 .put(&SecretRef::with_default_service(key), value)
155 .map_err(|e| format!("store {key}: {e}"))
156}
157
158pub fn store_tokens(api_base: &str, token: &TokenSet) -> Result<(), String> {
161 put(PARSLEE_ACCESS_TOKEN_KEY, &token.access_token)?;
162 put(PARSLEE_REFRESH_TOKEN_KEY, &token.refresh_token)?;
163 put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'))?;
164 put(
165 PARSLEE_EXPIRES_AT_KEY,
166 &(epoch_seconds() + token.expires_in).to_string(),
167 )?;
168 Ok(())
169}
170
171pub fn clear_tokens() -> Result<(), String> {
173 let store = SecretStore::new();
174 for key in [
175 PARSLEE_ACCESS_TOKEN_KEY,
176 PARSLEE_REFRESH_TOKEN_KEY,
177 PARSLEE_EXPIRES_AT_KEY,
178 PARSLEE_API_BASE_KEY,
179 ] {
180 let _ = store.delete(&SecretRef::with_default_service(key));
181 }
182 Ok(())
183}
184
185pub fn access_token() -> Option<String> {
187 car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)
188}
189
190pub const REFRESH_SKEW_SECS: u64 = 120;
194
195#[derive(Debug, Clone)]
198pub struct RefreshedTokens {
199 pub access_token: String,
200 pub refresh_token: Option<String>,
201 pub expires_in: Option<u64>,
202}
203
204pub async fn refresh_grant(api_base: &str, refresh_token: &str) -> Result<RefreshedTokens, String> {
212 #[derive(Deserialize)]
213 struct Resp {
214 access_token: String,
215 #[serde(default)]
216 refresh_token: Option<String>,
217 #[serde(default)]
218 expires_in: Option<u64>,
219 }
220 let body = form_body(&[
221 ("grant_type", "refresh_token"),
222 ("refresh_token", refresh_token),
223 ]);
224 let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
225 let response = reqwest::Client::new()
226 .post(token_url)
227 .header("content-type", "application/x-www-form-urlencoded")
228 .body(body)
229 .send()
230 .await
231 .map_err(|e| format!("refresh Parslee token: {e}"))?;
232 let status = response.status();
233 let text = response
234 .text()
235 .await
236 .map_err(|e| format!("read Parslee token response: {e}"))?;
237 if !status.is_success() {
238 return Err(format!("refresh Parslee token: HTTP {status}: {text}"));
239 }
240 let r: Resp =
241 serde_json::from_str(&text).map_err(|e| format!("parse Parslee token response: {e}"))?;
242 Ok(RefreshedTokens {
243 access_token: r.access_token,
244 refresh_token: r.refresh_token,
245 expires_in: r.expires_in,
246 })
247}
248
249fn persist_refreshed(api_base: &str, t: &RefreshedTokens) {
253 let _ = put(PARSLEE_ACCESS_TOKEN_KEY, &t.access_token);
254 if let Some(refresh) = &t.refresh_token {
255 let _ = put(PARSLEE_REFRESH_TOKEN_KEY, refresh);
256 }
257 if let Some(expires_in) = t.expires_in {
258 let _ = put(
259 PARSLEE_EXPIRES_AT_KEY,
260 &(epoch_seconds() + expires_in).to_string(),
261 );
262 }
263 let _ = put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'));
264}
265
266pub async fn access_token_refreshing() -> Option<String> {
277 if let Ok(tok) = std::env::var(PARSLEE_ACCESS_TOKEN_KEY) {
279 if !tok.is_empty() {
280 return Some(tok);
281 }
282 }
283 let current = car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)?;
284 let expiring = car_secrets::resolve_env_or_keychain(PARSLEE_EXPIRES_AT_KEY)
287 .and_then(|s| s.trim().parse::<u64>().ok())
288 .map(|exp| epoch_seconds() + REFRESH_SKEW_SECS >= exp)
289 .unwrap_or(false);
290 if !expiring {
291 return Some(current);
292 }
293 let Some(refresh) = car_secrets::resolve_env_or_keychain(PARSLEE_REFRESH_TOKEN_KEY) else {
294 return Some(current);
295 };
296 let base = api_base(None);
297 match refresh_grant(&base, &refresh).await {
298 Ok(tokens) => {
299 let access = tokens.access_token.clone();
300 persist_refreshed(&base, &tokens);
301 Some(access)
302 }
303 Err(_) => Some(current),
307 }
308}
309
310pub async fn force_refresh() -> Option<String> {
324 if let Ok(tok) = std::env::var(PARSLEE_ACCESS_TOKEN_KEY) {
325 if !tok.is_empty() {
326 return None;
327 }
328 }
329 let refresh = car_secrets::resolve_env_or_keychain(PARSLEE_REFRESH_TOKEN_KEY)?;
330 let base = api_base(None);
331 match refresh_grant(&base, &refresh).await {
332 Ok(tokens) => {
333 let access = tokens.access_token.clone();
334 persist_refreshed(&base, &tokens);
335 Some(access)
336 }
337 Err(_) => None,
338 }
339}
340
341pub fn api_base(override_: Option<&str>) -> String {
343 override_
344 .map(|s| s.trim_end_matches('/').to_string())
345 .or_else(|| car_secrets::resolve_env_or_keychain(PARSLEE_API_BASE_KEY))
346 .unwrap_or_else(|| DEFAULT_API_BASE.to_string())
347}
348
349pub async fn fetch_status(api_base_override: Option<&str>) -> Result<Option<String>, String> {
352 let Some(access) = access_token() else {
353 return Ok(None);
354 };
355 let base = api_base(api_base_override);
356 let response = reqwest::Client::new()
357 .get(format!("{}/connect/session", base.trim_end_matches('/')))
358 .bearer_auth(access)
359 .send()
360 .await
361 .map_err(|e| format!("fetch Parslee session: {e}"))?;
362 let status = response.status();
363 let text = response
364 .text()
365 .await
366 .map_err(|e| format!("read Parslee session response: {e}"))?;
367 if !status.is_success() {
368 return Err(format!("Parslee session check failed: HTTP {status}: {text}"));
369 }
370 Ok(Some(text))
371}
372
373#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn pkce_challenge_is_s256_urlsafe_nopad() {
386 let v = pkce_verifier();
387 let c = pkce_challenge(&v);
388 assert!(!c.contains('=') && !c.contains('+') && !c.contains('/'));
389 assert_eq!(c, pkce_challenge(&v)); }
391
392 #[test]
393 fn authorize_url_has_pkce_and_provider() {
394 let u = authorize_url(
395 "https://api.parslee.ai/",
396 "parslee-car",
397 "http://localhost:8765/auth/callback",
398 "st8",
399 "chal",
400 Some("microsoft"),
401 )
402 .unwrap();
403 assert!(u.starts_with("https://api.parslee.ai/connect/authorize?"));
404 assert!(u.contains("code_challenge=chal"));
405 assert!(u.contains("code_challenge_method=S256"));
406 assert!(u.contains("client_id=parslee-car"));
407 assert!(u.contains("provider=microsoft"));
408 }
409
410 #[test]
411 fn api_base_precedence() {
412 assert_eq!(api_base(Some("https://x.test/")), "https://x.test");
413 }
414
415 mod mock {
422 use std::io::{Read, Write};
423 use std::net::TcpListener;
424 use std::sync::{Arc, Mutex};
425 use std::thread;
426
427 pub struct Recorded {
428 pub method: String,
429 pub path: String,
430 pub authorization: Option<String>,
431 pub content_type: Option<String>,
432 pub body: String,
433 }
434
435 pub struct Mock {
436 pub base: String,
437 pub recorded: Arc<Mutex<Vec<Recorded>>>,
438 handle: Option<thread::JoinHandle<()>>,
439 }
440
441 impl Drop for Mock {
442 fn drop(&mut self) {
443 if let Some(h) = self.handle.take() {
444 let _ = h.join();
445 }
446 }
447 }
448
449 fn find(hay: &[u8], needle: &[u8]) -> Option<usize> {
450 hay.windows(needle.len()).position(|w| w == needle)
451 }
452
453 pub fn start(
454 expected: usize,
455 respond: impl Fn(&Recorded) -> (u16, String) + Send + 'static,
456 ) -> Mock {
457 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
458 let port = listener.local_addr().unwrap().port();
459 let recorded = Arc::new(Mutex::new(Vec::new()));
460 let rec = recorded.clone();
461 let handle = thread::spawn(move || {
462 for _ in 0..expected {
463 let (mut stream, _) = listener.accept().unwrap();
464 let mut buf = Vec::new();
465 let mut tmp = [0u8; 1024];
466 loop {
467 let n = stream.read(&mut tmp).unwrap();
468 if n == 0 {
469 break;
470 }
471 buf.extend_from_slice(&tmp[..n]);
472 let Some(hdr_end) = find(&buf, b"\r\n\r\n") else {
473 continue;
474 };
475 let headers = String::from_utf8_lossy(&buf[..hdr_end]).into_owned();
476 let content_length = headers
477 .lines()
478 .find_map(|l| {
479 let (k, v) = l.split_once(':')?;
480 if k.eq_ignore_ascii_case("content-length") {
481 v.trim().parse::<usize>().ok()
482 } else {
483 None
484 }
485 })
486 .unwrap_or(0);
487 let body_start = hdr_end + 4;
488 while buf.len() < body_start + content_length {
489 let n = stream.read(&mut tmp).unwrap();
490 if n == 0 {
491 break;
492 }
493 buf.extend_from_slice(&tmp[..n]);
494 }
495 let mut header_lines = headers.lines();
496 let req_line = header_lines.next().unwrap_or("");
497 let mut rl = req_line.split_whitespace();
498 let method = rl.next().unwrap_or("").to_string();
499 let path = rl.next().unwrap_or("").to_string();
500 let mut authorization = None;
501 let mut content_type = None;
502 for l in header_lines {
503 if let Some((k, v)) = l.split_once(':') {
504 if k.eq_ignore_ascii_case("authorization") {
505 authorization = Some(v.trim().to_string());
506 } else if k.eq_ignore_ascii_case("content-type") {
507 content_type = Some(v.trim().to_string());
508 }
509 }
510 }
511 let body = String::from_utf8_lossy(
512 &buf[body_start..(body_start + content_length).min(buf.len())],
513 )
514 .into_owned();
515 let r = Recorded {
516 method,
517 path,
518 authorization,
519 content_type,
520 body,
521 };
522 let (code, resp_body) = respond(&r);
523 rec.lock().unwrap().push(r);
524 let resp = format!(
525 "HTTP/1.1 {code} OK\r\ncontent-type: application/json\r\n\
526 content-length: {}\r\nconnection: close\r\n\r\n{}",
527 resp_body.len(),
528 resp_body
529 );
530 stream.write_all(resp.as_bytes()).unwrap();
531 let _ = stream.flush();
532 break;
533 }
534 }
535 });
536 Mock {
537 base: format!("http://127.0.0.1:{port}"),
538 recorded,
539 handle: Some(handle),
540 }
541 }
542 }
543
544 #[tokio::test]
545 async fn exchange_code_round_trips_token() {
546 let mock = mock::start(1, |_r| {
547 (
548 200,
549 r#"{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"Bearer"}"#
550 .to_string(),
551 )
552 });
553 let token = exchange_code(
554 &mock.base,
555 "parslee-car",
556 "http://localhost:1/cb",
557 "thecode",
558 "theverifier",
559 )
560 .await
561 .unwrap();
562 assert_eq!(token.access_token, "a");
563 assert_eq!(token.refresh_token, "r");
564 assert_eq!(token.expires_in, 3600);
565
566 let reqs = mock.recorded.lock().unwrap();
567 assert_eq!(reqs.len(), 1);
568 assert_eq!(reqs[0].method, "POST");
569 assert_eq!(reqs[0].path, "/connect/token");
570 assert!(reqs[0].body.contains("grant_type=authorization_code"));
571 assert!(reqs[0].body.contains("code=thecode"));
572 assert!(reqs[0].body.contains("code_verifier=theverifier"));
573 }
574
575 #[tokio::test]
576 async fn refresh_grant_round_trips_token() {
577 let mock = mock::start(1, |_r| {
580 (
581 200,
582 r#"{"access_token":"a2","expires_in":3600,"token_type":"Bearer"}"#.to_string(),
583 )
584 });
585 let tokens = refresh_grant(&mock.base, "the-refresh-token").await.unwrap();
586 assert_eq!(tokens.access_token, "a2");
587 assert_eq!(tokens.refresh_token, None);
588 assert_eq!(tokens.expires_in, Some(3600));
589
590 let reqs = mock.recorded.lock().unwrap();
591 assert_eq!(reqs.len(), 1);
592 assert_eq!(reqs[0].method, "POST");
593 assert_eq!(reqs[0].path, "/connect/token");
594 assert!(reqs[0].body.contains("grant_type=refresh_token"));
595 assert!(reqs[0].body.contains("refresh_token=the-refresh-token"));
596 assert!(!reqs[0].body.contains("client_id"));
598 }
599
600 #[tokio::test]
601 async fn fetch_status_sends_bearer() {
602 std::env::set_var(PARSLEE_ACCESS_TOKEN_KEY, "test-token");
605
606 let mock = mock::start(1, |_r| (200, r#"{"authenticated":true}"#.to_string()));
607
608 let session = fetch_status(Some(&mock.base)).await.unwrap();
609 assert_eq!(session.as_deref(), Some(r#"{"authenticated":true}"#));
610
611 let reqs = mock.recorded.lock().unwrap();
612 assert_eq!(reqs.len(), 1);
613 let sess = &reqs[0];
614 assert_eq!(sess.method, "GET");
615 assert_eq!(sess.path, "/connect/session");
616 assert_eq!(sess.authorization.as_deref(), Some("Bearer test-token"));
617
618 std::env::remove_var(PARSLEE_ACCESS_TOKEN_KEY);
619 }
620}