use std::{collections::HashSet, sync::Arc, time::Instant};
use bytes::Bytes;
use http::{Request, StatusCode};
use http_body_util::{BodyExt, Either, Full};
use http_wire::{WireDecode, WireEncode, response::FullResponse};
use tokio_uring::net::TcpStream;
use crate::{
client::utils::{
build_conn_endpoint, build_headers, build_trailers, get_conn_address, should_stop,
},
fatal,
options::Options,
stats::{RealtimeStats, Statistics},
};
pub async fn http_io_uring(
tid: usize,
cid: usize,
opts: Arc<Options>,
rt_stats: &RealtimeStats,
) -> Statistics {
let mut statistics = Statistics::new(opts.latency);
let mut total: u32 = 0;
let mut conn_req_count: u32;
let mut banner = HashSet::new();
let uri_str = opts.uri[cid % opts.uri.len()].as_str();
let uri = uri_str
.parse::<hyper::Uri>()
.unwrap_or_else(|e| fatal!(1, "invalid uri: {e}"));
let (host, port) =
get_conn_address(&opts, &uri).unwrap_or_else(|| fatal!(1, "no host specified in uri"));
let endpoint = build_conn_endpoint(&host, port);
let headers = build_headers(&uri, opts.as_ref())
.unwrap_or_else(|e| fatal!(2, "could not build headers: {e}"));
let trailers = build_trailers(opts.as_ref())
.unwrap_or_else(|e| fatal!(2, "could not build trailers: {e}"));
let bodies: Vec<Full<Bytes>> = opts.bodies().map_or_else(
|e| fatal!(2, "could not read body: {e}"),
|b| b.into_iter().map(Full::new).collect::<Vec<_>>(),
);
let body = bodies
.first()
.cloned()
.unwrap_or_else(|| Full::new(Bytes::new()));
let body = match &trailers {
None => Either::Left(body.clone()),
tr => {
let trailers = tr.clone().map(Ok);
Either::Right(body.clone().with_trailers(std::future::ready(trailers)))
}
};
let mut req = Request::new(body);
*req.method_mut() = opts.method.clone().unwrap_or(http::Method::GET);
*req.uri_mut() = uri.clone();
*req.headers_mut() = headers.clone();
let request_bytes = req
.encode()
.unwrap_or_else(|e| fatal!(2, "could not serialize request: {e}"));
let clock = quanta::Clock::new();
let start = Instant::now();
'connection: loop {
if should_stop(total, start, &opts) {
break 'connection;
}
if cid < opts.uri.len() && !banner.contains(uri_str) {
banner.insert(uri_str.to_owned());
println!(
"tokio-uring [{tid:>2}] -> connecting to {}:{}, method = {} uri = {} ...",
host,
port,
opts.method.as_ref().unwrap_or(&http::Method::GET),
uri,
);
}
let addr = endpoint
.parse()
.unwrap_or_else(|e| fatal!(1, "invalid address: {e}"));
let stream = match TcpStream::connect(addr).await {
Ok(s) => s,
Err(ref err) => {
statistics.set_error(err, rt_stats);
total += 1;
continue 'connection;
}
};
statistics.inc_conn();
conn_req_count = 0;
let mut connection_buffer = Vec::new();
let mut read_buf = vec![0u8; 4096];
let request = request_bytes.clone();
let mut request: Vec<u8> = request.into();
loop {
let start_lat = opts.latency.then_some(clock.raw());
let (result, req_buf) = stream.write_all(request).await;
request = req_buf;
if let Err(ref err) = result {
statistics.set_error(err, rt_stats);
total += 1;
continue 'connection;
}
loop {
let (result, buf) = stream.read(read_buf).await;
read_buf = buf;
let bytes_read = match result {
Ok(0) => {
total += 1;
continue 'connection;
}
Ok(n) => n,
Err(ref err) => {
statistics.set_error(err, rt_stats);
total += 1;
continue 'connection;
}
};
connection_buffer.extend_from_slice(&read_buf[..bytes_read]);
let mut headers = [httparse::EMPTY_HEADER; 16];
if let Ok((resp, response_end)) =
FullResponse::decode(&connection_buffer, &mut headers)
{
if let Some(start_lat) = start_lat
&& let Some(hist) = &mut statistics.latency
{
hist.record(clock.delta_as_nanos(start_lat, clock.raw()) / 1000)
.ok();
}
match resp.head.code {
Some(200) => statistics.inc_ok(rt_stats),
Some(c) => {
let code = StatusCode::from_u16(c)
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
statistics.set_http_status(code, rt_stats);
}
None => {
statistics.set_http_status(StatusCode::INTERNAL_SERVER_ERROR, rt_stats);
}
}
connection_buffer.drain(..response_end);
total += 1;
conn_req_count += 1;
if should_stop(total, start, &opts) {
break 'connection;
}
if conn_req_count >= opts.rpc {
continue 'connection;
}
break;
}
}
}
}
statistics
}