use std::sync::Arc;
use crate::error::Result;
use crate::handler::Handler;
use crate::proto::{H1Conn, Limits, Request, Response};
#[cfg(feature = "compress")]
use crate::compress;
#[cfg(feature = "h2")]
use crate::h2::H2Conn;
enum Transport {
Plain,
#[cfg(feature = "tls")]
Tls(Box<crate::tls::TlsStream>),
}
impl Transport {
fn decrypt(&mut self, wire_in: &[u8]) -> Result<Vec<u8>> {
match self {
Transport::Plain => Ok(wire_in.to_vec()),
#[cfg(feature = "tls")]
Transport::Tls(stream) => {
stream.feed(wire_in)?;
stream.recv_all()
}
}
}
fn encrypt(&mut self, app: &[u8]) -> Result<Vec<u8>> {
match self {
Transport::Plain => Ok(app.to_vec()),
#[cfg(feature = "tls")]
Transport::Tls(stream) => {
stream.send(app)?;
stream.pop_all()
}
}
}
fn handshaking(&self) -> bool {
match self {
Transport::Plain => false,
#[cfg(feature = "tls")]
Transport::Tls(stream) => !stream.is_handshake_complete(),
}
}
#[allow(unused)]
fn alpn(&self) -> Option<Vec<u8>> {
match self {
Transport::Plain => None,
#[cfg(feature = "tls")]
Transport::Tls(stream) => stream.alpn_protocol(),
}
}
}
enum Engine {
H1(H1Conn),
#[cfg(feature = "h2")]
H2(Box<H2Conn>),
}
#[derive(Clone)]
pub struct SessionConfig {
pub handler: Arc<dyn Handler>,
pub limits: Limits,
pub server_name: Option<String>,
pub hsts: Option<String>,
pub alt_svc: Option<String>,
#[cfg(feature = "compress")]
pub compression: compress::Options,
}
impl SessionConfig {
pub fn new(handler: Arc<dyn Handler>) -> SessionConfig {
SessionConfig {
handler,
limits: Limits::default(),
server_name: Some(concat!("httpsd/", env!("CARGO_PKG_VERSION")).to_owned()),
hsts: None,
alt_svc: None,
#[cfg(feature = "compress")]
compression: compress::Options::default(),
}
}
}
pub struct Session {
engine: Option<Engine>,
transport: Transport,
cfg: SessionConfig,
}
impl Session {
pub fn plain(cfg: SessionConfig) -> Session {
let engine = Engine::H1(Self::new_h1(&cfg));
Session {
engine: Some(engine),
transport: Transport::Plain,
cfg,
}
}
#[cfg(feature = "tls")]
pub fn tls(cfg: SessionConfig, stream: crate::tls::TlsStream) -> Session {
Session {
engine: None,
transport: Transport::Tls(Box::new(stream)),
cfg,
}
}
fn new_h1(cfg: &SessionConfig) -> H1Conn {
let mut conn = H1Conn::new(cfg.limits);
conn.set_server_name(cfg.server_name.clone());
conn
}
pub fn received(&mut self, wire_in: &[u8]) -> Result<()> {
let plaintext = self.transport.decrypt(wire_in)?;
if self.engine.is_none() && !self.transport.handshaking() {
self.select_engine();
}
if plaintext.is_empty() || self.engine.is_none() {
return Ok(());
}
self.drive(&plaintext)
}
#[cfg(feature = "tls")]
fn select_engine(&mut self) {
#[cfg(feature = "h2")]
if self.transport.alpn().as_deref() == Some(b"h2") {
self.engine = Some(Engine::H2(Box::new(H2Conn::new(
self.cfg.limits,
self.cfg.server_name.clone(),
))));
return;
}
self.engine = Some(Engine::H1(Self::new_h1(&self.cfg)));
}
#[cfg(not(feature = "tls"))]
fn select_engine(&mut self) {
self.engine = Some(Engine::H1(Self::new_h1(&self.cfg)));
}
fn drive(&mut self, plaintext: &[u8]) -> Result<()> {
let secure = !matches!(self.transport, Transport::Plain);
match self.engine.as_mut().unwrap() {
Engine::H1(conn) => {
conn.feed(plaintext);
while let Ok(Some(req)) = conn.poll_request() {
let resp = Self::run_handler(&self.cfg, &req, secure);
conn.respond(resp);
}
}
#[cfg(feature = "h2")]
Engine::H2(conn) => {
conn.received(plaintext);
while let Some((sid, req)) = conn.poll_request() {
let resp = Self::run_handler(&self.cfg, &req, secure);
conn.respond(sid, resp);
}
}
}
Ok(())
}
fn run_handler(cfg: &SessionConfig, req: &Request, secure: bool) -> Response {
let resp = cfg.handler.handle(req);
#[cfg(feature = "compress")]
let resp = compress::compress_response(req, resp, &cfg.compression);
apply_edge_headers(cfg, resp, secure)
}
pub fn to_send(&mut self) -> Result<Vec<u8>> {
let app = match self.engine.as_mut() {
Some(Engine::H1(conn)) => conn.take_out(),
#[cfg(feature = "h2")]
Some(Engine::H2(conn)) => conn.take_out(),
None => Vec::new(),
};
self.transport.encrypt(&app)
}
pub fn wants_close(&self) -> bool {
match self.engine.as_ref() {
Some(Engine::H1(conn)) => conn.wants_close(),
#[cfg(feature = "h2")]
Some(Engine::H2(conn)) => conn.wants_close(),
None => false,
}
}
pub fn handshaking(&self) -> bool {
self.transport.handshaking()
}
}
pub(crate) fn apply_edge_headers(
cfg: &SessionConfig,
mut resp: Response,
secure: bool,
) -> Response {
let h = resp.headers_mut();
if secure && let Some(value) = &cfg.hsts {
h.set_if_absent("Strict-Transport-Security", value.clone());
}
if let Some(value) = &cfg.alt_svc {
h.set_if_absent("Alt-Svc", value.clone());
}
resp
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::StatusCode;
fn cfg() -> SessionConfig {
let mut c = SessionConfig::new(Arc::new(|_: &Request| Response::status(StatusCode::OK)));
c.hsts = Some("max-age=31536000".into());
c
}
#[test]
fn hsts_added_only_on_secure() {
let secure = apply_edge_headers(&cfg(), Response::status(StatusCode::OK), true);
assert_eq!(
secure.headers().get("strict-transport-security"),
Some("max-age=31536000")
);
let plain = apply_edge_headers(&cfg(), Response::status(StatusCode::OK), false);
assert!(plain.headers().get("strict-transport-security").is_none());
}
#[test]
fn hsts_absent_when_unset() {
let c = SessionConfig::new(Arc::new(|_: &Request| Response::status(StatusCode::OK)));
let r = apply_edge_headers(&c, Response::status(StatusCode::OK), true);
assert!(r.headers().get("strict-transport-security").is_none());
}
}