use std::io::ErrorKind;
use std::time::Instant;
use crate::error::{Error, Result};
use crate::io::Machine;
use crate::net::NetStream;
pub(crate) fn drive<M, S>(machine: &mut M, io: &mut S) -> Result<Vec<M::Event>>
where
M: Machine,
S: NetStream + ?Sized,
{
let mut events = Vec::new();
let mut scratch = [0u8; 16 * 1024];
let mut out = Vec::new();
let mut eof_seen = false;
loop {
out.clear();
while machine.poll_transmit(&mut out) {}
if !out.is_empty() {
io.write_all(&out).map_err(Error::Io)?;
io.flush().map_err(Error::Io)?;
}
while let Some(ev) = machine.poll_event() {
events.push(ev);
}
if machine.is_finished() {
return Ok(events);
}
if eof_seen {
return Err(Error::UnexpectedEof);
}
if let Some(deadline) = machine.next_timeout() {
let now = Instant::now();
if now >= deadline {
machine.handle_timeout(now);
continue;
}
io.set_read_timeout(Some(deadline - now))
.map_err(Error::Io)?;
}
match io.read(&mut scratch) {
Ok(0) => {
eof_seen = true;
machine.handle_eof()?;
}
Ok(n) => {
machine.handle_input(&scratch[..n])?;
}
Err(e) if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => {
machine.handle_timeout(Instant::now());
}
Err(e) => return Err(Error::Io(e)),
}
}
}
#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::thread;
use crate::proto::http1::{ClientExchange, Event};
fn serve_once(response: &'static [u8]) -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
thread::spawn(move || {
let Ok((mut sock, _)) = listener.accept() else {
return;
};
let mut buf = Vec::new();
let mut byte = [0u8; 1];
while sock.read(&mut byte).map(|n| n == 1).unwrap_or(false) {
buf.push(byte[0]);
if buf.ends_with(b"\r\n\r\n") {
break;
}
}
let _ = sock.write_all(response);
});
port
}
#[test]
fn blocking_get_content_length_over_real_socket() {
let port = serve_once(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello");
let mut sock = TcpStream::connect(("127.0.0.1", port)).unwrap();
let req =
ClientExchange::encode_request("GET", "/", &[("Host".into(), "127.0.0.1".into())], b"");
let mut x = ClientExchange::new("GET", req);
let events = super::drive(&mut x, &mut sock).unwrap();
assert_eq!(events.len(), 1);
let Event::Response { head, body } = &events[0];
assert_eq!(head.status, 200);
assert_eq!(body, b"hello");
}
#[test]
fn blocking_get_eof_framed_body_over_real_socket() {
let port = serve_once(b"HTTP/1.1 200 OK\r\nServer: t\r\n\r\nstreamed payload");
let mut sock = TcpStream::connect(("127.0.0.1", port)).unwrap();
let req = ClientExchange::encode_request("GET", "/", &[("Host".into(), "x".into())], b"");
let mut x = ClientExchange::new("GET", req);
let events = super::drive(&mut x, &mut sock).unwrap();
let Event::Response { head, body } = &events[0];
assert_eq!(head.status, 200);
assert_eq!(body, b"streamed payload");
}
}