1use std::collections::HashMap;
8use std::net::TcpStream;
9use std::sync::{Arc, Mutex};
10
11use crate::header_map::HeaderMap;
12use crate::request::{FetchError, Request};
13use crate::response::Response;
14use crate::KEEP_ALIVE_TIMEOUT;
15
16#[derive(Default, Clone)]
19pub struct Client {
20    connection_pool: Arc<Mutex<ConnectionPool>>,
21    headers: HeaderMap,
22}
23
24impl Client {
25    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
32        self.headers.insert(name.into(), value.into());
33        self
34    }
35
36    pub fn fetch(&mut self, mut request: Request) -> Result<Response, FetchError> {
38        for (name, value) in self.headers.iter() {
40            request = request.header(name, value);
41        }
42
43        let addr = format!(
45            "{}:{}",
46            request.url.host().expect("No host in URL"),
47            request.url.port().unwrap_or(80)
48        );
49        let mut stream = self
50            .connection_pool
51            .lock()
52            .expect("Can't lock connection pool")
53            .take_connection(&addr)
54            .ok_or(FetchError)?;
55        stream
56            .set_read_timeout(Some(KEEP_ALIVE_TIMEOUT))
57            .map_err(|_| FetchError)?;
58
59        request.write_to_stream(&mut stream, true);
61        let res = Response::read_from_stream(&mut stream).map_err(|_| FetchError)?;
62
63        self.connection_pool
65            .lock()
66            .expect("Can't lock connection pool")
67            .return_connection(&addr, stream);
68        Ok(res)
69    }
70}
71
72#[derive(Default)]
74struct ConnectionPool {
75    connections: HashMap<String, Vec<TcpStream>>,
76}
77
78impl ConnectionPool {
79    fn take_connection(&mut self, addr: &str) -> Option<TcpStream> {
80        if !self.connections.contains_key(addr) {
82            self.connections.insert(addr.to_string(), Vec::new());
83        }
84
85        if let Some(connections) = self.connections.get_mut(addr) {
87            if let Some(conn) = connections.pop() {
89                return Some(conn);
90            }
91
92            if let Ok(conn) = TcpStream::connect(addr) {
94                return Some(conn);
95            }
96        }
97
98        None
100    }
101
102    fn return_connection(&mut self, addr: &str, conn: TcpStream) {
103        if let Some(connections) = self.connections.get_mut(addr) {
105            connections.push(conn);
106        }
107    }
108}
109
110#[cfg(test)]
112mod test {
113    use std::io::{Read, Write};
114    use std::net::{Ipv4Addr, TcpListener};
115    use std::thread;
116
117    use super::*;
118
119    #[test]
120    fn test_client_multiple_requests() {
121        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
123        let server_addr = listener.local_addr().unwrap();
124        thread::spawn(move || {
125            let (mut stream, _) = listener.accept().unwrap();
126            loop {
127                let mut buf = [0; 512];
128                _ = stream.read(&mut buf);
129                stream
130                    .write_all(
131                        b"HTTP/1.1 200 OK\r\nContent-Length: 4\r\nConnection: closed\r\n\r\ntest",
132                    )
133                    .unwrap();
134            }
135        });
136
137        let mut client = Client::new();
139        for _ in 0..10 {
140            client
141                .fetch(Request::get(format!("http://{}/", server_addr)))
142                .unwrap();
143        }
144    }
145}