use std::{
collections::HashMap,
io::{self, BufReader, BufWriter, Write},
net::TcpStream,
};
use headers::HeaderMapExt;
use http::{header::HOST, uri::Authority, StatusCode};
use thiserror::Error;
use crate::{request, response, Body, Connection, HttpBody};
#[derive(Debug, Error)]
pub enum RequestError {
#[error("invalid uri")]
InvalidUri,
#[error("unsupported scheme")]
UnsupportedScheme,
#[error("unsupported http version: {0}")]
UnsupportedHttpVersion(u8),
#[error("io error")]
Io(#[from] io::Error),
#[error("invalid request")]
InvalidRequest(#[from] Box<RequestError>),
}
#[derive(Debug, Default)]
pub struct Client {
connections: HashMap<Authority, Connection>,
}
impl Client {
pub fn new() -> Self {
Client {
connections: Default::default(),
}
}
pub fn request<B: HttpBody>(
&mut self,
mut req: http::Request<B>,
) -> Result<http::Response<Body>, RequestError> {
let authority = req
.uri()
.authority()
.ok_or(RequestError::InvalidUri)?
.clone();
let host = authority.host().to_string();
let port = authority.port_u16().unwrap_or(80);
let connection = match self.connections.remove(&authority) {
Some(conn) => conn,
None => TcpStream::connect(format!("{host}:{port}"))?.into(),
};
req.headers_mut()
.insert(HOST, host.as_str().try_into().unwrap());
let (connection, mut res) = send_request(connection, req)?;
match connection {
ConnectionOutcome::Close => Ok(res),
ConnectionOutcome::Upgrade(conn) => {
res.extensions_mut().insert(conn);
Ok(res)
}
ConnectionOutcome::KeepAlive(conn) => {
self.connections.insert(authority, conn);
Ok(res)
}
}
}
}
#[derive(Debug)]
pub enum ConnectionOutcome {
Close,
KeepAlive(Connection),
Upgrade(Connection),
}
impl ConnectionOutcome {
pub fn closed(&self) -> bool {
matches!(self, ConnectionOutcome::Close)
}
pub fn unwrap(self) -> Connection {
match self {
ConnectionOutcome::Close => panic!("Connection closed"),
ConnectionOutcome::KeepAlive(conn) => conn,
ConnectionOutcome::Upgrade(conn) => conn,
}
}
pub fn into_inner(self) -> Result<Connection, ConnectionOutcome> {
match self {
ConnectionOutcome::KeepAlive(conn) => Ok(conn),
ConnectionOutcome::Upgrade(conn) => Ok(conn),
ConnectionOutcome::Close => Err(self),
}
}
}
pub fn send_request<C, B>(
connection: C,
req: http::Request<B>,
) -> io::Result<(ConnectionOutcome, http::Response<Body>)>
where
C: Into<Connection>,
B: HttpBody,
{
let conn = connection.into();
let reader = BufReader::new(conn.clone());
let mut writer = BufWriter::new(conn);
request::write_request(req, &mut writer)?;
writer.flush()?;
let res = response::parse_response(reader)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let asks_for_close = res
.headers()
.typed_get::<headers::Connection>()
.filter(|conn| conn.contains("close"))
.is_some();
let outcome = if asks_for_close {
ConnectionOutcome::Close
} else if res.status() == StatusCode::SWITCHING_PROTOCOLS {
ConnectionOutcome::Upgrade(writer.into_inner()?)
} else {
ConnectionOutcome::KeepAlive(writer.into_inner()?)
};
Ok((outcome, res))
}
#[cfg(test)]
mod tests {
use std::{
io::Cursor,
net::{TcpListener, TcpStream},
thread,
};
use http::{Request, Version};
use crate::Server;
use super::*;
#[test]
fn test_client() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
Server::from(listener)
.serve(|req: Request<_>| http::Response::builder().body(req.into_body()))
.ok()
});
let mut client = Client::new();
let uri = format!("http://127.0.0.1:{port}");
let res = client
.request(
http::Request::builder()
.uri(&uri)
.method("POST")
.body("Hello world")
.unwrap(),
)
.unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"Hello world");
let res = client
.request(
http::Request::builder()
.uri(&uri)
.method("POST")
.body("Bye world")
.unwrap(),
)
.unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"Bye world");
}
#[test]
fn test_send_request() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
Server::from(listener)
.serve(|req: Request<_>| http::Response::builder().body(req.into_body()))
.ok()
});
let conn = TcpStream::connect(("127.0.0.1", port)).unwrap();
let req = http::Request::builder().body("Hello world").unwrap();
let (conn, res) = send_request(conn, req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"Hello world");
let req = http::Request::builder().body("Bye world").unwrap();
let (conn, res) = send_request(conn.unwrap(), req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"Bye world");
let req = http::Request::builder().body(()).unwrap();
let (conn, res) = send_request(conn.unwrap(), req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"");
let req = http::Request::builder()
.header("transfer-encoding", "chunked")
.body(Body::from_iter(vec![&b"lol"[..], &b"wut"[..]]))
.unwrap();
let (_conn, res) = send_request(conn.unwrap(), req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"lolwut");
}
#[test]
fn correctly_handles_closing_connections() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
Server::from(listener)
.serve(|_req| {
http::Response::builder()
.header("connection", "close")
.body(Body::from_reader(Cursor::new(b"lolwut"), None))
})
.ok();
});
let conn = TcpStream::connect(("127.0.0.1", port)).unwrap();
let req = http::Request::builder().body(()).unwrap();
let (conn, res) = send_request(conn, req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"lolwut");
assert!(conn.closed());
}
#[test]
fn keep_http_10_connection_alive_when_asked_to() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
Server::from(listener)
.serve(|_req| http::Response::builder().body("lolwut"))
.ok();
});
let conn = TcpStream::connect(("127.0.0.1", port)).unwrap();
let req = http::Request::builder()
.version(Version::HTTP_10)
.header("connection", "keep-alive")
.body(())
.unwrap();
let (conn, res) = send_request(conn, req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"lolwut");
assert!(matches!(conn, ConnectionOutcome::KeepAlive(_)));
let req = http::Request::builder()
.version(Version::HTTP_10)
.body(())
.unwrap();
let (conn, res) = send_request(conn.unwrap(), req).unwrap();
assert_eq!(res.into_body().into_bytes().unwrap(), b"lolwut");
assert!(matches!(conn, ConnectionOutcome::Close));
}
}