use std::collections::HashMap;
use std::net::TcpStream;
use std::sync::{Arc, Mutex};
use crate::header_map::HeaderMap;
use crate::request::{FetchError, Request};
use crate::response::Response;
use crate::KEEP_ALIVE_TIMEOUT;
#[derive(Default, Clone)]
pub struct Client {
connection_pool: Arc<Mutex<ConnectionPool>>,
headers: HeaderMap,
}
impl Client {
pub fn new() -> Self {
Self::default()
}
pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(name.into(), value.into());
self
}
pub fn fetch(&mut self, mut request: Request) -> Result<Response, FetchError> {
for (name, value) in self.headers.iter() {
request = request.header(name, value);
}
let addr = format!(
"{}:{}",
request.url.host().expect("No host in URL"),
request.url.port().unwrap_or(80)
);
let mut stream = self
.connection_pool
.lock()
.expect("Can't lock connection pool")
.take_connection(&addr)
.ok_or(FetchError)?;
stream
.set_read_timeout(Some(KEEP_ALIVE_TIMEOUT))
.map_err(|_| FetchError)?;
request.write_to_stream(&mut stream, true);
let res = Response::read_from_stream(&mut stream).map_err(|_| FetchError)?;
self.connection_pool
.lock()
.expect("Can't lock connection pool")
.return_connection(&addr, stream);
Ok(res)
}
}
#[derive(Default)]
struct ConnectionPool {
connections: HashMap<String, Vec<TcpStream>>,
}
impl ConnectionPool {
fn take_connection(&mut self, addr: &str) -> Option<TcpStream> {
if !self.connections.contains_key(addr) {
self.connections.insert(addr.to_string(), Vec::new());
}
if let Some(connections) = self.connections.get_mut(addr) {
if let Some(conn) = connections.pop() {
return Some(conn);
}
if let Ok(conn) = TcpStream::connect(addr) {
return Some(conn);
}
}
None
}
fn return_connection(&mut self, addr: &str, conn: TcpStream) {
if let Some(connections) = self.connections.get_mut(addr) {
connections.push(conn);
}
}
}
#[cfg(test)]
mod test {
use std::io::{Read, Write};
use std::net::{Ipv4Addr, TcpListener};
use std::thread;
use super::*;
#[test]
fn test_client_multiple_requests() {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
let server_addr = listener.local_addr().unwrap();
thread::spawn(move || {
let (mut stream, _) = listener.accept().unwrap();
loop {
let mut buf = [0; 512];
_ = stream.read(&mut buf);
stream
.write_all(
b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
)
.unwrap();
}
});
let mut client = Client::new();
for _ in 0..10 {
client
.fetch(Request::get(format!("http://{server_addr}/")))
.unwrap();
}
}
}