use nexus_net::http::{HttpError, ResponseReader};
use nexus_net::rest::{Request, RestError, RestResponse};
#[cfg(feature = "tls")]
use nexus_net::tls::TlsConfig;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::maybe_tls::MaybeTls;
pub struct AsyncHttpConnectionBuilder {
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
nodelay: bool,
connect_timeout: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
tcp_keepalive: Option<std::time::Duration>,
#[cfg(feature = "socket-opts")]
recv_buf_size: Option<usize>,
#[cfg(feature = "socket-opts")]
send_buf_size: Option<usize>,
}
impl AsyncHttpConnectionBuilder {
#[must_use]
pub fn new() -> Self {
Self {
#[cfg(feature = "tls")]
tls_config: None,
nodelay: false,
connect_timeout: None,
#[cfg(feature = "socket-opts")]
tcp_keepalive: None,
#[cfg(feature = "socket-opts")]
recv_buf_size: None,
#[cfg(feature = "socket-opts")]
send_buf_size: None,
}
}
#[cfg(feature = "tls")]
#[must_use]
pub fn tls(mut self, config: &TlsConfig) -> Self {
self.tls_config = Some(config.clone());
self
}
#[must_use]
pub fn disable_nagle(mut self) -> Self {
self.nodelay = true;
self
}
#[must_use]
pub fn connect_timeout(mut self, d: std::time::Duration) -> Self {
self.connect_timeout = Some(d);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
self.tcp_keepalive = Some(idle);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn recv_buffer_size(mut self, n: usize) -> Self {
self.recv_buf_size = Some(n);
self
}
#[cfg(feature = "socket-opts")]
#[must_use]
pub fn send_buffer_size(mut self, n: usize) -> Self {
self.send_buf_size = Some(n);
self
}
pub async fn connect(self, url: &str) -> Result<AsyncHttpConnection<MaybeTls>, RestError> {
let parsed = nexus_net::rest::parse_base_url(url)?;
let addr = format!("{}:{}", parsed.host, parsed.port);
let tcp = match self.connect_timeout {
Some(timeout) => tokio::time::timeout(timeout, TcpStream::connect(&addr))
.await
.map_err(|_| {
RestError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"connect timeout",
))
})??,
None => TcpStream::connect(&addr).await?,
};
if self.nodelay {
tcp.set_nodelay(true)?;
}
#[cfg(feature = "socket-opts")]
self.apply_socket_opts(&tcp)?;
let stream = if parsed.tls {
#[cfg(feature = "tls")]
{
let tls_config = match &self.tls_config {
Some(c) => c.clone(),
None => TlsConfig::new().map_err(RestError::Tls)?,
};
let connector =
tokio_rustls::TlsConnector::from(tls_config.client_config().clone());
let server_name =
tokio_rustls::rustls::pki_types::ServerName::try_from(parsed.host.to_owned())
.map_err(|_| {
RestError::InvalidUrl(format!("invalid hostname: {}", parsed.host))
})?;
let tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(RestError::Io)?;
MaybeTls::Tls(Box::new(tls_stream))
}
#[cfg(not(feature = "tls"))]
{
return Err(RestError::TlsNotEnabled);
}
} else {
MaybeTls::Plain(tcp)
};
Ok(AsyncHttpConnection {
stream,
poisoned: false,
})
}
pub fn connect_with<S: AsyncRead + AsyncWrite + Unpin>(
self,
stream: S,
) -> AsyncHttpConnection<S> {
AsyncHttpConnection {
stream,
poisoned: false,
}
}
}
#[cfg(feature = "socket-opts")]
impl AsyncHttpConnectionBuilder {
fn apply_socket_opts(&self, tcp: &TcpStream) -> Result<(), RestError> {
let sock = socket2::SockRef::from(tcp);
if let Some(idle) = self.tcp_keepalive {
let keepalive = socket2::TcpKeepalive::new().with_time(idle);
sock.set_tcp_keepalive(&keepalive).map_err(RestError::Io)?;
}
if let Some(size) = self.recv_buf_size {
sock.set_recv_buffer_size(size).map_err(RestError::Io)?;
}
if let Some(size) = self.send_buf_size {
sock.set_send_buffer_size(size).map_err(RestError::Io)?;
}
Ok(())
}
}
impl Default for AsyncHttpConnectionBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct AsyncHttpConnection<S> {
stream: S,
poisoned: bool,
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncHttpConnection<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
poisoned: false,
}
}
#[must_use]
pub fn builder() -> AsyncHttpConnectionBuilder {
AsyncHttpConnectionBuilder::new()
}
#[allow(clippy::needless_pass_by_value)] pub async fn send<'r>(
&mut self,
req: Request<'_>,
reader: &'r mut ResponseReader,
) -> Result<RestResponse<'r>, RestError> {
if self.poisoned {
return Err(RestError::ConnectionPoisoned);
}
if let Err(e) = self.stream.write_all(req.as_bytes()).await {
self.poisoned = true;
return Err(RestError::Io(e));
}
if let Err(e) = self.stream.flush().await {
self.poisoned = true;
return Err(RestError::Io(e));
}
match self.read_response(reader).await {
Ok(resp) => Ok(resp),
Err(e) => {
self.poisoned = true;
Err(self.diagnose_error(e))
}
}
}
pub fn is_poisoned(&self) -> bool {
self.poisoned
}
#[cold]
#[allow(clippy::unused_self)] fn diagnose_error(&self, err: RestError) -> RestError {
if let RestError::Io(ref io_err) = err {
if io_err.kind() == std::io::ErrorKind::TimedOut
|| io_err.kind() == std::io::ErrorKind::WouldBlock
{
return RestError::ConnectionStale;
}
}
err
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
async fn read_response<'r>(
&mut self,
reader: &'r mut ResponseReader,
) -> Result<RestResponse<'r>, RestError> {
reader.consume_response();
let mut tmp = [0u8; 4096];
loop {
match reader.next() {
Ok(Some(_)) => break,
Ok(None) => {}
Err(e) => {
self.poisoned = true;
return Err(e.into());
}
}
match self.stream.read(&mut tmp).await {
Ok(0) => {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed before response headers",
));
}
Ok(n) => {
if let Err(e) = reader.read(&tmp[..n]) {
self.poisoned = true;
return Err(e.into());
}
}
Err(e) => {
self.poisoned = true;
return Err(RestError::Io(e));
}
}
}
let status = reader.status();
if matches!(status, 100..=199 | 204 | 304) {
reader.set_body_consumed(0);
return Ok(RestResponse::new(status, 0, reader));
}
if reader.is_chunked() {
let body = self.read_chunked_body(reader).await?;
reader.set_body_consumed(reader.body_remaining());
return Ok(RestResponse::new_chunked(status, body, reader));
}
let content_length = match reader.content_length() {
Some(Ok(n)) => n,
Some(Err(())) => {
return Err(RestError::Http(HttpError::Malformed(
"invalid Content-Length header",
)));
}
None => {
self.poisoned = true;
return Err(RestError::Http(HttpError::Malformed(
"no Content-Length and not chunked",
)));
}
};
let max_body = reader.max_body_size_limit();
if max_body > 0 && content_length > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: content_length,
max: max_body,
});
}
while reader.body_remaining() < content_length {
match self.stream.read(&mut tmp).await {
Ok(0) => {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed during body read",
));
}
Ok(n) => {
if let Err(e) = reader.read(&tmp[..n]) {
self.poisoned = true;
return Err(e.into());
}
}
Err(e) => {
self.poisoned = true;
return Err(RestError::Io(e));
}
}
}
reader.set_body_consumed(content_length);
Ok(RestResponse::new(status, content_length, reader))
}
async fn read_chunked_body(&mut self, reader: &ResponseReader) -> Result<Vec<u8>, RestError> {
use nexus_net::http::ChunkedDecoder;
let max_body = reader.max_body_size_limit();
let mut decoder = ChunkedDecoder::new();
let mut body = Vec::with_capacity(4096);
let mut wire_buf = [0u8; 4096];
let mut decode_buf = [0u8; 4096];
let remainder = reader.remainder();
if !remainder.is_empty() {
let mut pos = 0;
while pos < remainder.len() && !decoder.is_done() {
let (consumed, produced) = decoder
.decode(&remainder[pos..], &mut decode_buf)
.map_err(RestError::Http)?;
pos += consumed;
if produced > 0 {
body.extend_from_slice(&decode_buf[..produced]);
if max_body > 0 && body.len() > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: body.len(),
max: max_body,
});
}
}
if consumed == 0 && produced == 0 {
break;
}
}
}
while !decoder.is_done() {
let n = match self.stream.read(&mut wire_buf).await {
Ok(0) => {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed during chunked body",
));
}
Ok(n) => n,
Err(e) => {
self.poisoned = true;
return Err(RestError::Io(e));
}
};
let mut pos = 0;
while pos < n && !decoder.is_done() {
let (consumed, produced) = decoder
.decode(&wire_buf[pos..n], &mut decode_buf)
.map_err(RestError::Http)?;
pos += consumed;
if produced > 0 {
body.extend_from_slice(&decode_buf[..produced]);
if max_body > 0 && body.len() > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: body.len(),
max: max_body,
});
}
}
if consumed == 0 && produced == 0 {
break;
}
}
}
Ok(body)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
struct MockAsyncStream {
written: Vec<u8>,
response: Cursor<Vec<u8>>,
}
impl MockAsyncStream {
fn new(response: &[u8]) -> Self {
Self {
written: Vec::new(),
response: Cursor::new(response.to_vec()),
}
}
fn written_str(&self) -> &str {
std::str::from_utf8(&self.written).unwrap()
}
}
impl AsyncRead for MockAsyncStream {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let n = std::io::Read::read(&mut self.response, buf.initialize_unfilled())?;
buf.advance(n);
Poll::Ready(Ok(()))
}
}
impl AsyncWrite for MockAsyncStream {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.written.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn ok_response(body: &str) -> Vec<u8> {
format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
)
.into_bytes()
}
#[tokio::test]
async fn async_get_request() {
use nexus_net::rest::RequestWriter;
let mock = MockAsyncStream::new(&ok_response(r#"{"ok":true}"#));
let mut writer = RequestWriter::new("api.example.com").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = AsyncHttpConnection::new(mock);
let req = writer.get("/status").finish().unwrap();
let resp = conn.send(req, &mut reader).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), r#"{"ok":true}"#);
let written = conn.stream().written_str();
assert!(written.starts_with("GET /status HTTP/1.1\r\n"));
assert!(written.contains("Host: api.example.com\r\n"));
}
#[tokio::test]
async fn async_post_with_body() {
use nexus_net::rest::RequestWriter;
let mock = MockAsyncStream::new(&ok_response(r#"{"filled":true}"#));
let mut writer = RequestWriter::new("api.example.com").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = AsyncHttpConnection::new(mock);
let body = br#"{"symbol":"BTC","side":"buy"}"#;
let req = writer.post("/order").body(body).finish().unwrap();
let resp = conn.send(req, &mut reader).await.unwrap();
assert_eq!(resp.status(), 200);
let written = conn.stream().written_str();
assert!(written.contains(&format!("Content-Length: {}\r\n", body.len())));
assert!(written.ends_with(std::str::from_utf8(body).unwrap()));
}
#[tokio::test]
async fn async_response_headers() {
use nexus_net::rest::RequestWriter;
let resp_bytes = b"HTTP/1.1 200 OK\r\nX-Request-Id: abc\r\nContent-Length: 2\r\n\r\n{}";
let mock = MockAsyncStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = AsyncHttpConnection::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).await.unwrap();
assert_eq!(resp.header("X-Request-Id"), Some("abc"));
}
#[tokio::test]
async fn async_connection_poisoned() {
use nexus_net::rest::RequestWriter;
let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\npartial";
let mock = MockAsyncStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = AsyncHttpConnection::new(mock);
let req = writer.get("/test").finish().unwrap();
let result = conn.send(req, &mut reader).await;
assert!(matches!(result, Err(RestError::ConnectionClosed(_))));
let req = writer.get("/test2").finish().unwrap();
let result = conn.send(req, &mut reader).await;
assert!(matches!(result, Err(RestError::ConnectionPoisoned)));
}
#[tokio::test]
async fn async_chunked_decoded() {
use nexus_net::rest::RequestWriter;
let resp_bytes =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
let mock = MockAsyncStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = AsyncHttpConnection::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).await.unwrap();
assert_eq!(resp.body_str().unwrap(), "hello");
}
}