use base64::Engine;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use car_secrets::{SecretRef, SecretStore};
pub const PARSLEE_ACCESS_TOKEN_KEY: &str = "PARSLEE_ACCESS_TOKEN";
pub const PARSLEE_REFRESH_TOKEN_KEY: &str = "PARSLEE_REFRESH_TOKEN";
pub const PARSLEE_EXPIRES_AT_KEY: &str = "PARSLEE_ACCESS_TOKEN_EXPIRES_AT";
pub const PARSLEE_API_BASE_KEY: &str = "PARSLEE_API_BASE";
pub const DEFAULT_API_BASE: &str = "https://api.parslee.ai";
#[derive(Debug, Clone, Deserialize)]
pub struct TokenSet {
pub access_token: String,
pub refresh_token: String,
pub expires_in: u64,
pub token_type: String,
}
fn epoch_seconds() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub fn pkce_verifier() -> String {
let raw = format!(
"{}{}",
uuid::Uuid::new_v4().simple(),
uuid::Uuid::new_v4().simple()
);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw.as_bytes())
}
pub fn new_state() -> String {
uuid::Uuid::new_v4().simple().to_string()
}
pub fn pkce_challenge(verifier: &str) -> String {
let digest = Sha256::digest(verifier.as_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
}
pub fn authorize_url(
api_base: &str,
client_id: &str,
redirect_uri: &str,
state: &str,
challenge: &str,
provider: Option<&str>,
) -> Result<String, String> {
let mut url = reqwest::Url::parse(&format!(
"{}/connect/authorize",
api_base.trim_end_matches('/')
))
.map_err(|e| format!("build authorize URL: {e}"))?;
url.query_pairs_mut()
.append_pair("client_id", client_id)
.append_pair("redirect_uri", redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", state)
.append_pair("code_challenge", challenge)
.append_pair("code_challenge_method", "S256");
if let Some(provider) = provider {
url.query_pairs_mut().append_pair("provider", provider);
}
Ok(url.to_string())
}
fn form_body(pairs: &[(&str, &str)]) -> String {
let mut s = String::new();
for (i, (k, v)) in pairs.iter().enumerate() {
if i > 0 {
s.push('&');
}
s.push_str(&urlencode(k));
s.push('=');
s.push_str(&urlencode(v));
}
s
}
fn urlencode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char)
}
_ => out.push_str(&format!("%{b:02X}")),
}
}
out
}
pub async fn exchange_code(
api_base: &str,
client_id: &str,
redirect_uri: &str,
code: &str,
verifier: &str,
) -> Result<TokenSet, String> {
let body = form_body(&[
("grant_type", "authorization_code"),
("client_id", client_id),
("redirect_uri", redirect_uri),
("code", code),
("code_verifier", verifier),
]);
let token_url = format!("{}/connect/token", api_base.trim_end_matches('/'));
let response = reqwest::Client::new()
.post(token_url)
.header("content-type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.map_err(|e| format!("exchange Parslee authorization code: {e}"))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| format!("read token response: {e}"))?;
if !status.is_success() {
return Err(format!("Parslee token exchange failed: HTTP {status}: {text}"));
}
let token: TokenSet =
serde_json::from_str(&text).map_err(|e| format!("parse token response: {e}"))?;
if !token.token_type.eq_ignore_ascii_case("bearer") {
return Err(format!("unexpected Parslee token_type `{}`", token.token_type));
}
Ok(token)
}
fn put(key: &str, value: &str) -> Result<(), String> {
SecretStore::new()
.put(&SecretRef::with_default_service(key), value)
.map_err(|e| format!("store {key}: {e}"))
}
pub fn store_tokens(api_base: &str, token: &TokenSet) -> Result<(), String> {
put(PARSLEE_ACCESS_TOKEN_KEY, &token.access_token)?;
put(PARSLEE_REFRESH_TOKEN_KEY, &token.refresh_token)?;
put(PARSLEE_API_BASE_KEY, api_base.trim_end_matches('/'))?;
put(
PARSLEE_EXPIRES_AT_KEY,
&(epoch_seconds() + token.expires_in).to_string(),
)?;
Ok(())
}
pub fn clear_tokens() -> Result<(), String> {
let store = SecretStore::new();
for key in [
PARSLEE_ACCESS_TOKEN_KEY,
PARSLEE_REFRESH_TOKEN_KEY,
PARSLEE_EXPIRES_AT_KEY,
PARSLEE_API_BASE_KEY,
] {
let _ = store.delete(&SecretRef::with_default_service(key));
}
Ok(())
}
pub fn access_token() -> Option<String> {
car_secrets::resolve_env_or_keychain(PARSLEE_ACCESS_TOKEN_KEY)
}
pub fn api_base(override_: Option<&str>) -> String {
override_
.map(|s| s.trim_end_matches('/').to_string())
.or_else(|| car_secrets::resolve_env_or_keychain(PARSLEE_API_BASE_KEY))
.unwrap_or_else(|| DEFAULT_API_BASE.to_string())
}
pub async fn fetch_status(api_base_override: Option<&str>) -> Result<Option<String>, String> {
let Some(access) = access_token() else {
return Ok(None);
};
let base = api_base(api_base_override);
let response = reqwest::Client::new()
.get(format!("{}/connect/session", base.trim_end_matches('/')))
.bearer_auth(access)
.send()
.await
.map_err(|e| format!("fetch Parslee session: {e}"))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|e| format!("read Parslee session response: {e}"))?;
if !status.is_success() {
return Err(format!("Parslee session check failed: HTTP {status}: {text}"));
}
Ok(Some(text))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pkce_challenge_is_s256_urlsafe_nopad() {
let v = pkce_verifier();
let c = pkce_challenge(&v);
assert!(!c.contains('=') && !c.contains('+') && !c.contains('/'));
assert_eq!(c, pkce_challenge(&v)); }
#[test]
fn authorize_url_has_pkce_and_provider() {
let u = authorize_url(
"https://api.parslee.ai/",
"parslee-car",
"http://localhost:8765/auth/callback",
"st8",
"chal",
Some("microsoft"),
)
.unwrap();
assert!(u.starts_with("https://api.parslee.ai/connect/authorize?"));
assert!(u.contains("code_challenge=chal"));
assert!(u.contains("code_challenge_method=S256"));
assert!(u.contains("client_id=parslee-car"));
assert!(u.contains("provider=microsoft"));
}
#[test]
fn api_base_precedence() {
assert_eq!(api_base(Some("https://x.test/")), "https://x.test");
}
mod mock {
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use std::thread;
pub struct Recorded {
pub method: String,
pub path: String,
pub authorization: Option<String>,
pub content_type: Option<String>,
pub body: String,
}
pub struct Mock {
pub base: String,
pub recorded: Arc<Mutex<Vec<Recorded>>>,
handle: Option<thread::JoinHandle<()>>,
}
impl Drop for Mock {
fn drop(&mut self) {
if let Some(h) = self.handle.take() {
let _ = h.join();
}
}
}
fn find(hay: &[u8], needle: &[u8]) -> Option<usize> {
hay.windows(needle.len()).position(|w| w == needle)
}
pub fn start(
expected: usize,
respond: impl Fn(&Recorded) -> (u16, String) + Send + 'static,
) -> Mock {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
let recorded = Arc::new(Mutex::new(Vec::new()));
let rec = recorded.clone();
let handle = thread::spawn(move || {
for _ in 0..expected {
let (mut stream, _) = listener.accept().unwrap();
let mut buf = Vec::new();
let mut tmp = [0u8; 1024];
loop {
let n = stream.read(&mut tmp).unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
let Some(hdr_end) = find(&buf, b"\r\n\r\n") else {
continue;
};
let headers = String::from_utf8_lossy(&buf[..hdr_end]).into_owned();
let content_length = headers
.lines()
.find_map(|l| {
let (k, v) = l.split_once(':')?;
if k.eq_ignore_ascii_case("content-length") {
v.trim().parse::<usize>().ok()
} else {
None
}
})
.unwrap_or(0);
let body_start = hdr_end + 4;
while buf.len() < body_start + content_length {
let n = stream.read(&mut tmp).unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
}
let mut header_lines = headers.lines();
let req_line = header_lines.next().unwrap_or("");
let mut rl = req_line.split_whitespace();
let method = rl.next().unwrap_or("").to_string();
let path = rl.next().unwrap_or("").to_string();
let mut authorization = None;
let mut content_type = None;
for l in header_lines {
if let Some((k, v)) = l.split_once(':') {
if k.eq_ignore_ascii_case("authorization") {
authorization = Some(v.trim().to_string());
} else if k.eq_ignore_ascii_case("content-type") {
content_type = Some(v.trim().to_string());
}
}
}
let body = String::from_utf8_lossy(
&buf[body_start..(body_start + content_length).min(buf.len())],
)
.into_owned();
let r = Recorded {
method,
path,
authorization,
content_type,
body,
};
let (code, resp_body) = respond(&r);
rec.lock().unwrap().push(r);
let resp = format!(
"HTTP/1.1 {code} OK\r\ncontent-type: application/json\r\n\
content-length: {}\r\nconnection: close\r\n\r\n{}",
resp_body.len(),
resp_body
);
stream.write_all(resp.as_bytes()).unwrap();
let _ = stream.flush();
break;
}
}
});
Mock {
base: format!("http://127.0.0.1:{port}"),
recorded,
handle: Some(handle),
}
}
}
#[tokio::test]
async fn exchange_code_round_trips_token() {
let mock = mock::start(1, |_r| {
(
200,
r#"{"access_token":"a","refresh_token":"r","expires_in":3600,"token_type":"Bearer"}"#
.to_string(),
)
});
let token = exchange_code(
&mock.base,
"parslee-car",
"http://localhost:1/cb",
"thecode",
"theverifier",
)
.await
.unwrap();
assert_eq!(token.access_token, "a");
assert_eq!(token.refresh_token, "r");
assert_eq!(token.expires_in, 3600);
let reqs = mock.recorded.lock().unwrap();
assert_eq!(reqs.len(), 1);
assert_eq!(reqs[0].method, "POST");
assert_eq!(reqs[0].path, "/connect/token");
assert!(reqs[0].body.contains("grant_type=authorization_code"));
assert!(reqs[0].body.contains("code=thecode"));
assert!(reqs[0].body.contains("code_verifier=theverifier"));
}
#[tokio::test]
async fn fetch_status_sends_bearer() {
std::env::set_var(PARSLEE_ACCESS_TOKEN_KEY, "test-token");
let mock = mock::start(1, |_r| (200, r#"{"authenticated":true}"#.to_string()));
let session = fetch_status(Some(&mock.base)).await.unwrap();
assert_eq!(session.as_deref(), Some(r#"{"authenticated":true}"#));
let reqs = mock.recorded.lock().unwrap();
assert_eq!(reqs.len(), 1);
let sess = &reqs[0];
assert_eq!(sess.method, "GET");
assert_eq!(sess.path, "/connect/session");
assert_eq!(sess.authorization.as_deref(), Some("Bearer test-token"));
std::env::remove_var(PARSLEE_ACCESS_TOKEN_KEY);
}
}