nexus_async_net/rest/tokio/
connection.rs1use std::io;
4use std::pin::Pin;
5
6use nexus_net::http::{HTTP_HANDSHAKE_BUFFER, HttpError, ResponseReader};
7use nexus_net::rest::{Request, RestError, RestResponse};
8#[cfg(feature = "tls")]
9use nexus_net::tls::TlsConfig;
10use nexus_net::{ParserSink, WireStream};
11use tokio::net::TcpStream;
12
13use crate::maybe_tls::MaybeTls;
14
15async fn fill_async<W: WireStream + Unpin, P: ParserSink>(
20 s: &mut W,
21 sink: &mut P,
22 max: usize,
23) -> io::Result<usize> {
24 std::future::poll_fn(|cx| Pin::new(&mut *s).poll_fill_into(cx, sink, max)).await
25}
26
27async fn write_all_async<W: WireStream + Unpin>(s: &mut W, mut buf: &[u8]) -> io::Result<()> {
28 while !buf.is_empty() {
29 let n = std::future::poll_fn(|cx| Pin::new(&mut *s).poll_write(cx, buf)).await?;
30 if n == 0 {
31 return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
32 }
33 buf = &buf[n..];
34 }
35 Ok(())
36}
37
38async fn flush_async<W: WireStream + Unpin>(s: &mut W) -> io::Result<()> {
39 std::future::poll_fn(|cx| Pin::new(&mut *s).poll_flush(cx)).await
40}
41
42struct SliceSink<'a> {
45 buf: &'a mut [u8],
46 filled: usize,
47}
48
49impl<'a> SliceSink<'a> {
50 fn new(buf: &'a mut [u8]) -> Self {
51 Self { buf, filled: 0 }
52 }
53
54 fn data(&self) -> &[u8] {
55 &self.buf[..self.filled]
56 }
57}
58
59impl ParserSink for SliceSink<'_> {
60 fn spare(&mut self) -> &mut [u8] {
61 &mut self.buf[self.filled..]
62 }
63
64 fn filled(&mut self, n: usize) {
65 self.filled += n;
66 }
67}
68
69pub struct HttpConnectionBuilder {
75 #[cfg(feature = "tls")]
76 tls_config: Option<TlsConfig>,
77 nodelay: bool,
78 connect_timeout: Option<std::time::Duration>,
79 #[cfg(feature = "socket-opts")]
80 tcp_keepalive: Option<std::time::Duration>,
81 #[cfg(feature = "socket-opts")]
82 recv_buf_size: Option<usize>,
83 #[cfg(feature = "socket-opts")]
84 send_buf_size: Option<usize>,
85}
86
87impl HttpConnectionBuilder {
88 #[must_use]
90 pub fn new() -> Self {
91 Self {
92 #[cfg(feature = "tls")]
93 tls_config: None,
94 nodelay: false,
95 connect_timeout: None,
96 #[cfg(feature = "socket-opts")]
97 tcp_keepalive: None,
98 #[cfg(feature = "socket-opts")]
99 recv_buf_size: None,
100 #[cfg(feature = "socket-opts")]
101 send_buf_size: None,
102 }
103 }
104
105 #[cfg(feature = "tls")]
107 #[must_use]
108 pub fn tls(mut self, config: &TlsConfig) -> Self {
109 self.tls_config = Some(config.clone());
110 self
111 }
112
113 #[must_use]
115 pub fn disable_nagle(mut self) -> Self {
116 self.nodelay = true;
117 self
118 }
119
120 #[must_use]
122 pub fn connect_timeout(mut self, d: std::time::Duration) -> Self {
123 self.connect_timeout = Some(d);
124 self
125 }
126
127 #[cfg(feature = "socket-opts")]
132 #[must_use]
133 pub fn tcp_keepalive(mut self, idle: std::time::Duration) -> Self {
134 self.tcp_keepalive = Some(idle);
135 self
136 }
137
138 #[cfg(feature = "socket-opts")]
140 #[must_use]
141 pub fn recv_buffer_size(mut self, n: usize) -> Self {
142 self.recv_buf_size = Some(n);
143 self
144 }
145
146 #[cfg(feature = "socket-opts")]
148 #[must_use]
149 pub fn send_buffer_size(mut self, n: usize) -> Self {
150 self.send_buf_size = Some(n);
151 self
152 }
153
154 pub async fn connect(self, url: &str) -> Result<HttpConnection<MaybeTls>, RestError> {
156 let parsed = nexus_net::rest::parse_base_url(url)?;
157 let addr = format!("{}:{}", parsed.host, parsed.port);
158
159 let tcp = match self.connect_timeout {
160 Some(timeout) => tokio::time::timeout(timeout, TcpStream::connect(&addr))
161 .await
162 .map_err(|_| {
163 RestError::Io(std::io::Error::new(
164 std::io::ErrorKind::TimedOut,
165 "connect timeout",
166 ))
167 })??,
168 None => TcpStream::connect(&addr).await?,
169 };
170 if self.nodelay {
171 tcp.set_nodelay(true)?;
172 }
173 #[cfg(feature = "socket-opts")]
174 self.apply_socket_opts(&tcp)?;
175
176 let stream = if parsed.tls {
177 #[cfg(feature = "tls")]
178 {
179 let tls_config = match &self.tls_config {
180 Some(c) => c.clone(),
181 None => TlsConfig::new().map_err(RestError::Tls)?,
182 };
183
184 let connector =
185 tokio_rustls::TlsConnector::from(tls_config.client_config().clone());
186 let server_name =
187 tokio_rustls::rustls::pki_types::ServerName::try_from(parsed.host.to_owned())
188 .map_err(|_| {
189 RestError::InvalidUrl(format!("invalid hostname: {}", parsed.host))
190 })?;
191 let tls_stream = connector
192 .connect(server_name, tcp)
193 .await
194 .map_err(RestError::Io)?;
195 MaybeTls::Tls(Box::new(tls_stream))
196 }
197 #[cfg(not(feature = "tls"))]
198 {
199 return Err(RestError::TlsNotEnabled);
200 }
201 } else {
202 MaybeTls::Plain(tcp)
203 };
204
205 Ok(HttpConnection {
206 stream,
207 poisoned: false,
208 })
209 }
210
211 pub fn connect_with<S: WireStream + Unpin>(self, stream: S) -> HttpConnection<S> {
213 HttpConnection {
214 stream,
215 poisoned: false,
216 }
217 }
218}
219
220#[cfg(feature = "socket-opts")]
221impl HttpConnectionBuilder {
222 fn apply_socket_opts(&self, tcp: &TcpStream) -> Result<(), RestError> {
223 let sock = socket2::SockRef::from(tcp);
224 if let Some(idle) = self.tcp_keepalive {
225 let keepalive = socket2::TcpKeepalive::new().with_time(idle);
226 sock.set_tcp_keepalive(&keepalive).map_err(RestError::Io)?;
227 }
228 if let Some(size) = self.recv_buf_size {
229 sock.set_recv_buffer_size(size).map_err(RestError::Io)?;
230 }
231 if let Some(size) = self.send_buf_size {
232 sock.set_send_buffer_size(size).map_err(RestError::Io)?;
233 }
234 Ok(())
235 }
236}
237
238impl Default for HttpConnectionBuilder {
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244pub struct HttpConnection<S> {
274 stream: S,
275 poisoned: bool,
276}
277
278impl<S: WireStream + Unpin> HttpConnection<S> {
281 pub fn new(stream: S) -> Self {
283 Self {
284 stream,
285 poisoned: false,
286 }
287 }
288
289 #[must_use]
291 pub fn builder() -> HttpConnectionBuilder {
292 HttpConnectionBuilder::new()
293 }
294
295 #[allow(clippy::needless_pass_by_value)] pub async fn send<'r>(
301 &mut self,
302 req: Request<'_>,
303 reader: &'r mut ResponseReader,
304 ) -> Result<RestResponse<'r>, RestError> {
305 if self.poisoned {
306 return Err(RestError::ConnectionPoisoned);
307 }
308
309 self.poisoned = true;
317
318 if let Err(e) = write_all_async(&mut self.stream, req.as_bytes()).await {
320 return Err(RestError::Io(e));
321 }
322 if let Err(e) = flush_async(&mut self.stream).await {
323 return Err(RestError::Io(e));
324 }
325
326 let resp = match self.read_response(reader).await {
328 Ok(resp) => resp,
329 Err(e) => return Err(self.diagnose_error(e)),
330 };
331
332 self.poisoned = false;
334 Ok(resp)
335 }
336
337 pub fn is_poisoned(&self) -> bool {
339 self.poisoned
340 }
341
342 #[cold]
348 #[allow(clippy::unused_self)] fn diagnose_error(&self, err: RestError) -> RestError {
350 if let RestError::Io(ref io_err) = err
351 && (io_err.kind() == std::io::ErrorKind::TimedOut
352 || io_err.kind() == std::io::ErrorKind::WouldBlock)
353 {
354 return RestError::ConnectionStale;
355 }
356 err
357 }
358
359 pub fn stream(&self) -> &S {
361 &self.stream
362 }
363
364 pub fn stream_mut(&mut self) -> &mut S {
366 &mut self.stream
367 }
368
369 async fn read_response<'r>(
374 &mut self,
375 reader: &'r mut ResponseReader,
376 ) -> Result<RestResponse<'r>, RestError> {
377 reader.consume_response();
378
379 loop {
382 match reader.next() {
383 Ok(Some(_)) => break,
384 Ok(None) => {}
385 Err(e) => {
386 self.poisoned = true;
387 return Err(e.into());
388 }
389 }
390 if reader.spare().is_empty() {
395 self.poisoned = true;
396 return Err(RestError::Http(HttpError::Malformed(
397 "response head exceeds reader capacity",
398 )));
399 }
400 match fill_async(&mut self.stream, reader, HTTP_HANDSHAKE_BUFFER).await {
401 Ok(0) => {
402 self.poisoned = true;
403 return Err(RestError::ConnectionClosed(
404 "server closed before response headers",
405 ));
406 }
407 Ok(_) => {}
408 Err(e) => {
409 self.poisoned = true;
410 return Err(RestError::Io(e));
411 }
412 }
413 }
414
415 let status = reader.status();
417
418 if matches!(status, 100..=199 | 204 | 304) {
419 reader.set_body_consumed(0);
420 return Ok(RestResponse::new(status, 0, reader));
421 }
422
423 if reader.is_chunked() {
424 let body = self.read_chunked_body(reader).await?;
425 reader.set_body_consumed(reader.body_remaining());
426 return Ok(RestResponse::new_chunked(status, body, reader));
427 }
428
429 let content_length = match reader.content_length() {
430 Some(Ok(n)) => n,
431 Some(Err(())) => {
432 return Err(RestError::Http(HttpError::Malformed(
433 "invalid Content-Length header",
434 )));
435 }
436 None => {
437 self.poisoned = true;
440 return Err(RestError::Http(HttpError::Malformed(
441 "no Content-Length and not chunked",
442 )));
443 }
444 };
445
446 let max_body = reader.max_body_size_limit();
447 if max_body > 0 && content_length > max_body {
448 self.poisoned = true;
449 return Err(RestError::BodyTooLarge {
450 size: content_length,
451 max: max_body,
452 });
453 }
454
455 while reader.body_remaining() < content_length {
457 if reader.spare().is_empty() {
461 self.poisoned = true;
462 let needed = content_length - reader.body_remaining();
463 return Err(RestError::Http(HttpError::BufferFull {
464 needed,
465 available: 0,
466 }));
467 }
468 match fill_async(&mut self.stream, reader, HTTP_HANDSHAKE_BUFFER).await {
469 Ok(0) => {
470 self.poisoned = true;
471 return Err(RestError::ConnectionClosed(
472 "server closed during body read",
473 ));
474 }
475 Ok(_) => {}
476 Err(e) => {
477 self.poisoned = true;
478 return Err(RestError::Io(e));
479 }
480 }
481 }
482
483 reader.set_body_consumed(content_length);
484 Ok(RestResponse::new(status, content_length, reader))
485 }
486
487 async fn read_chunked_body(&mut self, reader: &ResponseReader) -> Result<Vec<u8>, RestError> {
488 use nexus_net::http::ChunkedDecoder;
489
490 let max_body = reader.max_body_size_limit();
491 let mut decoder = ChunkedDecoder::new();
492 let mut body = Vec::with_capacity(HTTP_HANDSHAKE_BUFFER);
493 let mut wire_buf = [0u8; HTTP_HANDSHAKE_BUFFER];
494 let mut decode_buf = [0u8; HTTP_HANDSHAKE_BUFFER];
495
496 let remainder = reader.remainder();
498 if !remainder.is_empty() {
499 let mut pos = 0;
500 while pos < remainder.len() && !decoder.is_done() {
501 let (consumed, produced) = decoder
502 .decode(&remainder[pos..], &mut decode_buf)
503 .map_err(RestError::Http)?;
504 pos += consumed;
505 if produced > 0 {
506 body.extend_from_slice(&decode_buf[..produced]);
507 if max_body > 0 && body.len() > max_body {
508 self.poisoned = true;
509 return Err(RestError::BodyTooLarge {
510 size: body.len(),
511 max: max_body,
512 });
513 }
514 }
515 if consumed == 0 && produced == 0 {
516 break;
517 }
518 }
519 }
520
521 while !decoder.is_done() {
522 let mut sink = SliceSink::new(&mut wire_buf);
523 let cap = sink.spare().len();
524 let n = match fill_async(&mut self.stream, &mut sink, cap).await {
525 Ok(0) => {
526 self.poisoned = true;
527 return Err(RestError::ConnectionClosed(
528 "server closed during chunked body",
529 ));
530 }
531 Ok(n) => n,
532 Err(e) => {
533 self.poisoned = true;
534 return Err(RestError::Io(e));
535 }
536 };
537
538 let chunk = &sink.data()[..n];
539 let mut pos = 0;
540 while pos < n && !decoder.is_done() {
541 let (consumed, produced) = decoder
542 .decode(&chunk[pos..n], &mut decode_buf)
543 .map_err(RestError::Http)?;
544 pos += consumed;
545 if produced > 0 {
546 body.extend_from_slice(&decode_buf[..produced]);
547 if max_body > 0 && body.len() > max_body {
548 self.poisoned = true;
549 return Err(RestError::BodyTooLarge {
550 size: body.len(),
551 max: max_body,
552 });
553 }
554 }
555 if consumed == 0 && produced == 0 {
556 break;
557 }
558 }
559 }
560
561 Ok(body)
562 }
563}
564
565#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::AsyncReadAdapter;
573 use std::io::Cursor;
574 use std::pin::Pin;
575 use std::task::{Context, Poll};
576 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
577
578 struct MockAsyncStream {
579 written: Vec<u8>,
580 response: Cursor<Vec<u8>>,
581 }
582
583 impl MockAsyncStream {
584 fn new(response: &[u8]) -> Self {
585 Self {
586 written: Vec::new(),
587 response: Cursor::new(response.to_vec()),
588 }
589 }
590
591 fn written_str(&self) -> &str {
592 std::str::from_utf8(&self.written).unwrap()
593 }
594 }
595
596 impl AsyncRead for MockAsyncStream {
597 fn poll_read(
598 mut self: Pin<&mut Self>,
599 _cx: &mut Context<'_>,
600 buf: &mut ReadBuf<'_>,
601 ) -> Poll<std::io::Result<()>> {
602 let n = std::io::Read::read(&mut self.response, buf.initialize_unfilled())?;
603 buf.advance(n);
604 Poll::Ready(Ok(()))
605 }
606 }
607
608 impl AsyncWrite for MockAsyncStream {
609 fn poll_write(
610 mut self: Pin<&mut Self>,
611 _cx: &mut Context<'_>,
612 buf: &[u8],
613 ) -> Poll<std::io::Result<usize>> {
614 self.written.extend_from_slice(buf);
615 Poll::Ready(Ok(buf.len()))
616 }
617 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
618 Poll::Ready(Ok(()))
619 }
620 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
621 Poll::Ready(Ok(()))
622 }
623 }
624
625 fn ok_response(body: &str) -> Vec<u8> {
626 format!(
627 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
628 body.len(),
629 body
630 )
631 .into_bytes()
632 }
633
634 #[tokio::test]
635 async fn async_get_request() {
636 use nexus_net::rest::RequestWriter;
637
638 let mock = AsyncReadAdapter::new(MockAsyncStream::new(&ok_response(r#"{"ok":true}"#)));
639 let mut writer = RequestWriter::new("api.example.com").unwrap();
640 let mut reader = ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
641 let mut conn = HttpConnection::new(mock);
642
643 let req = writer.get("/status").finish().unwrap();
644 let resp = conn.send(req, &mut reader).await.unwrap();
645 assert_eq!(resp.status(), 200);
646 assert_eq!(resp.body_str().unwrap(), r#"{"ok":true}"#);
647
648 let written = conn.stream().get_ref().written_str();
649 assert!(written.starts_with("GET /status HTTP/1.1\r\n"));
650 assert!(written.contains("Host: api.example.com\r\n"));
651 }
652
653 #[tokio::test]
654 async fn async_post_with_body() {
655 use nexus_net::rest::RequestWriter;
656
657 let mock = AsyncReadAdapter::new(MockAsyncStream::new(&ok_response(r#"{"filled":true}"#)));
658 let mut writer = RequestWriter::new("api.example.com").unwrap();
659 let mut reader = ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
660 let mut conn = HttpConnection::new(mock);
661
662 let body = br#"{"symbol":"BTC","side":"buy"}"#;
663 let req = writer.post("/order").body(body).finish().unwrap();
664 let resp = conn.send(req, &mut reader).await.unwrap();
665 assert_eq!(resp.status(), 200);
666
667 let written = conn.stream().get_ref().written_str();
668 assert!(written.contains(&format!("Content-Length: {}\r\n", body.len())));
669 assert!(written.ends_with(std::str::from_utf8(body).unwrap()));
670 }
671
672 #[tokio::test]
673 async fn async_response_headers() {
674 use nexus_net::rest::RequestWriter;
675
676 let resp_bytes = b"HTTP/1.1 200 OK\r\nX-Request-Id: abc\r\nContent-Length: 2\r\n\r\n{}";
677 let mock = AsyncReadAdapter::new(MockAsyncStream::new(resp_bytes));
678 let mut writer = RequestWriter::new("host").unwrap();
679 let mut reader = ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
680 let mut conn = HttpConnection::new(mock);
681
682 let req = writer.get("/test").finish().unwrap();
683 let resp = conn.send(req, &mut reader).await.unwrap();
684 assert_eq!(resp.header("X-Request-Id"), Some("abc"));
685 }
686
687 #[tokio::test]
688 async fn async_connection_poisoned() {
689 use nexus_net::rest::RequestWriter;
690
691 let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\npartial";
693 let mock = AsyncReadAdapter::new(MockAsyncStream::new(resp_bytes));
694 let mut writer = RequestWriter::new("host").unwrap();
695 let mut reader = ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
696 let mut conn = HttpConnection::new(mock);
697
698 let req = writer.get("/test").finish().unwrap();
699 let result = conn.send(req, &mut reader).await;
700 assert!(matches!(result, Err(RestError::ConnectionClosed(_))));
701
702 let req = writer.get("/test2").finish().unwrap();
703 let result = conn.send(req, &mut reader).await;
704 assert!(matches!(result, Err(RestError::ConnectionPoisoned)));
705 }
706
707 #[tokio::test]
708 async fn async_chunked_decoded() {
709 use nexus_net::rest::RequestWriter;
710
711 let resp_bytes =
712 b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n0\r\n\r\n";
713 let mock = AsyncReadAdapter::new(MockAsyncStream::new(resp_bytes));
714 let mut writer = RequestWriter::new("host").unwrap();
715 let mut reader = ResponseReader::new(HTTP_HANDSHAKE_BUFFER);
716 let mut conn = HttpConnection::new(mock);
717
718 let req = writer.get("/test").finish().unwrap();
719 let resp = conn.send(req, &mut reader).await.unwrap();
720 assert_eq!(resp.body_str().unwrap(), "hello");
721 }
722}