use std::collections::HashMap;
use std::fmt;
use std::sync::Mutex;
use reqwest::cookie::CookieStore;
use reqwest::header::HeaderValue;
use url::Url;
#[derive(Default)]
pub struct NameKeyedJar {
inner: Mutex<HashMap<String, String>>,
}
impl fmt::Debug for NameKeyedJar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let mut names: Vec<&str> = guard.keys().map(String::as_str).collect();
names.sort_unstable();
f.debug_struct("NameKeyedJar")
.field("len", &guard.len())
.field("names", &names)
.finish()
}
}
fn is_unsafe_byte(b: u8) -> bool {
b < 0x20 || b == 0x7F
}
fn is_safe(s: &str) -> bool {
!s.bytes().any(is_unsafe_byte)
}
impl NameKeyedJar {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_pairs<I, S>(&self, pairs: I)
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
for raw in pairs {
let trimmed = raw.as_ref().trim();
if trimmed.is_empty() {
continue;
}
if let Some((name, value)) = trimmed.split_once('=') {
let name = name.trim();
let value = value.split(';').next().unwrap_or("").trim();
if name.is_empty() || !is_safe(name) || !is_safe(value) {
continue;
}
guard.insert(name.to_owned(), value.to_owned());
}
}
}
pub fn clear(&self) {
self.inner.lock().unwrap_or_else(|e| e.into_inner()).clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().unwrap_or_else(|e| e.into_inner()).len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn snapshot(&self) -> Vec<(String, String)> {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let mut entries: Vec<_> = guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
entries
}
}
impl CookieStore for NameKeyedJar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &HeaderValue>, _url: &Url) {
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
for header in cookie_headers {
let Ok(s) = header.to_str() else { continue };
let primary = s.split(';').next().unwrap_or("").trim();
if primary.is_empty() {
continue;
}
if let Some((name, value)) = primary.split_once('=') {
let name = name.trim();
let value = value.trim();
if name.is_empty() || !is_safe(name) || !is_safe(value) {
continue;
}
guard.insert(name.to_owned(), value.to_owned());
}
}
}
fn cookies(&self, _url: &Url) -> Option<HeaderValue> {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if guard.is_empty() {
return None;
}
let mut buf = String::new();
for (i, (name, value)) in guard.iter().enumerate() {
if i > 0 {
buf.push_str("; ");
}
buf.push_str(name);
buf.push('=');
buf.push_str(value);
}
HeaderValue::from_str(&buf).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn url() -> Url {
Url::parse("https://localhost:5000/").unwrap()
}
#[test]
fn set_cookies_replaces_by_name() {
let jar = NameKeyedJar::new();
let h1 = HeaderValue::from_static("JSESSIONID=OLD; Path=/sso; HttpOnly");
let h2 = HeaderValue::from_static("JSESSIONID=NEW; Path=/; Secure");
jar.set_cookies(&mut [h1].iter(), &url());
jar.set_cookies(&mut [h2].iter(), &url());
assert_eq!(jar.snapshot(), vec![("JSESSIONID".into(), "NEW".into())]);
}
#[test]
fn set_pairs_replaces_by_name() {
let jar = NameKeyedJar::new();
jar.set_pairs(["X=A"]);
jar.set_pairs(["X=B"]);
assert_eq!(jar.snapshot(), vec![("X".into(), "B".into())]);
}
#[test]
fn set_pairs_strips_attributes() {
let jar = NameKeyedJar::new();
jar.set_pairs(["JSESSIONID=ABC; Domain=.ibkr.com; Secure"]);
assert_eq!(jar.snapshot(), vec![("JSESSIONID".into(), "ABC".into())]);
}
#[test]
fn cookies_combines_pairs_with_semicolon_separator() {
let jar = NameKeyedJar::new();
jar.set_pairs(["b=2", "a=1", "c=3"]);
let s = jar.cookies(&url()).unwrap().to_str().unwrap().to_owned();
for needle in ["a=1", "b=2", "c=3"] {
assert!(s.contains(needle), "expected {needle} in {s:?}");
}
assert_eq!(s.matches("; ").count(), 2);
}
#[test]
fn cookies_single_entry_has_no_trailing_separator() {
let jar = NameKeyedJar::new();
jar.set_pairs(["only=one"]);
let header = jar.cookies(&url()).unwrap();
assert_eq!(header.to_str().unwrap(), "only=one");
}
#[test]
fn empty_jar_returns_none() {
let jar = NameKeyedJar::new();
assert!(jar.cookies(&url()).is_none());
assert!(jar.is_empty());
assert_eq!(jar.len(), 0);
}
#[test]
fn clear_empties_jar() {
let jar = NameKeyedJar::new();
jar.set_pairs(["a=1", "b=2"]);
assert_eq!(jar.len(), 2);
jar.clear();
assert!(jar.is_empty());
assert!(jar.cookies(&url()).is_none());
}
#[test]
fn snapshot_returns_sorted_by_name() {
let jar = NameKeyedJar::new();
jar.set_pairs(["c=3", "a=1", "b=2"]);
assert_eq!(
jar.snapshot(),
vec![
("a".into(), "1".into()),
("b".into(), "2".into()),
("c".into(), "3".into()),
]
);
}
#[test]
fn set_cookies_handles_iterator_of_multiple_values() {
let jar = NameKeyedJar::new();
let h1 = HeaderValue::from_static("a=1; Path=/");
let h2 = HeaderValue::from_static("b=2; HttpOnly");
let h3 = HeaderValue::from_static("c=3");
let mut iter = [&h1, &h2, &h3].into_iter();
jar.set_cookies(&mut iter, &url());
assert_eq!(jar.len(), 3);
}
#[test]
fn set_cookies_skips_non_ascii_header_values() {
let jar = NameKeyedJar::new();
let bad = HeaderValue::from_bytes(b"name=\xff").unwrap();
let good = HeaderValue::from_static("kept=yes");
jar.set_cookies(&mut [&bad, &good].into_iter(), &url());
assert_eq!(jar.snapshot(), vec![("kept".into(), "yes".into())]);
}
#[test]
fn set_pairs_skips_malformed_inputs() {
let jar = NameKeyedJar::new();
jar.set_pairs([
"", " ", "novalue", "=onlyvalue", "name=", "n=a=b", ]);
let snap = jar.snapshot();
assert!(snap.contains(&("name".into(), String::new())));
assert!(snap.contains(&("n".into(), "a=b".into())));
assert_eq!(snap.len(), 2);
}
#[test]
fn set_pairs_rejects_crlf_injection_in_value() {
let jar = NameKeyedJar::new();
jar.set_pairs(["X=foo\r\nInjected: yes", "kept=yes"]);
assert_eq!(jar.snapshot(), vec![("kept".into(), "yes".into())]);
let header = jar.cookies(&url()).unwrap();
assert_eq!(header.to_str().unwrap(), "kept=yes");
}
#[test]
fn rejects_control_chars_in_name() {
let jar = NameKeyedJar::new();
jar.set_pairs(["bad\rname=v", "ok=v"]);
assert_eq!(jar.snapshot(), vec![("ok".into(), "v".into())]);
}
#[test]
fn debug_does_not_leak_cookie_values() {
let jar = NameKeyedJar::new();
jar.set_pairs(["JSESSIONID=SECRET-VALUE-DO-NOT-LOG"]);
let rendered = format!("{jar:?}");
assert!(
!rendered.contains("SECRET-VALUE"),
"Debug must not leak cookie values, got: {rendered}"
);
assert!(rendered.contains("JSESSIONID"));
assert!(rendered.contains("len"));
}
}