use std::time::Duration;
use crate::{Error, ErrorType::*, OrErr, Result};
use pingora_timeout::timeout;
use bytes::Bytes;
use h2::SendStream;
pub mod client;
pub mod server;
async fn reserve_and_send(
writer: &mut SendStream<Bytes>,
remaining: &mut Bytes,
end: bool,
) -> Result<()> {
writer.reserve_capacity(remaining.len());
let res = std::future::poll_fn(|cx| writer.poll_capacity(cx)).await;
match res {
None => Error::e_explain(H2Error, "cannot reserve capacity"),
Some(ready) => {
let n = ready.or_err(H2Error, "while waiting for capacity")?;
let remaining_size = remaining.len();
let data_to_send = remaining.split_to(std::cmp::min(remaining_size, n));
writer
.send_data(data_to_send, remaining.is_empty() && end)
.or_err(WriteError, "while writing h2 request body")?;
Ok(())
}
}
}
pub async fn write_body(
writer: &mut SendStream<Bytes>,
data: Bytes,
end: bool,
write_timeout: Option<Duration>,
) -> Result<()> {
let mut remaining = data;
if remaining.is_empty() {
writer
.send_data(remaining, end)
.or_err(WriteError, "while writing h2 request body")?;
return Ok(());
}
loop {
match write_timeout {
Some(t) => match timeout(t, reserve_and_send(writer, &mut remaining, end)).await {
Ok(res) => res?,
Err(_) => Error::e_explain(
WriteTimedout,
format!("while writing h2 request body, timeout: {t:?}"),
)?,
},
None => {
reserve_and_send(writer, &mut remaining, end).await?;
}
}
if remaining.is_empty() {
return Ok(());
}
}
}
#[cfg(test)]
mod test {
use std::{sync::Arc, time::Duration};
use bytes::Bytes;
use futures::SinkExt;
use h2::frame::*;
use http::{HeaderMap, Method, Uri};
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio_stream::StreamExt;
use pingora_http::{RequestHeader, ResponseHeader};
use pingora_timeout::sleep;
use crate::protocols::{
http::v2::server::{handshake, HttpSession},
Digest,
};
#[tokio::test]
async fn test_client_write_timeout() {
let mut handles = vec![];
let (client, mut server) = duplex(65536);
handles.push(tokio::spawn(async move {
let conn = crate::connectors::http::v2::handshake(Box::new(client), 500, None)
.await
.unwrap();
let mut h2_stream = conn.spawn_stream().await.unwrap().unwrap();
h2_stream.write_timeout = Some(Duration::from_millis(100));
let mut request = RequestHeader::build("GET", b"/", None).unwrap();
request.insert_header("Host", "one.one.one.one").unwrap();
h2_stream
.write_request_header(Box::new(request), false)
.unwrap();
h2_stream.read_response_header().await.unwrap();
assert_eq!(h2_stream.response_header().unwrap().status.as_u16(), 200);
let err = h2_stream
.write_request_body(Bytes::from_static(b"client body"), true)
.await
.err()
.unwrap();
assert_eq!(err.etype(), &pingora_error::ErrorType::WriteTimedout);
}));
handles.push(tokio::spawn(async move {
let mut outbound: Vec<h2::frame::Frame<Bytes>> = Vec::new();
let mut settings = Settings::default();
settings.set_initial_window_size(Some(1));
settings.set_max_concurrent_streams(Some(1));
outbound.push(settings.into());
outbound.push(Settings::ack().into());
let headers = HeaderMap::new();
outbound.push(
Headers::new(1.into(), Pseudo::response(http::StatusCode::OK), headers).into(),
);
outbound.push(WindowUpdate::new(1.into(), 10000).into());
server.read_exact(&mut [0u8; 24]).await.unwrap();
let mut server: h2::Codec<DuplexStream, Bytes> = h2::Codec::new(server);
for _ in 0..3 {
_ = server.next().await.unwrap();
}
for (i, frame) in outbound.into_iter().enumerate() {
if i == 3 {
sleep(Duration::from_millis(200)).await;
}
_ = server.send(frame).await;
}
}));
for handle in handles {
assert!(handle.await.is_ok());
}
}
#[tokio::test]
async fn test_server_write_timeout() {
let mut handles = vec![];
let (mut client, server) = duplex(65536);
handles.push(tokio::spawn(async move {
let mut outbound: Vec<h2::frame::Frame<Bytes>> = Vec::new();
let mut settings = Settings::default();
settings.set_initial_window_size(Some(1));
settings.set_max_concurrent_streams(Some(1));
outbound.push(settings.into());
outbound.push(Settings::ack().into());
let mut headers = Headers::new(
1.into(),
Pseudo::request(
Method::GET,
Uri::from_static("https://one.one.one.one"),
None,
),
HeaderMap::new(),
);
headers.set_end_headers();
outbound.push(headers.into());
outbound.push(WindowUpdate::new(1.into(), 10000).into());
client
.write_all(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
.await
.unwrap();
let mut client: h2::Codec<DuplexStream, Bytes> = h2::Codec::new(client);
for (i, frame) in outbound.into_iter().enumerate() {
if i == 3 {
sleep(Duration::from_millis(200)).await;
}
_ = client.send(frame).await;
}
for _ in 0..3 {
_ = client.next().await.unwrap();
}
}));
let mut connection = handshake(Box::new(server), None).await.unwrap();
let digest = Arc::new(Digest::default());
while let Some(mut h2_stream) = HttpSession::from_h2_conn(&mut connection, digest.clone())
.await
.unwrap()
{
handles.push(tokio::spawn(async move {
h2_stream.set_write_timeout(Some(Duration::from_millis(100)));
let req = h2_stream.req_header();
assert_eq!(req.method, Method::GET);
let response_header = Box::new(ResponseHeader::build(200, None).unwrap());
assert!(h2_stream
.write_response_header(response_header.clone(), false)
.is_ok());
let err = h2_stream
.write_body(Bytes::from_static(b"server body"), true)
.await
.err()
.unwrap();
assert_eq!(err.etype(), &pingora_error::ErrorType::WriteTimedout);
}));
}
for handle in handles {
assert!(handle.await.is_ok());
}
}
}