Skip to main content

nexus_web/rest/
connection.rs

1//! HTTP/1.1 keep-alive connection — pure transport.
2//!
3//! `Client<S>` is a thin I/O wrapper. It sends request bytes and
4//! reads response bytes. All protocol logic (request encoding, response
5//! parsing) lives in [`RequestWriter`](super::RequestWriter) and
6//! [`ResponseReader`](crate::http::ResponseReader).
7
8use std::time::Duration;
9
10use super::error::RestError;
11use super::request::RequestWriter;
12
13use super::request::Request;
14use super::response::RestResponse;
15use crate::http::{HttpError, ResponseReader};
16use std::io::{self, Read, Write};
17
18#[cfg(feature = "tls")]
19use nexus_net::tls::TlsConfig;
20
21// =============================================================================
22// URL parsing
23// =============================================================================
24
25/// Parsed HTTP URL.
26#[non_exhaustive]
27pub struct ParsedUrl<'a> {
28    /// Whether the URL is `https://` (true) or `http://` (false).
29    pub tls: bool,
30    /// Host portion (no port).
31    pub host: &'a str,
32    /// Port — explicit if present, otherwise the scheme default
33    /// (80 for http, 443 for https).
34    pub port: u16,
35    /// Path portion (everything after the host:port, including the
36    /// leading `/`). Empty if the URL had no path.
37    pub path: &'a str,
38}
39
40impl ParsedUrl<'_> {
41    /// Host header value: includes port if non-default.
42    pub fn host_header(&self) -> String {
43        let default = if self.tls { 443 } else { 80 };
44        if self.port == default {
45            self.host.to_string()
46        } else {
47            format!("{}:{}", self.host, self.port)
48        }
49    }
50}
51
52/// Parse an `http://` or `https://` URL into its scheme, host, port,
53/// and path. Returns [`RestError::InvalidUrl`] on a malformed input or
54/// missing scheme.
55pub fn parse_base_url(url: &str) -> Result<ParsedUrl<'_>, RestError> {
56    let (tls, rest) = if let Some(r) = url.strip_prefix("https://") {
57        (true, r)
58    } else if let Some(r) = url.strip_prefix("http://") {
59        (false, r)
60    } else {
61        return Err(RestError::InvalidUrl(url.to_string()));
62    };
63
64    // Split host:port from path
65    let (host_port, path) = rest
66        .find('/')
67        .map_or((rest, ""), |i| (&rest[..i], &rest[i..]));
68
69    if host_port.is_empty() {
70        return Err(RestError::InvalidUrl(format!("empty host: {url}")));
71    }
72
73    let default_port = if tls { 443 } else { 80 };
74
75    // IPv6 bracket notation: [::1]:8080
76    let (host, port) = if host_port.starts_with('[') {
77        match host_port.find(']') {
78            Some(end) => {
79                let h = &host_port[1..end];
80                let rest = &host_port[end + 1..];
81                if let Some(port_str) = rest.strip_prefix(':') {
82                    let p = port_str
83                        .parse::<u16>()
84                        .map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
85                    (h, p)
86                } else {
87                    (h, default_port)
88                }
89            }
90            None => return Err(RestError::InvalidUrl(format!("unclosed bracket: {url}"))),
91        }
92    } else {
93        match host_port.rfind(':') {
94            None => (host_port, default_port),
95            Some(i) => {
96                let port_str = &host_port[i + 1..];
97                if port_str.is_empty() {
98                    // Trailing colon with no port: "host:" → strip colon
99                    (&host_port[..i], default_port)
100                } else {
101                    let p = port_str
102                        .parse::<u16>()
103                        .map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
104                    (&host_port[..i], p)
105                }
106            }
107        }
108    };
109
110    Ok(ParsedUrl {
111        tls,
112        host,
113        port,
114        path,
115    })
116}
117
118// =============================================================================
119// ClientBuilder
120// =============================================================================
121
122/// Builder for [`Client`].
123///
124/// Configures transport: TLS, timeouts, socket options.
125/// Protocol configuration (host, headers, base path) lives on
126/// [`RequestWriter`].
127pub struct ClientBuilder {
128    #[cfg(feature = "tls")]
129    tls_config: Option<TlsConfig>,
130    tcp_nodelay: bool,
131    connect_timeout: Option<Duration>,
132    read_timeout: Option<Duration>,
133}
134
135impl ClientBuilder {
136    /// Create a new builder with defaults.
137    #[must_use]
138    pub fn new() -> Self {
139        Self {
140            #[cfg(feature = "tls")]
141            tls_config: None,
142            tcp_nodelay: false,
143            connect_timeout: None,
144            read_timeout: None,
145        }
146    }
147
148    /// Set a custom TLS configuration.
149    ///
150    /// If not set, `https://` URLs use [`TlsConfig::new()`] (system defaults).
151    #[cfg(feature = "tls")]
152    #[must_use]
153    pub fn tls(mut self, config: &TlsConfig) -> Self {
154        self.tls_config = Some(config.clone());
155        self
156    }
157
158    /// Set `TCP_NODELAY` (disable Nagle's algorithm).
159    #[must_use]
160    pub fn disable_nagle(mut self) -> Self {
161        self.tcp_nodelay = true;
162        self
163    }
164
165    /// TCP connect timeout.
166    #[must_use]
167    pub fn connect_timeout(mut self, d: Duration) -> Self {
168        self.connect_timeout = Some(d);
169        self
170    }
171
172    /// Socket read timeout.
173    #[must_use]
174    pub fn read_timeout(mut self, d: Duration) -> Self {
175        self.read_timeout = Some(d);
176        self
177    }
178
179    /// Connect to an HTTP(S) endpoint (blocking).
180    ///
181    /// TLS is auto-detected from the URL scheme. When the `tls` feature is
182    /// enabled, returns `Client<MaybeTls<TcpStream>>` — `https://` uses
183    /// `MaybeTls::Tls`, `http://` uses `MaybeTls::Plain`. Without the `tls`
184    /// feature, returns `Client<TcpStream>` and errors on `https://`.
185    #[cfg(feature = "tls")]
186    pub fn connect(
187        self,
188        url: &str,
189    ) -> Result<Client<nexus_net::MaybeTls<std::net::TcpStream>>, RestError> {
190        let parsed = parse_base_url(url)?;
191        let addr = format!("{}:{}", parsed.host, parsed.port);
192
193        let tcp = match self.connect_timeout {
194            Some(timeout) => {
195                let addrs: Vec<std::net::SocketAddr> =
196                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
197                        .map_err(RestError::Io)?
198                        .collect();
199                let first = addrs
200                    .first()
201                    .ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
202                std::net::TcpStream::connect_timeout(first, timeout)?
203            }
204            None => std::net::TcpStream::connect(&addr)?,
205        };
206
207        if self.tcp_nodelay {
208            tcp.set_nodelay(true)?;
209        }
210        if let Some(timeout) = self.read_timeout {
211            tcp.set_read_timeout(Some(timeout))?;
212        }
213
214        let stream = if parsed.tls {
215            let config = match self.tls_config {
216                Some(c) => c,
217                None => TlsConfig::new().map_err(RestError::Tls)?,
218            };
219            let codec = nexus_net::tls::TlsCodec::new(&config, parsed.host)?;
220            let tls = nexus_net::tls::TlsStream::connect(tcp, codec).map_err(RestError::Tls)?;
221            nexus_net::MaybeTls::Tls(Box::new(tls))
222        } else {
223            nexus_net::MaybeTls::Plain(tcp)
224        };
225
226        Ok(Client {
227            stream,
228            poisoned: false,
229        })
230    }
231
232    /// Connect to an HTTP(S) endpoint (blocking, no TLS feature).
233    #[cfg(not(feature = "tls"))]
234    pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, RestError> {
235        let parsed = parse_base_url(url)?;
236        if parsed.tls {
237            return Err(RestError::TlsNotEnabled);
238        }
239        let addr = format!("{}:{}", parsed.host, parsed.port);
240
241        let tcp = match self.connect_timeout {
242            Some(timeout) => {
243                let addrs: Vec<std::net::SocketAddr> =
244                    std::net::ToSocketAddrs::to_socket_addrs(&addr)
245                        .map_err(RestError::Io)?
246                        .collect();
247                let first = addrs
248                    .first()
249                    .ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
250                std::net::TcpStream::connect_timeout(first, timeout)?
251            }
252            None => std::net::TcpStream::connect(&addr)?,
253        };
254
255        if self.tcp_nodelay {
256            tcp.set_nodelay(true)?;
257        }
258        if let Some(timeout) = self.read_timeout {
259            tcp.set_read_timeout(Some(timeout))?;
260        }
261
262        Ok(Client {
263            stream: tcp,
264            poisoned: false,
265        })
266    }
267
268    /// Connect using a pre-connected socket.
269    ///
270    /// The stream must already handle TLS if connecting to `https://`.
271    /// For example, pass a `TlsStream<TcpStream>` or `MaybeTls<TcpStream>`.
272    pub fn connect_with<S: Read + Write>(
273        self,
274        stream: S,
275        url: &str,
276    ) -> Result<Client<S>, RestError> {
277        // Validate the URL even though we don't use it for connection —
278        // catches malformed URLs early rather than at first request.
279        parse_base_url(url)?;
280        Ok(Client::new(stream))
281    }
282
283    /// Create a `RequestWriter` configured for this URL.
284    ///
285    /// Convenience: extracts host and path from the URL to create
286    /// a writer with the correct Host header and base path.
287    pub fn writer_for(url: &str) -> Result<RequestWriter, RestError> {
288        let parsed = parse_base_url(url)?;
289        let host_header = parsed.host_header();
290        let mut writer = RequestWriter::new(&host_header)?;
291        if !parsed.path.is_empty() {
292            writer.set_base_path(parsed.path)?;
293        }
294        Ok(writer)
295    }
296}
297
298impl Default for ClientBuilder {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304// =============================================================================
305// Client — pure transport
306// =============================================================================
307
308/// HTTP/1.1 keep-alive connection — pure transport.
309///
310/// Sends request bytes and reads response bytes. All protocol logic
311/// lives in [`RequestWriter`] (request encoding) and
312/// [`ResponseReader`](crate::http::ResponseReader) (response parsing).
313///
314/// # Usage
315///
316/// ```ignore
317/// use nexus_web::rest::{Client, RequestWriter};
318/// use nexus_web::http::ResponseReader;
319/// use nexus_web::tls::TlsConfig;
320///
321/// // Protocol (sans-IO)
322/// let mut writer = RequestWriter::new("api.binance.com").unwrap();
323/// let mut reader = ResponseReader::new(32 * 1024);
324///
325/// // Transport
326/// let tls = TlsConfig::new()?;
327/// let mut conn = Client::builder().tls(&tls).connect("https://api.binance.com")?;
328///
329/// // Build + send
330/// let req = writer.get("/orders").query("symbol", "BTC").finish()?;
331/// let resp = conn.send(req, &mut reader)?;
332/// ```
333pub struct Client<S> {
334    pub(crate) stream: S,
335    pub(crate) poisoned: bool,
336}
337
338impl Client<std::net::TcpStream> {
339    /// Create a transport builder.
340    #[must_use]
341    pub fn builder() -> ClientBuilder {
342        ClientBuilder::new()
343    }
344
345    /// Set read timeout on the socket.
346    ///
347    /// **Strongly recommended for production.** Without a timeout, reads
348    /// block indefinitely on stale connections.
349    pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> Result<(), RestError> {
350        self.stream.set_read_timeout(timeout).map_err(RestError::Io)
351    }
352
353    /// Set TCP keepalive on the underlying socket.
354    ///
355    /// Enables OS-level dead connection detection. The kernel sends
356    /// probes after `idle` of inactivity.
357    #[cfg(feature = "socket-opts")]
358    pub fn set_tcp_keepalive(&self, idle: std::time::Duration) -> Result<(), RestError> {
359        let sock = socket2::SockRef::from(&self.stream);
360        let keepalive = socket2::TcpKeepalive::new().with_time(idle);
361        sock.set_tcp_keepalive(&keepalive).map_err(RestError::Io)
362    }
363}
364
365// -- Unbounded impl: accessors and constructors that need no I/O traits -------
366
367impl<S> Client<S> {
368    /// Wrap a pre-connected stream.
369    pub fn new(stream: S) -> Self {
370        Self {
371            stream,
372            poisoned: false,
373        }
374    }
375
376    /// Whether the connection is poisoned (I/O error occurred).
377    pub fn is_poisoned(&self) -> bool {
378        self.poisoned
379    }
380
381    /// Access the underlying stream.
382    pub fn stream(&self) -> &S {
383        &self.stream
384    }
385
386    /// Mutable access to the underlying stream.
387    pub fn stream_mut(&mut self) -> &mut S {
388        &mut self.stream
389    }
390}
391
392// -- Blocking I/O impl --------------------------------------------------------
393
394impl<S: Read + Write> Client<S> {
395    /// Send a request and read the response.
396    ///
397    /// `req` provides the outbound bytes (from [`RequestWriter`]).
398    /// `reader` receives and parses the response (body size limit
399    /// configured on the reader via [`ResponseReader::max_body_size`]).
400    ///
401    /// Read timeout is a stream-level concern — configure via the builder
402    /// (`read_timeout`) or `conn.stream().set_read_timeout()` for
403    /// `TcpStream`. Without a timeout, reads block indefinitely.
404    ///
405    /// `Response` borrows from `reader` — drop before next send.
406    #[allow(clippy::needless_pass_by_value)] // Move by design — request is consumed after send.
407    pub fn send<'r>(
408        &mut self,
409        req: Request<'_>,
410        reader: &'r mut ResponseReader,
411    ) -> Result<RestResponse<'r>, RestError> {
412        if self.poisoned {
413            return Err(RestError::ConnectionPoisoned);
414        }
415
416        // Send request bytes
417        if let Err(e) = self.write_all(req.as_bytes()) {
418            self.poisoned = true;
419            return Err(e);
420        }
421
422        // Read response
423        match self.read_response(reader) {
424            Ok(resp) => Ok(resp),
425            Err(e) => self.handle_send_error(e),
426        }
427    }
428
429    /// Cold path: diagnose send failure.
430    #[cold]
431    fn handle_send_error<T>(&mut self, err: RestError) -> Result<T, RestError> {
432        self.poisoned = true;
433        // On timeout, check if the socket is actually dead (stale connection)
434        // vs the server just being slow.
435        if let RestError::Io(ref io_err) = err
436            && (io_err.kind() == std::io::ErrorKind::TimedOut
437                || io_err.kind() == std::io::ErrorKind::WouldBlock)
438        {
439            if self.peek_is_dead() {
440                return Err(RestError::ConnectionStale);
441            }
442            return Err(RestError::ReadTimeout);
443        }
444        Err(err)
445    }
446
447    /// Check if the socket has been closed by the peer.
448    ///
449    /// For generic streams we can't peek, so we assume alive and
450    /// report `ReadTimeout`. The connection is poisoned either way.
451    #[allow(clippy::unused_self)]
452    fn peek_is_dead(&self) -> bool {
453        // For generic S, assume alive (report ReadTimeout not ConnectionStale).
454        // The caller still gets an error; it's just less specific.
455        false
456    }
457
458    // =========================================================================
459    // Internal — I/O with optional TLS
460    // =========================================================================
461
462    fn write_all(&mut self, data: &[u8]) -> Result<(), RestError> {
463        self.stream.write_all(data)?;
464        self.stream.flush()?;
465        Ok(())
466    }
467
468    fn read_into_reader(&mut self, reader: &mut ResponseReader) -> Result<usize, RestError> {
469        let n = reader.read_from(&mut self.stream)?;
470        Ok(n)
471    }
472
473    fn read_response<'r>(
474        &mut self,
475        reader: &'r mut ResponseReader,
476    ) -> Result<RestResponse<'r>, RestError> {
477        // Consume previous response, preserving pipelined bytes.
478        reader.consume_response();
479
480        // Read until headers are complete.
481        loop {
482            match reader.next() {
483                Ok(Some(_)) => break,
484                Ok(None) => {}
485                Err(e) => {
486                    self.poisoned = true;
487                    return Err(e.into());
488                }
489            }
490            match self.read_into_reader(reader) {
491                Ok(0) => {
492                    self.poisoned = true;
493                    return Err(RestError::ConnectionClosed(
494                        "server closed before response headers",
495                    ));
496                }
497                Ok(_) => {}
498                Err(e) => {
499                    self.poisoned = true;
500                    return Err(e);
501                }
502            }
503        }
504
505        // Validate using cached values from try_parse.
506        let status = reader.status();
507
508        // RFC 7230: 1xx, 204, 304 have no body.
509        if matches!(status, 100..=199 | 204 | 304) {
510            reader.set_body_consumed(0);
511            return Ok(RestResponse::new(status, 0, reader));
512        }
513
514        if reader.is_chunked() {
515            let body = self.read_chunked_body(reader)?;
516            // All remainder bytes were consumed (decoded or framing),
517            // plus whatever was read from the socket during decode.
518            // For consume_response, we need the total raw bytes in the
519            // reader's buffer that belong to this response's body.
520            // Since chunked body goes into a Vec (not the reader), the
521            // remainder bytes are all raw chunked wire data that should
522            // be skipped on consume.
523            reader.set_body_consumed(reader.body_remaining());
524            return Ok(RestResponse::new_chunked(status, body, reader));
525        }
526
527        let content_length = match reader.content_length() {
528            Some(Ok(n)) => n,
529            Some(Err(())) => {
530                return Err(RestError::Http(HttpError::Malformed(
531                    "invalid Content-Length header",
532                )));
533            }
534            None => {
535                // No Content-Length and not chunked — can't determine body
536                // boundaries for keep-alive. Error instead of silent empty body.
537                self.poisoned = true;
538                return Err(RestError::Http(HttpError::Malformed(
539                    "no Content-Length and not chunked",
540                )));
541            }
542        };
543
544        let max_body = reader.max_body_size_limit();
545        if max_body > 0 && content_length > max_body {
546            self.poisoned = true;
547            return Err(RestError::BodyTooLarge {
548                size: content_length,
549                max: max_body,
550            });
551        }
552
553        // Read remaining body bytes (Content-Length delimited).
554        while reader.body_remaining() < content_length {
555            match self.read_into_reader(reader) {
556                Ok(0) => {
557                    self.poisoned = true;
558                    return Err(RestError::ConnectionClosed(
559                        "server closed during body read",
560                    ));
561                }
562                Ok(_) => {}
563                Err(e) => {
564                    self.poisoned = true;
565                    return Err(e);
566                }
567            }
568        }
569
570        reader.set_body_consumed(content_length);
571        Ok(RestResponse::new(status, content_length, reader))
572    }
573
574    /// Read a chunked transfer-encoded body. Returns decoded body bytes.
575    ///
576    /// One allocation: the Vec for the decoded body. The chunk framing
577    /// is stripped and only payload bytes are accumulated.
578    fn read_chunked_body(&mut self, reader: &ResponseReader) -> Result<Vec<u8>, RestError> {
579        use crate::http::ChunkedDecoder;
580
581        let max_body = reader.max_body_size_limit();
582        let mut decoder = ChunkedDecoder::new();
583        let mut body = Vec::with_capacity(4096);
584        let mut wire_buf = [0u8; 4096];
585        let mut decode_buf = [0u8; 4096];
586
587        // Decode any chunk data that arrived with the headers.
588        let remainder = reader.remainder();
589        if !remainder.is_empty() {
590            let mut pos = 0;
591            while pos < remainder.len() && !decoder.is_done() {
592                let (consumed, produced) = decoder
593                    .decode(&remainder[pos..], &mut decode_buf)
594                    .map_err(RestError::Http)?;
595                pos += consumed;
596                if produced > 0 {
597                    body.extend_from_slice(&decode_buf[..produced]);
598                    if max_body > 0 && body.len() > max_body {
599                        self.poisoned = true;
600                        return Err(RestError::BodyTooLarge {
601                            size: body.len(),
602                            max: max_body,
603                        });
604                    }
605                }
606                if consumed == 0 && produced == 0 {
607                    break;
608                }
609            }
610        }
611
612        // Read from socket until all chunks decoded.
613        while !decoder.is_done() {
614            let n = self.read_wire_bytes(&mut wire_buf)?;
615            if n == 0 {
616                self.poisoned = true;
617                return Err(RestError::ConnectionClosed(
618                    "server closed during chunked body",
619                ));
620            }
621
622            let mut pos = 0;
623            while pos < n && !decoder.is_done() {
624                let (consumed, produced) = decoder
625                    .decode(&wire_buf[pos..n], &mut decode_buf)
626                    .map_err(RestError::Http)?;
627                pos += consumed;
628                if produced > 0 {
629                    body.extend_from_slice(&decode_buf[..produced]);
630                    // Check body size limit after each decode, not per read.
631                    if max_body > 0 && body.len() > max_body {
632                        self.poisoned = true;
633                        return Err(RestError::BodyTooLarge {
634                            size: body.len(),
635                            max: max_body,
636                        });
637                    }
638                }
639                if consumed == 0 && produced == 0 {
640                    break;
641                }
642            }
643        }
644
645        Ok(body)
646    }
647
648    /// Read raw bytes from the socket.
649    fn read_wire_bytes(&mut self, buf: &mut [u8]) -> Result<usize, RestError> {
650        Ok(self.stream.read(buf)?)
651    }
652}
653
654// =============================================================================
655// Tests
656// =============================================================================
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use std::io::{Cursor, Read, Write};
662    use std::net::{TcpListener, TcpStream};
663
664    struct MockStream {
665        written: Vec<u8>,
666        response: Cursor<Vec<u8>>,
667    }
668
669    impl MockStream {
670        fn new(response: &[u8]) -> Self {
671            Self {
672                written: Vec::new(),
673                response: Cursor::new(response.to_vec()),
674            }
675        }
676
677        fn written_str(&self) -> &str {
678            std::str::from_utf8(&self.written).unwrap()
679        }
680    }
681
682    impl Read for MockStream {
683        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
684            self.response.read(buf)
685        }
686    }
687
688    impl Write for MockStream {
689        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
690            self.written.extend_from_slice(buf);
691            Ok(buf.len())
692        }
693        fn flush(&mut self) -> io::Result<()> {
694            Ok(())
695        }
696    }
697
698    fn ok_response(body: &str) -> Vec<u8> {
699        format!(
700            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
701            body.len(),
702            body
703        )
704        .into_bytes()
705    }
706
707    /// Helper: build request + send via mock.
708    #[allow(dead_code)]
709    fn send_get<'r>(
710        writer: &mut RequestWriter,
711        conn: &mut Client<MockStream>,
712        reader: &'r mut ResponseReader,
713        path: &str,
714    ) -> Result<RestResponse<'r>, RestError> {
715        let req = writer.get(path).finish()?;
716        conn.send(req, reader)
717    }
718
719    // --- Request format ---
720
721    #[test]
722    fn get_request_format() {
723        let resp = ok_response(r#"{"ok":true}"#);
724        let mock = MockStream::new(&resp);
725        let mut writer = RequestWriter::new("api.example.com").unwrap();
726        let mut reader = ResponseReader::new(4096);
727        let mut conn = Client::new(mock);
728
729        let req = writer.get("/api/v1/status").finish().unwrap();
730        let resp = conn.send(req, &mut reader).unwrap();
731        assert_eq!(resp.status(), 200);
732        assert_eq!(resp.body_str().unwrap(), r#"{"ok":true}"#);
733
734        let written = conn.stream().written_str();
735        assert!(written.starts_with("GET /api/v1/status HTTP/1.1\r\n"));
736        assert!(written.contains("Host: api.example.com\r\n"));
737        assert!(written.contains("Connection: keep-alive\r\n"));
738        assert!(written.ends_with("\r\n\r\n"));
739    }
740
741    #[test]
742    fn post_with_body() {
743        let resp = ok_response(r#"{"filled":true}"#);
744        let mock = MockStream::new(&resp);
745        let mut writer = RequestWriter::new("api.example.com").unwrap();
746        let mut reader = ResponseReader::new(4096);
747        let mut conn = Client::new(mock);
748
749        let body = br#"{"symbol":"BTC","side":"buy"}"#;
750        let req = writer.post("/api/v3/order").body(body).finish().unwrap();
751        let resp = conn.send(req, &mut reader).unwrap();
752        assert_eq!(resp.status(), 200);
753
754        let written = conn.stream().written_str();
755        assert!(written.starts_with("POST /api/v3/order HTTP/1.1\r\n"));
756        assert!(written.contains(&format!("Content-Length: {}\r\n", body.len())));
757        assert!(written.ends_with(std::str::from_utf8(body).unwrap()));
758    }
759
760    #[test]
761    fn post_body_writer() {
762        let resp = ok_response(r#"{"ok":true}"#);
763        let mock = MockStream::new(&resp);
764        let mut writer = RequestWriter::new("host").unwrap();
765        let mut reader = ResponseReader::new(4096);
766        let mut conn = Client::new(mock);
767
768        let body = br#"{"symbol":"BTC","side":"buy"}"#;
769        let req = writer
770            .post("/order")
771            .body_writer(|w| {
772                use std::io::Write;
773                w.write_all(body)
774            })
775            .finish()
776            .unwrap();
777
778        let written_before = std::str::from_utf8(req.as_bytes()).unwrap().to_string();
779        // Verify Content-Length is backfilled correctly (exact digits)
780        assert!(written_before.contains("Content-Length:"));
781        assert!(written_before.contains(&format!("{}", body.len())));
782        assert!(written_before.ends_with(std::str::from_utf8(body).unwrap()));
783
784        let resp = conn.send(req, &mut reader).unwrap();
785        assert_eq!(resp.status(), 200);
786    }
787
788    #[test]
789    fn body_writer_from_headers_phase() {
790        let mut writer = RequestWriter::new("host").unwrap();
791        let body = b"test-body";
792        let req = writer
793            .post("/order")
794            .header("X-Custom", "val")
795            .body_writer(|w| {
796                use std::io::Write;
797                w.write_all(body)
798            })
799            .finish()
800            .unwrap();
801
802        let data = std::str::from_utf8(req.as_bytes()).unwrap();
803        assert!(data.contains("X-Custom: val\r\n"));
804        assert!(data.contains(&format!("{}", body.len())));
805        assert!(data.ends_with("test-body"));
806    }
807
808    #[test]
809    fn body_writer_empty() {
810        let mut writer = RequestWriter::new("host").unwrap();
811        let req = writer
812            .post("/order")
813            .body_writer(|_w| Ok::<(), std::io::Error>(()))
814            .finish()
815            .unwrap();
816
817        let data = std::str::from_utf8(req.as_bytes()).unwrap();
818        // Content-Length should be 0
819        assert!(data.contains("Content-Length:"));
820        assert!(data.contains("0\r\n\r\n"));
821    }
822
823    #[test]
824    fn body_writer_matches_body() {
825        // Verify body_writer produces identical wire bytes to body()
826        let mut writer1 = RequestWriter::new("host").unwrap();
827        let mut writer2 = RequestWriter::new("host").unwrap();
828
829        let body = b"identical-content";
830
831        let req1 = writer1.post("/test").body(body).finish().unwrap();
832        let req2 = writer2
833            .post("/test")
834            .body_writer(|w| {
835                use std::io::Write;
836                w.write_all(body)
837            })
838            .finish()
839            .unwrap();
840
841        // Both paths produce identical wire format.
842        let d1 = std::str::from_utf8(req1.as_bytes()).unwrap();
843        let d2 = std::str::from_utf8(req2.as_bytes()).unwrap();
844        assert_eq!(d1, d2);
845    }
846
847    #[test]
848    fn all_methods() {
849        for (method, expected) in [
850            (super::super::request::Method::Put, "PUT"),
851            (super::super::request::Method::Delete, "DELETE"),
852            (super::super::request::Method::Patch, "PATCH"),
853        ] {
854            let resp = ok_response("{}");
855            let mock = MockStream::new(&resp);
856            let mut writer = RequestWriter::new("host").unwrap();
857            let mut reader = ResponseReader::new(4096);
858            let mut conn = Client::new(mock);
859
860            let req = writer.request(method, "/test").finish().unwrap();
861            let _ = conn.send(req, &mut reader).unwrap();
862            assert!(
863                conn.stream()
864                    .written_str()
865                    .starts_with(&format!("{expected} /test HTTP/1.1\r\n"))
866            );
867        }
868    }
869
870    #[test]
871    fn default_headers_included() {
872        let resp = ok_response("{}");
873        let mock = MockStream::new(&resp);
874        let mut writer = RequestWriter::new("api.example.com").unwrap();
875        writer.default_header("X-API-KEY", "secret123").unwrap();
876        writer
877            .default_header("Content-Type", "application/json")
878            .unwrap();
879        let mut reader = ResponseReader::new(4096);
880        let mut conn = Client::new(mock);
881
882        let req = writer.get("/test").finish().unwrap();
883        let _ = conn.send(req, &mut reader).unwrap();
884
885        let written = conn.stream().written_str();
886        assert!(written.contains("X-API-KEY: secret123\r\n"));
887        assert!(written.contains("Content-Type: application/json\r\n"));
888    }
889
890    #[test]
891    fn extra_headers() {
892        let resp = ok_response("{}");
893        let mock = MockStream::new(&resp);
894        let mut writer = RequestWriter::new("api.example.com").unwrap();
895        let mut reader = ResponseReader::new(4096);
896        let mut conn = Client::new(mock);
897
898        let req = writer
899            .get("/test")
900            .header("X-Custom", "value1")
901            .header("Authorization", "Bearer tok")
902            .finish()
903            .unwrap();
904        let _ = conn.send(req, &mut reader).unwrap();
905
906        let written = conn.stream().written_str();
907        assert!(written.contains("X-Custom: value1\r\n"));
908        assert!(written.contains("Authorization: Bearer tok\r\n"));
909    }
910
911    // --- Query parameters ---
912
913    #[test]
914    fn query_params_encoded() {
915        let mut writer = RequestWriter::new("host").unwrap();
916        let req = writer
917            .get("/orders")
918            .query("symbol", "BTC-USD")
919            .query("limit", "100")
920            .finish()
921            .unwrap();
922        let data = std::str::from_utf8(req.as_bytes()).unwrap();
923        assert!(data.starts_with("GET /orders?symbol=BTC-USD&limit=100 HTTP/1.1\r\n"));
924    }
925
926    #[test]
927    fn query_encodes_special_chars() {
928        let mut writer = RequestWriter::new("host").unwrap();
929        let req = writer
930            .get("/search")
931            .query("q", "hello world&more=yes")
932            .finish()
933            .unwrap();
934        let data = std::str::from_utf8(req.as_bytes()).unwrap();
935        assert!(data.starts_with("GET /search?q=hello%20world%26more%3Dyes HTTP/1.1\r\n"));
936    }
937
938    #[test]
939    fn query_raw_no_encoding() {
940        let mut writer = RequestWriter::new("host").unwrap();
941        let req = writer
942            .get("/orders")
943            .query_raw("symbol", "BTC-USD")
944            .finish()
945            .unwrap();
946        let data = std::str::from_utf8(req.as_bytes()).unwrap();
947        assert!(data.starts_with("GET /orders?symbol=BTC-USD HTTP/1.1\r\n"));
948    }
949
950    #[test]
951    fn query_then_header() {
952        let mut writer = RequestWriter::new("host").unwrap();
953        let req = writer
954            .get("/orders")
955            .query("sym", "ETH")
956            .header("X-Nonce", "123")
957            .finish()
958            .unwrap();
959        let data = std::str::from_utf8(req.as_bytes()).unwrap();
960        assert!(data.starts_with("GET /orders?sym=ETH HTTP/1.1\r\n"));
961        assert!(data.contains("X-Nonce: 123\r\n"));
962    }
963
964    #[test]
965    fn path_with_existing_query() {
966        let mut writer = RequestWriter::new("host").unwrap();
967        let req = writer
968            .get("/path?existing=true")
969            .query("extra", "val")
970            .finish()
971            .unwrap();
972        let data = std::str::from_utf8(req.as_bytes()).unwrap();
973        assert!(data.starts_with("GET /path?existing=true&extra=val HTTP/1.1\r\n"));
974    }
975
976    #[test]
977    fn base_path_prepended() {
978        let mut writer = RequestWriter::new("host").unwrap();
979        writer.set_base_path("/api/v3").unwrap();
980        let req = writer.get("/orders").finish().unwrap();
981        let data = std::str::from_utf8(req.as_bytes()).unwrap();
982        assert!(data.starts_with("GET /api/v3/orders HTTP/1.1\r\n"));
983    }
984
985    #[test]
986    fn get_raw_skips_query_phase() {
987        let mut writer = RequestWriter::new("host").unwrap();
988        let req = writer
989            .get_raw("/orders?symbol=BTC&limit=100")
990            .finish()
991            .unwrap();
992        let data = std::str::from_utf8(req.as_bytes()).unwrap();
993        assert!(data.starts_with("GET /orders?symbol=BTC&limit=100 HTTP/1.1\r\n"));
994    }
995
996    // --- Validation ---
997
998    #[test]
999    fn crlf_in_header_rejected() {
1000        let mut writer = RequestWriter::new("host").unwrap();
1001        let result = writer.get("/test").header("X-Bad\r\n", "val").finish();
1002        assert!(matches!(result, Err(RestError::CrlfInjection)));
1003    }
1004
1005    #[test]
1006    fn crlf_in_path_rejected() {
1007        let mut writer = RequestWriter::new("host").unwrap();
1008        let result = writer.get("/path\r\nEvil: yes").finish();
1009        assert!(matches!(result, Err(RestError::CrlfInjection)));
1010    }
1011
1012    #[test]
1013    fn crlf_in_default_header_rejected() {
1014        let mut writer = RequestWriter::new("host").unwrap();
1015        assert!(matches!(
1016            writer.default_header("X-Bad\n", "val"),
1017            Err(RestError::CrlfInjection)
1018        ));
1019    }
1020
1021    #[test]
1022    fn crlf_in_query_raw_rejected() {
1023        let mut writer = RequestWriter::new("host").unwrap();
1024        let result = writer.get("/test").query_raw("k", "v\r\n").finish();
1025        assert!(matches!(result, Err(RestError::CrlfInjection)));
1026    }
1027
1028    #[test]
1029    fn crlf_in_host_rejected() {
1030        assert!(matches!(
1031            RequestWriter::new("evil.com\r\nX-Injected: yes"),
1032            Err(RestError::CrlfInjection)
1033        ));
1034    }
1035
1036    // --- Response handling ---
1037
1038    #[test]
1039    fn response_headers_accessible() {
1040        let resp_bytes = b"HTTP/1.1 200 OK\r\nX-Request-Id: abc123\r\nX-RateLimit-Remaining: 42\r\nContent-Length: 2\r\n\r\n{}";
1041        let mock = MockStream::new(resp_bytes);
1042        let mut writer = RequestWriter::new("host").unwrap();
1043        let mut reader = ResponseReader::new(4096);
1044        let mut conn = Client::new(mock);
1045
1046        let req = writer.get("/test").finish().unwrap();
1047        let resp = conn.send(req, &mut reader).unwrap();
1048        assert_eq!(resp.header("X-Request-Id"), Some("abc123"));
1049        assert_eq!(resp.header("X-RateLimit-Remaining"), Some("42"));
1050    }
1051
1052    #[test]
1053    fn chunked_encoding_decoded() {
1054        let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nMozilla\r\n11\r\nDeveloper Network\r\n0\r\n\r\n";
1055        let mock = MockStream::new(resp_bytes);
1056        let mut writer = RequestWriter::new("host").unwrap();
1057        let mut reader = ResponseReader::new(4096);
1058        let mut conn = Client::new(mock);
1059
1060        let req = writer.get("/test").finish().unwrap();
1061        let resp = conn.send(req, &mut reader).unwrap();
1062        assert_eq!(resp.status(), 200);
1063        assert_eq!(resp.body_str().unwrap(), "MozillaDeveloper Network");
1064    }
1065
1066    #[test]
1067    fn chunked_empty_body() {
1068        let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n";
1069        let mock = MockStream::new(resp_bytes);
1070        let mut writer = RequestWriter::new("host").unwrap();
1071        let mut reader = ResponseReader::new(4096);
1072        let mut conn = Client::new(mock);
1073
1074        let req = writer.get("/test").finish().unwrap();
1075        let resp = conn.send(req, &mut reader).unwrap();
1076        assert_eq!(resp.body().len(), 0);
1077    }
1078
1079    #[test]
1080    fn chunked_json_response() {
1081        // Simulates a CDN/proxy chunking a JSON response
1082        let body = r#"{"orderId":12345,"status":"FILLED"}"#;
1083        let chunked = format!(
1084            "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n",
1085            body.len(),
1086            body
1087        );
1088        let mock = MockStream::new(chunked.as_bytes());
1089        let mut writer = RequestWriter::new("host").unwrap();
1090        let mut reader = ResponseReader::new(4096);
1091        let mut conn = Client::new(mock);
1092
1093        let req = writer.get("/test").finish().unwrap();
1094        let resp = conn.send(req, &mut reader).unwrap();
1095        assert_eq!(resp.body_str().unwrap(), body);
1096    }
1097
1098    #[test]
1099    fn malformed_content_length_rejected() {
1100        let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: abc\r\n\r\nbody";
1101        let mock = MockStream::new(resp_bytes);
1102        let mut writer = RequestWriter::new("host").unwrap();
1103        let mut reader = ResponseReader::new(4096);
1104        let mut conn = Client::new(mock);
1105
1106        let req = writer.get("/test").finish().unwrap();
1107        let result = conn.send(req, &mut reader);
1108        assert!(matches!(result, Err(RestError::Http(_))));
1109    }
1110
1111    #[test]
1112    fn body_too_large_rejected() {
1113        let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 999999\r\n\r\n";
1114        let mock = MockStream::new(resp_bytes);
1115        let mut writer = RequestWriter::new("host").unwrap();
1116        let mut reader = ResponseReader::new(4096).max_body_size(32 * 1024);
1117        let mut conn = Client::new(mock);
1118
1119        let req = writer.get("/test").finish().unwrap();
1120        let result = conn.send(req, &mut reader);
1121        assert!(matches!(
1122            result,
1123            Err(RestError::BodyTooLarge { size: 999_999, .. })
1124        ));
1125    }
1126
1127    #[test]
1128    fn status_204_no_body() {
1129        let resp_bytes = b"HTTP/1.1 204 No Content\r\nContent-Length: 5\r\n\r\nxxxxx";
1130        let mock = MockStream::new(resp_bytes);
1131        let mut writer = RequestWriter::new("host").unwrap();
1132        let mut reader = ResponseReader::new(4096);
1133        let mut conn = Client::new(mock);
1134
1135        let req = writer.get("/test").finish().unwrap();
1136        let resp = conn.send(req, &mut reader).unwrap();
1137        assert_eq!(resp.status(), 204);
1138        assert_eq!(resp.body().len(), 0);
1139    }
1140
1141    #[test]
1142    fn connection_poisoned_after_io_error() {
1143        let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\npartial";
1144        let mock = MockStream::new(resp_bytes);
1145        let mut writer = RequestWriter::new("host").unwrap();
1146        let mut reader = ResponseReader::new(4096);
1147        let mut conn = Client::new(mock);
1148
1149        let req = writer.get("/test").finish().unwrap();
1150        let result = conn.send(req, &mut reader);
1151        assert!(matches!(result, Err(RestError::ConnectionClosed(_))));
1152
1153        let req = writer.get("/test2").finish().unwrap();
1154        let result = conn.send(req, &mut reader);
1155        assert!(matches!(result, Err(RestError::ConnectionPoisoned)));
1156    }
1157
1158    // --- URL parsing ---
1159
1160    #[test]
1161    fn url_parsing() {
1162        let parsed = parse_base_url("https://api.binance.com").unwrap();
1163        assert!(parsed.tls);
1164        assert_eq!(parsed.host, "api.binance.com");
1165        assert_eq!(parsed.port, 443);
1166        assert_eq!(parsed.path, "");
1167
1168        let parsed = parse_base_url("http://localhost:8080").unwrap();
1169        assert!(!parsed.tls);
1170        assert_eq!(parsed.host, "localhost");
1171        assert_eq!(parsed.port, 8080);
1172
1173        let parsed = parse_base_url("https://api.example.com/v1/foo").unwrap();
1174        assert_eq!(parsed.path, "/v1/foo");
1175
1176        assert!(parse_base_url("ftp://host").is_err());
1177        assert!(parse_base_url("http://").is_err());
1178    }
1179
1180    #[test]
1181    fn ipv6_url_parsing() {
1182        let parsed = parse_base_url("http://[::1]:8080").unwrap();
1183        assert_eq!(parsed.host, "::1");
1184        assert_eq!(parsed.port, 8080);
1185
1186        let parsed = parse_base_url("http://[::1]").unwrap();
1187        assert_eq!(parsed.host, "::1");
1188        assert_eq!(parsed.port, 80);
1189
1190        assert!(parse_base_url("http://[::1").is_err());
1191    }
1192
1193    // --- Keep-alive ---
1194
1195    #[test]
1196    fn keep_alive_sequential_requests() {
1197        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1198        let addr = listener.local_addr().unwrap();
1199
1200        let server = std::thread::spawn(move || {
1201            let (mut tcp, _) = listener.accept().unwrap();
1202            let mut buf = [0u8; 4096];
1203
1204            let n = tcp.read(&mut buf).unwrap();
1205            assert!(
1206                std::str::from_utf8(&buf[..n])
1207                    .unwrap()
1208                    .contains("GET /first")
1209            );
1210            let body1 = r#"{"id":1}"#;
1211            let resp1 = format!(
1212                "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
1213                body1.len(),
1214                body1
1215            );
1216            tcp.write_all(resp1.as_bytes()).unwrap();
1217
1218            let n = tcp.read(&mut buf).unwrap();
1219            assert!(
1220                std::str::from_utf8(&buf[..n])
1221                    .unwrap()
1222                    .contains("GET /second")
1223            );
1224            let body2 = r#"{"id":2}"#;
1225            let resp2 = format!(
1226                "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
1227                body2.len(),
1228                body2
1229            );
1230            tcp.write_all(resp2.as_bytes()).unwrap();
1231        });
1232
1233        let tcp = TcpStream::connect(addr).unwrap();
1234        let mut writer = RequestWriter::new("localhost").unwrap();
1235        let mut reader = ResponseReader::new(4096);
1236        let mut conn = Client::new(tcp);
1237
1238        let req = writer.get("/first").finish().unwrap();
1239        let resp = conn.send(req, &mut reader).unwrap();
1240        assert_eq!(resp.body_str().unwrap(), r#"{"id":1}"#);
1241        drop(resp);
1242
1243        let req = writer.get("/second").finish().unwrap();
1244        let resp = conn.send(req, &mut reader).unwrap();
1245        assert_eq!(resp.body_str().unwrap(), r#"{"id":2}"#);
1246
1247        server.join().unwrap();
1248    }
1249
1250    // --- Display ---
1251
1252    #[test]
1253    fn method_display() {
1254        use super::super::request::Method;
1255        assert_eq!(format!("{}", Method::Get), "GET");
1256        assert_eq!(format!("{}", Method::Post), "POST");
1257        assert_eq!(format!("{}", Method::Delete), "DELETE");
1258    }
1259}