use anyhow::{Context, Result};
#[cfg(feature = "openssl-tls")]
use base64::{Engine as _, engine::general_purpose};
#[cfg(feature = "openssl-tls")]
use ed25519_dalek::Signer as Ed25519Signer;
#[cfg(feature = "openssl-tls")]
use ed25519_dalek::SigningKey;
#[cfg(feature = "openssl-tls")]
use ed25519_dalek::pkcs8::DecodePrivateKey;
use flate2::read::GzDecoder;
use hex;
use hmac::{Hmac, Mac};
use http::HeaderMap;
use http::HeaderValue;
use http::header::ACCEPT_ENCODING;
use once_cell::sync::OnceCell;
#[cfg(feature = "openssl-tls")]
use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer as OpenSslSigner};
use rand::{RngCore, rngs::OsRng};
use regex::Captures;
use regex::Regex;
use reqwest::Client;
use reqwest::Proxy;
use reqwest::{Method, Request};
use serde::de::DeserializeOwned;
use serde_json::Number;
use serde_json::{Value, json};
use sha2::Sha256;
use std::fmt;
use std::fmt::Display;
use std::hash::BuildHasher;
use std::sync::LazyLock;
use std::{
collections::BTreeMap,
collections::HashMap,
io::Read,
time::Duration,
time::{SystemTime, UNIX_EPOCH},
};
#[cfg(feature = "openssl-tls")]
use std::{fs, path::Path};
use tokio::time::sleep;
use url::form_urlencoded;
use url::{Url, form_urlencoded::Serializer};
use super::config::{
ConfigurationRestApi, ConfigurationWebsocketApi, HttpAgent, PrivateKey, ProxyConfig,
};
use super::errors::ConnectorError;
use super::models::{
Interval, RateLimitType, RestApiRateLimit, RestApiResponse, StreamId, TimeUnit,
};
use super::websocket::WebsocketMessageSendOptions;
pub(crate) static ID_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
static PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(@)?<([^>]+)>").unwrap());
#[derive(Default, Clone)]
#[allow(dead_code)]
pub struct SignatureGenerator {
api_secret: Option<String>,
private_key: Option<PrivateKey>,
private_key_passphrase: Option<String>,
raw_key_data: OnceCell<String>,
#[cfg(feature = "openssl-tls")]
key_object: OnceCell<PKey<openssl::pkey::Private>>,
#[cfg(feature = "openssl-tls")]
ed25519_signing_key: OnceCell<SigningKey>,
}
impl fmt::Debug for SignatureGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SignatureGenerator")
.field(
"api_secret",
&self.api_secret.as_ref().map(|_| "[REDACTED]"),
)
.field(
"private_key",
&self.private_key.as_ref().map(|_| "[REDACTED]"),
)
.field(
"private_key_passphrase",
&self.private_key_passphrase.as_ref().map(|_| "[REDACTED]"),
)
.field(
"raw_key_data",
&self.raw_key_data.get().map(|_| "[REDACTED]"),
)
.field("key_object", &"[REDACTED]")
.field("ed25519_signing_key", &"[REDACTED]")
.finish()
}
}
impl SignatureGenerator {
#[must_use]
pub fn new(
api_secret: Option<String>,
private_key: Option<PrivateKey>,
private_key_passphrase: Option<String>,
) -> Self {
SignatureGenerator {
api_secret,
private_key,
private_key_passphrase,
raw_key_data: OnceCell::new(),
#[cfg(feature = "openssl-tls")]
key_object: OnceCell::new(),
#[cfg(feature = "openssl-tls")]
ed25519_signing_key: OnceCell::new(),
}
}
#[cfg(feature = "openssl-tls")]
fn get_raw_key_data(&self) -> Result<&String> {
self.raw_key_data.get_or_try_init(|| {
let pk = self
.private_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("No private_key provided"))?;
match pk {
PrivateKey::File(path) => {
if Path::new(path).exists() {
fs::read_to_string(path)
.with_context(|| format!("Failed to read private key file: {path}"))
} else {
Err(anyhow::anyhow!("Private key file does not exist: {}", path))
}
}
PrivateKey::Raw(bytes) => Ok(String::from_utf8_lossy(bytes).to_string()),
}
})
}
#[cfg(feature = "openssl-tls")]
fn get_key_object(&self) -> Result<&PKey<openssl::pkey::Private>> {
self.key_object.get_or_try_init(|| {
let key_data = self.get_raw_key_data()?;
if let Some(pass) = self.private_key_passphrase.as_ref() {
PKey::private_key_from_pem_passphrase(key_data.as_bytes(), pass.as_bytes())
.context("Failed to parse private key with passphrase")
} else {
PKey::private_key_from_pem(key_data.as_bytes())
.context("Failed to parse private key")
}
})
}
#[cfg(feature = "openssl-tls")]
fn get_ed25519_signing_key(
&self,
key_obj: &PKey<openssl::pkey::Private>,
) -> Result<&SigningKey> {
self.ed25519_signing_key.get_or_try_init(|| {
let der = key_obj
.private_key_to_der()
.context("Failed to export Ed25519 key to DER")?;
SigningKey::from_pkcs8_der(&der)
.map_err(|e| anyhow::anyhow!("Failed to parse Ed25519 key: {}", e))
})
}
pub fn get_signature(
&self,
query_params: &BTreeMap<String, Value>,
body_params: Option<&BTreeMap<String, Value>>,
) -> Result<String> {
let query_str = build_query_string(query_params)?;
let params = if let Some(body) = body_params {
if body.is_empty() {
query_str
} else {
let body_str = build_query_string(body)?;
format!("{query_str}{body_str}")
}
} else {
query_str
};
if self.private_key.is_none() {
if let Some(secret) = self.api_secret.as_ref() {
let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
.context("HMAC key initialization failed")?;
mac.update(params.as_bytes());
let result = mac.finalize().into_bytes();
return Ok(hex::encode(result));
}
}
if self.private_key.is_some() {
#[cfg(feature = "openssl-tls")]
{
let key_obj = self.get_key_object()?;
match key_obj.id() {
openssl::pkey::Id::RSA => {
let mut signer = OpenSslSigner::new(MessageDigest::sha256(), key_obj)
.context("Failed to create RSA signer")?;
signer
.update(params.as_bytes())
.context("Failed to update RSA signer")?;
let sig = signer.sign_to_vec().context("RSA signing failed")?;
return Ok(general_purpose::STANDARD.encode(sig));
}
openssl::pkey::Id::ED25519 => {
let signing_key = self.get_ed25519_signing_key(key_obj)?;
let signature = signing_key.sign(params.as_bytes());
return Ok(general_purpose::STANDARD.encode(signature.to_bytes()));
}
other => {
return Err(anyhow::anyhow!(
"Unsupported private key type: {:?}. Must be RSA or ED25519.",
other
));
}
}
}
#[cfg(not(feature = "openssl-tls"))]
{
return Err(anyhow::anyhow!(
"Private key signing requires the 'openssl-tls' feature to be enabled."
));
}
}
Err(anyhow::anyhow!(
"Either 'api_secret' or 'private_key' must be provided for signed requests."
))
}
}
#[must_use]
pub fn build_client(
timeout: u64,
keep_alive: bool,
proxy: Option<&ProxyConfig>,
agent: Option<HttpAgent>,
) -> Client {
let builder = Client::builder().timeout(Duration::from_millis(timeout));
let mut builder = if keep_alive {
builder
} else {
builder.pool_idle_timeout(Some(Duration::from_secs(0)))
};
if let Some(proxy_conf) = proxy {
let protocol = proxy_conf
.protocol
.clone()
.unwrap_or_else(|| "http".to_string());
let proxy_url = format!("{}://{}:{}", protocol, proxy_conf.host, proxy_conf.port);
let mut proxy_builder = Proxy::all(&proxy_url).expect("Failed to create proxy from URL");
if let Some(auth) = &proxy_conf.auth {
proxy_builder = proxy_builder.basic_auth(&auth.username, &auth.password);
}
builder = builder.proxy(proxy_builder);
}
if let Some(HttpAgent(agent_fn)) = agent {
builder = (agent_fn)(builder);
}
builder.build().expect("Failed to build reqwest client")
}
#[must_use]
pub fn build_user_agent(product: &str) -> String {
format!(
"{}/{}/{} (Rust/{}; {}; {})",
env!("CARGO_PKG_NAME"),
product,
env!("CARGO_PKG_VERSION"),
env!("RUSTC_VERSION"),
std::env::consts::OS,
std::env::consts::ARCH,
)
}
pub fn validate_time_unit(time_unit: &str) -> Result<Option<&str>, anyhow::Error> {
match time_unit {
"" => Ok(None),
"MILLISECOND" | "MICROSECOND" | "millisecond" | "microsecond" => Ok(Some(time_unit)),
_ => Err(anyhow::anyhow!(
"time_unit must be either 'MILLISECOND' or 'MICROSECOND'"
)),
}
}
#[must_use]
pub fn get_timestamp() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_millis()
}
pub async fn delay(ms: u64) {
sleep(Duration::from_millis(ms)).await;
}
pub fn build_query_string(params: &BTreeMap<String, Value>) -> Result<String, anyhow::Error> {
let mut segments = Vec::with_capacity(params.len());
for (key, value) in params {
if value.is_null() {
continue;
}
let value_str = match value {
Value::String(s) => s.clone(),
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::Array(_) | Value::Object(_) => serde_json::to_string(value)
.with_context(|| format!("failed to JSON-serialize `{}`", key))?,
Value::Null => unreachable!(),
};
let mut ser = Serializer::new(String::new());
ser.append_pair(key, &value_str);
segments.push(ser.finish());
}
Ok(segments.join("&"))
}
#[must_use]
pub fn should_retry_request(
error: &reqwest::Error,
method: Option<&str>,
retries_left: Option<usize>,
) -> bool {
let method = method.unwrap_or("");
let is_retriable_method =
method.eq_ignore_ascii_case("GET") || method.eq_ignore_ascii_case("DELETE");
let status = error.status().map_or(0, |s| s.as_u16());
let is_retriable_status = [500, 502, 503, 504].contains(&status);
let retries_left = retries_left.unwrap_or(0);
retries_left > 0 && is_retriable_method && (is_retriable_status || error.status().is_none())
}
#[must_use]
pub fn parse_rate_limit_headers<S>(headers: &HashMap<String, String, S>) -> Vec<RestApiRateLimit>
where
S: BuildHasher,
{
let mut rate_limits = Vec::new();
let re = Regex::new(r"x-mbx-(used-weight|order-count)-(\d+)([smhd])").unwrap();
for (key, value) in headers {
let normalized_key = key.to_lowercase();
if normalized_key.starts_with("x-mbx-used-weight-")
|| normalized_key.starts_with("x-mbx-order-count-")
{
if let Some(caps) = re.captures(&normalized_key) {
let interval_num: u32 = caps.get(2).unwrap().as_str().parse().unwrap_or(0);
let interval_letter = caps.get(3).unwrap().as_str().to_uppercase();
let interval = match interval_letter.as_str() {
"S" => Interval::Second,
"M" => Interval::Minute,
"H" => Interval::Hour,
"D" => Interval::Day,
_ => continue,
};
let count: u32 = value.parse().unwrap_or(0);
let rate_limit_type = if normalized_key.starts_with("x-mbx-used-weight-") {
RateLimitType::RequestWeight
} else {
RateLimitType::Orders
};
rate_limits.push(RestApiRateLimit {
rate_limit_type,
interval,
interval_num,
count,
retry_after: headers.get("retry-after").and_then(|v| v.parse().ok()),
});
}
}
}
rate_limits
}
pub async fn http_request<T: DeserializeOwned + Send + 'static>(
req: Request,
configuration: &ConfigurationRestApi,
) -> Result<RestApiResponse<T>, ConnectorError> {
let client = &configuration.client;
let retries = configuration.retries as usize;
let backoff = configuration.backoff;
let mut attempt = 0;
loop {
let req_clone = req
.try_clone()
.context("Failed to clone request")
.map_err(|e| ConnectorError::ConnectorClientError {
msg: e.to_string(),
code: None,
})?;
match client.execute(req_clone).await {
Ok(response) => {
let status = response.status();
let headers_map: HashMap<String, String> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let raw_bytes = match response.bytes().await {
Ok(b) => b,
Err(e) => {
attempt += 1;
if attempt <= retries {
continue;
}
return Err(ConnectorError::ConnectorClientError {
msg: format!("Failed to get response bytes: {e}"),
code: None,
});
}
};
let content = if headers_map
.get("content-encoding")
.is_some_and(|enc| enc.to_lowercase().contains("gzip"))
{
let mut decoder = GzDecoder::new(&raw_bytes[..]);
let mut decompressed = String::new();
decoder
.read_to_string(&mut decompressed)
.context("Failed to decompress gzip response")
.map_err(|e: anyhow::Error| ConnectorError::ConnectorClientError {
msg: e.to_string(),
code: None,
})?;
decompressed
} else {
String::from_utf8(raw_bytes.to_vec())
.context("Failed to convert response to UTF-8")
.map_err(|e| ConnectorError::ConnectorClientError {
msg: e.to_string(),
code: None,
})?
};
let rate_limits = parse_rate_limit_headers(&headers_map);
if status.is_client_error() || status.is_server_error() {
let mut err_msg = content.clone();
let mut err_code: Option<i64> = None;
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&content) {
if let Some(m) = v.get("msg").and_then(|m| m.as_str()) {
err_msg = m.to_string();
}
err_code = v.get("code").and_then(serde_json::Value::as_i64);
}
match status.as_u16() {
400 => {
return Err(ConnectorError::BadRequestError {
msg: err_msg,
code: err_code,
});
}
401 => {
return Err(ConnectorError::UnauthorizedError {
msg: err_msg,
code: err_code,
});
}
403 => {
return Err(ConnectorError::ForbiddenError {
msg: err_msg,
code: err_code,
});
}
404 => {
return Err(ConnectorError::NotFoundError {
msg: err_msg,
code: err_code,
});
}
418 => {
return Err(ConnectorError::RateLimitBanError {
msg: err_msg,
code: err_code,
});
}
429 => {
return Err(ConnectorError::TooManyRequestsError {
msg: err_msg,
code: err_code,
});
}
s if (500..600).contains(&s) => {
return Err(ConnectorError::ServerError {
msg: format!("Server error: {s}"),
status_code: Some(s),
});
}
_ => {
return Err(ConnectorError::ConnectorClientError {
msg: err_msg,
code: err_code,
});
}
}
}
let raw = content.clone();
return Ok(RestApiResponse {
data_fn: Box::new(move || {
Box::pin(async move {
let parsed: T = serde_json::from_str(&raw).map_err(|e| {
ConnectorError::ConnectorClientError {
msg: e.to_string(),
code: None,
}
})?;
Ok(parsed)
})
}),
status: status.as_u16(),
headers: headers_map,
rate_limits: if rate_limits.is_empty() {
None
} else {
Some(rate_limits)
},
});
}
Err(e) => {
attempt += 1;
if should_retry_request(&e, Some(req.method().as_str()), Some(retries - attempt)) {
delay(backoff * attempt as u64).await;
continue;
}
return Err(ConnectorError::ConnectorClientError {
msg: format!("HTTP request failed: {e}"),
code: None,
});
}
}
}
}
pub async fn send_request<T: DeserializeOwned + Send + 'static>(
configuration: &ConfigurationRestApi,
endpoint: &str,
method: Method,
mut query_params: BTreeMap<String, Value>,
body_params: BTreeMap<String, Value>,
time_unit: Option<TimeUnit>,
is_signed: bool,
) -> anyhow::Result<RestApiResponse<T>> {
let base = configuration.base_path.as_deref().unwrap_or("");
let full_url = reqwest::Url::parse(base)
.and_then(|u| u.join(endpoint))
.context("Failed to join base URL and endpoint")?
.to_string();
if is_signed {
let timestamp = get_timestamp();
query_params.insert("timestamp".to_string(), json!(timestamp));
}
let signature = if is_signed {
let body_ref = if body_params.is_empty() {
None
} else {
Some(&body_params)
};
Some(
configuration
.signature_gen
.get_signature(&query_params, body_ref)?,
)
} else {
None
};
let mut url = Url::parse(&full_url)?;
{
let mut pairs = url.query_pairs_mut();
for (key, value) in &query_params {
let val_str = match value {
Value::String(s) => s.clone(),
_ => value.to_string(),
};
pairs.append_pair(key, &val_str);
}
if let Some(signature) = &signature {
pairs.append_pair("signature", signature);
}
}
let mut headers = HeaderMap::new();
let forbidden = ["host", "authorization", "cookie", ":method", ":path"]
.into_iter()
.map(str::to_ascii_lowercase)
.collect::<std::collections::HashSet<_>>();
if let Some(custom) = &configuration.custom_headers {
for (raw_name, raw_val) in custom {
let name = raw_name.trim();
if forbidden.contains(&name.to_ascii_lowercase()) {
continue;
}
if let (Ok(header_name), Ok(header_val)) = (
name.parse::<reqwest::header::HeaderName>(),
HeaderValue::from_str(raw_val),
) {
headers.append(header_name, header_val);
}
}
}
if body_params.is_empty() {
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
} else {
headers.insert(
"Content-Type",
HeaderValue::from_static("application/x-www-form-urlencoded"),
);
}
headers.insert("User-Agent", configuration.user_agent.parse().unwrap());
if let Some(api_key) = &configuration.api_key {
headers.insert("X-MBX-APIKEY", HeaderValue::from_str(api_key)?);
}
if configuration.compression {
headers.insert(ACCEPT_ENCODING, "gzip, deflate, br".parse().unwrap());
}
let time_unit_to_apply = time_unit.or(configuration.time_unit);
if let Some(time_unit) = time_unit_to_apply {
headers.insert("X-MBX-TIME-UNIT", time_unit.as_upper_str().parse()?);
}
let mut req_builder = configuration.client.request(method, url).headers(headers);
if !body_params.is_empty() {
let mut serializer = form_urlencoded::Serializer::new(String::new());
for (key, value) in body_params {
let val_str = match value {
Value::String(s) => s,
_ => value.to_string(),
};
serializer.append_pair(&key, &val_str);
}
let body_str = serializer.finish();
req_builder = req_builder.body(body_str);
}
let req = req_builder.build()?;
Ok(http_request::<T>(req, configuration).await?)
}
#[must_use]
pub fn random_string() -> String {
let mut buf = [0u8; 16];
rand::thread_rng().fill_bytes(&mut buf);
hex::encode(buf)
}
#[must_use]
pub fn random_integer() -> u32 {
let mut buf = [0u8; 4];
OsRng.fill_bytes(&mut buf);
u32::from_ne_bytes(buf)
}
#[must_use]
pub fn normalize_stream_id(id: Option<StreamId>, stream_id_is_strictly_number: bool) -> Value {
if stream_id_is_strictly_number {
let n = match id {
Some(StreamId::Number(n)) => n,
_ => random_integer(),
};
return Value::Number(Number::from(n));
}
match id {
Some(StreamId::Number(n)) => Value::Number(Number::from(n)),
Some(StreamId::Str(s)) => {
let out = if ID_REGEX.is_match(&s) {
s
} else {
random_string()
};
Value::String(out)
}
None => Value::String(random_string()),
}
}
pub fn remove_empty_value<I>(entries: I) -> BTreeMap<String, Value>
where
I: IntoIterator<Item = (String, Value)>,
{
entries
.into_iter()
.filter(|(_, value)| match value {
Value::Null => false,
Value::String(s) if s.is_empty() => false,
_ => true,
})
.collect()
}
#[must_use]
pub fn sort_object_params(params: &BTreeMap<String, Value>) -> BTreeMap<String, Value> {
let mut sorted = BTreeMap::new();
for (k, v) in params {
sorted.insert(k.clone(), v.clone());
}
sorted
}
fn normalize_ws_streams_key(key: &str) -> String {
key.to_lowercase().replace(&['_', '-'][..], "")
}
pub fn replace_websocket_streams_placeholders<V, S>(
input: &str,
variables: &HashMap<&str, V, S>,
) -> String
where
V: Display,
S: BuildHasher,
{
let original = input;
let body = original.strip_prefix('/').unwrap_or(original);
let normalized: HashMap<String, String> = variables
.iter()
.map(|(k, v)| (normalize_ws_streams_key(k), v.to_string()))
.collect();
let replaced = PLACEHOLDER_RE
.replace_all(body, |caps: &Captures| {
let prefix = caps.get(1).map_or("", |m| m.as_str());
let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
let val = normalized.get(&key).cloned().unwrap_or_default();
format!("{prefix}{val}")
})
.into_owned();
let stripped = replaced.trim_end_matches('@').to_string();
let should_lower_head =
original.starts_with('/') && PLACEHOLDER_RE.find(body).is_some_and(|m| m.start() == 0);
if should_lower_head {
if let Some(caps) = PLACEHOLDER_RE.captures(body) {
let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
let first_val = normalized.get(&key).cloned().unwrap_or_default();
if stripped.starts_with(&first_val) {
let tail = &stripped[first_val.len()..];
format!("{}{}", first_val.to_lowercase(), tail)
} else {
stripped.clone()
}
} else {
stripped.clone()
}
} else {
stripped.clone()
}
}
pub fn build_websocket_api_message(
configuration: &ConfigurationWebsocketApi,
method: &str,
mut payload: BTreeMap<String, Value>,
options: &WebsocketMessageSendOptions,
skip_auth: bool,
) -> (String, serde_json::Value) {
let id = payload
.get("id")
.and_then(Value::as_str)
.filter(|s| ID_REGEX.is_match(s))
.map_or_else(random_string, String::from);
payload.remove("id");
let mut params = remove_empty_value(payload);
if (options.with_api_key || options.is_signed) && !skip_auth {
params.insert(
"apiKey".into(),
Value::String(configuration.api_key.clone().expect("API key must be set")),
);
}
if options.is_signed {
let ts = get_timestamp();
let ts_i64 = i64::try_from(ts).expect("timestamp fits in i64");
params.insert("timestamp".into(), Value::Number(ts_i64.into()));
let mut sorted = sort_object_params(¶ms);
if !skip_auth {
let sig = configuration
.signature_gen
.get_signature(&sorted, None)
.expect("signature generation");
sorted.insert("signature".into(), Value::String(sig));
}
params = sorted.into_iter().collect();
}
let request = json!({
"id": id,
"method": method,
"params": params,
});
(id, request)
}
#[cfg(test)]
mod tests {
use crate::TOKIO_SHARED_RT;
mod build_client {
use std::{
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use reqwest::ClientBuilder;
use crate::{
common::utils::build_client,
config::{HttpAgent, ProxyAuth, ProxyConfig},
};
use super::TOKIO_SHARED_RT;
#[test]
fn enforces_timeout() {
TOKIO_SHARED_RT.block_on(async {
let client = build_client(100, true, None, None);
let start = Instant::now();
let res = client.get("http://10.255.255.1").send().await;
assert!(
res.is_err(),
"expected an error (timeout or connect) but got {res:?}"
);
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_millis(500),
"timed out too slowly: {elapsed:?}"
);
});
}
#[test]
fn builds_with_keep_alive_disabled() {
let client = build_client(200, false, None, None);
let _: reqwest::Client = client;
}
#[test]
#[should_panic(expected = "Failed to create proxy from URL")]
fn invalid_proxy_url_panics() {
let bad_proxy = ProxyConfig {
protocol: Some("http".to_string()),
host: String::new(),
port: 8080,
auth: None,
};
let _ = build_client(1_000, true, Some(&bad_proxy), None);
}
#[test]
fn builds_with_proxy_and_auth() {
let proxy = ProxyConfig {
protocol: Some("https".to_string()),
host: "127.0.0.1".to_string(),
port: 3128,
auth: Some(ProxyAuth {
username: "alice".to_string(),
password: "secret".to_string(),
}),
};
let client = build_client(2_000, true, Some(&proxy), None);
let _: reqwest::Client = client;
}
#[test]
fn custom_agent_invoked() {
let called = Arc::new(Mutex::new(false));
let called_clone = Arc::clone(&called);
let agent = HttpAgent(Arc::new(move |builder: ClientBuilder| {
*called_clone.lock().unwrap() = true;
builder
}));
let client = build_client(1_000, true, None, Some(agent));
assert!(*called.lock().unwrap(), "agent closure wasn’t invoked");
let _: reqwest::Client = client;
}
}
mod build_user_agent {
use crate::common::utils::build_user_agent;
#[test]
fn build_user_agent_contains_crate_product_and_rust_info() {
let product = "product";
let user_agent = build_user_agent(product);
let name = env!("CARGO_PKG_NAME");
let version = env!("CARGO_PKG_VERSION");
let rustc = env!("RUSTC_VERSION");
let os = std::env::consts::OS;
let arch = std::env::consts::ARCH;
let expected_prefix = format!("{name}/{product}/{version} (Rust/");
assert!(
user_agent.starts_with(&expected_prefix),
"prefix mismatch: {user_agent}"
);
assert!(
user_agent.contains(rustc),
"user agent missing RUSTC_VERSION: {user_agent}"
);
assert!(
user_agent.contains(&format!("; {os}")),
"user agent missing OS: {user_agent}"
);
assert!(
user_agent.contains(&format!("; {arch}")),
"user agent missing ARCH: {user_agent}"
);
}
#[test]
fn build_user_agent_is_deterministic() {
let product = "product";
let user_agent1 = build_user_agent(product);
let user_agent2 = build_user_agent(product);
assert_eq!(
user_agent1, user_agent2,
"user agent should be the same on repeated calls"
);
}
}
mod validate_time_unit {
use crate::common::utils::validate_time_unit;
#[test]
fn empty_string_returns_none() {
let res = validate_time_unit("").expect("Should not error on empty string");
assert_eq!(res, None);
}
#[test]
fn uppercase_millisecond() {
let res = validate_time_unit("MILLISECOND").expect("Should accept MILLISECOND");
assert_eq!(res, Some("MILLISECOND"));
}
#[test]
fn uppercase_microsecond() {
let res = validate_time_unit("MICROSECOND").expect("Should accept MICROSECOND");
assert_eq!(res, Some("MICROSECOND"));
}
#[test]
fn lowercase_millisecond() {
let res = validate_time_unit("millisecond").expect("Should accept millisecond");
assert_eq!(res, Some("millisecond"));
}
#[test]
fn lowercase_microsecond() {
let res = validate_time_unit("microsecond").expect("Should accept microsecond");
assert_eq!(res, Some("microsecond"));
}
#[test]
fn invalid_value_returns_err() {
let err = validate_time_unit("SECOND").unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
}
#[test]
fn partial_match_returns_err() {
let err = validate_time_unit("MILLI").unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
}
}
mod get_timestamp {
use crate::common::utils::get_timestamp;
use std::{
thread::sleep,
time::{Duration, SystemTime, UNIX_EPOCH},
};
#[test]
fn timestamp_is_within_system_time_bounds() {
let before = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_millis();
let ts = get_timestamp();
let after = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX_EPOCH")
.as_millis();
assert!(
ts >= before,
"timestamp {ts} is before captured before time {before}"
);
assert!(
ts <= after,
"timestamp {ts} is after captured after time {after}"
);
}
#[test]
fn timestamps_are_monotonic() {
let t1 = get_timestamp();
sleep(Duration::from_millis(1));
let t2 = get_timestamp();
assert!(
t2 >= t1,
"second timestamp {t2} is not >= first timestamp {t1}"
);
}
}
mod build_query_string {
use std::collections::BTreeMap;
use anyhow::Result;
use serde_json::{Value, json};
use url::form_urlencoded::Serializer;
use crate::common::utils::build_query_string;
fn mk_map(pairs: Vec<(&str, Value)>) -> BTreeMap<String, Value> {
let mut m = BTreeMap::new();
for (k, v) in pairs {
m.insert(k.to_string(), v);
}
m
}
#[test]
fn empty_map_returns_empty_string() -> Result<()> {
let params = BTreeMap::new();
let qs = build_query_string(¶ms)?;
assert_eq!(qs, "");
Ok(())
}
#[test]
fn string_and_number_and_bool() -> Result<()> {
let params = mk_map(vec![
("foo", json!("bar")),
("num", json!(42)),
("flag", json!(true)),
]);
let qs = build_query_string(¶ms)?;
assert_eq!(qs, "flag=true&foo=bar&num=42");
Ok(())
}
#[test]
fn null_is_skipped() -> Result<()> {
let params = mk_map(vec![("a", json!(true)), ("b", Value::Null)]);
let qs = build_query_string(¶ms)?;
assert_eq!(qs, "a=true");
Ok(())
}
#[test]
fn percent_encode_special_chars() -> Result<()> {
let params = mk_map(vec![
("space", json!("hello world")),
("symbols", json!("a/b?c")),
]);
let qs = build_query_string(¶ms)?;
let mut parts = vec![];
let mut ser = Serializer::new(String::new());
ser.append_pair("space", "hello world");
parts.push(ser.finish());
let mut ser = Serializer::new(String::new());
ser.append_pair("symbols", "a/b?c");
parts.push(ser.finish());
let expected = parts.join("&");
assert_eq!(qs, expected);
Ok(())
}
#[test]
fn primitive_array_json_encoded() -> Result<()> {
let params = mk_map(vec![
("strs", json!(["a", "b", "c"])),
("nums", json!([1, 2, 3])),
("bools", json!([true, false])),
]);
let qs = build_query_string(¶ms)?;
let mut parts = Vec::new();
for (k, v) in ¶ms {
let json = serde_json::to_string(v)?;
let mut ser = Serializer::new(String::new());
ser.append_pair(k, &json);
parts.push(ser.finish());
}
let expected = parts.join("&");
assert_eq!(qs, expected);
Ok(())
}
#[test]
fn nested_array_json_encoded() -> Result<()> {
let params = mk_map(vec![("nested", json!([[1, 2], [3, 4]]))]);
let qs = build_query_string(¶ms)?;
let nested_json = serde_json::to_string(&json!([[1, 2], [3, 4]]))?;
let mut ser = Serializer::new(String::new());
ser.append_pair("nested", &nested_json);
let expected = ser.finish();
assert_eq!(qs, expected);
Ok(())
}
#[test]
fn object_json_encoded() -> Result<()> {
let params = mk_map(vec![("obj", json!({"k":1, "v":"two"}))]);
let qs = build_query_string(¶ms)?;
let obj_json = serde_json::to_string(&json!({"k":1, "v":"two"}))?;
let mut ser = Serializer::new(String::new());
ser.append_pair("obj", &obj_json);
let expected = ser.finish();
assert_eq!(qs, expected);
Ok(())
}
#[test]
fn empty_array() {
let params = mk_map(vec![("foo", json!([]))]);
let qs = build_query_string(¶ms).unwrap();
let json = serde_json::to_string(&json!([])).unwrap();
let expected = Serializer::new(String::new())
.append_pair("foo", &json)
.finish();
assert_eq!(qs, expected);
}
#[test]
fn mixed_array() {
let params = mk_map(vec![("mix", json!([1, "x", false]))]);
let qs = build_query_string(¶ms).unwrap();
let json = serde_json::to_string(&json!([1, "x", false])).unwrap();
let expected = Serializer::new(String::new())
.append_pair("mix", &json)
.finish();
assert_eq!(qs, expected);
}
#[test]
fn array_of_objects() {
let params = mk_map(vec![("objs", json!([{"a":1}, {"b":2}]))]);
let qs = build_query_string(¶ms).unwrap();
let json = serde_json::to_string(&json!([{"a":1}, {"b":2}])).unwrap();
let expected = Serializer::new(String::new())
.append_pair("objs", &json)
.finish();
assert_eq!(qs, expected);
}
#[test]
fn empty_object() {
let params = mk_map(vec![("emp", json!({}))]);
let qs = build_query_string(¶ms).unwrap();
let json = serde_json::to_string(&json!({})).unwrap();
let expected = Serializer::new(String::new())
.append_pair("emp", &json)
.finish();
assert_eq!(qs, expected);
}
#[test]
fn floats_and_negatives() {
let params = mk_map(vec![("fl", json!(1.23456)), ("neg", json!(-0.001))]);
let qs = build_query_string(¶ms).unwrap();
assert_eq!(qs, "fl=1.23456&neg=-0.001");
}
#[test]
fn unicode_and_special_key() {
let params = mk_map(vec![
("こんにちは", json!("世界")),
("weird key/?=", json!("val")),
]);
let qs = build_query_string(¶ms).unwrap();
let mut parts = Vec::new();
for (k, v) in ¶ms {
let mut ser = Serializer::new(String::new());
ser.append_pair(k, v.as_str().unwrap());
parts.push(ser.finish());
}
let expected = parts.join("&");
assert_eq!(qs, expected);
}
#[test]
fn empty_string_value() {
let params = mk_map(vec![("empty", json!(""))]);
let qs = build_query_string(¶ms).unwrap();
assert_eq!(qs, "empty=");
}
#[test]
fn nulls_in_array() {
let params = mk_map(vec![("a", json!([null, 1, "x"]))]);
let qs = build_query_string(¶ms).unwrap();
let json = serde_json::to_string(&json!([null, 1, "x"])).unwrap();
let expected = Serializer::new(String::new())
.append_pair("a", &json)
.finish();
assert_eq!(qs, expected);
}
#[test]
fn special_chars_in_key() {
let params = mk_map(vec![("a=b&c%", json!("val"))]);
let qs = build_query_string(¶ms).unwrap();
let expected = Serializer::new(String::new())
.append_pair("a=b&c%", "val")
.finish();
assert_eq!(qs, expected);
}
#[test]
fn empty_key() {
let params = mk_map(vec![("", json!("v"))]);
let qs = build_query_string(¶ms).unwrap();
assert_eq!(qs, "=v");
}
}
#[cfg(feature = "openssl-tls")]
mod signature_generator {
use base64::{Engine, engine::general_purpose};
use ed25519_dalek::{SigningKey, ed25519::signature::SignerMut, pkcs8::DecodePrivateKey};
use hex;
use hmac::{Hmac, Mac};
#[cfg(feature = "openssl-tls")]
use openssl::{hash::MessageDigest, pkey::PKey, rsa::Rsa, sign::Verifier};
use serde_json::Value;
use sha2::Sha256;
use std::collections::BTreeMap;
use std::io::Write;
use tempfile::NamedTempFile;
use crate::{common::utils::SignatureGenerator, config::PrivateKey};
#[test]
fn hmac_sha256_signature() {
let mut params = BTreeMap::new();
params.insert("b".into(), Value::Number(2.into()));
params.insert("a".into(), Value::Number(1.into()));
let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
let sig = signature_gen
.get_signature(¶ms, None)
.expect("HMAC signing failed");
let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
let qs = "a=1&b=2";
mac.update(qs.as_bytes());
let expected = hex::encode(mac.finalize().into_bytes());
assert_eq!(sig, expected);
}
#[test]
fn hmac_sha256_signature_with_body() {
let mut query_params = BTreeMap::new();
query_params.insert("b".into(), Value::Number(2.into()));
query_params.insert("a".into(), Value::Number(1.into()));
let mut body_params = BTreeMap::new();
body_params.insert("d".into(), Value::Number(4.into()));
body_params.insert("c".into(), Value::Number(3.into()));
let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
let sig = signature_gen
.get_signature(&query_params, Some(&body_params))
.expect("HMAC signing with body failed");
let query_str = "a=1&b=2";
let body_str = "c=3&d=4";
let payload = format!("{query_str}{body_str}");
let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
mac.update(payload.as_bytes());
let expected = hex::encode(mac.finalize().into_bytes());
assert_eq!(sig, expected);
}
#[test]
fn repeated_hmac_signature() {
let mut params = BTreeMap::new();
params.insert("x".into(), Value::String("y".into()));
let signature_gen = SignatureGenerator::new(Some("abc".into()), None, None);
let s1 = signature_gen.get_signature(¶ms, None).unwrap();
let s2 = signature_gen.get_signature(¶ms, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn rsa_signature_verification() {
let mut params = BTreeMap::new();
params.insert("a".into(), Value::Number(1.into()));
params.insert("b".into(), Value::Number(2.into()));
let rsa = Rsa::generate(2048).unwrap();
let priv_pem = rsa.private_key_to_pem().unwrap();
let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
let sig = signature_gen
.get_signature(¶ms, None)
.expect("RSA signing failed");
let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
verifier.update(b"a=1&b=2").unwrap();
assert!(verifier.verify(&sig_bytes).unwrap());
}
#[test]
fn rsa_signature_verification_with_body() {
let mut query_params = BTreeMap::new();
query_params.insert("a".into(), Value::Number(1.into()));
query_params.insert("b".into(), Value::Number(2.into()));
let mut body_params = BTreeMap::new();
body_params.insert("c".into(), Value::Number(3.into()));
body_params.insert("d".into(), Value::Number(4.into()));
let rsa = Rsa::generate(2048).unwrap();
let priv_pem = rsa.private_key_to_pem().unwrap();
let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
let sig = signature_gen
.get_signature(&query_params, Some(&body_params))
.expect("RSA signing with body failed");
let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
verifier.update(b"a=1&b=2c=3&d=4").unwrap();
assert!(verifier.verify(&sig_bytes).unwrap());
}
#[test]
fn repeated_rsa_signature() {
let mut params = BTreeMap::new();
params.insert("k".into(), Value::Number(5.into()));
let rsa = Rsa::generate(2048).unwrap();
let priv_pem = rsa.private_key_to_pem().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem)), None);
let s1 = signature_gen.get_signature(¶ms, None).unwrap();
let s2 = signature_gen.get_signature(¶ms, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn ed25519_signature_verification() {
let mut params = BTreeMap::new();
params.insert("a".into(), Value::Number(1.into()));
params.insert("b".into(), Value::Number(2.into()));
let qs = "a=1&b=2";
let ed = PKey::generate_ed25519().unwrap();
let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
let sig = signature_gen
.get_signature(¶ms, None)
.expect("Ed25519 signing failed");
let pem_str = String::from_utf8(priv_pem).unwrap();
let b64 = pem_str
.lines()
.filter(|l| !l.starts_with("-----"))
.collect::<String>();
let der = general_purpose::STANDARD.decode(b64).unwrap();
let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
let expected_bytes = sk.sign(qs.as_bytes()).to_bytes();
let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
assert_eq!(sig, expected_sig);
}
#[test]
fn ed25519_signature_verification_with_body() {
let mut query_params = BTreeMap::new();
query_params.insert("a".into(), Value::Number(1.into()));
query_params.insert("b".into(), Value::Number(2.into()));
let qs = "a=1&b=2";
let mut body_params = BTreeMap::new();
body_params.insert("c".into(), Value::Number(3.into()));
body_params.insert("d".into(), Value::Number(4.into()));
let body_qs = "c=3&d=4";
let ed = PKey::generate_ed25519().unwrap();
let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
let sig = signature_gen
.get_signature(&query_params, Some(&body_params))
.expect("Ed25519 signing with body failed");
let pem_str = String::from_utf8(priv_pem).unwrap();
let b64 = pem_str
.lines()
.filter(|l| !l.starts_with("-----"))
.collect::<String>();
let der = general_purpose::STANDARD.decode(b64).unwrap();
let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
let payload = format!("{qs}{body_qs}");
let expected_bytes = sk.sign(payload.as_bytes()).to_bytes();
let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
assert_eq!(sig, expected_sig);
}
#[test]
fn repeated_ed25519_signature() {
let mut params = BTreeMap::new();
params.insert("m".into(), Value::String("n".into()));
let ed = PKey::generate_ed25519().unwrap();
let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
let s1 = signature_gen.get_signature(¶ms, None).unwrap();
let s2 = signature_gen.get_signature(¶ms, None).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn file_based_key() {
let rsa = Rsa::generate(1024).unwrap();
let priv_pem = rsa.private_key_to_pem().unwrap();
let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
let mut file = NamedTempFile::new().unwrap();
file.write_all(&priv_pem).unwrap();
let path = file.path().to_str().unwrap().to_string();
let mut params = BTreeMap::new();
params.insert("z".into(), Value::Number(9.into()));
let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::File(path)), None);
let sig = signature_gen.get_signature(¶ms, None).unwrap();
let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
verifier.update(b"z=9").unwrap();
assert!(verifier.verify(&sig_bytes).unwrap());
}
#[cfg(feature = "openssl-tls")]
#[test]
fn unsupported_key_type_error() {
let mut params = BTreeMap::new();
params.insert("x".into(), Value::String("y".into()));
let group =
openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
let pkey_ec = PKey::from_ec_key(ec_key).unwrap();
let raw = pkey_ec.private_key_to_pem_pkcs8().unwrap();
let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::Raw(raw)), None);
let err = signature_gen
.get_signature(¶ms, None)
.unwrap_err()
.to_string();
assert!(err.contains("Unsupported private key type"));
}
#[test]
fn invalid_private_key_error() {
let mut params = BTreeMap::new();
params.insert("foo".into(), Value::String("bar".into()));
let signature_gen =
SignatureGenerator::new(None, Some(PrivateKey::Raw(b"not a key".to_vec())), None);
let err = signature_gen
.get_signature(¶ms, None)
.unwrap_err()
.to_string();
assert!(err.contains("Failed to parse private key"));
}
#[test]
fn missing_credentials_error() {
let mut params = BTreeMap::new();
params.insert("a".into(), Value::Number(1.into()));
let signature_gen = SignatureGenerator::new(None, None, None);
let err = signature_gen
.get_signature(¶ms, None)
.unwrap_err()
.to_string();
assert!(err.contains("Either 'api_secret' or 'private_key' must be provided"));
}
}
mod should_retry_request {
use crate::common::utils::should_retry_request;
use reqwest::{Error, Response};
fn mk_http_error(code: u16) -> Error {
let resp = Response::from(
http::response::Response::builder()
.status(code)
.body("")
.unwrap(),
);
resp.error_for_status().unwrap_err()
}
fn mk_network_error() -> Error {
reqwest::blocking::get("http://256.256.256.256").unwrap_err()
}
#[test]
fn retry_on_retriable_status_and_method() {
let err = mk_http_error(500);
assert!(should_retry_request(&err, Some("GET"), Some(1)));
assert!(should_retry_request(&err, Some("delete"), Some(2)));
}
#[test]
fn retry_when_status_none_and_retriable_method() {
let retriable_methods = ["GET", "DELETE"];
for &method in &retriable_methods {
let err = mk_network_error();
assert!(
should_retry_request(&err, Some(method), Some(1)),
"Should retry when no status and method {method}"
);
}
}
#[test]
fn no_retry_when_no_retries_left() {
let err = mk_http_error(503);
assert!(!should_retry_request(&err, Some("GET"), Some(0)));
}
#[test]
fn no_retry_on_non_retriable_status() {
let non_retriable_statuses = [400, 401, 404, 422];
for &status in &non_retriable_statuses {
let err = mk_http_error(status);
assert!(
!should_retry_request(&err, Some("GET"), Some(2)),
"Should not retry for non-retriable status {status}"
);
}
}
#[test]
fn no_retry_on_non_retriable_method() {
let non_retriable_methods = ["POST", "PUT", "PATCH"];
for &method in &non_retriable_methods {
let err = mk_http_error(500);
assert!(
!should_retry_request(&err, Some(method), Some(2)),
"Should not retry for non-retriable method {method}"
);
}
}
#[test]
fn no_retry_when_status_none_and_non_retriable_method() {
let non_retriable_methods = ["POST", "PUT"];
for &method in &non_retriable_methods {
let err = mk_network_error();
assert!(
!should_retry_request(&err, Some(method), Some(1)),
"Should not retry when no status and method {method}"
);
}
}
}
mod parse_rate_limit_headers_tests {
use crate::common::{
models::{Interval, RateLimitType},
utils::parse_rate_limit_headers,
};
use std::collections::HashMap;
fn mk_headers(pairs: Vec<(&str, &str)>) -> HashMap<String, String> {
let mut m = HashMap::new();
for (k, v) in pairs {
m.insert(k.to_string(), v.to_string());
}
m
}
#[test]
fn single_weight_header() {
let headers = mk_headers(vec![("x-mbx-used-weight-1s", "123")]);
let limits = parse_rate_limit_headers(&headers);
assert_eq!(limits.len(), 1);
let rl = &limits[0];
assert_eq!(rl.rate_limit_type, RateLimitType::RequestWeight);
assert_eq!(rl.interval, Interval::Second);
assert_eq!(rl.interval_num, 1);
assert_eq!(rl.count, 123);
assert_eq!(rl.retry_after, None);
}
#[test]
fn single_order_count_with_retry_after() {
let headers = mk_headers(vec![("x-mbx-order-count-5m", "42"), ("retry-after", "7")]);
let limits = parse_rate_limit_headers(&headers);
assert_eq!(limits.len(), 1);
let rl = &limits[0];
assert_eq!(rl.rate_limit_type, RateLimitType::Orders);
assert_eq!(rl.interval, Interval::Minute);
assert_eq!(rl.interval_num, 5);
assert_eq!(rl.count, 42);
assert_eq!(rl.retry_after, Some(7));
}
#[test]
fn multiple_headers() {
let headers = mk_headers(vec![
("X-MBX-USED-WEIGHT-1h", "10"),
("x-mbx-order-count-2d", "20"),
]);
let mut limits = parse_rate_limit_headers(&headers);
limits.sort_by_key(|r| (r.interval_num, format!("{:?}", r.rate_limit_type)));
assert_eq!(limits.len(), 2);
let w = &limits[0];
assert_eq!(w.rate_limit_type, RateLimitType::RequestWeight);
assert_eq!(w.interval, Interval::Hour);
assert_eq!(w.interval_num, 1);
assert_eq!(w.count, 10);
let o = &limits[1];
assert_eq!(o.rate_limit_type, RateLimitType::Orders);
assert_eq!(o.interval, Interval::Day);
assert_eq!(o.interval_num, 2);
assert_eq!(o.count, 20);
}
#[test]
fn ignores_unknown_and_malformed() {
let headers = mk_headers(vec![
("x-mbx-used-weight-3x", "5"),
("random-header", "100"),
]);
let limits = parse_rate_limit_headers(&headers);
assert!(limits.is_empty());
}
}
mod http_request {
use std::io::Write;
use flate2::{Compression, write::GzEncoder};
use httpmock::MockServer;
use reqwest::{Client, Method, Request};
use serde::Deserialize;
use crate::{
common::utils::http_request, config::ConfigurationRestApi, errors::ConnectorError,
models::RestApiResponse,
};
use super::TOKIO_SHARED_RT;
#[derive(Deserialize, Debug, PartialEq)]
struct Dummy {
foo: String,
}
fn make_config(server_url: &str) -> ConfigurationRestApi {
ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server_url)
.build()
.expect("Failed to build configuration")
}
#[test]
fn http_request_success_plain_text() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/test");
then.status(200)
.header("Content-Type", "application/json")
.body(r#"{"foo":"bar"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/test"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
assert_eq!(resp.status, 200);
let data = resp.data().await.unwrap();
assert_eq!(data, Dummy { foo: "bar".into() });
mock.assert();
});
}
#[test]
fn http_request_success_gzip() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let body = r#"{"foo":"baz"}"#;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(body.as_bytes()).unwrap();
let gz = encoder.finish().unwrap();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/gz");
then.status(200)
.header("Content-Type", "application/json")
.header("Content-Encoding", "gzip")
.body(gz);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/gz"))
.build()
.unwrap();
let mut cfg = make_config(&server.url(""));
cfg.compression = true;
let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
assert_eq!(resp.status, 200);
let data = resp.data().await.unwrap();
assert_eq!(data, Dummy { foo: "baz".into() });
mock.assert();
});
}
#[test]
fn http_request_client_error_bad_request() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/400");
then.status(400)
.header("Content-Type", "application/json")
.body(r#"{"code":-1121,"msg":"bad request"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/400"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(
result,
Err(ConnectorError::BadRequestError { .. })
));
if let Err(ConnectorError::BadRequestError { msg, code }) = result {
assert_eq!(msg, "bad request");
assert_eq!(code, Some(-1121));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_unauthorized() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/401");
then.status(401)
.header("Content-Type", "application/json")
.body(r#"{"code":-2015,"msg":"unauthorized"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/401"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(
result,
Err(ConnectorError::UnauthorizedError { .. })
));
if let Err(ConnectorError::UnauthorizedError { msg, code }) = result {
assert_eq!(msg, "unauthorized");
assert_eq!(code, Some(-2015));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_forbidden() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/403");
then.status(403)
.header("Content-Type", "application/json")
.body(r#"{"code":-2010,"msg":"forbidden"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/403"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(result, Err(ConnectorError::ForbiddenError { .. })));
if let Err(ConnectorError::ForbiddenError { msg, code }) = result {
assert_eq!(msg, "forbidden");
assert_eq!(code, Some(-2010));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_not_found() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/404");
then.status(404)
.header("Content-Type", "application/json")
.body(r#"{"code":-1003,"msg":"not found"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/404"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(result, Err(ConnectorError::NotFoundError { .. })));
if let Err(ConnectorError::NotFoundError { msg, code }) = result {
assert_eq!(msg, "not found");
assert_eq!(code, Some(-1003));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_rate_limit_exceeded() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/418");
then.status(418)
.header("Content-Type", "application/json")
.body(r#"{"code":-1003,"msg":"rate limit exceeded"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/418"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(
result,
Err(ConnectorError::RateLimitBanError { .. })
));
if let Err(ConnectorError::RateLimitBanError { msg, code }) = result {
assert_eq!(msg, "rate limit exceeded");
assert_eq!(code, Some(-1003));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_too_many_requests() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/429");
then.status(429)
.header("Content-Type", "application/json")
.body(r#"{"code":-1003,"msg":"too many requests"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/429"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(
result,
Err(ConnectorError::TooManyRequestsError { .. })
));
if let Err(ConnectorError::TooManyRequestsError { msg, code }) = result {
assert_eq!(msg, "too many requests");
assert_eq!(code, Some(-1003));
}
mock.assert();
});
}
#[test]
fn http_request_client_error_server_error() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/500");
then.status(500)
.header("Content-Type", "application/json")
.body(r#"{"code":-1000,"msg":"internal server error"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/500"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(result, Err(ConnectorError::ServerError { .. })));
if let Err(ConnectorError::ServerError {
msg,
status_code: Some(500),
}) = result
{
assert_eq!(msg, "Server error: 500".to_string());
}
mock.assert();
});
}
#[test]
fn http_request_unexpected_status_maps_generic() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let code_http = 402;
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/402");
then.status(code_http)
.header("Content-Type", "application/json")
.body(r#"{"code":-12345,"msg":"payment required"}"#);
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/402"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let result = http_request::<Dummy>(req, &cfg).await;
assert!(matches!(
result,
Err(ConnectorError::ConnectorClientError { .. })
));
if let Err(ConnectorError::ConnectorClientError { msg, code }) = result {
assert_eq!(msg, "payment required");
assert_eq!(code, Some(-12345));
}
mock.assert();
});
}
#[test]
fn http_request_malformed_json_maps_generic() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let mock = server.mock(|when, then| {
when.method(httpmock::Method::GET).path("/malformed");
then.status(200)
.header("Content-Type", "application/json")
.body("not json");
});
let client = Client::new();
let req: Request = client
.request(Method::GET, format!("{}{}", server.url(""), "/malformed"))
.build()
.unwrap();
let cfg = make_config(&server.url(""));
let resp = http_request::<Dummy>(req, &cfg)
.await
.expect("http_request should succeed even if JSON is bad");
let err = resp
.data()
.await
.expect_err("malformed JSON should turn into ConnectorClientError");
assert!(matches!(err, ConnectorError::ConnectorClientError { .. }));
if let ConnectorError::ConnectorClientError { msg: _, code } = err {
assert_eq!(code, None);
}
mock.assert();
});
}
}
mod send_request {
use anyhow::Result;
use httpmock::prelude::*;
use reqwest::Method;
use serde::Deserialize;
use serde_json::json;
use std::collections::{BTreeMap, HashMap};
use crate::{
common::{models::TimeUnit, utils::send_request},
config::ConfigurationRestApi,
};
use super::TOKIO_SHARED_RT;
#[derive(Deserialize, Debug, PartialEq)]
struct TestResponse {
message: String,
}
#[test]
fn basic_get_request() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/api/v1/test");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message": "success"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let result = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "success");
Ok(())
})
}
#[test]
fn signed_post_request() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(POST).path("/api/v3/order");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message": "order placed"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let mut params = BTreeMap::new();
params.insert("symbol".to_string(), json!("ETHUSDT"));
params.insert("side".to_string(), json!("BUY"));
params.insert("type".to_string(), json!("MARKET"));
params.insert("quantity".to_string(), json!("1"));
let result = send_request::<TestResponse>(
&configuration,
"/api/v3/order",
Method::POST,
params,
BTreeMap::new(),
None,
true,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "order placed");
Ok(())
})
}
#[test]
fn signed_post_request_with_body() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(POST).path("/api/v3/order");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message": "order placed"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let mut query_params = BTreeMap::new();
query_params.insert("symbol".to_string(), json!("ETHUSDT"));
let mut body_params = BTreeMap::new();
body_params.insert("side".to_string(), json!("BUY"));
body_params.insert("type".to_string(), json!("MARKET"));
body_params.insert("quantity".to_string(), json!("1"));
let result = send_request::<TestResponse>(
&configuration,
"/api/v3/order",
Method::POST,
query_params,
body_params,
None,
true,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "order placed");
Ok(())
})
}
#[test]
fn get_request_with_params() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET)
.path("/api/v1/data")
.query_param("symbol", "BTCUSDT")
.query_param("limit", "10");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message": "data retrieved"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let mut params = BTreeMap::new();
params.insert("symbol".to_string(), json!("BTCUSDT"));
params.insert("limit".to_string(), json!(10));
let result = send_request::<TestResponse>(
&configuration,
"/api/v1/data",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "data retrieved");
Ok(())
})
}
#[test]
fn invalid_endpoint() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let result = send_request::<TestResponse>(
&configuration,
"http://invalid",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await;
assert!(result.is_err());
});
}
#[test]
fn missing_signature_on_signed_request() {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.build()
.expect("Failed to build configuration");
let mut params = BTreeMap::new();
params.insert("symbol".to_string(), json!("BTCUSDT"));
params.insert("side".to_string(), json!("BUY"));
let result = send_request::<TestResponse>(
&configuration,
"/api/v3/order",
Method::POST,
params,
BTreeMap::new(),
None,
true,
)
.await;
assert!(result.is_err());
});
}
#[test]
fn compression_enabled() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET).path("/api/v1/test");
then.status(200)
.header("content-type", "application/json")
.header("accept-encoding", "gzip, deflate, br")
.body(r#"{"message": "compression enabled"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(true)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let result = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "compression enabled");
Ok(())
})
}
#[test]
fn get_request_with_time_unit_header() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET)
.path("/api/v1/test")
.header("X-MBX-TIME-UNIT", "MILLISECOND");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message": "time unit applied"}"#);
});
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.time_unit(TimeUnit::Millisecond)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let result = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
Some(TimeUnit::Millisecond),
false,
)
.await?;
let data = result.data().await.unwrap();
assert_eq!(data.message, "time unit applied");
Ok(())
})
}
#[test]
fn custom_headers_are_sent() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET)
.path("/api/v1/test")
.header("X-My-Test", "all-clear");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message":"ok"}"#);
});
let mut custom = HashMap::new();
custom.insert("X-My-Test".to_string(), "all-clear".to_string());
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.custom_headers(custom)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let res = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = res.data().await.unwrap();
assert_eq!(data.message, "ok");
Ok(())
})
}
#[test]
fn custom_header_override_prevention() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET)
.path("/api/v1/test")
.header("content-type", "application/json")
.header("x-mbx-apikey", "key")
.header("X-My-Test", "ok");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message":"defaults intact"}"#);
});
let mut custom = HashMap::new();
custom.insert("Content-Type".to_string(), "text/plain".to_string());
custom.insert("X-MBX-APIKEY".to_string(), "BAD".to_string());
custom.insert("X-My-Test".to_string(), "ok".to_string());
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.custom_headers(custom)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let res = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = res.data().await.unwrap();
assert_eq!(data.message, "defaults intact");
Ok(())
})
}
#[test]
fn crlf_in_header_values_are_dropped() -> Result<()> {
TOKIO_SHARED_RT.block_on(async {
let server = MockServer::start();
server.mock(|when, then| {
when.method(GET)
.path("/api/v1/test")
.header("X-Good", "safe");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"message":"clean only"}"#);
});
let mut custom = HashMap::new();
custom.insert("X-Bad".to_string(), "evil\r\ninject".to_string());
custom.insert("X-Good".to_string(), "safe".to_string());
let configuration = ConfigurationRestApi::builder()
.api_key("key")
.api_secret("secret")
.base_path(server.base_url())
.compression(false)
.custom_headers(custom)
.build()
.expect("Failed to build configuration");
let params = BTreeMap::new();
let res = send_request::<TestResponse>(
&configuration,
"/api/v1/test",
Method::GET,
params,
BTreeMap::new(),
None,
false,
)
.await?;
let data = res.data().await.unwrap();
assert_eq!(data.message, "clean only");
Ok(())
})
}
}
mod random_string {
use crate::common::utils::random_string;
use hex;
#[test]
fn length_is_32() {
let s = random_string();
assert_eq!(
s.len(),
32,
"random_string() should be 32 chars, got {}",
s.len()
);
}
#[test]
fn is_valid_lowercase_hex() {
let s = random_string();
assert!(
s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')),
"random_string() contains invalid hex characters: {s}"
);
}
#[test]
fn decodes_to_16_bytes() {
let s = random_string();
let bytes = hex::decode(&s).expect("random_string() output must be valid hex");
assert_eq!(
bytes.len(),
16,
"hex::decode returned {} bytes",
bytes.len()
);
}
#[test]
fn two_calls_are_different() {
let a = random_string();
let b = random_string();
assert_ne!(
a, b,
"Two calls to random_string() returned the same value: {a}"
);
}
}
mod random_integer {
use crate::common::utils::random_integer;
#[test]
fn is_within_u32_range() {
let n = random_integer();
assert!(
n <= u32::MAX,
"random_integer() should be <= u32::MAX, got {n}"
);
}
#[test]
fn two_calls_can_differ() {
let a = random_integer();
let b = random_integer();
assert_ne!(
a, b,
"Two calls to random_integer() returned the same value: {a}"
);
}
}
mod normalize_stream_id {
use crate::common::utils::{StreamId, normalize_stream_id};
use serde_json::Value;
fn is_lower_hex32(s: &str) -> bool {
s.len() == 32 && s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f'))
}
#[test]
fn valid_hex_string_is_kept() {
let id = "0123456789abcdef0123456789abcdef".to_string();
let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
match out {
Value::String(s) => assert_eq!(s, id, "Expected to keep the valid hex id"),
other => panic!("Expected Value::String, got {other:?}"),
}
}
#[test]
fn invalid_hex_string_generates_random_hex() {
let id = "not-hex".to_string();
let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
match out {
Value::String(s) => {
assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
assert_ne!(s, id, "Expected generated id to differ from input");
assert!(
is_lower_hex32(&s),
"Generated id contains invalid hex characters: {s}"
);
}
other => panic!("Expected Value::String, got {other:?}"),
}
}
#[test]
fn none_generates_random_hex() {
let out = normalize_stream_id(None, false);
match out {
Value::String(s) => {
assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
assert!(
is_lower_hex32(&s),
"Generated id contains invalid hex characters: {s}"
);
}
other => panic!("Expected Value::String, got {other:?}"),
}
}
#[test]
fn number_is_kept_when_not_strict() {
let out = normalize_stream_id(Some(StreamId::Number(42)), false);
match out {
Value::Number(n) => {
assert_eq!(n.as_u64(), Some(42), "Expected to keep the numeric id");
}
other => panic!("Expected Value::Number, got {other:?}"),
}
}
#[test]
fn strict_number_forces_number_even_for_valid_hex_string() {
let id = "0123456789abcdef0123456789abcdef".to_string();
let out = normalize_stream_id(Some(StreamId::Str(id)), true);
match out {
Value::Number(n) => {
assert!(
n.as_u64().is_some(),
"Expected unsigned integer JSON number, got {n}"
);
}
other => panic!("Expected Value::Number, got {other:?}"),
}
}
#[test]
fn strict_number_keeps_number_if_provided() {
let out = normalize_stream_id(Some(StreamId::Number(7)), true);
match out {
Value::Number(n) => {
assert_eq!(n.as_u64(), Some(7), "Expected to keep the numeric id");
}
other => panic!("Expected Value::Number, got {other:?}"),
}
}
#[test]
fn strict_number_generates_number_when_none() {
let out = normalize_stream_id(None, true);
match out {
Value::Number(n) => {
assert!(
n.as_u64().is_some(),
"Expected unsigned integer JSON number, got {n}"
);
}
other => panic!("Expected Value::Number, got {other:?}"),
}
}
#[test]
fn strict_number_generates_number_for_invalid_hex_string() {
let out = normalize_stream_id(Some(StreamId::Str("nope".to_string())), true);
match out {
Value::Number(n) => {
assert!(
n.as_u64().is_some(),
"Expected unsigned integer JSON number, got {n}"
);
}
other => panic!("Expected Value::Number, got {other:?}"),
}
}
}
mod remove_empty_value {
use crate::common::utils::remove_empty_value;
use serde_json::{Map, Value};
#[test]
fn filters_out_null_and_empty_strings() {
let entries = vec![
("key1".to_string(), Value::String("value1".to_string())),
("key2".to_string(), Value::Null),
("key3".to_string(), Value::String(String::new())),
];
let result = remove_empty_value(entries);
assert_eq!(
result.len(),
1,
"expected only one entry, got {}",
result.len()
);
assert_eq!(
result.get("key1"),
Some(&Value::String("value1".to_string()))
);
assert!(!result.contains_key("key2"));
assert!(!result.contains_key("key3"));
}
#[test]
fn retains_other_value_types() {
let entries = vec![
("bool".to_string(), Value::Bool(true)),
("num".to_string(), Value::Number(42.into())),
("arr".to_string(), Value::Array(vec![])),
("obj".to_string(), Value::Object(Map::default())),
("nil".to_string(), Value::Null),
("empty_str".to_string(), Value::String(String::new())),
];
let result = remove_empty_value(entries);
let keys: Vec<&String> = result.keys().collect();
assert_eq!(keys.len(), 4, "expected 4 entries, got {}", keys.len());
assert!(result.get("bool") == Some(&Value::Bool(true)));
assert!(result.get("num") == Some(&Value::Number(42.into())));
assert!(result.get("arr") == Some(&Value::Array(vec![])));
assert!(result.get("obj") == Some(&Value::Object(Map::default())));
assert!(!result.contains_key("nil"));
assert!(!result.contains_key("empty_str"));
}
#[test]
fn empty_iterator_returns_empty_map() {
let entries: Vec<(String, Value)> = vec![];
let result = remove_empty_value(entries);
assert!(result.is_empty(), "expected an empty map");
}
#[test]
fn keys_are_sorted() {
let entries = vec![
("c".to_string(), Value::String("foo".to_string())),
("a".to_string(), Value::String("bar".to_string())),
("b".to_string(), Value::String("baz".to_string())),
];
let result = remove_empty_value(entries);
let sorted_keys: Vec<&String> = result.keys().collect();
assert_eq!(
sorted_keys,
[&"a".to_string(), &"b".to_string(), &"c".to_string()]
);
}
}
mod sort_object_params {
use crate::common::utils::sort_object_params;
use serde_json::Value;
use std::collections::BTreeMap;
#[test]
fn sorts_keys() {
let mut params = BTreeMap::new();
params.insert("z".to_string(), Value::String("last".to_string()));
params.insert("a".to_string(), Value::String("first".to_string()));
params.insert("m".to_string(), Value::String("middle".to_string()));
let sorted = sort_object_params(¶ms);
let keys: Vec<&String> = sorted.keys().collect();
assert_eq!(
keys,
[&"a".to_string(), &"m".to_string(), &"z".to_string()],
"Keys should be sorted alphabetically"
);
}
#[test]
fn preserves_values() {
let mut params = BTreeMap::new();
params.insert("one".to_string(), Value::Number(1.into()));
params.insert("two".to_string(), Value::Bool(true));
let sorted = sort_object_params(¶ms);
assert_eq!(sorted.get("one"), Some(&Value::Number(1.into())));
assert_eq!(sorted.get("two"), Some(&Value::Bool(true)));
}
#[test]
fn empty_map_returns_empty() {
let params: BTreeMap<String, Value> = BTreeMap::new();
let sorted = sort_object_params(¶ms);
assert!(sorted.is_empty(), "Expected empty map");
}
#[test]
fn independent_clone() {
let mut params = BTreeMap::new();
params.insert("key".to_string(), Value::String("val".to_string()));
let mut sorted = sort_object_params(¶ms);
sorted.insert("new".to_string(), Value::String("x".to_string()));
assert!(
!params.contains_key("new"),
"Original should not be modified when changing sorted"
);
assert!(
sorted.contains_key("new"),
"Sorted map should reflect its own insertions"
);
}
}
mod normalize_ws_streams_key {
use crate::common::utils::normalize_ws_streams_key;
#[test]
fn returns_empty_for_empty() {
assert_eq!(normalize_ws_streams_key(""), "");
}
#[test]
fn already_normalized_stays_same() {
assert_eq!(normalize_ws_streams_key("streamname"), "streamname");
}
#[test]
fn uppercases_are_lowercased() {
assert_eq!(normalize_ws_streams_key("MyStream"), "mystream");
}
#[test]
fn underscores_are_removed() {
assert_eq!(normalize_ws_streams_key("my_stream_name"), "mystreamname");
}
#[test]
fn hyphens_are_removed() {
assert_eq!(normalize_ws_streams_key("my-stream-name"), "mystreamname");
}
#[test]
fn mixed_underscores_and_hyphens_and_case() {
let input = "Mixed_Case-Stream_Name";
let expected = "mixedcasestreamname";
assert_eq!(normalize_ws_streams_key(input), expected);
}
#[test]
fn retains_other_punctuation() {
assert_eq!(normalize_ws_streams_key("stream.name!"), "stream.name!");
}
}
mod replace_websocket_streams_placeholders {
use crate::common::utils::replace_websocket_streams_placeholders;
use std::collections::HashMap;
#[test]
fn empty_string_unchanged() {
let vars: HashMap<&str, &str> = HashMap::new();
assert_eq!(replace_websocket_streams_placeholders("", &vars), "");
}
#[test]
fn unknown_placeholder_becomes_empty() {
let vars: HashMap<&str, &str> = HashMap::new();
assert_eq!(replace_websocket_streams_placeholders("<foo>", &vars), "");
}
#[test]
fn leading_slash_symbol_lowercases_head() {
let mut vars = HashMap::new();
vars.insert("symbol", "BTC");
assert_eq!(
replace_websocket_streams_placeholders("/<symbol>", &vars),
"btc"
);
}
#[test]
fn no_lowercase_without_slash() {
let mut vars = HashMap::new();
vars.insert("symbol", "BTC");
assert_eq!(
replace_websocket_streams_placeholders("<symbol>", &vars),
"BTC"
);
}
#[test]
fn multiple_placeholders_mid_preserve_ats() {
let mut vars = HashMap::new();
vars.insert("symbol", "BNBUSDT");
vars.insert("levels", "10");
vars.insert("updateSpeed", "1000ms");
let out = replace_websocket_streams_placeholders(
"/<symbol>@depth<levels>@<updateSpeed>",
&vars,
);
assert_eq!(out, "bnbusdt@depth10@1000ms");
}
#[test]
fn trailing_at_removed_when_missing_var() {
let mut vars = HashMap::new();
vars.insert("symbol", "BNBUSDT");
vars.insert("levels", "10");
let out = replace_websocket_streams_placeholders(
"/<symbol>@depth<levels>@<updateSpeed>",
&vars,
);
assert_eq!(out, "bnbusdt@depth10");
}
#[test]
fn custom_key_normalization_and_value() {
let mut vars = HashMap::new();
vars.insert("my-stream_key", "Value");
assert_eq!(
replace_websocket_streams_placeholders("<My_Stream-Key>", &vars),
"Value"
);
}
#[test]
fn text_surrounding_placeholders_intact() {
let mut vars = HashMap::new();
vars.insert("symbol", "ABC");
let input = "pre-<symbol>-post";
assert_eq!(
replace_websocket_streams_placeholders(input, &vars),
"pre-ABC-post"
);
}
}
mod build_websocket_api_message {
use serde_json::{Value, json};
use std::collections::BTreeMap;
use crate::{
common::{
utils::{ID_REGEX, build_websocket_api_message, remove_empty_value},
websocket::WebsocketMessageSendOptions,
},
config::ConfigurationWebsocketApi,
};
fn make_config() -> ConfigurationWebsocketApi {
ConfigurationWebsocketApi::builder()
.api_key("api-key".to_string())
.api_secret("api-secret".to_string())
.build()
.unwrap()
}
#[test]
fn no_auth_or_sign_with_skip_auth() {
let mut payload = BTreeMap::new();
payload.insert("foo".into(), Value::String("bar".into()));
let cfg = make_config();
let (id, req) = build_websocket_api_message(
&cfg,
"method",
payload.clone(),
&WebsocketMessageSendOptions {
with_api_key: true,
is_signed: true,
..Default::default()
},
true,
);
assert!(ID_REGEX.is_match(&id));
assert_eq!(req["method"], "method");
assert_eq!(req["params"]["foo"], "bar");
assert!(req["params"].get("apiKey").is_none());
assert!(req["params"].get("signature").is_none());
assert!(req["params"]["timestamp"].is_number());
}
#[test]
fn only_api_key_when_not_signed() {
let cfg = make_config();
let (id, req) = build_websocket_api_message(
&cfg,
"method",
BTreeMap::new(),
&WebsocketMessageSendOptions {
with_api_key: true,
is_signed: false,
..Default::default()
},
false,
);
assert!(ID_REGEX.is_match(&id));
assert_eq!(req["method"], "method");
assert_eq!(req["params"]["apiKey"], "api-key");
assert!(req["params"].get("timestamp").is_none());
assert!(req["params"].get("signature").is_none());
}
#[test]
fn signed_includes_timestamp_and_signature() {
let mut payload = BTreeMap::new();
payload.insert("foo".into(), Value::String("bar".into()));
let cfg = make_config();
let (id, req) = build_websocket_api_message(
&cfg,
"method",
payload.clone(),
&WebsocketMessageSendOptions {
with_api_key: true,
is_signed: true,
..Default::default()
},
false,
);
assert!(ID_REGEX.is_match(&id));
assert_eq!(req["method"], "method");
let params = &req["params"];
assert_eq!(params["apiKey"], "api-key");
let timestamp = params["timestamp"].as_i64().unwrap();
assert!(timestamp > 0, "timestamp should not be empty");
let sig = params["signature"].as_str().unwrap();
assert!(!sig.is_empty(), "signature should not be empty");
}
#[test]
fn respects_provided_valid_id_and_removes_from_params() {
let mut payload = BTreeMap::new();
let custom = "0123456789abcdef0123456789abcdef".to_string();
payload.insert("id".into(), Value::String(custom.clone()));
payload.insert("foo".into(), Value::Number(123.into()));
let cfg = make_config();
let (id, req) = build_websocket_api_message(
&cfg,
"method",
payload.clone(),
&WebsocketMessageSendOptions::default(),
true,
);
assert_eq!(id, custom);
assert!(req["params"].get("id").is_none());
assert_eq!(req["params"]["foo"], 123);
}
#[test]
fn skip_auth_blocks_api_and_signature_but_keeps_timestamp() {
let mut payload = BTreeMap::new();
payload.insert("foo".into(), Value::String("bar".into()));
let cfg = make_config();
let (_id, req) = build_websocket_api_message(
&cfg,
"method",
payload.clone(),
&WebsocketMessageSendOptions {
with_api_key: true,
is_signed: true,
..Default::default()
},
true,
);
let p = &req["params"];
assert_eq!(p["foo"], "bar");
assert!(p.get("apiKey").is_none());
assert!(p.get("signature").is_none());
assert!(p["timestamp"].is_number());
}
#[test]
fn random_id_changes_each_call() {
let cfg = make_config();
let (id1, _) = build_websocket_api_message(
&cfg,
"method",
BTreeMap::new(),
&WebsocketMessageSendOptions::default(),
true,
);
let (id2, _) = build_websocket_api_message(
&cfg,
"method",
BTreeMap::new(),
&WebsocketMessageSendOptions::default(),
true,
);
assert!(ID_REGEX.is_match(&id1));
assert!(ID_REGEX.is_match(&id2));
assert_ne!(id1, id2, "IDs should be random and not equal");
}
#[test]
fn null_and_empty_values_are_stripped() {
let mut payload = BTreeMap::new();
payload.insert("a".into(), Value::Null);
payload.insert("b".into(), Value::String(String::new()));
payload.insert("c".into(), Value::String("ok".into()));
let cleaned = remove_empty_value(payload.clone());
assert!(!cleaned.contains_key("a"), "Null should be stripped");
assert!(
!cleaned.contains_key("b"),
"Empty string should be stripped"
);
assert!(cleaned.contains_key("c"), "Non-empty string should be kept");
let cfg = make_config();
let (_id, req) = build_websocket_api_message(
&cfg,
"method",
payload,
&WebsocketMessageSendOptions::default(),
true,
);
let params = &req["params"];
assert!(params.get("a").is_none(), "`a` should not appear");
assert!(params.get("b").is_none(), "`b` should not appear");
assert_eq!(params["c"], "ok", "`c` should be present with value \"ok\"");
}
#[test]
fn provided_invalid_id_gets_replaced() {
let mut payload = BTreeMap::new();
payload.insert("id".into(), Value::String("not-hex-32-chars".into()));
let cfg = make_config();
let (id, _req) = build_websocket_api_message(
&cfg,
"method",
payload,
&WebsocketMessageSendOptions::default(),
true,
);
assert!(ID_REGEX.is_match(&id));
assert_ne!(id, "not-hex-32-chars");
}
#[test]
fn sign_only_includes_api_key_even_when_with_api_key_false() {
let mut payload = BTreeMap::new();
payload.insert("x".into(), json!(1));
let cfg = make_config();
let (_id, req) = build_websocket_api_message(
&cfg,
"method",
payload,
&WebsocketMessageSendOptions {
with_api_key: false,
is_signed: true,
..Default::default()
},
false,
);
let params = &req["params"];
assert_eq!(params["apiKey"], "api-key");
assert!(params["timestamp"].is_number());
assert!(params["signature"].is_string());
}
#[test]
fn skip_auth_false_without_any_auth_flags() {
let cfg = make_config();
let (_id, req) = build_websocket_api_message(
&cfg,
"method",
BTreeMap::new(),
&WebsocketMessageSendOptions {
with_api_key: false,
is_signed: false,
..Default::default()
},
false,
);
let params = &req["params"];
assert!(params.as_object().unwrap().is_empty());
}
}
}