use std::io::{Read, Write};
use std::net::TcpStream;
use crate::error::Error;
use crate::headers::Headers;
use crate::method::Method;
use crate::request::Request;
use crate::response::Response;
use crate::url::Url;
pub fn exchange(request: &Request) -> Result<Response, Error> {
let url = request.url();
if url.scheme() == "http" {
let raw = transmit(request, url)?;
parse_response(&raw)
} else {
Err(Error::UnsupportedScheme {
scheme: url.scheme().to_owned(),
})
}
}
fn transmit(request: &Request, url: &Url) -> Result<Vec<u8>, Error> {
let address = format!("{}:{}", url.host(), url.port());
let stream = TcpStream::connect(address)?;
let wire = build_wire_request(request, url);
write_all(&stream, &wire)?;
read_to_end(stream)
}
fn build_wire_request(request: &Request, url: &Url) -> Vec<u8> {
let request_line = format!(
"{} {} HTTP/1.1\r\n",
request.method().as_str(),
url.request_target()
);
let host_header = host_header(url);
let user_agent = "User-Agent: net-cat/0.1\r\n".to_owned();
let connection_close = "Connection: close\r\n".to_owned();
let content_length = content_length_header(request);
let user_headers = format_headers(request.headers());
let headers_blob = format!(
"{request_line}{host_header}{user_agent}{connection_close}{content_length}{user_headers}\r\n"
);
let mut wire = headers_blob.into_bytes();
if !request.body().is_empty() {
wire.extend_from_slice(request.body());
}
wire
}
fn host_header(url: &Url) -> String {
let default = matches!((url.scheme(), url.port()), ("http", 80) | ("https", 443));
if default {
format!("Host: {}\r\n", url.host())
} else {
format!("Host: {}:{}\r\n", url.host(), url.port())
}
}
fn content_length_header(request: &Request) -> String {
match request.method() {
Method::Get | Method::Head | Method::Options => String::new(),
Method::Post | Method::Put | Method::Delete | Method::Patch => {
format!("Content-Length: {}\r\n", request.body().len())
}
}
}
fn format_headers(headers: &Headers) -> String {
headers
.iter()
.filter(|(name, _)| {
!matches!(
name.to_ascii_lowercase().as_str(),
"host" | "user-agent" | "connection" | "content-length"
)
})
.fold(String::new(), |acc, (name, value)| {
format!("{acc}{name}: {value}\r\n")
})
}
fn write_all(stream: &TcpStream, bytes: &[u8]) -> Result<(), Error> {
let mut handle = stream;
handle.write_all(bytes)?;
handle.flush()?;
Ok(())
}
fn read_to_end(stream: TcpStream) -> Result<Vec<u8>, Error> {
let mut owned = stream;
let mut buffer = Vec::new();
owned.read_to_end(&mut buffer)?;
Ok(buffer)
}
fn parse_response(raw: &[u8]) -> Result<Response, Error> {
let split = find_double_crlf(raw).ok_or_else(|| Error::InvalidStatusLine {
text: String::from_utf8_lossy(raw).into_owned(),
})?;
let head = raw.get(..split).unwrap_or(&[]);
let body_start = split + 4;
let body: Vec<u8> = raw.get(body_start..).unwrap_or(&[]).to_vec();
let head_text = std::str::from_utf8(head).map_err(|_| Error::InvalidStatusLine {
text: String::from_utf8_lossy(head).into_owned(),
})?;
let mut lines = head_text.split("\r\n");
let status_line = lines.next().unwrap_or("");
let (status, reason) = parse_status_line(status_line)?;
let headers = lines.try_fold(Headers::new(), |acc, line| {
if line.is_empty() {
Ok(acc)
} else {
parse_header(line).map(|(name, value)| acc.with(name, value))
}
})?;
Ok(Response::new(status, reason, headers, body))
}
fn find_double_crlf(bytes: &[u8]) -> Option<usize> {
let separator = b"\r\n\r\n";
bytes
.windows(separator.len())
.position(|window| window == separator)
}
fn parse_status_line(line: &str) -> Result<(u16, String), Error> {
let mut parts = line.splitn(3, ' ');
let _version = parts.next().unwrap_or("");
let status_text = parts.next().unwrap_or("");
let reason = parts.next().unwrap_or("").to_owned();
let status = status_text
.parse::<u16>()
.map_err(|_| Error::InvalidStatusLine {
text: line.to_owned(),
})?;
Ok((status, reason))
}
fn parse_header(line: &str) -> Result<(String, String), Error> {
line.split_once(':')
.map(|(name, value)| (name.trim().to_owned(), value.trim().to_owned()))
.ok_or_else(|| Error::InvalidHeader {
line: line.to_owned(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
#[test]
fn parse_minimal_response() -> Result<(), Error> {
let raw = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nhi";
let response = parse_response(raw)?;
(response.status() == 200 && response.body() == b"hi")
.then_some(())
.ok_or(Error::InvalidStatusLine {
text: "expected 200 OK with body 'hi'".to_owned(),
})
}
#[test]
fn parse_empty_body() -> Result<(), Error> {
let raw = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n";
let response = parse_response(raw)?;
(response.status() == 204 && response.body().is_empty())
.then_some(())
.ok_or(Error::InvalidStatusLine {
text: "expected 204".to_owned(),
})
}
#[test]
fn parse_multiple_headers() -> Result<(), Error> {
let raw = b"HTTP/1.1 200 OK\r\nServer: net-cat\r\nContent-Type: text/html\r\n\r\n<p>hi</p>";
let response = parse_response(raw)?;
let server = response.headers().get("server").unwrap_or("");
let content_type = response.headers().get("content-type").unwrap_or("");
(server == "net-cat" && content_type == "text/html")
.then_some(())
.ok_or(Error::InvalidStatusLine {
text: "header round-trip failed".to_owned(),
})
}
}