1use crate::QuartzResult;
2use chrono::prelude::*;
3use hyper::http::uri::Scheme;
4use std::{
5 collections::HashSet,
6 convert::Infallible,
7 hash::Hash,
8 ops::{Deref, DerefMut},
9 path::{Path, PathBuf},
10 str::FromStr,
11};
12
13pub enum Field {
14 Domain,
15 Subdomains,
16 Path,
17 Secure,
18 ExpiresAt,
19 Name,
20 Value,
21}
22
23#[derive(Debug, Clone)]
24pub struct CookieError;
25
26#[derive(Debug, PartialEq, Eq, Clone)]
27pub struct Domain(String);
28
29impl Deref for Domain {
30 type Target = String;
31
32 fn deref(&self) -> &Self::Target {
33 &self.0
34 }
35}
36
37impl DerefMut for Domain {
38 fn deref_mut(&mut self) -> &mut Self::Target {
39 &mut self.0
40 }
41}
42
43impl From<&str> for Domain {
44 fn from(value: &str) -> Self {
45 Self::new(value)
46 }
47}
48
49impl Domain {
50 pub fn new<T>(s: T) -> Self
51 where
52 T: Into<String>,
53 {
54 let s = Self::canonicalize(&s.into());
55
56 Self(s)
57 }
58
59 pub fn canonicalize(value: &str) -> String {
74 let value = value.to_ascii_lowercase();
75 let value = value.trim();
76
77 let mut res = String::new();
78
79 let chars = value.chars();
80 let mut last = '*';
81 for ch in chars {
82 if ch == '.' && last == '.' {
83 continue;
84 }
85
86 last = ch;
87 res.push(ch);
88 }
89
90 res
91 }
92
93 #[must_use]
112 pub fn matches<T>(&self, other: T) -> bool
113 where
114 T: Into<Domain>,
115 {
116 let other: Domain = other.into();
117
118 if **self == *other {
119 return true;
120 }
121
122 let this_segments: Vec<&str> = self.as_segments().collect();
123 let other_segments = other.as_segments();
124 for (idx, other_seg) in other_segments.enumerate() {
125 if this_segments.len() <= idx {
126 break;
127 }
128
129 if other_seg != this_segments[idx] {
130 return false;
131 }
132 }
133
134 true
135 }
136
137 pub fn as_segments(&self) -> impl Iterator<Item = &str> {
150 self.split('.').filter(|s| !s.is_empty()).rev()
151 }
152}
153
154#[derive(Default)]
155pub struct CookieBuilder {
156 domain: Option<String>,
157 subdomains: bool,
158 path: Option<String>,
159 secure: bool,
160 expires_at: i64,
161 name: Option<String>,
162 value: Option<String>,
163}
164
165impl CookieBuilder {
166 pub fn domain<T>(&mut self, s: T) -> &mut Self
167 where
168 T: Into<String>,
169 {
170 self.domain = Some(s.into());
171 self
172 }
173
174 pub fn subdomains(&mut self, v: bool) -> &mut Self {
175 self.subdomains = v;
176 self
177 }
178
179 pub fn path<T>(&mut self, s: T) -> &mut Self
180 where
181 T: Into<String>,
182 {
183 self.path = Some(s.into());
184 self
185 }
186
187 pub fn secure(&mut self, v: bool) -> &mut Self {
188 self.secure = v;
189 self
190 }
191
192 pub fn expires_at(&mut self, v: i64) -> &mut Self {
193 self.expires_at = v;
194 self
195 }
196
197 pub fn name<T>(&mut self, s: T) -> &mut Self
198 where
199 T: Into<String>,
200 {
201 self.name = Some(s.into());
202 self
203 }
204
205 pub fn value<T>(&mut self, s: T) -> &mut Self
206 where
207 T: Into<String>,
208 {
209 self.value = Some(s.into());
210 self
211 }
212
213 pub fn build(self) -> QuartzResult<Cookie, CookieError> {
220 let domain = Domain::new(self.domain.ok_or(CookieError)?);
221 let name = self.name.ok_or(CookieError)?;
222 let value = self.value.ok_or(CookieError)?;
223
224 Ok(Cookie {
225 domain,
226 subdomains: self.subdomains,
227 path: PathAttr::from(self.path.unwrap_or_default().as_str()),
228 secure: self.secure,
229 expires_at: self.expires_at,
230 name,
231 value,
232 })
233 }
234}
235
236#[derive(Debug, Clone)]
237pub struct Cookie {
238 domain: Domain,
239 subdomains: bool,
240 path: PathAttr,
241 secure: bool,
242 expires_at: i64,
243 name: String,
244 value: String,
245}
246
247impl Eq for Cookie {}
248
249impl PartialEq for Cookie {
250 fn eq(&self, other: &Self) -> bool {
251 self.domain == other.domain && self.name == other.name
252 }
253}
254
255impl Hash for Cookie {
256 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
257 self.domain.hash(state);
258 self.name.hash(state);
259 }
260}
261
262impl ToString for Cookie {
263 fn to_string(&self) -> String {
284 format!(
285 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
286 *self.domain,
287 self.subdomains.to_string().to_uppercase(),
288 self.path.to_string(),
289 self.secure.to_string().to_uppercase(),
290 self.expires_at,
291 self.name,
292 self.value,
293 )
294 }
295}
296
297impl FromStr for Cookie {
298 type Err = CookieError;
299
300 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 let mut cookie = Cookie::builder();
324 let line: Vec<&str> = s.splitn(7, '\t').collect();
325
326 if line.len() != 7 {
327 return Err(CookieError);
328 }
329
330 cookie
331 .domain(line[Field::Domain as usize])
332 .subdomains(line[Field::Subdomains as usize] == "TRUE")
333 .path(line[Field::Path as usize])
334 .secure(line[Field::Secure as usize] == "TRUE")
335 .name(line[Field::Name as usize])
336 .value(line[Field::Value as usize]);
337
338 if let Ok(v) = line[4].parse() {
339 cookie.expires_at(v);
340 }
341
342 cookie.build()
343 }
344}
345
346impl Cookie {
347 pub fn builder() -> CookieBuilder {
348 CookieBuilder::default()
349 }
350
351 pub fn matches<T>(&self, req: hyper::Request<T>) -> bool {
352 if !self.domain().matches(req.uri().host().unwrap_or_default()) {
353 return false;
354 }
355
356 if self.secure() {
357 let scheme = req.uri().scheme().unwrap_or(&Scheme::HTTP);
358 if scheme == &Scheme::HTTP {
359 return false;
360 }
361 }
362
363 if !self.path().matches(req.uri().path()) {
364 return false;
365 }
366
367 true
368 }
369
370 pub fn expired(&self) -> bool {
372 self.expires_at != 0 && Utc::now().timestamp_micros() > self.expires_at
373 }
374
375 pub fn domain(&self) -> &Domain {
376 &self.domain
377 }
378
379 pub fn subdomains(&self) -> bool {
380 self.subdomains
381 }
382
383 pub fn path(&self) -> &PathAttr {
384 &self.path
385 }
386
387 pub fn secure(&self) -> bool {
388 self.secure
389 }
390
391 pub fn expires_at(&self) -> i64 {
392 self.expires_at
393 }
394
395 pub fn name(&self) -> &str {
396 self.name.as_ref()
397 }
398
399 pub fn value(&self) -> &str {
400 self.value.as_ref()
401 }
402}
403
404#[derive(Default)]
405pub struct CookieJar {
406 data: HashSet<Cookie>,
407 pub path: PathBuf,
408}
409
410impl Deref for CookieJar {
411 type Target = HashSet<Cookie>;
412
413 fn deref(&self) -> &Self::Target {
414 &self.data
415 }
416}
417
418impl DerefMut for CookieJar {
419 fn deref_mut(&mut self) -> &mut Self::Target {
420 &mut self.data
421 }
422}
423
424impl ToString for CookieJar {
425 fn to_string(&self) -> String {
426 let mut jar = String::new();
427
428 for cookie in self.iter() {
429 jar.push_str(&cookie.to_string());
430 jar.push('\n');
431 }
432
433 jar
434 }
435}
436
437impl CookieJar {
438 fn pair(v: &str) -> Option<(&str, &str)> {
439 v.trim().split_once('=')
440 }
441
442 pub fn set(&mut self, domain: &str, input: &'_ str) -> Cookie {
444 let mut cookie = Cookie::builder();
445 cookie.domain(domain);
446
447 let (pair, settings) = input.split_once(';').unwrap_or((input, ""));
448
449 let (key, value) = Self::pair(pair).unwrap_or_else(|| panic!("malformed cookie: {}", pair));
450
451 cookie.name(key);
452 cookie.value(value);
453
454 for v in settings.split(';') {
455 let (key, value) = Self::pair(v).unwrap_or((v, ""));
456
457 match key.to_lowercase().as_str() {
458 "domain" => cookie.domain(value),
459 "path" => cookie.path(value),
460 "secure" => cookie.secure(true),
461 "max-age" => {
462 cookie.expires_at(value.parse::<i64>().unwrap() + Utc::now().timestamp_micros())
463 }
464 "expires" => cookie.expires_at(
465 DateTime::parse_from_rfc2822(value)
466 .unwrap()
467 .timestamp_micros(),
468 ),
469 _ => &mut cookie,
470 };
471 }
472
473 let cookie = cookie.build().unwrap();
474
475 if self.contains(&cookie) {
478 self.remove(&cookie);
479 }
480
481 if !cookie.expired() {
484 self.insert(cookie.clone());
485 }
486
487 cookie
488 }
489
490 pub fn find_by_name(&self, s: &str) -> Vec<&Cookie> {
491 self.iter().filter(|c| c.name() == s).collect()
492 }
493}
494
495impl CookieJar {
496 pub const FILENAME: &'static str = "cookies";
497
498 pub fn read(path: &Path) -> QuartzResult<Self> {
507 let mut cookies = Self::default();
508 let file = std::fs::read_to_string(path)?;
509 let lines = file.lines();
510
511 for line in lines {
512 if line.is_empty() || line.starts_with('#') {
513 continue;
514 }
515
516 if let Ok(cookie) = Cookie::from_str(line) {
517 if !cookie.expired() {
518 cookies.insert(cookie);
519 }
520 }
521 }
522
523 cookies.path = path.to_path_buf();
524 Ok(cookies)
525 }
526
527 pub fn write(&self) -> std::io::Result<()> {
529 self.write_at(&self.path)
530 }
531
532 pub fn write_at(&self, path: &Path) -> std::io::Result<()> {
534 std::fs::write(path, self.to_string())
535 }
536}
537
538#[derive(Debug, Default, Clone)]
539pub struct PathAttr(Vec<String>);
540
541impl Deref for PathAttr {
542 type Target = Vec<String>;
543
544 fn deref(&self) -> &Self::Target {
545 &self.0
546 }
547}
548
549impl DerefMut for PathAttr {
550 fn deref_mut(&mut self) -> &mut Self::Target {
551 &mut self.0
552 }
553}
554
555impl FromStr for PathAttr {
556 type Err = Infallible;
557
558 fn from_str(s: &str) -> Result<Self, Self::Err> {
559 Ok(Self::from(s))
560 }
561}
562
563impl From<&str> for PathAttr {
564 fn from(value: &str) -> Self {
582 if let Ok(uri) = hyper::Uri::from_str(value) {
583 let path = uri
584 .path()
585 .split('/')
586 .filter(|v| !v.is_empty())
587 .map(String::from);
588
589 return Self(path.collect::<Vec<String>>());
590 }
591
592 if !value.starts_with('/') {
596 return Self::default();
597 }
598
599 let path: Vec<String> = value
600 .split('/')
601 .filter(|v| !v.is_empty())
602 .map(String::from)
603 .collect();
604
605 Self(path)
606 }
607}
608
609impl ToString for PathAttr {
610 fn to_string(&self) -> String {
619 let mut s = String::from("/");
620
621 s.push_str(&self.join("/"));
622
623 s
624 }
625}
626
627impl PathAttr {
628 #[must_use]
647 pub fn matches<T>(&self, other: T) -> bool
648 where
649 T: Into<PathAttr>,
650 {
651 if self.is_empty() {
652 return true;
653 }
654
655 let other: PathAttr = other.into();
656
657 if self.len() > other.len() {
658 return false;
659 }
660
661 for (idx, p) in self.iter().enumerate() {
662 if p != &other[idx] {
663 return false;
664 }
665 }
666
667 true
668 }
669}
670
671#[cfg(test)]
672mod test {
673 use super::*;
674
675 #[test]
676 fn jar_set_overwrite() {
677 let mut jar = CookieJar::default();
678
679 jar.set("example.com", "foo=bar");
680 jar.set("example.com", "foo=baz");
681
682 let found = jar.find_by_name("foo");
683 assert_eq!(found.len(), 1);
684 assert_eq!(found[0].value(), "baz");
685 }
686
687 #[test]
688 fn jar_set_same_name_different_domain() {
689 let mut jar = CookieJar::default();
690
691 jar.set("example.com", "mycookie=true");
692 jar.set("httpbin.org", "mycookie=false");
693
694 let cookies = jar.find_by_name("mycookie");
695 assert_eq!(cookies.len(), 2);
696
697 let cookie = cookies
698 .iter()
699 .find(|c| c.domain().matches("example.com"))
700 .expect("did not find cookie");
701
702 assert_eq!(cookie.value(), "true");
703
704 let cookie = cookies
705 .iter()
706 .find(|c| c.domain().matches("httpbin.org"))
707 .expect("did not find cookie");
708
709 assert_eq!(cookie.value(), "false");
710 }
711
712 #[test]
713 fn jar_set_remove() {
714 let mut jar = CookieJar::default();
715
716 let foo = jar.set("httpbin.org", "foo=bar");
717 let baz = jar.set("httpbin.org", "baz=baz");
718 assert_eq!(jar.len(), 2);
719 assert!(jar.contains(&foo));
720
721 jar.set("httpbin.org", "foo=; Expires=Sun, 06 Nov 1994 08:49:37 GMT");
722 assert_eq!(jar.len(), 1);
723 assert!(!jar.contains(&foo));
724
725 jar.set(
726 "httpbin.org",
727 "baz=bar; Expires=Sun, 06 Nov 1994 08:49:37 GMT",
728 );
729 assert_eq!(jar.len(), 0);
730 assert!(!jar.contains(&baz));
731 }
732}