use std::sync::Arc;
pub use cookie::Cookie;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use crate::http::{header, StatusCode};
use crate::{throw, Context, Next, Result};
struct CookieScope;
pub trait CookieGetter {
fn must_cookie(&mut self, name: &str) -> Result<Arc<Cookie<'static>>>;
fn cookie(&self, name: &str) -> Option<Arc<Cookie<'static>>>;
}
pub trait CookieSetter {
fn set_cookie(&mut self, cookie: Cookie<'_>) -> Result;
}
#[inline]
pub async fn cookie_parser<S>(ctx: &mut Context<S>, next: Next<'_>) -> Result {
if let Some(cookies) = ctx.get(header::COOKIE) {
for cookie in cookies
.split(';')
.map(|cookie| cookie.trim())
.map(Cookie::parse_encoded)
.filter_map(|cookie| cookie.ok())
.map(|cookie| cookie.into_owned())
.collect::<Vec<_>>()
.into_iter()
{
let name = cookie.name().to_string();
ctx.store_scoped(CookieScope, name, cookie);
}
}
next.await
}
impl<S> CookieGetter for Context<S> {
#[inline]
fn must_cookie(&mut self, name: &str) -> Result<Arc<Cookie<'static>>> {
match self.cookie(name) {
Some(value) => Ok(value),
None => {
self.resp.headers.insert(
header::WWW_AUTHENTICATE,
format!(
r#"Cookie name="{}""#,
utf8_percent_encode(name, NON_ALPHANUMERIC)
)
.parse()?,
);
throw!(StatusCode::UNAUTHORIZED)
}
}
}
#[inline]
fn cookie(&self, name: &str) -> Option<Arc<Cookie<'static>>> {
Some(self.load_scoped::<CookieScope, Cookie>(name)?.value())
}
}
impl<S> CookieSetter for Context<S> {
#[inline]
fn set_cookie(&mut self, cookie: Cookie<'_>) -> Result {
let cookie_value = cookie.encoded().to_string();
self.resp
.headers
.append(header::SET_COOKIE, cookie_value.parse()?);
Ok(())
}
}
#[cfg(all(test, feature = "tcp"))]
mod tests {
use tokio::task::spawn;
use crate::cookie::{cookie_parser, Cookie};
use crate::http::header::{COOKIE, WWW_AUTHENTICATE};
use crate::http::StatusCode;
use crate::preload::*;
use crate::{App, Context};
async fn must(ctx: &mut Context) -> crate::Result {
assert_eq!("Hexi Lee", ctx.must_cookie("nick name")?.value());
Ok(())
}
async fn none(ctx: &mut Context) -> crate::Result {
assert!(ctx.cookie("nick name").is_none());
Ok(())
}
#[tokio::test]
async fn parser() -> Result<(), Box<dyn std::error::Error>> {
let (addr, server) = App::new().gate(cookie_parser).end(must).run()?;
spawn(server);
let client = reqwest::Client::new();
let resp = client
.get(&format!("http://{}", addr))
.header(COOKIE, "nick%20name=Hexi%20Lee")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
let (addr, server) = App::new().end(must).run()?;
spawn(server);
let resp = client
.get(&format!("http://{}", addr))
.header(COOKIE, "nick%20name=Hexi%20Lee")
.send()
.await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
Ok(())
}
#[tokio::test]
async fn cookie() -> Result<(), Box<dyn std::error::Error>> {
let (addr, server) = App::new().end(none).run()?;
spawn(server);
let resp = reqwest::get(&format!("http://{}", addr)).await?;
assert_eq!(StatusCode::OK, resp.status());
let (addr, server) = App::new().gate(cookie_parser).end(must).run()?;
spawn(server);
let resp = reqwest::get(&format!("http://{}", addr)).await?;
assert_eq!(StatusCode::UNAUTHORIZED, resp.status());
assert_eq!(
r#"Cookie name="nick%20name""#,
resp.headers()
.get(WWW_AUTHENTICATE)
.unwrap()
.to_str()
.unwrap()
);
let (addr, server) = App::new().gate(cookie_parser).end(must).run()?;
spawn(server);
let client = reqwest::Client::new();
let resp = client
.get(&format!("http://{}", addr))
.header(COOKIE, "nick%20name=Hexi%20Lee")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
Ok(())
}
#[tokio::test]
async fn cookie_action() -> Result<(), Box<dyn std::error::Error>> {
async fn test(ctx: &mut Context) -> crate::Result {
assert_eq!("bar baz", ctx.must_cookie("bar baz")?.value());
assert_eq!("bar foo", ctx.must_cookie("foo baz")?.value());
Ok(())
}
let (addr, server) = App::new().gate(cookie_parser).end(test).run()?;
spawn(server);
let client = reqwest::Client::new();
let resp = client
.get(&format!("http://{}", addr))
.header(COOKIE, "bar%20baz=bar%20baz; foo%20baz=bar%20foo")
.send()
.await?;
assert_eq!(StatusCode::OK, resp.status());
Ok(())
}
#[tokio::test]
async fn set_cookie() -> Result<(), Box<dyn std::error::Error>> {
async fn test(ctx: &mut Context) -> crate::Result {
ctx.set_cookie(Cookie::new("bar baz", "bar baz"))?;
ctx.set_cookie(Cookie::new("bar foo", "foo baz"))?;
Ok(())
}
let (addr, server) = App::new().end(test).run()?;
spawn(server);
let resp = reqwest::get(&format!("http://{}", addr)).await?;
assert_eq!(StatusCode::OK, resp.status());
let cookies: Vec<reqwest::cookie::Cookie> = resp.cookies().collect();
assert_eq!(2, cookies.len());
assert_eq!(("bar%20baz"), cookies[0].name());
assert_eq!(("bar%20baz"), cookies[0].value());
assert_eq!(("bar%20foo"), cookies[1].name());
assert_eq!(("foo%20baz"), cookies[1].value());
Ok(())
}
}