use std::time::Duration;
use super::error::RestError;
use super::request::RequestWriter;
#[cfg(not(feature = "tokio"))]
use super::request::Request;
#[cfg(not(feature = "tokio"))]
use super::response::RestResponse;
#[cfg(not(feature = "tokio"))]
use crate::http::{HttpError, ResponseReader};
#[cfg(not(feature = "tokio"))]
use std::io::{self, Read, Write};
#[cfg(feature = "tls")]
use crate::tls::TlsConfig;
#[non_exhaustive]
pub struct ParsedUrl<'a> {
pub tls: bool,
pub host: &'a str,
pub port: u16,
pub path: &'a str,
}
impl ParsedUrl<'_> {
pub fn host_header(&self) -> String {
let default = if self.tls { 443 } else { 80 };
if self.port == default {
self.host.to_string()
} else {
format!("{}:{}", self.host, self.port)
}
}
}
pub fn parse_base_url(url: &str) -> Result<ParsedUrl<'_>, RestError> {
let (tls, rest) = if let Some(r) = url.strip_prefix("https://") {
(true, r)
} else if let Some(r) = url.strip_prefix("http://") {
(false, r)
} else {
return Err(RestError::InvalidUrl(url.to_string()));
};
let (host_port, path) = rest
.find('/')
.map_or((rest, ""), |i| (&rest[..i], &rest[i..]));
if host_port.is_empty() {
return Err(RestError::InvalidUrl(format!("empty host: {url}")));
}
let default_port = if tls { 443 } else { 80 };
let (host, port) = if host_port.starts_with('[') {
match host_port.find(']') {
Some(end) => {
let h = &host_port[1..end];
let rest = &host_port[end + 1..];
if let Some(port_str) = rest.strip_prefix(':') {
let p = port_str
.parse::<u16>()
.map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
(h, p)
} else {
(h, default_port)
}
}
None => return Err(RestError::InvalidUrl(format!("unclosed bracket: {url}"))),
}
} else {
match host_port.rfind(':') {
None => (host_port, default_port),
Some(i) => {
let port_str = &host_port[i + 1..];
if port_str.is_empty() {
(&host_port[..i], default_port)
} else {
let p = port_str
.parse::<u16>()
.map_err(|_| RestError::InvalidUrl(format!("invalid port: {url}")))?;
(&host_port[..i], p)
}
}
}
};
Ok(ParsedUrl {
tls,
host,
port,
path,
})
}
pub struct ClientBuilder {
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
tcp_nodelay: bool,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
}
impl ClientBuilder {
#[must_use]
pub fn new() -> Self {
Self {
#[cfg(feature = "tls")]
tls_config: None,
tcp_nodelay: false,
connect_timeout: None,
read_timeout: None,
}
}
#[cfg(feature = "tls")]
#[must_use]
pub fn tls(mut self, config: &TlsConfig) -> Self {
self.tls_config = Some(config.clone());
self
}
#[must_use]
pub fn disable_nagle(mut self) -> Self {
self.tcp_nodelay = true;
self
}
#[must_use]
pub fn connect_timeout(mut self, d: Duration) -> Self {
self.connect_timeout = Some(d);
self
}
#[must_use]
pub fn read_timeout(mut self, d: Duration) -> Self {
self.read_timeout = Some(d);
self
}
#[cfg(all(not(feature = "tokio"), feature = "tls"))]
pub fn connect(
self,
url: &str,
) -> Result<Client<crate::MaybeTls<std::net::TcpStream>>, RestError> {
let parsed = parse_base_url(url)?;
let addr = format!("{}:{}", parsed.host, parsed.port);
let tcp = match self.connect_timeout {
Some(timeout) => {
let addrs: Vec<std::net::SocketAddr> =
std::net::ToSocketAddrs::to_socket_addrs(&addr)
.map_err(RestError::Io)?
.collect();
let first = addrs
.first()
.ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
std::net::TcpStream::connect_timeout(first, timeout)?
}
None => std::net::TcpStream::connect(&addr)?,
};
if self.tcp_nodelay {
tcp.set_nodelay(true)?;
}
if let Some(timeout) = self.read_timeout {
tcp.set_read_timeout(Some(timeout))?;
}
let stream = if parsed.tls {
let config = match self.tls_config {
Some(c) => c,
None => TlsConfig::new().map_err(RestError::Tls)?,
};
let codec = crate::tls::TlsCodec::new(&config, parsed.host)?;
let mut tls = crate::tls::TlsStream::new(tcp, codec);
tls.handshake().map_err(RestError::Tls)?;
crate::MaybeTls::Tls(Box::new(tls))
} else {
crate::MaybeTls::Plain(tcp)
};
Ok(Client {
stream,
poisoned: false,
})
}
#[cfg(all(not(feature = "tokio"), not(feature = "tls")))]
pub fn connect(self, url: &str) -> Result<Client<std::net::TcpStream>, RestError> {
let parsed = parse_base_url(url)?;
if parsed.tls {
return Err(RestError::TlsNotEnabled);
}
let addr = format!("{}:{}", parsed.host, parsed.port);
let tcp = match self.connect_timeout {
Some(timeout) => {
let addrs: Vec<std::net::SocketAddr> =
std::net::ToSocketAddrs::to_socket_addrs(&addr)
.map_err(RestError::Io)?
.collect();
let first = addrs
.first()
.ok_or_else(|| RestError::Io(io::Error::other("DNS resolution failed")))?;
std::net::TcpStream::connect_timeout(first, timeout)?
}
None => std::net::TcpStream::connect(&addr)?,
};
if self.tcp_nodelay {
tcp.set_nodelay(true)?;
}
if let Some(timeout) = self.read_timeout {
tcp.set_read_timeout(Some(timeout))?;
}
Ok(Client {
stream: tcp,
poisoned: false,
})
}
#[cfg(not(feature = "tokio"))]
pub fn connect_with<S: Read + Write>(
self,
stream: S,
url: &str,
) -> Result<Client<S>, RestError> {
parse_base_url(url)?;
Ok(Client::new(stream))
}
pub fn writer_for(url: &str) -> Result<RequestWriter, RestError> {
let parsed = parse_base_url(url)?;
let host_header = parsed.host_header();
let mut writer = RequestWriter::new(&host_header)?;
if !parsed.path.is_empty() {
writer.set_base_path(parsed.path)?;
}
Ok(writer)
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct Client<S> {
pub(crate) stream: S,
pub(crate) poisoned: bool,
}
impl Client<std::net::TcpStream> {
#[must_use]
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub fn set_read_timeout(&self, timeout: Option<std::time::Duration>) -> Result<(), RestError> {
self.stream.set_read_timeout(timeout).map_err(RestError::Io)
}
#[cfg(feature = "socket-opts")]
pub fn set_tcp_keepalive(&self, idle: std::time::Duration) -> Result<(), RestError> {
let sock = socket2::SockRef::from(&self.stream);
let keepalive = socket2::TcpKeepalive::new().with_time(idle);
sock.set_tcp_keepalive(&keepalive).map_err(RestError::Io)
}
}
impl<S> Client<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
poisoned: false,
}
}
pub fn is_poisoned(&self) -> bool {
self.poisoned
}
pub fn stream(&self) -> &S {
&self.stream
}
pub fn stream_mut(&mut self) -> &mut S {
&mut self.stream
}
}
#[cfg(not(feature = "tokio"))]
impl<S: Read + Write> Client<S> {
#[allow(clippy::needless_pass_by_value)] pub fn send<'r>(
&mut self,
req: Request<'_>,
reader: &'r mut ResponseReader,
) -> Result<RestResponse<'r>, RestError> {
if self.poisoned {
return Err(RestError::ConnectionPoisoned);
}
if let Err(e) = self.write_all(req.as_bytes()) {
self.poisoned = true;
return Err(e);
}
match self.read_response(reader) {
Ok(resp) => Ok(resp),
Err(e) => self.handle_send_error(e),
}
}
#[cold]
fn handle_send_error<T>(&mut self, err: RestError) -> Result<T, RestError> {
self.poisoned = true;
if let RestError::Io(ref io_err) = err {
if io_err.kind() == std::io::ErrorKind::TimedOut
|| io_err.kind() == std::io::ErrorKind::WouldBlock
{
if self.peek_is_dead() {
return Err(RestError::ConnectionStale);
}
return Err(RestError::ReadTimeout);
}
}
Err(err)
}
#[allow(clippy::unused_self)]
fn peek_is_dead(&self) -> bool {
false
}
fn write_all(&mut self, data: &[u8]) -> Result<(), RestError> {
self.stream.write_all(data)?;
self.stream.flush()?;
Ok(())
}
fn read_into_reader(&mut self, reader: &mut ResponseReader) -> Result<usize, RestError> {
let n = reader.read_from(&mut self.stream)?;
Ok(n)
}
fn read_response<'r>(
&mut self,
reader: &'r mut ResponseReader,
) -> Result<RestResponse<'r>, RestError> {
reader.consume_response();
loop {
match reader.next() {
Ok(Some(_)) => break,
Ok(None) => {}
Err(e) => {
self.poisoned = true;
return Err(e.into());
}
}
match self.read_into_reader(reader) {
Ok(0) => {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed before response headers",
));
}
Ok(_) => {}
Err(e) => {
self.poisoned = true;
return Err(e);
}
}
}
let status = reader.status();
if matches!(status, 100..=199 | 204 | 304) {
reader.set_body_consumed(0);
return Ok(RestResponse::new(status, 0, reader));
}
if reader.is_chunked() {
let body = self.read_chunked_body(reader)?;
reader.set_body_consumed(reader.body_remaining());
return Ok(RestResponse::new_chunked(status, body, reader));
}
let content_length = match reader.content_length() {
Some(Ok(n)) => n,
Some(Err(())) => {
return Err(RestError::Http(HttpError::Malformed(
"invalid Content-Length header",
)));
}
None => {
self.poisoned = true;
return Err(RestError::Http(HttpError::Malformed(
"no Content-Length and not chunked",
)));
}
};
let max_body = reader.max_body_size_limit();
if max_body > 0 && content_length > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: content_length,
max: max_body,
});
}
while reader.body_remaining() < content_length {
match self.read_into_reader(reader) {
Ok(0) => {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed during body read",
));
}
Ok(_) => {}
Err(e) => {
self.poisoned = true;
return Err(e);
}
}
}
reader.set_body_consumed(content_length);
Ok(RestResponse::new(status, content_length, reader))
}
fn read_chunked_body(&mut self, reader: &ResponseReader) -> Result<Vec<u8>, RestError> {
use crate::http::ChunkedDecoder;
let max_body = reader.max_body_size_limit();
let mut decoder = ChunkedDecoder::new();
let mut body = Vec::with_capacity(4096);
let mut wire_buf = [0u8; 4096];
let mut decode_buf = [0u8; 4096];
let remainder = reader.remainder();
if !remainder.is_empty() {
let mut pos = 0;
while pos < remainder.len() && !decoder.is_done() {
let (consumed, produced) = decoder
.decode(&remainder[pos..], &mut decode_buf)
.map_err(RestError::Http)?;
pos += consumed;
if produced > 0 {
body.extend_from_slice(&decode_buf[..produced]);
if max_body > 0 && body.len() > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: body.len(),
max: max_body,
});
}
}
if consumed == 0 && produced == 0 {
break;
}
}
}
while !decoder.is_done() {
let n = self.read_wire_bytes(&mut wire_buf)?;
if n == 0 {
self.poisoned = true;
return Err(RestError::ConnectionClosed(
"server closed during chunked body",
));
}
let mut pos = 0;
while pos < n && !decoder.is_done() {
let (consumed, produced) = decoder
.decode(&wire_buf[pos..n], &mut decode_buf)
.map_err(RestError::Http)?;
pos += consumed;
if produced > 0 {
body.extend_from_slice(&decode_buf[..produced]);
if max_body > 0 && body.len() > max_body {
self.poisoned = true;
return Err(RestError::BodyTooLarge {
size: body.len(),
max: max_body,
});
}
}
if consumed == 0 && produced == 0 {
break;
}
}
}
Ok(body)
}
fn read_wire_bytes(&mut self, buf: &mut [u8]) -> Result<usize, RestError> {
Ok(self.stream.read(buf)?)
}
}
#[cfg(test)]
#[cfg(not(feature = "tokio"))]
mod tests {
use super::*;
use std::io::{Cursor, Read, Write};
use std::net::{TcpListener, TcpStream};
struct MockStream {
written: Vec<u8>,
response: Cursor<Vec<u8>>,
}
impl MockStream {
fn new(response: &[u8]) -> Self {
Self {
written: Vec::new(),
response: Cursor::new(response.to_vec()),
}
}
fn written_str(&self) -> &str {
std::str::from_utf8(&self.written).unwrap()
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.response.read(buf)
}
}
impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
fn ok_response(body: &str) -> Vec<u8> {
format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
)
.into_bytes()
}
#[allow(dead_code)]
fn send_get<'r>(
writer: &mut RequestWriter,
conn: &mut Client<MockStream>,
reader: &'r mut ResponseReader,
path: &str,
) -> Result<RestResponse<'r>, RestError> {
let req = writer.get(path).finish()?;
conn.send(req, reader)
}
#[test]
fn get_request_format() {
let resp = ok_response(r#"{"ok":true}"#);
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("api.example.com").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/api/v1/status").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), r#"{"ok":true}"#);
let written = conn.stream().written_str();
assert!(written.starts_with("GET /api/v1/status HTTP/1.1\r\n"));
assert!(written.contains("Host: api.example.com\r\n"));
assert!(written.contains("Connection: keep-alive\r\n"));
assert!(written.ends_with("\r\n\r\n"));
}
#[test]
fn post_with_body() {
let resp = ok_response(r#"{"filled":true}"#);
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("api.example.com").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let body = br#"{"symbol":"BTC","side":"buy"}"#;
let req = writer.post("/api/v3/order").body(body).finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.status(), 200);
let written = conn.stream().written_str();
assert!(written.starts_with("POST /api/v3/order HTTP/1.1\r\n"));
assert!(written.contains(&format!("Content-Length: {}\r\n", body.len())));
assert!(written.ends_with(std::str::from_utf8(body).unwrap()));
}
#[test]
fn post_body_writer() {
let resp = ok_response(r#"{"ok":true}"#);
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let body = br#"{"symbol":"BTC","side":"buy"}"#;
let req = writer
.post("/order")
.body_writer(|w| {
use std::io::Write;
w.write_all(body)
})
.finish()
.unwrap();
let written_before = std::str::from_utf8(req.as_bytes()).unwrap().to_string();
assert!(written_before.contains("Content-Length:"));
assert!(written_before.contains(&format!("{}", body.len())));
assert!(written_before.ends_with(std::str::from_utf8(body).unwrap()));
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.status(), 200);
}
#[test]
fn body_writer_from_headers_phase() {
let mut writer = RequestWriter::new("host").unwrap();
let body = b"test-body";
let req = writer
.post("/order")
.header("X-Custom", "val")
.body_writer(|w| {
use std::io::Write;
w.write_all(body)
})
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.contains("X-Custom: val\r\n"));
assert!(data.contains(&format!("{}", body.len())));
assert!(data.ends_with("test-body"));
}
#[test]
fn body_writer_empty() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.post("/order")
.body_writer(|_w| Ok::<(), std::io::Error>(()))
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.contains("Content-Length:"));
assert!(data.contains("0\r\n\r\n"));
}
#[test]
fn body_writer_matches_body() {
let mut writer1 = RequestWriter::new("host").unwrap();
let mut writer2 = RequestWriter::new("host").unwrap();
let body = b"identical-content";
let req1 = writer1.post("/test").body(body).finish().unwrap();
let req2 = writer2
.post("/test")
.body_writer(|w| {
use std::io::Write;
w.write_all(body)
})
.finish()
.unwrap();
let d1 = std::str::from_utf8(req1.as_bytes()).unwrap();
let d2 = std::str::from_utf8(req2.as_bytes()).unwrap();
assert_eq!(d1, d2);
}
#[test]
fn all_methods() {
for (method, expected) in [
(super::super::request::Method::Put, "PUT"),
(super::super::request::Method::Delete, "DELETE"),
(super::super::request::Method::Patch, "PATCH"),
] {
let resp = ok_response("{}");
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.request(method, "/test").finish().unwrap();
let _ = conn.send(req, &mut reader).unwrap();
assert!(
conn.stream()
.written_str()
.starts_with(&format!("{expected} /test HTTP/1.1\r\n"))
);
}
}
#[test]
fn default_headers_included() {
let resp = ok_response("{}");
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("api.example.com").unwrap();
writer.default_header("X-API-KEY", "secret123").unwrap();
writer
.default_header("Content-Type", "application/json")
.unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let _ = conn.send(req, &mut reader).unwrap();
let written = conn.stream().written_str();
assert!(written.contains("X-API-KEY: secret123\r\n"));
assert!(written.contains("Content-Type: application/json\r\n"));
}
#[test]
fn extra_headers() {
let resp = ok_response("{}");
let mock = MockStream::new(&resp);
let mut writer = RequestWriter::new("api.example.com").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer
.get("/test")
.header("X-Custom", "value1")
.header("Authorization", "Bearer tok")
.finish()
.unwrap();
let _ = conn.send(req, &mut reader).unwrap();
let written = conn.stream().written_str();
assert!(written.contains("X-Custom: value1\r\n"));
assert!(written.contains("Authorization: Bearer tok\r\n"));
}
#[test]
fn query_params_encoded() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get("/orders")
.query("symbol", "BTC-USD")
.query("limit", "100")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /orders?symbol=BTC-USD&limit=100 HTTP/1.1\r\n"));
}
#[test]
fn query_encodes_special_chars() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get("/search")
.query("q", "hello world&more=yes")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /search?q=hello%20world%26more%3Dyes HTTP/1.1\r\n"));
}
#[test]
fn query_raw_no_encoding() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get("/orders")
.query_raw("symbol", "BTC-USD")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /orders?symbol=BTC-USD HTTP/1.1\r\n"));
}
#[test]
fn query_then_header() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get("/orders")
.query("sym", "ETH")
.header("X-Nonce", "123")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /orders?sym=ETH HTTP/1.1\r\n"));
assert!(data.contains("X-Nonce: 123\r\n"));
}
#[test]
fn path_with_existing_query() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get("/path?existing=true")
.query("extra", "val")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /path?existing=true&extra=val HTTP/1.1\r\n"));
}
#[test]
fn base_path_prepended() {
let mut writer = RequestWriter::new("host").unwrap();
writer.set_base_path("/api/v3").unwrap();
let req = writer.get("/orders").finish().unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /api/v3/orders HTTP/1.1\r\n"));
}
#[test]
fn get_raw_skips_query_phase() {
let mut writer = RequestWriter::new("host").unwrap();
let req = writer
.get_raw("/orders?symbol=BTC&limit=100")
.finish()
.unwrap();
let data = std::str::from_utf8(req.as_bytes()).unwrap();
assert!(data.starts_with("GET /orders?symbol=BTC&limit=100 HTTP/1.1\r\n"));
}
#[test]
fn crlf_in_header_rejected() {
let mut writer = RequestWriter::new("host").unwrap();
let result = writer.get("/test").header("X-Bad\r\n", "val").finish();
assert!(matches!(result, Err(RestError::CrlfInjection)));
}
#[test]
fn crlf_in_path_rejected() {
let mut writer = RequestWriter::new("host").unwrap();
let result = writer.get("/path\r\nEvil: yes").finish();
assert!(matches!(result, Err(RestError::CrlfInjection)));
}
#[test]
fn crlf_in_default_header_rejected() {
let mut writer = RequestWriter::new("host").unwrap();
assert!(matches!(
writer.default_header("X-Bad\n", "val"),
Err(RestError::CrlfInjection)
));
}
#[test]
fn crlf_in_query_raw_rejected() {
let mut writer = RequestWriter::new("host").unwrap();
let result = writer.get("/test").query_raw("k", "v\r\n").finish();
assert!(matches!(result, Err(RestError::CrlfInjection)));
}
#[test]
fn crlf_in_host_rejected() {
assert!(matches!(
RequestWriter::new("evil.com\r\nX-Injected: yes"),
Err(RestError::CrlfInjection)
));
}
#[test]
fn response_headers_accessible() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nX-Request-Id: abc123\r\nX-RateLimit-Remaining: 42\r\nContent-Length: 2\r\n\r\n{}";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.header("X-Request-Id"), Some("abc123"));
assert_eq!(resp.header("X-RateLimit-Remaining"), Some("42"));
}
#[test]
fn chunked_encoding_decoded() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nMozilla\r\n11\r\nDeveloper Network\r\n0\r\n\r\n";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(resp.body_str().unwrap(), "MozillaDeveloper Network");
}
#[test]
fn chunked_empty_body() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.body().len(), 0);
}
#[test]
fn chunked_json_response() {
let body = r#"{"orderId":12345,"status":"FILLED"}"#;
let chunked = format!(
"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n{:x}\r\n{}\r\n0\r\n\r\n",
body.len(),
body
);
let mock = MockStream::new(chunked.as_bytes());
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.body_str().unwrap(), body);
}
#[test]
fn malformed_content_length_rejected() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: abc\r\n\r\nbody";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let result = conn.send(req, &mut reader);
assert!(matches!(result, Err(RestError::Http(_))));
}
#[test]
fn body_too_large_rejected() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 999999\r\n\r\n";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096).max_body_size(32 * 1024);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let result = conn.send(req, &mut reader);
assert!(matches!(
result,
Err(RestError::BodyTooLarge { size: 999_999, .. })
));
}
#[test]
fn status_204_no_body() {
let resp_bytes = b"HTTP/1.1 204 No Content\r\nContent-Length: 5\r\n\r\nxxxxx";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.status(), 204);
assert_eq!(resp.body().len(), 0);
}
#[test]
fn connection_poisoned_after_io_error() {
let resp_bytes = b"HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\npartial";
let mock = MockStream::new(resp_bytes);
let mut writer = RequestWriter::new("host").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(mock);
let req = writer.get("/test").finish().unwrap();
let result = conn.send(req, &mut reader);
assert!(matches!(result, Err(RestError::ConnectionClosed(_))));
let req = writer.get("/test2").finish().unwrap();
let result = conn.send(req, &mut reader);
assert!(matches!(result, Err(RestError::ConnectionPoisoned)));
}
#[test]
fn url_parsing() {
let parsed = parse_base_url("https://api.binance.com").unwrap();
assert!(parsed.tls);
assert_eq!(parsed.host, "api.binance.com");
assert_eq!(parsed.port, 443);
assert_eq!(parsed.path, "");
let parsed = parse_base_url("http://localhost:8080").unwrap();
assert!(!parsed.tls);
assert_eq!(parsed.host, "localhost");
assert_eq!(parsed.port, 8080);
let parsed = parse_base_url("https://api.example.com/v1/foo").unwrap();
assert_eq!(parsed.path, "/v1/foo");
assert!(parse_base_url("ftp://host").is_err());
assert!(parse_base_url("http://").is_err());
}
#[test]
fn ipv6_url_parsing() {
let parsed = parse_base_url("http://[::1]:8080").unwrap();
assert_eq!(parsed.host, "::1");
assert_eq!(parsed.port, 8080);
let parsed = parse_base_url("http://[::1]").unwrap();
assert_eq!(parsed.host, "::1");
assert_eq!(parsed.port, 80);
assert!(parse_base_url("http://[::1").is_err());
}
#[test]
fn keep_alive_sequential_requests() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let server = std::thread::spawn(move || {
let (mut tcp, _) = listener.accept().unwrap();
let mut buf = [0u8; 4096];
let n = tcp.read(&mut buf).unwrap();
assert!(
std::str::from_utf8(&buf[..n])
.unwrap()
.contains("GET /first")
);
let body1 = r#"{"id":1}"#;
let resp1 = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body1.len(),
body1
);
tcp.write_all(resp1.as_bytes()).unwrap();
let n = tcp.read(&mut buf).unwrap();
assert!(
std::str::from_utf8(&buf[..n])
.unwrap()
.contains("GET /second")
);
let body2 = r#"{"id":2}"#;
let resp2 = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body2.len(),
body2
);
tcp.write_all(resp2.as_bytes()).unwrap();
});
let tcp = TcpStream::connect(addr).unwrap();
let mut writer = RequestWriter::new("localhost").unwrap();
let mut reader = ResponseReader::new(4096);
let mut conn = Client::new(tcp);
let req = writer.get("/first").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.body_str().unwrap(), r#"{"id":1}"#);
drop(resp);
let req = writer.get("/second").finish().unwrap();
let resp = conn.send(req, &mut reader).unwrap();
assert_eq!(resp.body_str().unwrap(), r#"{"id":2}"#);
server.join().unwrap();
}
#[test]
fn method_display() {
use super::super::request::Method;
assert_eq!(format!("{}", Method::Get), "GET");
assert_eq!(format!("{}", Method::Post), "POST");
assert_eq!(format!("{}", Method::Delete), "DELETE");
}
}