1use base64::Engine;
11use serde::Deserialize;
12use sha2::{Digest, Sha256};
13
14use car_secrets::{SecretRef, SecretStore};
15
16pub const PARSLEE_ACCESS_TOKEN_KEY: &str = "PARSLEE_ACCESS_TOKEN";
17pub const PARSLEE_REFRESH_TOKEN_KEY: &str = "PARSLEE_REFRESH_TOKEN";
18pub const PARSLEE_EXPIRES_AT_KEY: &str = "PARSLEE_ACCESS_TOKEN_EXPIRES_AT";
19pub const PARSLEE_API_BASE_KEY: &str = "PARSLEE_API_BASE";
20pub const DEFAULT_API_BASE: &str = "https://api.parslee.ai";
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct TokenSet {
25 pub access_token: String,
26 pub refresh_token: String,
27 pub expires_in: u64,
28 pub token_type: String,
29}
30
31fn epoch_seconds() -> u64 {
32 std::time::SystemTime::now()
33 .duration_since(std::time::UNIX_EPOCH)
34 .map(|d| d.as_secs())
35 .unwrap_or(0)
36}
37
38pub fn pkce_verifier() -> String {
40 let raw = format!(
41 "{}{}",
42 uuid::Uuid::new_v4().simple(),
43 uuid::Uuid::new_v4().simple()
44 );
45 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw.as_bytes())
46}
47
48pub fn new_state() -> String {
50 uuid::Uuid::new_v4().simple().to_string()
51}
52
53pub fn pkce_challenge(verifier: &str) -> String {
55 let digest = Sha256::digest(verifier.as_bytes());
56 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
57}
58
59pub fn authorize_url(
61 api_base: &str,
62 client_id: &str,
63 redirect_uri: &str,
64 state: &str,
65 challenge: &str,
66 provider: Option<&str>,
67) -> Result<String, String> {
68 let mut url = reqwest::Url::parse(&format!(
69 "{}/connect/authorize",
70 api_base.trim_end_matches('/')
71 ))
72 .map_err(|e| format!("build authorize URL: {e}"))?;
73 url.query_pairs_mut()
74 .append_pair("client_id", client_id)
75 .append_pair("redirect_uri", redirect_uri)
76 .append_pair("response_type", "code")
77 .append_pair("scope", "openid profile email")
78 .append_pair("state", state)
79 .append_pair("code_challenge", challenge)
80 .append_pair("code_challenge_method", "S256");
81 if let Some(provider) = provider {
82 url.query_pairs_mut().append_pair("provider", provider);
83 }
84 Ok(url.to_string())
85}
86
87fn form_body(pairs: &[(&str, &str)]) -> String {
88 let mut s = String::new();
89 for (i, (k, v)) in pairs.iter().enumerate() {
90 if i > 0 {
91 s.push('&');
92 }
93 s.push_str(&urlencode(k));
94 s.push('=');
95 s.push_str(&urlencode(v));
96 }
97 s
98}
99
100fn urlencode(s: &str) -> String {
101 let mut out = String::with_capacity(s.len());
102 for b in s.bytes() {
103 match b {
104 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
105 out.push(b as char)
106 }
107 _ => out.push_str(&format!("%{b:02X}")),
108 }
109 }
110 out
111}
112
113pub async fn exchange_code(
115 api_base: &str,
116 client_id: &str,
117 redirect_uri: &str,
118 code: &str,
119 verifier: &str,
120) -> Result<TokenSet, String> {
121 let body = form_body(&[
122 ("grant_type", "authorization_code"),
123 ("client_id", client_id),
124 ("redirect_uri", redirect_uri),
125 ("code", code),
126 ("code_verifier", verifier),
127 ]);
128 let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
129 let response = reqwest::Client::new()
130 .post(token_url)
131 .header("content-type", "application/x-www-form-urlencoded")
132 .body(body)
133 .send()
134 .await
135 .map_err(|e| format!("exchange Parslee authorization code: {e}"))?;
136 let status = response.status();
137 let text = response
138 .text()
139 .await
140 .map_err(|e| format!("read token response: {e}"))?;
141 if !status.is_success() {
142 return Err(format!("Parslee token exchange failed: HTTP {status}: {text}"));
143 }
144 let token: TokenSet =
145 serde_json::from_str(&text).map_err(|e| format!("parse token response: {e}"))?;
146 if !token.token_type.eq_ignore_ascii_case("bearer") {
147 return Err(format!("unexpected Parslee token_type `{}`", token.token_type));
148 }
149 Ok(token)
150}
151
152fn put(key: &str, value: &str) -> Result<(), String> {
153 SecretStore::new()
154 .put(&SecretRef::with_default_service(key), value)
155 .map_err(|e| format!("store {key}: {e}"))
156}
157
158pub fn store_tokens(api_base: &str, token: &TokenSet) -> Result<(), String> {
161 put(PARSLEE_ACCESS_TOKEN_KEY, &token.access_token)?;
162 put(PARSLEE_REFRESH_TOKEN_KEY, &token.refresh_token)?;
163 put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'))?;
164 put(
165 PARSLEE_EXPIRES_AT_KEY,
166 &(epoch_seconds() + token.expires_in).to_string(),
167 )?;
168 Ok(())
169}
170
171pub fn clear_tokens() -> Result<(), String> {
173 let store = SecretStore::new();
174 for key in [
175 PARSLEE_ACCESS_TOKEN_KEY,
176 PARSLEE_REFRESH_TOKEN_KEY,
177 PARSLEE_EXPIRES_AT_KEY,
178 PARSLEE_API_BASE_KEY,
179 ] {
180 let _ = store.delete(&SecretRef::with_default_service(key));
181 }
182 Ok(())
183}
184
185pub fn access_token() -> Option<String> {
187 car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)
188}
189
190pub fn api_base(override_: Option<&str>) -> String {
192 override_
193 .map(|s| s.trim_end_matches('/').to_string())
194 .or_else(|| car_secrets::resolve_env_or_keychain(PARSLEE_API_BASE_KEY))
195 .unwrap_or_else(|| DEFAULT_API_BASE.to_string())
196}
197
198pub async fn fetch_status(api_base_override: Option<&str>) -> Result<Option<String>, String> {
201 let Some(access) = access_token() else {
202 return Ok(None);
203 };
204 let base = api_base(api_base_override);
205 let response = reqwest::Client::new()
206 .get(format!("{}/connect/session", base.trim_end_matches('/')))
207 .bearer_auth(access)
208 .send()
209 .await
210 .map_err(|e| format!("fetch Parslee session: {e}"))?;
211 let status = response.status();
212 let text = response
213 .text()
214 .await
215 .map_err(|e| format!("read Parslee session response: {e}"))?;
216 if !status.is_success() {
217 return Err(format!("Parslee session check failed: HTTP {status}: {text}"));
218 }
219 Ok(Some(text))
220}
221
222#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn pkce_challenge_is_s256_urlsafe_nopad() {
235 let v = pkce_verifier();
236 let c = pkce_challenge(&v);
237 assert!(!c.contains('=') && !c.contains('+') && !c.contains('/'));
238 assert_eq!(c, pkce_challenge(&v)); }
240
241 #[test]
242 fn authorize_url_has_pkce_and_provider() {
243 let u = authorize_url(
244 "https://api.parslee.ai/",
245 "parslee-car",
246 "http://localhost:8765/auth/callback",
247 "st8",
248 "chal",
249 Some("microsoft"),
250 )
251 .unwrap();
252 assert!(u.starts_with("https://api.parslee.ai/connect/authorize?"));
253 assert!(u.contains("code_challenge=chal"));
254 assert!(u.contains("code_challenge_method=S256"));
255 assert!(u.contains("client_id=parslee-car"));
256 assert!(u.contains("provider=microsoft"));
257 }
258
259 #[test]
260 fn api_base_precedence() {
261 assert_eq!(api_base(Some("https://x.test/")), "https://x.test");
262 }
263
264 mod mock {
271 use std::io::{Read, Write};
272 use std::net::TcpListener;
273 use std::sync::{Arc, Mutex};
274 use std::thread;
275
276 pub struct Recorded {
277 pub method: String,
278 pub path: String,
279 pub authorization: Option<String>,
280 pub content_type: Option<String>,
281 pub body: String,
282 }
283
284 pub struct Mock {
285 pub base: String,
286 pub recorded: Arc<Mutex<Vec<Recorded>>>,
287 handle: Option<thread::JoinHandle<()>>,
288 }
289
290 impl Drop for Mock {
291 fn drop(&mut self) {
292 if let Some(h) = self.handle.take() {
293 let _ = h.join();
294 }
295 }
296 }
297
298 fn find(hay: &[u8], needle: &[u8]) -> Option<usize> {
299 hay.windows(needle.len()).position(|w| w == needle)
300 }
301
302 pub fn start(
303 expected: usize,
304 respond: impl Fn(&Recorded) -> (u16, String) + Send + 'static,
305 ) -> Mock {
306 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
307 let port = listener.local_addr().unwrap().port();
308 let recorded = Arc::new(Mutex::new(Vec::new()));
309 let rec = recorded.clone();
310 let handle = thread::spawn(move || {
311 for _ in 0..expected {
312 let (mut stream, _) = listener.accept().unwrap();
313 let mut buf = Vec::new();
314 let mut tmp = [0u8; 1024];
315 loop {
316 let n = stream.read(&mut tmp).unwrap();
317 if n == 0 {
318 break;
319 }
320 buf.extend_from_slice(&tmp[..n]);
321 let Some(hdr_end) = find(&buf, b"\r\n\r\n") else {
322 continue;
323 };
324 let headers = String::from_utf8_lossy(&buf[..hdr_end]).into_owned();
325 let content_length = headers
326 .lines()
327 .find_map(|l| {
328 let (k, v) = l.split_once(':')?;
329 if k.eq_ignore_ascii_case("content-length") {
330 v.trim().parse::<usize>().ok()
331 } else {
332 None
333 }
334 })
335 .unwrap_or(0);
336 let body_start = hdr_end + 4;
337 while buf.len() < body_start + content_length {
338 let n = stream.read(&mut tmp).unwrap();
339 if n == 0 {
340 break;
341 }
342 buf.extend_from_slice(&tmp[..n]);
343 }
344 let mut header_lines = headers.lines();
345 let req_line = header_lines.next().unwrap_or("");
346 let mut rl = req_line.split_whitespace();
347 let method = rl.next().unwrap_or("").to_string();
348 let path = rl.next().unwrap_or("").to_string();
349 let mut authorization = None;
350 let mut content_type = None;
351 for l in header_lines {
352 if let Some((k, v)) = l.split_once(':') {
353 if k.eq_ignore_ascii_case("authorization") {
354 authorization = Some(v.trim().to_string());
355 } else if k.eq_ignore_ascii_case("content-type") {
356 content_type = Some(v.trim().to_string());
357 }
358 }
359 }
360 let body = String::from_utf8_lossy(
361 &buf[body_start..(body_start + content_length).min(buf.len())],
362 )
363 .into_owned();
364 let r = Recorded {
365 method,
366 path,
367 authorization,
368 content_type,
369 body,
370 };
371 let (code, resp_body) = respond(&r);
372 rec.lock().unwrap().push(r);
373 let resp = format!(
374 "HTTP/1.1 {code} OK\r\ncontent-type: application/json\r\n\
375 content-length: {}\r\nconnection: close\r\n\r\n{}",
376 resp_body.len(),
377 resp_body
378 );
379 stream.write_all(resp.as_bytes()).unwrap();
380 let _ = stream.flush();
381 break;
382 }
383 }
384 });
385 Mock {
386 base: format!("http://127.0.0.1:{port}"),
387 recorded,
388 handle: Some(handle),
389 }
390 }
391 }
392
393 #[tokio::test]
394 async fn exchange_code_round_trips_token() {
395 let mock = mock::start(1, |_r| {
396 (
397 200,
398 r#"{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"Bearer"}"#
399 .to_string(),
400 )
401 });
402 let token = exchange_code(
403 &mock.base,
404 "parslee-car",
405 "http://localhost:1/cb",
406 "thecode",
407 "theverifier",
408 )
409 .await
410 .unwrap();
411 assert_eq!(token.access_token, "a");
412 assert_eq!(token.refresh_token, "r");
413 assert_eq!(token.expires_in, 3600);
414
415 let reqs = mock.recorded.lock().unwrap();
416 assert_eq!(reqs.len(), 1);
417 assert_eq!(reqs[0].method, "POST");
418 assert_eq!(reqs[0].path, "/connect/token");
419 assert!(reqs[0].body.contains("grant_type=authorization_code"));
420 assert!(reqs[0].body.contains("code=thecode"));
421 assert!(reqs[0].body.contains("code_verifier=theverifier"));
422 }
423
424 #[tokio::test]
425 async fn fetch_status_sends_bearer() {
426 std::env::set_var(PARSLEE_ACCESS_TOKEN_KEY, "test-token");
429
430 let mock = mock::start(1, |_r| (200, r#"{"authenticated":true}"#.to_string()));
431
432 let session = fetch_status(Some(&mock.base)).await.unwrap();
433 assert_eq!(session.as_deref(), Some(r#"{"authenticated":true}"#));
434
435 let reqs = mock.recorded.lock().unwrap();
436 assert_eq!(reqs.len(), 1);
437 let sess = &reqs[0];
438 assert_eq!(sess.method, "GET");
439 assert_eq!(sess.path, "/connect/session");
440 assert_eq!(sess.authorization.as_deref(), Some("Bearer test-token"));
441
442 std::env::remove_var(PARSLEE_ACCESS_TOKEN_KEY);
443 }
444}