1use std::time::Duration;
9
10use super::error::RestError;
11use super::request::RequestWriter;
12
13use super::request::Request;
14use super::response::RestResponse;
15use crate::http::{HttpError, ResponseReader};
16use std::io::{self, Read, Write};
17
18#[cfg(feature = "tls")]
19use nexus_net::tls::TlsConfig;
20
21#[non_exhaustive]
27pub struct ParsedUrl<'a> {
28 pub tls: bool,
30 pub host: &'a str,
32 pub port: u16,
35 pub path: &'a str,
38}
39
40impl ParsedUrl<'_> {
41 pub fn host_header(&self) -> String {
43 let default = if self.tls { 443 } else { 80 };
44 if self.port == default {
45 self.host.to_string()
46 } else {
47 format!("{}:{}", self.host, self.port)
48 }
49 }
50}
51
52pub fn parse_base_url(url: &str) -> Result<ParsedUrl<'_>, RestError> {
56 let (tls, rest) = if let Some(r) = url.strip_prefix("https://") {
57 (true, r)
58 } else if let Some(r) = url.strip_prefix("http://") {
59 (false, r)
60 } else {
61 return Err(RestError::InvalidUrl(url.to_string()));
62 };
63
64 let (host_port, path) = rest
66 .find('/')
67 .map_or((rest, ""), |i| (&rest[..i], &rest[i..]));
68
69 if host_port.is_empty() {
70 return Err(RestError::InvalidUrl(format!("empty host: {url}")));
71 }
72
73 let default_port = if tls { 443 } else { 80 };
74
75 let (host, port) = if host_port.starts_with('[') {
77 match host_port.find(']') {
78 Some(end) => {
79 let h = &host_port[1..end];
80 let rest = &host_port[end + 1..];
81 if let Some(port_str) = rest.strip_prefix(':') {
82 let p = port_str
83 .parse::<u16>()
84 .map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
85 (h, p)
86 } else {
87 (h, default_port)
88 }
89 }
90 None => return Err(RestError::InvalidUrl(format!("unclosed bracket: {url}"))),
91 }
92 } else {
93 match host_port.rfind(':') {
94 None => (host_port, default_port),
95 Some(i) => {
96 let port_str = &host_port[i + 1..];
97 if port_str.is_empty() {
98 (&host_port[..i], default_port)
100 } else {
101 let p = port_str
102 .parse::<u16>()
103 .map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
104 (&host_port[..i], p)
105 }
106 }
107 }
108 };
109
110 Ok(ParsedUrl {
111 tls,
112 host,
113 port,
114 path,
115 })
116}
117
118pub struct ClientBuilder {
128 #[cfg(feature = "tls")]
129 tls_config: Option<TlsConfig>,
130 tcp_nodelay: bool,
131 connect_timeout: Option<Duration>,
132 read_timeout: Option<Duration>,
133}
134
135impl ClientBuilder {
136 #[must_use]
138 pub fn new() -> Self {
139 Self {
140 #[cfg(feature = "tls")]
141 tls_config: None,
142 tcp_nodelay: false,
143 connect_timeout: None,
144 read_timeout: None,
145 }
146 }
147
148 #[cfg(feature = "tls")]
152 #[must_use]
153 pub fn tls(mut self, config: &TlsConfig) -> Self {
154 self.tls_config = Some(config.clone());
155 self
156 }
157
158 #[must_use]
160 pub fn disable_nagle(mut self) -> Self {
161 self.tcp_nodelay = true;
162 self
163 }
164
165 #[must_use]
167 pub fn connect_timeout(mut self, d: Duration) -> Self {
168 self.connect_timeout = Some(d);
169 self
170 }
171
172 #[must_use]
174 pub fn read_timeout(mut self, d: Duration) -> Self {
175 self.read_timeout = Some(d);
176 self
177 }
178
179 #[cfg(feature = "tls")]
186 pub fn connect(
187 self,
188 url: &str,
189 ) -> Result<Client<nexus_net::MaybeTls<std::net::TcpStream>>, RestError> {
190 let parsed = parse_base_url(url)?;
191 let addr = format!("{}:{}", parsed.host, parsed.port);
192
193 let tcp = match self.connect_timeout {
194 Some(timeout) => {
195 let addrs: Vec<std::net::SocketAddr> =
196 std::net::ToSocketAddrs::to_socket_addrs(&addr)
197 .map_err(RestError::Io)?
198 .collect();
199 let first = addrs
200 .first()
201 .ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
202 std::net::TcpStream::connect_timeout(first, timeout)?
203 }
204 None => std::net::TcpStream::connect(&addr)?,
205 };
206
207 if self.tcp_nodelay {
208 tcp.set_nodelay(true)?;
209 }
210 if let Some(timeout) = self.read_timeout {
211 tcp.set_read_timeout(Some(timeout))?;
212 }
213
214 let stream = if parsed.tls {
215 let config = match self.tls_config {
216 Some(c) => c,
217 None => TlsConfig::new().map_err(RestError::Tls)?,
218 };
219 let codec = nexus_net::tls::TlsCodec::new(&config, parsed.host)?;
220 let tls = nexus_net::tls::TlsStream::connect(tcp, codec).map_err(RestError::Tls)?;
221 nexus_net::MaybeTls::Tls(Box::new(tls))
222 } else {
223 nexus_net::MaybeTls::Plain(tcp)
224 };
225
226 Ok(Client {
227 stream,
228 poisoned: false,
229 })
230 }
231
232 #[cfg(not(feature = "tls"))]
234 pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, RestError> {
235 let parsed = parse_base_url(url)?;
236 if parsed.tls {
237 return Err(RestError::TlsNotEnabled);
238 }
239 let addr = format!("{}:{}", parsed.host, parsed.port);
240
241 let tcp = match self.connect_timeout {
242 Some(timeout) => {
243 let addrs: Vec<std::net::SocketAddr> =
244 std::net::ToSocketAddrs::to_socket_addrs(&addr)
245 .map_err(RestError::Io)?
246 .collect();
247 let first = addrs
248 .first()
249 .ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
250 std::net::TcpStream::connect_timeout(first, timeout)?
251 }
252 None => std::net::TcpStream::connect(&addr)?,
253 };
254
255 if self.tcp_nodelay {
256 tcp.set_nodelay(true)?;
257 }
258 if let Some(timeout) = self.read_timeout {
259 tcp.set_read_timeout(Some(timeout))?;
260 }
261
262 Ok(Client {
263 stream: tcp,
264 poisoned: false,
265 })
266 }
267
268 pub fn connect_with<S: Read + Write>(
273 self,
274 stream: S,
275 url: &str,
276 ) -> Result<Client<S>, RestError> {
277 parse_base_url(url)?;
280 Ok(Client::new(stream))
281 }
282
283 pub fn writer_for(url: &str) -> Result<RequestWriter, RestError> {
288 let parsed = parse_base_url(url)?;
289 let host_header = parsed.host_header();
290 let mut writer = RequestWriter::new(&host_header)?;
291 if !parsed.path.is_empty() {
292 writer.set_base_path(parsed.path)?;
293 }
294 Ok(writer)
295 }
296}
297
298impl Default for ClientBuilder {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304pub struct Client<S> {
334 pub(crate) stream: S,
335 pub(crate) poisoned: bool,
336}
337
338impl Client<std::net::TcpStream> {
339 #[must_use]
341 pub fn builder() -> ClientBuilder {
342 ClientBuilder::new()
343 }
344
345 pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> Result<(), RestError> {
350 self.stream.set_read_timeout(timeout).map_err(RestError::Io)
351 }
352
353 #[cfg(feature = "socket-opts")]
358 pub fn set_tcp_keepalive(&self, idle: std::time::Duration) -> Result<(), RestError> {
359 let sock = socket2::SockRef::from(&self.stream);
360 let keepalive = socket2::TcpKeepalive::new().with_time(idle);
361 sock.set_tcp_keepalive(&keepalive).map_err(RestError::Io)
362 }
363}
364
365impl<S> Client<S> {
368 pub fn new(stream: S) -> Self {
370 Self {
371 stream,
372 poisoned: false,
373 }
374 }
375
376 pub fn is_poisoned(&self) -> bool {
378 self.poisoned
379 }
380
381 pub fn stream(&self) -> &S {
383 &self.stream
384 }
385
386 pub fn stream_mut(&mut self) -> &mut S {
388 &mut self.stream
389 }
390}
391
392impl<S: Read + Write> Client<S> {
395 #[allow(clippy::needless_pass_by_value)] pub fn send<'r>(
408 &mut self,
409 req: Request<'_>,
410 reader: &'r mut ResponseReader,
411 ) -> Result<RestResponse<'r>, RestError> {
412 if self.poisoned {
413 return Err(RestError::ConnectionPoisoned);
414 }
415
416 if let Err(e) = self.write_all(req.as_bytes()) {
418 self.poisoned = true;
419 return Err(e);
420 }
421
422 match self.read_response(reader) {
424 Ok(resp) => Ok(resp),
425 Err(e) => self.handle_send_error(e),
426 }
427 }
428
429 #[cold]
431 fn handle_send_error<T>(&mut self, err: RestError) -> Result<T, RestError> {
432 self.poisoned = true;
433 if let RestError::Io(ref io_err) = err
436 && (io_err.kind() == std::io::ErrorKind::TimedOut
437 || io_err.kind() == std::io::ErrorKind::WouldBlock)
438 {
439 if self.peek_is_dead() {
440 return Err(RestError::ConnectionStale);
441 }
442 return Err(RestError::ReadTimeout);
443 }
444 Err(err)
445 }
446
447 #[allow(clippy::unused_self)]
452 fn peek_is_dead(&self) -> bool {
453 false
456 }
457
458 fn write_all(&mut self, data: &[u8]) -> Result<(), RestError> {
463 self.stream.write_all(data)?;
464 self.stream.flush()?;
465 Ok(())
466 }
467
468 fn read_into_reader(&mut self, reader: &mut ResponseReader) -> Result<usize, RestError> {
469 let n = reader.read_from(&mut self.stream)?;
470 Ok(n)
471 }
472
473 fn read_response<'r>(
474 &mut self,
475 reader: &'r mut ResponseReader,
476 ) -> Result<RestResponse<'r>, RestError> {
477 reader.consume_response();
479
480 loop {
482 match reader.next() {
483 Ok(Some(_)) => break,
484 Ok(None) => {}
485 Err(e) => {
486 self.poisoned = true;
487 return Err(e.into());
488 }
489 }
490 match self.read_into_reader(reader) {
491 Ok(0) => {
492 self.poisoned = true;
493 return Err(RestError::ConnectionClosed(
494 "server closed before response headers",
495 ));
496 }
497 Ok(_) => {}
498 Err(e) => {
499 self.poisoned = true;
500 return Err(e);
501 }
502 }
503 }
504
505 let status = reader.status();
507
508 if matches!(status, 100..=199 | 204 | 304) {
510 reader.set_body_consumed(0);
511 return Ok(RestResponse::new(status, 0, reader));
512 }
513
514 if reader.is_chunked() {
515 let body = self.read_chunked_body(reader)?;
516 reader.set_body_consumed(reader.body_remaining());
524 return Ok(RestResponse::new_chunked(status, body, reader));
525 }
526
527 let content_length = match reader.content_length() {
528 Some(Ok(n)) => n,
529 Some(Err(())) => {
530 return Err(RestError::Http(HttpError::Malformed(
531 "invalid Content-Length header",
532 )));
533 }
534 None => {
535 self.poisoned = true;
538 return Err(RestError::Http(HttpError::Malformed(
539 "no Content-Length and not chunked",
540 )));
541 }
542 };
543
544 let max_body = reader.max_body_size_limit();
545 if max_body > 0 && content_length > max_body {
546 self.poisoned = true;
547 return Err(RestError::BodyTooLarge {
548 size: content_length,
549 max: max_body,
550 });
551 }
552
553 while reader.body_remaining() < content_length {
555 match self.read_into_reader(reader) {
556 Ok(0) => {
557 self.poisoned = true;
558 return Err(RestError::ConnectionClosed(
559 "server closed during body read",
560 ));
561 }
562 Ok(_) => {}
563 Err(e) => {
564 self.poisoned = true;
565 return Err(e);
566 }
567 }
568 }
569
570 reader.set_body_consumed(content_length);
571 Ok(RestResponse::new(status, content_length, reader))
572 }
573
574 fn read_chunked_body(&mut self, reader: &ResponseReader) -> Result<Vec<u8>, RestError> {
579 use crate::http::ChunkedDecoder;
580
581 let max_body = reader.max_body_size_limit();
582 let mut decoder = ChunkedDecoder::new();
583 let mut body = Vec::with_capacity(4096);
584 let mut wire_buf = [0u8; 4096];
585 let mut decode_buf = [0u8; 4096];
586
587 let remainder = reader.remainder();
589 if !remainder.is_empty() {
590 let mut pos = 0;
591 while pos < remainder.len() && !decoder.is_done() {
592 let (consumed, produced) = decoder
593 .decode(&remainder[pos..], &mut decode_buf)
594 .map_err(RestError::Http)?;
595 pos += consumed;
596 if produced > 0 {
597 body.extend_from_slice(&decode_buf[..produced]);
598 if max_body > 0 && body.len() > max_body {
599 self.poisoned = true;
600 return Err(RestError::BodyTooLarge {
601 size: body.len(),
602 max: max_body,
603 });
604 }
605 }
606 if consumed == 0 && produced == 0 {
607 break;
608 }
609 }
610 }
611
612 while !decoder.is_done() {
614 let n = self.read_wire_bytes(&mut wire_buf)?;
615 if n == 0 {
616 self.poisoned = true;
617 return Err(RestError::ConnectionClosed(
618 "server closed during chunked body",
619 ));
620 }
621
622 let mut pos = 0;
623 while pos < n && !decoder.is_done() {
624 let (consumed, produced) = decoder
625 .decode(&wire_buf[pos..n], &mut decode_buf)
626 .map_err(RestError::Http)?;
627 pos += consumed;
628 if produced > 0 {
629 body.extend_from_slice(&decode_buf[..produced]);
630 if max_body > 0 && body.len() > max_body {
632 self.poisoned = true;
633 return Err(RestError::BodyTooLarge {
634 size: body.len(),
635 max: max_body,
636 });
637 }
638 }
639 if consumed == 0 && produced == 0 {
640 break;
641 }
642 }
643 }
644
645 Ok(body)
646 }
647
648 fn read_wire_bytes(&mut self, buf: &mut [u8]) -> Result<usize, RestError> {
650 Ok(self.stream.read(buf)?)
651 }
652}
653
654#[cfg(test)]
659mod tests {
660 use super::*;
661 use std::io::{Cursor, Read, Write};
662 use std::net::{TcpListener, TcpStream};
663
664 struct MockStream {
665 written: Vec<u8>,
666 response: Cursor<Vec<u8>>,
667 }
668
669 impl MockStream {
670 fn new(response: &[u8]) -> Self {
671 Self {
672 written: Vec::new(),
673 response: Cursor::new(response.to_vec()),
674 }
675 }
676
677 fn written_str(&self) -> &str {
678 std::str::from_utf8(&self.written).unwrap()
679 }
680 }
681
682 impl Read for MockStream {
683 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
684 self.response.read(buf)
685 }
686 }
687
688 impl Write for MockStream {
689 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
690 self.written.extend_from_slice(buf);
691 Ok(buf.len())
692 }
693 fn flush(&mut self) -> io::Result<()> {
694 Ok(())
695 }
696 }
697
698 fn ok_response(body: &str) -> Vec<u8> {
699 format!(
700 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
701 body.len(),
702 body
703 )
704 .into_bytes()
705 }
706
707 #[allow(dead_code)]
709 fn send_get<'r>(
710 writer: &mut RequestWriter,
711 conn: &mut Client<MockStream>,
712 reader: &'r mut ResponseReader,
713 path: &str,
714 ) -> Result<RestResponse<'r>, RestError> {
715 let req = writer.get(path).finish()?;
716 conn.send(req, reader)
717 }
718
719 #[test]
722 fn get_request_format() {
723 let resp = ok_response(r#"{"ok":true}"#);
724 let mock = MockStream::new(&resp);
725 let mut writer = RequestWriter::new("api.example.com").unwrap();
726 let mut reader = ResponseReader::new(4096);
727 let mut conn = Client::new(mock);
728
729 let req = writer.get("/api/v1/status").finish().unwrap();
730 let resp = conn.send(req, &mut reader).unwrap();
731 assert_eq!(resp.status(), 200);
732 assert_eq!(resp.body_str().unwrap(), r#"{"ok":true}"#);
733
734 let written = conn.stream().written_str();
735 assert!(written.starts_with("GET /api/v1/status HTTP/1.1\r\n"));
736 assert!(written.contains("Host: api.example.com\r\n"));
737 assert!(written.contains("Connection: keep-alive\r\n"));
738 assert!(written.ends_with("\r\n\r\n"));
739 }
740
741 #[test]
742 fn post_with_body() {
743 let resp = ok_response(r#"{"filled":true}"#);
744 let mock = MockStream::new(&resp);
745 let mut writer = RequestWriter::new("api.example.com").unwrap();
746 let mut reader = ResponseReader::new(4096);
747 let mut conn = Client::new(mock);
748
749 let body = br#"{"symbol":"BTC","side":"buy"}"#;
750 let req = writer.post("/api/v3/order").body(body).finish().unwrap();
751 let resp = conn.send(req, &mut reader).unwrap();
752 assert_eq!(resp.status(), 200);
753
754 let written = conn.stream().written_str();
755 assert!(written.starts_with("POST /api/v3/order HTTP/1.1\r\n"));
756 assert!(written.contains(&format!("Content-Length: {}\r\n", body.len())));
757 assert!(written.ends_with(std::str::from_utf8(body).unwrap()));
758 }
759
760 #[test]
761 fn post_body_writer() {
762 let resp = ok_response(r#"{"ok":true}"#);
763 let mock = MockStream::new(&resp);
764 let mut writer = RequestWriter::new("host").unwrap();
765 let mut reader = ResponseReader::new(4096);
766 let mut conn = Client::new(mock);
767
768 let body = br#"{"symbol":"BTC","side":"buy"}"#;
769 let req = writer
770 .post("/order")
771 .body_writer(|w| {
772 use std::io::Write;
773 w.write_all(body)
774 })
775 .finish()
776 .unwrap();
777
778 let written_before = std::str::from_utf8(req.as_bytes()).unwrap().to_string();
779 assert!(written_before.contains("Content-Length:"));
781 assert!(written_before.contains(&format!("{}", body.len())));
782 assert!(written_before.ends_with(std::str::from_utf8(body).unwrap()));
783
784 let resp = conn.send(req, &mut reader).unwrap();
785 assert_eq!(resp.status(), 200);
786 }
787
788 #[test]
789 fn body_writer_from_headers_phase() {
790 let mut writer = RequestWriter::new("host").unwrap();
791 let body = b"test-body";
792 let req = writer
793 .post("/order")
794 .header("X-Custom", "val")
795 .body_writer(|w| {
796 use std::io::Write;
797 w.write_all(body)
798 })
799 .finish()
800 .unwrap();
801
802 let data = std::str::from_utf8(req.as_bytes()).unwrap();
803 assert!(data.contains("X-Custom: val\r\n"));
804 assert!(data.contains(&format!("{}", body.len())));
805 assert!(data.ends_with("test-body"));
806 }
807
808 #[test]
809 fn body_writer_empty() {
810 let mut writer = RequestWriter::new("host").unwrap();
811 let req = writer
812 .post("/order")
813 .body_writer(|_w| Ok::<(), std::io::Error>(()))
814 .finish()
815 .unwrap();
816
817 let data = std::str::from_utf8(req.as_bytes()).unwrap();
818 assert!(data.contains("Content-Length:"));
820 assert!(data.contains("0\r\n\r\n"));
821 }
822
823 #[test]
824 fn body_writer_matches_body() {
825 let mut writer1 = RequestWriter::new("host").unwrap();
827 let mut writer2 = RequestWriter::new("host").unwrap();
828
829 let body = b"identical-content";
830
831 let req1 = writer1.post("/test").body(body).finish().unwrap();
832 let req2 = writer2
833 .post("/test")
834 .body_writer(|w| {
835 use std::io::Write;
836 w.write_all(body)
837 })
838 .finish()
839 .unwrap();
840
841 let d1 = std::str::from_utf8(req1.as_bytes()).unwrap();
843 let d2 = std::str::from_utf8(req2.as_bytes()).unwrap();
844 assert_eq!(d1, d2);
845 }
846
847 #[test]
848 fn all_methods() {
849 for (method, expected) in [
850 (super::super::request::Method::Put, "PUT"),
851 (super::super::request::Method::Delete, "DELETE"),
852 (super::super::request::Method::Patch, "PATCH"),
853 ] {
854 let resp = ok_response("{}");
855 let mock = MockStream::new(&resp);
856 let mut writer = RequestWriter::new("host").unwrap();
857 let mut reader = ResponseReader::new(4096);
858 let mut conn = Client::new(mock);
859
860 let req = writer.request(method, "/test").finish().unwrap();
861 let _ = conn.send(req, &mut reader).unwrap();
862 assert!(
863 conn.stream()
864 .written_str()
865 .starts_with(&format!("{expected} /test HTTP/1.1\r\n"))
866 );
867 }
868 }
869
870 #[test]
871 fn default_headers_included() {
872 let resp = ok_response("{}");
873 let mock = MockStream::new(&resp);
874 let mut writer = RequestWriter::new("api.example.com").unwrap();
875 writer.default_header("X-API-KEY", "secret123").unwrap();
876 writer
877 .default_header("Content-Type", "application/json")
878 .unwrap();
879 let mut reader = ResponseReader::new(4096);
880 let mut conn = Client::new(mock);
881
882 let req = writer.get("/test").finish().unwrap();
883 let _ = conn.send(req, &mut reader).unwrap();
884
885 let written = conn.stream().written_str();
886 assert!(written.contains("X-API-KEY: secret123\r\n"));
887 assert!(written.contains("Content-Type: application/json\r\n"));
888 }
889
890 #[test]
891 fn extra_headers() {
892 let resp = ok_response("{}");
893 let mock = MockStream::new(&resp);
894 let mut writer = RequestWriter::new("api.example.com").unwrap();
895 let mut reader = ResponseReader::new(4096);
896 let mut conn = Client::new(mock);
897
898 let req = writer
899 .get("/test")
900 .header("X-Custom", "value1")
901 .header("Authorization", "Bearer tok")
902 .finish()
903 .unwrap();
904 let _ = conn.send(req, &mut reader).unwrap();
905
906 let written = conn.stream().written_str();
907 assert!(written.contains("X-Custom: value1\r\n"));
908 assert!(written.contains("Authorization: Bearer tok\r\n"));
909 }
910
911 #[test]
914 fn query_params_encoded() {
915 let mut writer = RequestWriter::new("host").unwrap();
916 let req = writer
917 .get("/orders")
918 .query("symbol", "BTC-USD")
919 .query("limit", "100")
920 .finish()
921 .unwrap();
922 let data = std::str::from_utf8(req.as_bytes()).unwrap();
923 assert!(data.starts_with("GET /orders?symbol=BTC-USD&limit=100 HTTP/1.1\r\n"));
924 }
925
926 #[test]
927 fn query_encodes_special_chars() {
928 let mut writer = RequestWriter::new("host").unwrap();
929 let req = writer
930 .get("/search")
931 .query("q", "hello world&more=yes")
932 .finish()
933 .unwrap();
934 let data = std::str::from_utf8(req.as_bytes()).unwrap();
935 assert!(data.starts_with("GET /search?q=hello%20world%26more%3Dyes HTTP/1.1\r\n"));
936 }
937
938 #[test]
939 fn query_raw_no_encoding() {
940 let mut writer = RequestWriter::new("host").unwrap();
941 let req = writer
942 .get("/orders")
943 .query_raw("symbol", "BTC-USD")
944 .finish()
945 .unwrap();
946 let data = std::str::from_utf8(req.as_bytes()).unwrap();
947 assert!(data.starts_with("GET /orders?symbol=BTC-USD HTTP/1.1\r\n"));
948 }
949
950 #[test]
951 fn query_then_header() {
952 let mut writer = RequestWriter::new("host").unwrap();
953 let req = writer
954 .get("/orders")
955 .query("sym", "ETH")
956 .header("X-Nonce", "123")
957 .finish()
958 .unwrap();
959 let data = std::str::from_utf8(req.as_bytes()).unwrap();
960 assert!(data.starts_with("GET /orders?sym=ETH HTTP/1.1\r\n"));
961 assert!(data.contains("X-Nonce: 123\r\n"));
962 }
963
964 #[test]
965 fn path_with_existing_query() {
966 let mut writer = RequestWriter::new("host").unwrap();
967 let req = writer
968 .get("/path?existing=true")
969 .query("extra", "val")
970 .finish()
971 .unwrap();
972 let data = std::str::from_utf8(req.as_bytes()).unwrap();
973 assert!(data.starts_with("GET /path?existing=true&extra=val HTTP/1.1\r\n"));
974 }
975
976 #[test]
977 fn base_path_prepended() {
978 let mut writer = RequestWriter::new("host").unwrap();
979 writer.set_base_path("/api/v3").unwrap();
980 let req = writer.get("/orders").finish().unwrap();
981 let data = std::str::from_utf8(req.as_bytes()).unwrap();
982 assert!(data.starts_with("GET /api/v3/orders HTTP/1.1\r\n"));
983 }
984
985 #[test]
986 fn get_raw_skips_query_phase() {
987 let mut writer = RequestWriter::new("host").unwrap();
988 let req = writer
989 .get_raw("/orders?symbol=BTC&limit=100")
990 .finish()
991 .unwrap();
992 let data = std::str::from_utf8(req.as_bytes()).unwrap();
993 assert!(data.starts_with("GET /orders?symbol=BTC&limit=100 HTTP/1.1\r\n"));
994 }
995
996 #[test]
999 fn crlf_in_header_rejected() {
1000 let mut writer = RequestWriter::new("host").unwrap();
1001 let result = writer.get("/test").header("X-Bad\r\n", "val").finish();
1002 assert!(matches!(result, Err(RestError::CrlfInjection)));
1003 }
1004
1005 #[test]
1006 fn crlf_in_path_rejected() {
1007 let mut writer = RequestWriter::new("host").unwrap();
1008 let result = writer.get("/path\r\nEvil: yes").finish();
1009 assert!(matches!(result, Err(RestError::CrlfInjection)));
1010 }
1011
1012 #[test]
1013 fn crlf_in_default_header_rejected() {
1014 let mut writer = RequestWriter::new("host").unwrap();
1015 assert!(matches!(
1016 writer.default_header("X-Bad\n", "val"),
1017 Err(RestError::CrlfInjection)
1018 ));
1019 }
1020
1021 #[test]
1022 fn crlf_in_query_raw_rejected() {
1023 let mut writer = RequestWriter::new("host").unwrap();
1024 let result = writer.get("/test").query_raw("k", "v\r\n").finish();
1025 assert!(matches!(result, Err(RestError::CrlfInjection)));
1026 }
1027
1028 #[test]
1029 fn crlf_in_host_rejected() {
1030 assert!(matches!(
1031 RequestWriter::new("evil.com\r\nX-Injected: yes"),
1032 Err(RestError::CrlfInjection)
1033 ));
1034 }
1035
1036 #[test]
1039 fn response_headers_accessible() {
1040 let resp_bytes = b"HTTP/1.1 200 OK\r\nX-Request-Id: abc123\r\nX-RateLimit-Remaining: 42\r\nContent-Length: 2\r\n\r\n{}";
1041 let mock = MockStream::new(resp_bytes);
1042 let mut writer = RequestWriter::new("host").unwrap();
1043 let mut reader = ResponseReader::new(4096);
1044 let mut conn = Client::new(mock);
1045
1046 let req = writer.get("/test").finish().unwrap();
1047 let resp = conn.send(req, &mut reader).unwrap();
1048 assert_eq!(resp.header("X-Request-Id"), Some("abc123"));
1049 assert_eq!(resp.header("X-RateLimit-Remaining"), Some("42"));
1050 }
1051
1052 #[test]
1053 fn chunked_encoding_decoded() {
1054 let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nMozilla\r\n11\r\nDeveloper Network\r\n0\r\n\r\n";
1055 let mock = MockStream::new(resp_bytes);
1056 let mut writer = RequestWriter::new("host").unwrap();
1057 let mut reader = ResponseReader::new(4096);
1058 let mut conn = Client::new(mock);
1059
1060 let req = writer.get("/test").finish().unwrap();
1061 let resp = conn.send(req, &mut reader).unwrap();
1062 assert_eq!(resp.status(), 200);
1063 assert_eq!(resp.body_str().unwrap(), "MozillaDeveloper Network");
1064 }
1065
1066 #[test]
1067 fn chunked_empty_body() {
1068 let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n";
1069 let mock = MockStream::new(resp_bytes);
1070 let mut writer = RequestWriter::new("host").unwrap();
1071 let mut reader = ResponseReader::new(4096);
1072 let mut conn = Client::new(mock);
1073
1074 let req = writer.get("/test").finish().unwrap();
1075 let resp = conn.send(req, &mut reader).unwrap();
1076 assert_eq!(resp.body().len(), 0);
1077 }
1078
1079 #[test]
1080 fn chunked_json_response() {
1081 let body = r#"{"orderId":12345,"status":"FILLED"}"#;
1083 let chunked = format!(
1084 "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n",
1085 body.len(),
1086 body
1087 );
1088 let mock = MockStream::new(chunked.as_bytes());
1089 let mut writer = RequestWriter::new("host").unwrap();
1090 let mut reader = ResponseReader::new(4096);
1091 let mut conn = Client::new(mock);
1092
1093 let req = writer.get("/test").finish().unwrap();
1094 let resp = conn.send(req, &mut reader).unwrap();
1095 assert_eq!(resp.body_str().unwrap(), body);
1096 }
1097
1098 #[test]
1099 fn malformed_content_length_rejected() {
1100 let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: abc\r\n\r\nbody";
1101 let mock = MockStream::new(resp_bytes);
1102 let mut writer = RequestWriter::new("host").unwrap();
1103 let mut reader = ResponseReader::new(4096);
1104 let mut conn = Client::new(mock);
1105
1106 let req = writer.get("/test").finish().unwrap();
1107 let result = conn.send(req, &mut reader);
1108 assert!(matches!(result, Err(RestError::Http(_))));
1109 }
1110
1111 #[test]
1112 fn body_too_large_rejected() {
1113 let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 999999\r\n\r\n";
1114 let mock = MockStream::new(resp_bytes);
1115 let mut writer = RequestWriter::new("host").unwrap();
1116 let mut reader = ResponseReader::new(4096).max_body_size(32 * 1024);
1117 let mut conn = Client::new(mock);
1118
1119 let req = writer.get("/test").finish().unwrap();
1120 let result = conn.send(req, &mut reader);
1121 assert!(matches!(
1122 result,
1123 Err(RestError::BodyTooLarge { size: 999_999, .. })
1124 ));
1125 }
1126
1127 #[test]
1128 fn status_204_no_body() {
1129 let resp_bytes = b"HTTP/1.1 204 No Content\r\nContent-Length: 5\r\n\r\nxxxxx";
1130 let mock = MockStream::new(resp_bytes);
1131 let mut writer = RequestWriter::new("host").unwrap();
1132 let mut reader = ResponseReader::new(4096);
1133 let mut conn = Client::new(mock);
1134
1135 let req = writer.get("/test").finish().unwrap();
1136 let resp = conn.send(req, &mut reader).unwrap();
1137 assert_eq!(resp.status(), 204);
1138 assert_eq!(resp.body().len(), 0);
1139 }
1140
1141 #[test]
1142 fn connection_poisoned_after_io_error() {
1143 let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\npartial";
1144 let mock = MockStream::new(resp_bytes);
1145 let mut writer = RequestWriter::new("host").unwrap();
1146 let mut reader = ResponseReader::new(4096);
1147 let mut conn = Client::new(mock);
1148
1149 let req = writer.get("/test").finish().unwrap();
1150 let result = conn.send(req, &mut reader);
1151 assert!(matches!(result, Err(RestError::ConnectionClosed(_))));
1152
1153 let req = writer.get("/test2").finish().unwrap();
1154 let result = conn.send(req, &mut reader);
1155 assert!(matches!(result, Err(RestError::ConnectionPoisoned)));
1156 }
1157
1158 #[test]
1161 fn url_parsing() {
1162 let parsed = parse_base_url("https://api.binance.com").unwrap();
1163 assert!(parsed.tls);
1164 assert_eq!(parsed.host, "api.binance.com");
1165 assert_eq!(parsed.port, 443);
1166 assert_eq!(parsed.path, "");
1167
1168 let parsed = parse_base_url("http://localhost:8080").unwrap();
1169 assert!(!parsed.tls);
1170 assert_eq!(parsed.host, "localhost");
1171 assert_eq!(parsed.port, 8080);
1172
1173 let parsed = parse_base_url("https://api.example.com/v1/foo").unwrap();
1174 assert_eq!(parsed.path, "/v1/foo");
1175
1176 assert!(parse_base_url("ftp://host").is_err());
1177 assert!(parse_base_url("http://").is_err());
1178 }
1179
1180 #[test]
1181 fn ipv6_url_parsing() {
1182 let parsed = parse_base_url("http://[::1]:8080").unwrap();
1183 assert_eq!(parsed.host, "::1");
1184 assert_eq!(parsed.port, 8080);
1185
1186 let parsed = parse_base_url("http://[::1]").unwrap();
1187 assert_eq!(parsed.host, "::1");
1188 assert_eq!(parsed.port, 80);
1189
1190 assert!(parse_base_url("http://[::1").is_err());
1191 }
1192
1193 #[test]
1196 fn keep_alive_sequential_requests() {
1197 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1198 let addr = listener.local_addr().unwrap();
1199
1200 let server = std::thread::spawn(move || {
1201 let (mut tcp, _) = listener.accept().unwrap();
1202 let mut buf = [0u8; 4096];
1203
1204 let n = tcp.read(&mut buf).unwrap();
1205 assert!(
1206 std::str::from_utf8(&buf[..n])
1207 .unwrap()
1208 .contains("GET /first")
1209 );
1210 let body1 = r#"{"id":1}"#;
1211 let resp1 = format!(
1212 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
1213 body1.len(),
1214 body1
1215 );
1216 tcp.write_all(resp1.as_bytes()).unwrap();
1217
1218 let n = tcp.read(&mut buf).unwrap();
1219 assert!(
1220 std::str::from_utf8(&buf[..n])
1221 .unwrap()
1222 .contains("GET /second")
1223 );
1224 let body2 = r#"{"id":2}"#;
1225 let resp2 = format!(
1226 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
1227 body2.len(),
1228 body2
1229 );
1230 tcp.write_all(resp2.as_bytes()).unwrap();
1231 });
1232
1233 let tcp = TcpStream::connect(addr).unwrap();
1234 let mut writer = RequestWriter::new("localhost").unwrap();
1235 let mut reader = ResponseReader::new(4096);
1236 let mut conn = Client::new(tcp);
1237
1238 let req = writer.get("/first").finish().unwrap();
1239 let resp = conn.send(req, &mut reader).unwrap();
1240 assert_eq!(resp.body_str().unwrap(), r#"{"id":1}"#);
1241 drop(resp);
1242
1243 let req = writer.get("/second").finish().unwrap();
1244 let resp = conn.send(req, &mut reader).unwrap();
1245 assert_eq!(resp.body_str().unwrap(), r#"{"id":2}"#);
1246
1247 server.join().unwrap();
1248 }
1249
1250 #[test]
1253 fn method_display() {
1254 use super::super::request::Method;
1255 assert_eq!(format!("{}", Method::Get), "GET");
1256 assert_eq!(format!("{}", Method::Post), "POST");
1257 assert_eq!(format!("{}", Method::Delete), "DELETE");
1258 }
1259}