1use anyhow::{Context, Result};
2use base64::{Engine as _, engine::general_purpose};
3use ed25519_dalek::Signer as Ed25519Signer;
4use ed25519_dalek::SigningKey;
5use ed25519_dalek::pkcs8::DecodePrivateKey;
6use flate2::read::GzDecoder;
7use hex;
8use hmac::{Hmac, Mac};
9use http::HeaderMap;
10use http::HeaderValue;
11use http::header::ACCEPT_ENCODING;
12use once_cell::sync::OnceCell;
13use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer as OpenSslSigner};
14use rand::{RngCore, rngs::OsRng};
15use regex::Captures;
16use regex::Regex;
17use reqwest::Client;
18use reqwest::Proxy;
19use reqwest::{Method, Request};
20use serde::de::DeserializeOwned;
21use serde_json::Number;
22use serde_json::{Value, json};
23use sha2::Sha256;
24use std::fmt::Display;
25use std::hash::BuildHasher;
26use std::sync::LazyLock;
27use std::{
28 collections::BTreeMap,
29 collections::HashMap,
30 fs,
31 io::Read,
32 path::Path,
33 time::Duration,
34 time::{SystemTime, UNIX_EPOCH},
35};
36use tokio::time::sleep;
37use tracing::info;
38use url::form_urlencoded;
39use url::{Url, form_urlencoded::Serializer};
40
41use super::config::{
42 ConfigurationRestApi, ConfigurationWebsocketApi, HttpAgent, PrivateKey, ProxyConfig,
43};
44use super::errors::ConnectorError;
45use super::models::{
46 Interval, RateLimitType, RestApiRateLimit, RestApiResponse, StreamId, TimeUnit,
47};
48use super::websocket::WebsocketMessageSendOptions;
49
50pub(crate) static ID_REGEX: LazyLock<Regex> =
51 LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
52static PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(@)?<([^>]+)>").unwrap());
53
54#[derive(Debug, Default, Clone)]
69pub struct SignatureGenerator {
70 api_secret: Option<String>,
71 private_key: Option<PrivateKey>,
72 private_key_passphrase: Option<String>,
73 raw_key_data: OnceCell<String>,
74 key_object: OnceCell<PKey<openssl::pkey::Private>>,
75 ed25519_signing_key: OnceCell<SigningKey>,
76}
77
78impl SignatureGenerator {
79 #[must_use]
80 pub fn new(
81 api_secret: Option<String>,
82 private_key: Option<PrivateKey>,
83 private_key_passphrase: Option<String>,
84 ) -> Self {
85 SignatureGenerator {
86 api_secret,
87 private_key,
88 private_key_passphrase,
89 raw_key_data: OnceCell::new(),
90 key_object: OnceCell::new(),
91 ed25519_signing_key: OnceCell::new(),
92 }
93 }
94
95 fn get_raw_key_data(&self) -> Result<&String> {
110 self.raw_key_data.get_or_try_init(|| {
111 let pk = self
112 .private_key
113 .as_ref()
114 .ok_or_else(|| anyhow::anyhow!("No private_key provided"))?;
115 match pk {
116 PrivateKey::File(path) => {
117 if Path::new(path).exists() {
118 fs::read_to_string(path)
119 .with_context(|| format!("Failed to read private key file: {path}"))
120 } else {
121 Err(anyhow::anyhow!("Private key file does not exist: {}", path))
122 }
123 }
124 PrivateKey::Raw(bytes) => Ok(String::from_utf8_lossy(bytes).to_string()),
125 }
126 })
127 }
128
129 fn get_key_object(&self) -> Result<&PKey<openssl::pkey::Private>> {
144 self.key_object.get_or_try_init(|| {
145 let key_data = self.get_raw_key_data()?;
146 if let Some(pass) = self.private_key_passphrase.as_ref() {
147 PKey::private_key_from_pem_passphrase(key_data.as_bytes(), pass.as_bytes())
148 .context("Failed to parse private key with passphrase")
149 } else {
150 PKey::private_key_from_pem(key_data.as_bytes())
151 .context("Failed to parse private key")
152 }
153 })
154 }
155
156 fn get_ed25519_signing_key(
169 &self,
170 key_obj: &PKey<openssl::pkey::Private>,
171 ) -> Result<&SigningKey> {
172 self.ed25519_signing_key.get_or_try_init(|| {
173 let der = key_obj
174 .private_key_to_der()
175 .context("Failed to export Ed25519 key to DER")?;
176 SigningKey::from_pkcs8_der(&der)
177 .map_err(|e| anyhow::anyhow!("Failed to parse Ed25519 key: {}", e))
178 })
179 }
180
181 pub fn get_signature(
204 &self,
205 query_params: &BTreeMap<String, Value>,
206 body_params: Option<&BTreeMap<String, Value>>,
207 ) -> Result<String> {
208 let query_str = build_query_string(query_params)?;
209 let params = if let Some(body) = body_params {
210 if body.is_empty() {
211 query_str
212 } else {
213 let body_str = build_query_string(body)?;
214 format!("{query_str}{body_str}")
215 }
216 } else {
217 query_str
218 };
219
220 if self.private_key.is_none() {
221 if let Some(secret) = self.api_secret.as_ref() {
222 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
223 .context("HMAC key initialization failed")?;
224 mac.update(params.as_bytes());
225 let result = mac.finalize().into_bytes();
226 return Ok(hex::encode(result));
227 }
228 }
229
230 if self.private_key.is_some() {
231 let key_obj = self.get_key_object()?;
232 match key_obj.id() {
233 openssl::pkey::Id::RSA => {
234 let mut signer = OpenSslSigner::new(MessageDigest::sha256(), key_obj)
235 .context("Failed to create RSA signer")?;
236 signer
237 .update(params.as_bytes())
238 .context("Failed to update RSA signer")?;
239 let sig = signer.sign_to_vec().context("RSA signing failed")?;
240 return Ok(general_purpose::STANDARD.encode(sig));
241 }
242 openssl::pkey::Id::ED25519 => {
243 let signing_key = self.get_ed25519_signing_key(key_obj)?;
244 let signature = signing_key.sign(params.as_bytes());
245 return Ok(general_purpose::STANDARD.encode(signature.to_bytes()));
246 }
247 other => {
248 return Err(anyhow::anyhow!(
249 "Unsupported private key type: {:?}. Must be RSA or ED25519.",
250 other
251 ));
252 }
253 }
254 }
255
256 Err(anyhow::anyhow!(
257 "Either 'api_secret' or 'private_key' must be provided for signed requests."
258 ))
259 }
260}
261
262#[must_use]
285pub fn build_client(
286 timeout: u64,
287 keep_alive: bool,
288 proxy: Option<&ProxyConfig>,
289 agent: Option<HttpAgent>,
290) -> Client {
291 let builder = Client::builder().timeout(Duration::from_millis(timeout));
292
293 let mut builder = if keep_alive {
294 builder
295 } else {
296 builder.pool_idle_timeout(Some(Duration::from_secs(0)))
297 };
298
299 if let Some(proxy_conf) = proxy {
300 let protocol = proxy_conf
301 .protocol
302 .clone()
303 .unwrap_or_else(|| "http".to_string());
304 let proxy_url = format!("{}://{}:{}", protocol, proxy_conf.host, proxy_conf.port);
305 let mut proxy_builder = Proxy::all(&proxy_url).expect("Failed to create proxy from URL");
306 if let Some(auth) = &proxy_conf.auth {
307 proxy_builder = proxy_builder.basic_auth(&auth.username, &auth.password);
308 }
309 builder = builder.proxy(proxy_builder);
310 }
311
312 if let Some(HttpAgent(agent_fn)) = agent {
313 builder = (agent_fn)(builder);
314 }
315
316 info!("Client builder {:?}", builder);
317
318 builder.build().expect("Failed to build reqwest client")
319}
320
321#[must_use]
344pub fn build_user_agent(product: &str) -> String {
345 format!(
346 "{}/{}/{} (Rust/{}; {}; {})",
347 env!("CARGO_PKG_NAME"),
348 product,
349 env!("CARGO_PKG_VERSION"),
350 env!("RUSTC_VERSION"),
351 std::env::consts::OS,
352 std::env::consts::ARCH,
353 )
354}
355
356pub fn validate_time_unit(time_unit: &str) -> Result<Option<&str>, anyhow::Error> {
384 match time_unit {
385 "" => Ok(None),
386 "MILLISECOND" | "MICROSECOND" | "millisecond" | "microsecond" => Ok(Some(time_unit)),
387 _ => Err(anyhow::anyhow!(
388 "time_unit must be either 'MILLISECOND' or 'MICROSECOND'"
389 )),
390 }
391}
392
393#[must_use]
410pub fn get_timestamp() -> u128 {
411 SystemTime::now()
412 .duration_since(UNIX_EPOCH)
413 .expect("Time went backwards")
414 .as_millis()
415}
416
417pub async fn delay(ms: u64) {
429 sleep(Duration::from_millis(ms)).await;
430}
431
432pub fn build_query_string(params: &BTreeMap<String, Value>) -> Result<String, anyhow::Error> {
451 let mut segments = Vec::with_capacity(params.len());
452
453 for (key, value) in params {
454 if value.is_null() {
455 continue;
456 }
457
458 let value_str = match value {
459 Value::String(s) => s.clone(),
460 Value::Bool(b) => b.to_string(),
461 Value::Number(n) => n.to_string(),
462 Value::Array(_) | Value::Object(_) => serde_json::to_string(value)
463 .with_context(|| format!("failed to JSON-serialize `{}`", key))?,
464 Value::Null => unreachable!(),
465 };
466
467 let mut ser = Serializer::new(String::new());
468 ser.append_pair(key, &value_str);
469 segments.push(ser.finish());
470 }
471
472 Ok(segments.join("&"))
473}
474
475#[must_use]
483pub fn should_retry_request(
484 error: &reqwest::Error,
485 method: Option<&str>,
486 retries_left: Option<usize>,
487) -> bool {
488 let method = method.unwrap_or("");
489 let is_retriable_method =
490 method.eq_ignore_ascii_case("GET") || method.eq_ignore_ascii_case("DELETE");
491
492 let status = error.status().map_or(0, |s| s.as_u16());
493 let is_retriable_status = [500, 502, 503, 504].contains(&status);
494
495 let retries_left = retries_left.unwrap_or(0);
496 retries_left > 0 && is_retriable_method && (is_retriable_status || error.status().is_none())
497}
498
499#[must_use]
524pub fn parse_rate_limit_headers<S>(headers: &HashMap<String, String, S>) -> Vec<RestApiRateLimit>
525where
526 S: BuildHasher,
527{
528 let mut rate_limits = Vec::new();
529 let re = Regex::new(r"x-mbx-(used-weight|order-count)-(\d+)([smhd])").unwrap();
530 for (key, value) in headers {
531 let normalized_key = key.to_lowercase();
532 if normalized_key.starts_with("x-mbx-used-weight-")
533 || normalized_key.starts_with("x-mbx-order-count-")
534 {
535 if let Some(caps) = re.captures(&normalized_key) {
536 let interval_num: u32 = caps.get(2).unwrap().as_str().parse().unwrap_or(0);
537 let interval_letter = caps.get(3).unwrap().as_str().to_uppercase();
538 let interval = match interval_letter.as_str() {
539 "S" => Interval::Second,
540 "M" => Interval::Minute,
541 "H" => Interval::Hour,
542 "D" => Interval::Day,
543 _ => continue,
544 };
545 let count: u32 = value.parse().unwrap_or(0);
546 let rate_limit_type = if normalized_key.starts_with("x-mbx-used-weight-") {
547 RateLimitType::RequestWeight
548 } else {
549 RateLimitType::Orders
550 };
551
552 rate_limits.push(RestApiRateLimit {
553 rate_limit_type,
554 interval,
555 interval_num,
556 count,
557 retry_after: headers.get("retry-after").and_then(|v| v.parse().ok()),
558 });
559 }
560 }
561 }
562 rate_limits
563}
564
565pub async fn http_request<T: DeserializeOwned + Send + 'static>(
595 req: Request,
596 configuration: &ConfigurationRestApi,
597) -> Result<RestApiResponse<T>, ConnectorError> {
598 let client = &configuration.client;
599 let retries = configuration.retries as usize;
600 let backoff = configuration.backoff;
601 let mut attempt = 0;
602
603 loop {
604 let req_clone = req
605 .try_clone()
606 .context("Failed to clone request")
607 .map_err(|e| ConnectorError::ConnectorClientError {
608 msg: e.to_string(),
609 code: None,
610 })?;
611 match client.execute(req_clone).await {
612 Ok(response) => {
613 let status = response.status();
614 let headers_map: HashMap<String, String> = response
615 .headers()
616 .iter()
617 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
618 .collect();
619
620 let raw_bytes = match response.bytes().await {
621 Ok(b) => b,
622 Err(e) => {
623 attempt += 1;
624 if attempt <= retries {
625 continue;
626 }
627 return Err(ConnectorError::ConnectorClientError {
628 msg: format!("Failed to get response bytes: {e}"),
629 code: None,
630 });
631 }
632 };
633
634 let content = if headers_map
635 .get("content-encoding")
636 .is_some_and(|enc| enc.to_lowercase().contains("gzip"))
637 {
638 let mut decoder = GzDecoder::new(&raw_bytes[..]);
639 let mut decompressed = String::new();
640 decoder
641 .read_to_string(&mut decompressed)
642 .context("Failed to decompress gzip response")
643 .map_err(|e| ConnectorError::ConnectorClientError {
644 msg: e.to_string(),
645 code: None,
646 })?;
647 decompressed
648 } else {
649 String::from_utf8(raw_bytes.to_vec())
650 .context("Failed to convert response to UTF-8")
651 .map_err(|e| ConnectorError::ConnectorClientError {
652 msg: e.to_string(),
653 code: None,
654 })?
655 };
656
657 let rate_limits = parse_rate_limit_headers(&headers_map);
658
659 if status.is_client_error() || status.is_server_error() {
660 let mut err_msg = content.clone();
661 let mut err_code: Option<i64> = None;
662
663 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&content) {
664 if let Some(m) = v.get("msg").and_then(|m| m.as_str()) {
665 err_msg = m.to_string();
666 }
667 err_code = v.get("code").and_then(serde_json::Value::as_i64);
668 }
669
670 match status.as_u16() {
671 400 => {
672 return Err(ConnectorError::BadRequestError {
673 msg: err_msg,
674 code: err_code,
675 });
676 }
677 401 => {
678 return Err(ConnectorError::UnauthorizedError {
679 msg: err_msg,
680 code: err_code,
681 });
682 }
683 403 => {
684 return Err(ConnectorError::ForbiddenError {
685 msg: err_msg,
686 code: err_code,
687 });
688 }
689 404 => {
690 return Err(ConnectorError::NotFoundError {
691 msg: err_msg,
692 code: err_code,
693 });
694 }
695 418 => {
696 return Err(ConnectorError::RateLimitBanError {
697 msg: err_msg,
698 code: err_code,
699 });
700 }
701 429 => {
702 return Err(ConnectorError::TooManyRequestsError {
703 msg: err_msg,
704 code: err_code,
705 });
706 }
707 s if (500..600).contains(&s) => {
708 return Err(ConnectorError::ServerError {
709 msg: format!("Server error: {s}"),
710 status_code: Some(s),
711 });
712 }
713 _ => {
714 return Err(ConnectorError::ConnectorClientError {
715 msg: err_msg,
716 code: err_code,
717 });
718 }
719 }
720 }
721
722 let raw = content.clone();
723 return Ok(RestApiResponse {
724 data_fn: Box::new(move || {
725 Box::pin(async move {
726 let parsed: T = serde_json::from_str(&raw).map_err(|e| {
727 ConnectorError::ConnectorClientError {
728 msg: e.to_string(),
729 code: None,
730 }
731 })?;
732 Ok(parsed)
733 })
734 }),
735 status: status.as_u16(),
736 headers: headers_map,
737 rate_limits: if rate_limits.is_empty() {
738 None
739 } else {
740 Some(rate_limits)
741 },
742 });
743 }
744 Err(e) => {
745 attempt += 1;
746 if should_retry_request(&e, Some(req.method().as_str()), Some(retries - attempt)) {
747 delay(backoff * attempt as u64).await;
748 continue;
749 }
750 return Err(ConnectorError::ConnectorClientError {
751 msg: format!("HTTP request failed: {e}"),
752 code: None,
753 });
754 }
755 }
756 }
757}
758
759pub async fn send_request<T: DeserializeOwned + Send + 'static>(
786 configuration: &ConfigurationRestApi,
787 endpoint: &str,
788 method: Method,
789 mut query_params: BTreeMap<String, Value>,
790 body_params: BTreeMap<String, Value>,
791 time_unit: Option<TimeUnit>,
792 is_signed: bool,
793) -> anyhow::Result<RestApiResponse<T>> {
794 let base = configuration.base_path.as_deref().unwrap_or("");
795 let full_url = reqwest::Url::parse(base)
796 .and_then(|u| u.join(endpoint))
797 .context("Failed to join base URL and endpoint")?
798 .to_string();
799
800 if is_signed {
801 let timestamp = get_timestamp();
802 query_params.insert("timestamp".to_string(), json!(timestamp));
803 }
804
805 let signature = if is_signed {
806 let body_ref = if body_params.is_empty() {
807 None
808 } else {
809 Some(&body_params)
810 };
811 Some(
812 configuration
813 .signature_gen
814 .get_signature(&query_params, body_ref)?,
815 )
816 } else {
817 None
818 };
819
820 let mut url = Url::parse(&full_url)?;
821 {
822 let mut pairs = url.query_pairs_mut();
823 for (key, value) in &query_params {
824 let val_str = match value {
825 Value::String(s) => s.clone(),
826 _ => value.to_string(),
827 };
828 pairs.append_pair(key, &val_str);
829 }
830 if let Some(signature) = &signature {
831 pairs.append_pair("signature", signature);
832 }
833 }
834
835 let mut headers = HeaderMap::new();
836
837 let forbidden = ["host", "authorization", "cookie", ":method", ":path"]
838 .into_iter()
839 .map(str::to_ascii_lowercase)
840 .collect::<std::collections::HashSet<_>>();
841
842 if let Some(custom) = &configuration.custom_headers {
843 for (raw_name, raw_val) in custom {
844 let name = raw_name.trim();
845 if forbidden.contains(&name.to_ascii_lowercase()) {
846 continue;
847 }
848 if let (Ok(header_name), Ok(header_val)) = (
849 name.parse::<reqwest::header::HeaderName>(),
850 HeaderValue::from_str(raw_val),
851 ) {
852 headers.append(header_name, header_val);
853 }
854 }
855 }
856
857 if body_params.is_empty() {
858 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
859 } else {
860 headers.insert(
861 "Content-Type",
862 HeaderValue::from_static("application/x-www-form-urlencoded"),
863 );
864 }
865
866 headers.insert("User-Agent", configuration.user_agent.parse().unwrap());
867 if let Some(api_key) = &configuration.api_key {
868 headers.insert("X-MBX-APIKEY", HeaderValue::from_str(api_key)?);
869 }
870
871 if configuration.compression {
872 headers.insert(ACCEPT_ENCODING, "gzip, deflate, br".parse().unwrap());
873 }
874
875 let time_unit_to_apply = time_unit.or(configuration.time_unit);
876 if let Some(time_unit) = time_unit_to_apply {
877 headers.insert("X-MBX-TIME-UNIT", time_unit.as_upper_str().parse()?);
878 }
879
880 let mut req_builder = configuration.client.request(method, url).headers(headers);
881
882 if !body_params.is_empty() {
883 let mut serializer = form_urlencoded::Serializer::new(String::new());
884 for (key, value) in body_params {
885 let val_str = match value {
886 Value::String(s) => s,
887 _ => value.to_string(),
888 };
889 serializer.append_pair(&key, &val_str);
890 }
891 let body_str = serializer.finish();
892 req_builder = req_builder.body(body_str);
893 }
894
895 let req = req_builder.build()?;
896
897 Ok(http_request::<T>(req, configuration).await?)
898}
899
900#[must_use]
909pub fn random_string() -> String {
910 let mut buf = [0u8; 16];
911 rand::thread_rng().fill_bytes(&mut buf);
912 hex::encode(buf)
913}
914
915#[must_use]
924pub fn random_integer() -> u32 {
925 let mut buf = [0u8; 4];
926 OsRng.fill_bytes(&mut buf);
927 u32::from_ne_bytes(buf)
928}
929
930#[must_use]
941pub fn normalize_stream_id(id: Option<StreamId>, stream_id_is_strictly_number: bool) -> Value {
942 if stream_id_is_strictly_number {
943 let n = match id {
944 Some(StreamId::Number(n)) => n,
945 _ => random_integer(),
946 };
947 return Value::Number(Number::from(n));
948 }
949
950 match id {
951 Some(StreamId::Number(n)) => Value::Number(Number::from(n)),
952 Some(StreamId::Str(s)) => {
953 let out = if ID_REGEX.is_match(&s) {
954 s
955 } else {
956 random_string()
957 };
958 Value::String(out)
959 }
960 None => Value::String(random_string()),
961 }
962}
963
964pub fn remove_empty_value<I>(entries: I) -> BTreeMap<String, Value>
986where
987 I: IntoIterator<Item = (String, Value)>,
988{
989 entries
990 .into_iter()
991 .filter(|(_, value)| match value {
992 Value::Null => false,
993 Value::String(s) if s.is_empty() => false,
994 _ => true,
995 })
996 .collect()
997}
998
999#[must_use]
1020pub fn sort_object_params(params: &BTreeMap<String, Value>) -> BTreeMap<String, Value> {
1021 let mut sorted = BTreeMap::new();
1022 for (k, v) in params {
1023 sorted.insert(k.clone(), v.clone());
1024 }
1025 sorted
1026}
1027
1028fn normalize_ws_streams_key(key: &str) -> String {
1038 key.to_lowercase().replace(&['_', '-'][..], "")
1039}
1040
1041pub fn replace_websocket_streams_placeholders<V, S>(
1066 input: &str,
1067 variables: &HashMap<&str, V, S>,
1068) -> String
1069where
1070 V: Display,
1071 S: BuildHasher,
1072{
1073 let original = input;
1074
1075 let body = original.strip_prefix('/').unwrap_or(original);
1077
1078 let normalized: HashMap<String, String> = variables
1080 .iter()
1081 .map(|(k, v)| (normalize_ws_streams_key(k), v.to_string()))
1082 .collect();
1083
1084 let replaced = PLACEHOLDER_RE
1086 .replace_all(body, |caps: &Captures| {
1087 let prefix = caps.get(1).map_or("", |m| m.as_str());
1088 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
1089 let val = normalized.get(&key).cloned().unwrap_or_default();
1090 format!("{prefix}{val}")
1091 })
1092 .into_owned();
1093
1094 let stripped = replaced.trim_end_matches('@').to_string();
1096
1097 let should_lower_head =
1100 original.starts_with('/') && PLACEHOLDER_RE.find(body).is_some_and(|m| m.start() == 0);
1101
1102 if should_lower_head {
1104 if let Some(caps) = PLACEHOLDER_RE.captures(body) {
1105 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
1106 let first_val = normalized.get(&key).cloned().unwrap_or_default();
1107 if stripped.starts_with(&first_val) {
1108 let tail = &stripped[first_val.len()..];
1109 format!("{}{}", first_val.to_lowercase(), tail)
1110 } else {
1111 stripped.clone()
1112 }
1113 } else {
1114 stripped.clone()
1115 }
1116 } else {
1117 stripped.clone()
1118 }
1119}
1120
1121pub fn build_websocket_api_message(
1139 configuration: &ConfigurationWebsocketApi,
1140 method: &str,
1141 mut payload: BTreeMap<String, Value>,
1142 options: &WebsocketMessageSendOptions,
1143 skip_auth: bool,
1144) -> (String, serde_json::Value) {
1145 let id = payload
1146 .get("id")
1147 .and_then(Value::as_str)
1148 .filter(|s| ID_REGEX.is_match(s))
1149 .map_or_else(random_string, String::from);
1150
1151 payload.remove("id");
1152
1153 let mut params = remove_empty_value(payload);
1154
1155 if (options.with_api_key || options.is_signed) && !skip_auth {
1156 params.insert(
1157 "apiKey".into(),
1158 Value::String(configuration.api_key.clone().expect("API key must be set")),
1159 );
1160 }
1161
1162 if options.is_signed {
1163 let ts = get_timestamp();
1164 let ts_i64 = i64::try_from(ts).expect("timestamp fits in i64");
1165 params.insert("timestamp".into(), Value::Number(ts_i64.into()));
1166
1167 let mut sorted = sort_object_params(¶ms);
1168 if !skip_auth {
1169 let sig = configuration
1170 .signature_gen
1171 .get_signature(&sorted, None)
1172 .expect("signature generation");
1173 sorted.insert("signature".into(), Value::String(sig));
1174 }
1175 params = sorted.into_iter().collect();
1176 }
1177
1178 let request = json!({
1179 "id": id,
1180 "method": method,
1181 "params": params,
1182 });
1183
1184 (id, request)
1185}
1186
1187#[cfg(test)]
1188mod tests {
1189 use crate::TOKIO_SHARED_RT;
1190
1191 mod build_client {
1192 use std::{
1193 sync::{Arc, Mutex},
1194 time::{Duration, Instant},
1195 };
1196
1197 use reqwest::ClientBuilder;
1198
1199 use crate::{
1200 common::utils::build_client,
1201 config::{HttpAgent, ProxyAuth, ProxyConfig},
1202 };
1203
1204 use super::TOKIO_SHARED_RT;
1205
1206 #[test]
1207 fn enforces_timeout() {
1208 TOKIO_SHARED_RT.block_on(async {
1209 let client = build_client(100, true, None, None);
1210 let start = Instant::now();
1211 let res = client.get("http://10.255.255.1").send().await;
1212 assert!(
1213 res.is_err(),
1214 "expected an error (timeout or connect) but got {res:?}"
1215 );
1216 let elapsed = start.elapsed();
1217 assert!(
1218 elapsed < Duration::from_millis(500),
1219 "timed out too slowly: {elapsed:?}"
1220 );
1221 });
1222 }
1223
1224 #[test]
1225 fn builds_with_keep_alive_disabled() {
1226 let client = build_client(200, false, None, None);
1227 let _: reqwest::Client = client;
1228 }
1229
1230 #[test]
1231 #[should_panic(expected = "Failed to create proxy from URL")]
1232 fn invalid_proxy_url_panics() {
1233 let bad_proxy = ProxyConfig {
1234 protocol: Some("http".to_string()),
1235 host: String::new(),
1236 port: 8080,
1237 auth: None,
1238 };
1239 let _ = build_client(1_000, true, Some(&bad_proxy), None);
1240 }
1241
1242 #[test]
1243 fn builds_with_proxy_and_auth() {
1244 let proxy = ProxyConfig {
1245 protocol: Some("https".to_string()),
1246 host: "127.0.0.1".to_string(),
1247 port: 3128,
1248 auth: Some(ProxyAuth {
1249 username: "alice".to_string(),
1250 password: "secret".to_string(),
1251 }),
1252 };
1253 let client = build_client(2_000, true, Some(&proxy), None);
1254 let _: reqwest::Client = client;
1255 }
1256
1257 #[test]
1258 fn custom_agent_invoked() {
1259 let called = Arc::new(Mutex::new(false));
1260 let called_clone = Arc::clone(&called);
1261
1262 let agent = HttpAgent(Arc::new(move |builder: ClientBuilder| {
1263 *called_clone.lock().unwrap() = true;
1264 builder
1265 }));
1266
1267 let client = build_client(1_000, true, None, Some(agent));
1268 assert!(*called.lock().unwrap(), "agent closure wasn’t invoked");
1269 let _: reqwest::Client = client;
1270 }
1271 }
1272
1273 mod build_user_agent {
1274 use crate::common::utils::build_user_agent;
1275
1276 #[test]
1277 fn build_user_agent_contains_crate_product_and_rust_info() {
1278 let product = "product";
1279 let user_agent = build_user_agent(product);
1280
1281 let name = env!("CARGO_PKG_NAME");
1282 let version = env!("CARGO_PKG_VERSION");
1283 let rustc = env!("RUSTC_VERSION");
1284 let os = std::env::consts::OS;
1285 let arch = std::env::consts::ARCH;
1286
1287 let expected_prefix = format!("{name}/{product}/{version} (Rust/");
1288 assert!(
1289 user_agent.starts_with(&expected_prefix),
1290 "prefix mismatch: {user_agent}"
1291 );
1292
1293 assert!(
1294 user_agent.contains(rustc),
1295 "user agent missing RUSTC_VERSION: {user_agent}"
1296 );
1297
1298 assert!(
1299 user_agent.contains(&format!("; {os}")),
1300 "user agent missing OS: {user_agent}"
1301 );
1302 assert!(
1303 user_agent.contains(&format!("; {arch}")),
1304 "user agent missing ARCH: {user_agent}"
1305 );
1306 }
1307
1308 #[test]
1309 fn build_user_agent_is_deterministic() {
1310 let product = "product";
1311 let user_agent1 = build_user_agent(product);
1312 let user_agent2 = build_user_agent(product);
1313 assert_eq!(
1314 user_agent1, user_agent2,
1315 "user agent should be the same on repeated calls"
1316 );
1317 }
1318 }
1319
1320 mod validate_time_unit {
1321 use crate::common::utils::validate_time_unit;
1322
1323 #[test]
1324 fn empty_string_returns_none() {
1325 let res = validate_time_unit("").expect("Should not error on empty string");
1326 assert_eq!(res, None);
1327 }
1328
1329 #[test]
1330 fn uppercase_millisecond() {
1331 let res = validate_time_unit("MILLISECOND").expect("Should accept MILLISECOND");
1332 assert_eq!(res, Some("MILLISECOND"));
1333 }
1334
1335 #[test]
1336 fn uppercase_microsecond() {
1337 let res = validate_time_unit("MICROSECOND").expect("Should accept MICROSECOND");
1338 assert_eq!(res, Some("MICROSECOND"));
1339 }
1340
1341 #[test]
1342 fn lowercase_millisecond() {
1343 let res = validate_time_unit("millisecond").expect("Should accept millisecond");
1344 assert_eq!(res, Some("millisecond"));
1345 }
1346
1347 #[test]
1348 fn lowercase_microsecond() {
1349 let res = validate_time_unit("microsecond").expect("Should accept microsecond");
1350 assert_eq!(res, Some("microsecond"));
1351 }
1352
1353 #[test]
1354 fn invalid_value_returns_err() {
1355 let err = validate_time_unit("SECOND").unwrap_err();
1356 let msg = format!("{err}");
1357 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1358 }
1359
1360 #[test]
1361 fn partial_match_returns_err() {
1362 let err = validate_time_unit("MILLI").unwrap_err();
1363 let msg = format!("{err}");
1364 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1365 }
1366 }
1367
1368 mod get_timestamp {
1369 use crate::common::utils::get_timestamp;
1370 use std::{
1371 thread::sleep,
1372 time::{Duration, SystemTime, UNIX_EPOCH},
1373 };
1374
1375 #[test]
1376 fn timestamp_is_within_system_time_bounds() {
1377 let before = SystemTime::now()
1378 .duration_since(UNIX_EPOCH)
1379 .expect("SystemTime before UNIX_EPOCH")
1380 .as_millis();
1381 let ts = get_timestamp();
1382 let after = SystemTime::now()
1383 .duration_since(UNIX_EPOCH)
1384 .expect("SystemTime before UNIX_EPOCH")
1385 .as_millis();
1386
1387 assert!(
1388 ts >= before,
1389 "timestamp {ts} is before captured before time {before}"
1390 );
1391 assert!(
1392 ts <= after,
1393 "timestamp {ts} is after captured after time {after}"
1394 );
1395 }
1396
1397 #[test]
1398 fn timestamps_are_monotonic() {
1399 let t1 = get_timestamp();
1400 sleep(Duration::from_millis(1));
1401 let t2 = get_timestamp();
1402 assert!(
1403 t2 >= t1,
1404 "second timestamp {t2} is not >= first timestamp {t1}"
1405 );
1406 }
1407 }
1408
1409 mod build_query_string {
1410 use std::collections::BTreeMap;
1411
1412 use anyhow::Result;
1413 use serde_json::{Value, json};
1414 use url::form_urlencoded::Serializer;
1415
1416 use crate::common::utils::build_query_string;
1417
1418 fn mk_map(pairs: Vec<(&str, Value)>) -> BTreeMap<String, Value> {
1419 let mut m = BTreeMap::new();
1420 for (k, v) in pairs {
1421 m.insert(k.to_string(), v);
1422 }
1423 m
1424 }
1425
1426 #[test]
1427 fn empty_map_returns_empty_string() -> Result<()> {
1428 let params = BTreeMap::new();
1429 let qs = build_query_string(¶ms)?;
1430 assert_eq!(qs, "");
1431 Ok(())
1432 }
1433
1434 #[test]
1435 fn string_and_number_and_bool() -> Result<()> {
1436 let params = mk_map(vec![
1437 ("foo", json!("bar")),
1438 ("num", json!(42)),
1439 ("flag", json!(true)),
1440 ]);
1441 let qs = build_query_string(¶ms)?;
1442 assert_eq!(qs, "flag=true&foo=bar&num=42");
1443 Ok(())
1444 }
1445
1446 #[test]
1447 fn null_is_skipped() -> Result<()> {
1448 let params = mk_map(vec![("a", json!(true)), ("b", Value::Null)]);
1449 let qs = build_query_string(¶ms)?;
1450 assert_eq!(qs, "a=true");
1451 Ok(())
1452 }
1453
1454 #[test]
1455 fn percent_encode_special_chars() -> Result<()> {
1456 let params = mk_map(vec![
1457 ("space", json!("hello world")),
1458 ("symbols", json!("a/b?c")),
1459 ]);
1460 let qs = build_query_string(¶ms)?;
1461 let mut parts = vec![];
1462 let mut ser = Serializer::new(String::new());
1463 ser.append_pair("space", "hello world");
1464 parts.push(ser.finish());
1465 let mut ser = Serializer::new(String::new());
1466 ser.append_pair("symbols", "a/b?c");
1467 parts.push(ser.finish());
1468 let expected = parts.join("&");
1469 assert_eq!(qs, expected);
1470 Ok(())
1471 }
1472
1473 #[test]
1474 fn primitive_array_json_encoded() -> Result<()> {
1475 let params = mk_map(vec![
1476 ("strs", json!(["a", "b", "c"])),
1477 ("nums", json!([1, 2, 3])),
1478 ("bools", json!([true, false])),
1479 ]);
1480 let qs = build_query_string(¶ms)?;
1481
1482 let mut parts = Vec::new();
1483 for (k, v) in ¶ms {
1484 let json = serde_json::to_string(v)?;
1485 let mut ser = Serializer::new(String::new());
1486 ser.append_pair(k, &json);
1487 parts.push(ser.finish());
1488 }
1489 let expected = parts.join("&");
1490 assert_eq!(qs, expected);
1491 Ok(())
1492 }
1493
1494 #[test]
1495 fn nested_array_json_encoded() -> Result<()> {
1496 let params = mk_map(vec![("nested", json!([[1, 2], [3, 4]]))]);
1497 let qs = build_query_string(¶ms)?;
1498
1499 let nested_json = serde_json::to_string(&json!([[1, 2], [3, 4]]))?;
1500 let mut ser = Serializer::new(String::new());
1501 ser.append_pair("nested", &nested_json);
1502 let expected = ser.finish();
1503
1504 assert_eq!(qs, expected);
1505 Ok(())
1506 }
1507
1508 #[test]
1509 fn object_json_encoded() -> Result<()> {
1510 let params = mk_map(vec![("obj", json!({"k":1, "v":"two"}))]);
1511 let qs = build_query_string(¶ms)?;
1512
1513 let obj_json = serde_json::to_string(&json!({"k":1, "v":"two"}))?;
1514 let mut ser = Serializer::new(String::new());
1515 ser.append_pair("obj", &obj_json);
1516 let expected = ser.finish();
1517
1518 assert_eq!(qs, expected);
1519 Ok(())
1520 }
1521
1522 #[test]
1523 fn empty_array() {
1524 let params = mk_map(vec![("foo", json!([]))]);
1525 let qs = build_query_string(¶ms).unwrap();
1526
1527 let json = serde_json::to_string(&json!([])).unwrap();
1528 let expected = Serializer::new(String::new())
1529 .append_pair("foo", &json)
1530 .finish();
1531 assert_eq!(qs, expected);
1532 }
1533
1534 #[test]
1535 fn mixed_array() {
1536 let params = mk_map(vec![("mix", json!([1, "x", false]))]);
1537 let qs = build_query_string(¶ms).unwrap();
1538
1539 let json = serde_json::to_string(&json!([1, "x", false])).unwrap();
1540 let expected = Serializer::new(String::new())
1541 .append_pair("mix", &json)
1542 .finish();
1543 assert_eq!(qs, expected);
1544 }
1545
1546 #[test]
1547 fn array_of_objects() {
1548 let params = mk_map(vec![("objs", json!([{"a":1}, {"b":2}]))]);
1549 let qs = build_query_string(¶ms).unwrap();
1550
1551 let json = serde_json::to_string(&json!([{"a":1}, {"b":2}])).unwrap();
1552 let expected = Serializer::new(String::new())
1553 .append_pair("objs", &json)
1554 .finish();
1555 assert_eq!(qs, expected);
1556 }
1557
1558 #[test]
1559 fn empty_object() {
1560 let params = mk_map(vec![("emp", json!({}))]);
1561 let qs = build_query_string(¶ms).unwrap();
1562
1563 let json = serde_json::to_string(&json!({})).unwrap();
1564 let expected = Serializer::new(String::new())
1565 .append_pair("emp", &json)
1566 .finish();
1567 assert_eq!(qs, expected);
1568 }
1569
1570 #[test]
1571 fn floats_and_negatives() {
1572 let params = mk_map(vec![("fl", json!(1.23456)), ("neg", json!(-0.001))]);
1573 let qs = build_query_string(¶ms).unwrap();
1574 assert_eq!(qs, "fl=1.23456&neg=-0.001");
1575 }
1576
1577 #[test]
1578 fn unicode_and_special_key() {
1579 let params = mk_map(vec![
1580 ("こんにちは", json!("世界")),
1581 ("weird key/?=", json!("val")),
1582 ]);
1583 let qs = build_query_string(¶ms).unwrap();
1584
1585 let mut parts = Vec::new();
1586 for (k, v) in ¶ms {
1587 let mut ser = Serializer::new(String::new());
1588 ser.append_pair(k, v.as_str().unwrap());
1589 parts.push(ser.finish());
1590 }
1591 let expected = parts.join("&");
1592 assert_eq!(qs, expected);
1593 }
1594
1595 #[test]
1596 fn empty_string_value() {
1597 let params = mk_map(vec![("empty", json!(""))]);
1598 let qs = build_query_string(¶ms).unwrap();
1599 assert_eq!(qs, "empty=");
1600 }
1601
1602 #[test]
1603 fn nulls_in_array() {
1604 let params = mk_map(vec![("a", json!([null, 1, "x"]))]);
1605 let qs = build_query_string(¶ms).unwrap();
1606
1607 let json = serde_json::to_string(&json!([null, 1, "x"])).unwrap();
1608 let expected = Serializer::new(String::new())
1609 .append_pair("a", &json)
1610 .finish();
1611 assert_eq!(qs, expected);
1612 }
1613
1614 #[test]
1615 fn special_chars_in_key() {
1616 let params = mk_map(vec![("a=b&c%", json!("val"))]);
1617 let qs = build_query_string(¶ms).unwrap();
1618
1619 let expected = Serializer::new(String::new())
1620 .append_pair("a=b&c%", "val")
1621 .finish();
1622 assert_eq!(qs, expected);
1623 }
1624
1625 #[test]
1626 fn empty_key() {
1627 let params = mk_map(vec![("", json!("v"))]);
1628 let qs = build_query_string(¶ms).unwrap();
1629 assert_eq!(qs, "=v");
1630 }
1631 }
1632
1633 mod signature_generator {
1634 use base64::{Engine, engine::general_purpose};
1635 use ed25519_dalek::{SigningKey, ed25519::signature::SignerMut, pkcs8::DecodePrivateKey};
1636 use hex;
1637 use hmac::{Hmac, Mac};
1638 use openssl::{hash::MessageDigest, pkey::PKey, rsa::Rsa, sign::Verifier};
1639 use serde_json::Value;
1640 use sha2::Sha256;
1641 use std::collections::BTreeMap;
1642 use std::io::Write;
1643 use tempfile::NamedTempFile;
1644
1645 use crate::{common::utils::SignatureGenerator, config::PrivateKey};
1646
1647 #[test]
1648 fn hmac_sha256_signature() {
1649 let mut params = BTreeMap::new();
1650 params.insert("b".into(), Value::Number(2.into()));
1651 params.insert("a".into(), Value::Number(1.into()));
1652
1653 let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1654 let sig = signature_gen
1655 .get_signature(¶ms, None)
1656 .expect("HMAC signing failed");
1657
1658 let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1659 let qs = "a=1&b=2";
1660 mac.update(qs.as_bytes());
1661 let expected = hex::encode(mac.finalize().into_bytes());
1662
1663 assert_eq!(sig, expected);
1664 }
1665
1666 #[test]
1667 fn hmac_sha256_signature_with_body() {
1668 let mut query_params = BTreeMap::new();
1669 query_params.insert("b".into(), Value::Number(2.into()));
1670 query_params.insert("a".into(), Value::Number(1.into()));
1671
1672 let mut body_params = BTreeMap::new();
1673 body_params.insert("d".into(), Value::Number(4.into()));
1674 body_params.insert("c".into(), Value::Number(3.into()));
1675
1676 let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1677 let sig = signature_gen
1678 .get_signature(&query_params, Some(&body_params))
1679 .expect("HMAC signing with body failed");
1680
1681 let query_str = "a=1&b=2";
1682 let body_str = "c=3&d=4";
1683
1684 let payload = format!("{query_str}{body_str}");
1685
1686 let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1687 mac.update(payload.as_bytes());
1688 let expected = hex::encode(mac.finalize().into_bytes());
1689
1690 assert_eq!(sig, expected);
1691 }
1692
1693 #[test]
1694 fn repeated_hmac_signature() {
1695 let mut params = BTreeMap::new();
1696 params.insert("x".into(), Value::String("y".into()));
1697 let signature_gen = SignatureGenerator::new(Some("abc".into()), None, None);
1698 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1699 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1700 assert_eq!(s1, s2);
1701 }
1702
1703 #[test]
1704 fn rsa_signature_verification() {
1705 let mut params = BTreeMap::new();
1706 params.insert("a".into(), Value::Number(1.into()));
1707 params.insert("b".into(), Value::Number(2.into()));
1708
1709 let rsa = Rsa::generate(2048).unwrap();
1710 let priv_pem = rsa.private_key_to_pem().unwrap();
1711 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1712
1713 let signature_gen =
1714 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1715 let sig = signature_gen
1716 .get_signature(¶ms, None)
1717 .expect("RSA signing failed");
1718
1719 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1720 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1721 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1722 verifier.update(b"a=1&b=2").unwrap();
1723 assert!(verifier.verify(&sig_bytes).unwrap());
1724 }
1725
1726 #[test]
1727 fn rsa_signature_verification_with_body() {
1728 let mut query_params = BTreeMap::new();
1729 query_params.insert("a".into(), Value::Number(1.into()));
1730 query_params.insert("b".into(), Value::Number(2.into()));
1731
1732 let mut body_params = BTreeMap::new();
1733 body_params.insert("c".into(), Value::Number(3.into()));
1734 body_params.insert("d".into(), Value::Number(4.into()));
1735
1736 let rsa = Rsa::generate(2048).unwrap();
1737 let priv_pem = rsa.private_key_to_pem().unwrap();
1738 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1739
1740 let signature_gen =
1741 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1742 let sig = signature_gen
1743 .get_signature(&query_params, Some(&body_params))
1744 .expect("RSA signing with body failed");
1745
1746 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1747 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1748 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1749 verifier.update(b"a=1&b=2c=3&d=4").unwrap();
1750 assert!(verifier.verify(&sig_bytes).unwrap());
1751 }
1752
1753 #[test]
1754 fn repeated_rsa_signature() {
1755 let mut params = BTreeMap::new();
1756 params.insert("k".into(), Value::Number(5.into()));
1757 let rsa = Rsa::generate(2048).unwrap();
1758 let priv_pem = rsa.private_key_to_pem().unwrap();
1759 let signature_gen =
1760 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem)), None);
1761 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1762 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1763 assert_eq!(s1, s2);
1764 }
1765
1766 #[test]
1767 fn ed25519_signature_verification() {
1768 let mut params = BTreeMap::new();
1769 params.insert("a".into(), Value::Number(1.into()));
1770 params.insert("b".into(), Value::Number(2.into()));
1771 let qs = "a=1&b=2";
1772
1773 let ed = PKey::generate_ed25519().unwrap();
1774 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1775
1776 let signature_gen =
1777 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1778 let sig = signature_gen
1779 .get_signature(¶ms, None)
1780 .expect("Ed25519 signing failed");
1781
1782 let pem_str = String::from_utf8(priv_pem).unwrap();
1783 let b64 = pem_str
1784 .lines()
1785 .filter(|l| !l.starts_with("-----"))
1786 .collect::<String>();
1787 let der = general_purpose::STANDARD.decode(b64).unwrap();
1788 let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1789 let expected_bytes = sk.sign(qs.as_bytes()).to_bytes();
1790 let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1791 assert_eq!(sig, expected_sig);
1792 }
1793
1794 #[test]
1795 fn ed25519_signature_verification_with_body() {
1796 let mut query_params = BTreeMap::new();
1797 query_params.insert("a".into(), Value::Number(1.into()));
1798 query_params.insert("b".into(), Value::Number(2.into()));
1799 let qs = "a=1&b=2";
1800
1801 let mut body_params = BTreeMap::new();
1802 body_params.insert("c".into(), Value::Number(3.into()));
1803 body_params.insert("d".into(), Value::Number(4.into()));
1804 let body_qs = "c=3&d=4";
1805
1806 let ed = PKey::generate_ed25519().unwrap();
1807 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1808
1809 let signature_gen =
1810 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1811 let sig = signature_gen
1812 .get_signature(&query_params, Some(&body_params))
1813 .expect("Ed25519 signing with body failed");
1814
1815 let pem_str = String::from_utf8(priv_pem).unwrap();
1816 let b64 = pem_str
1817 .lines()
1818 .filter(|l| !l.starts_with("-----"))
1819 .collect::<String>();
1820 let der = general_purpose::STANDARD.decode(b64).unwrap();
1821 let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1822 let payload = format!("{qs}{body_qs}");
1823 let expected_bytes = sk.sign(payload.as_bytes()).to_bytes();
1824 let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1825 assert_eq!(sig, expected_sig);
1826 }
1827
1828 #[test]
1829 fn repeated_ed25519_signature() {
1830 let mut params = BTreeMap::new();
1831 params.insert("m".into(), Value::String("n".into()));
1832 let ed = PKey::generate_ed25519().unwrap();
1833 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1834 let signature_gen =
1835 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1836 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1837 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1838 assert_eq!(s1, s2);
1839 }
1840
1841 #[test]
1842 fn file_based_key() {
1843 let rsa = Rsa::generate(1024).unwrap();
1844 let priv_pem = rsa.private_key_to_pem().unwrap();
1845 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1846
1847 let mut file = NamedTempFile::new().unwrap();
1848 file.write_all(&priv_pem).unwrap();
1849 let path = file.path().to_str().unwrap().to_string();
1850
1851 let mut params = BTreeMap::new();
1852 params.insert("z".into(), Value::Number(9.into()));
1853
1854 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::File(path)), None);
1855 let sig = signature_gen.get_signature(¶ms, None).unwrap();
1856
1857 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1858 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1859 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1860 verifier.update(b"z=9").unwrap();
1861 assert!(verifier.verify(&sig_bytes).unwrap());
1862 }
1863
1864 #[test]
1865 fn unsupported_key_type_error() {
1866 let mut params = BTreeMap::new();
1867 params.insert("x".into(), Value::String("y".into()));
1868
1869 let group =
1870 openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
1871 let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
1872 let pkey_ec = PKey::from_ec_key(ec_key).unwrap();
1873 let raw = pkey_ec.private_key_to_pem_pkcs8().unwrap();
1874
1875 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::Raw(raw)), None);
1876 let err = signature_gen
1877 .get_signature(¶ms, None)
1878 .unwrap_err()
1879 .to_string();
1880 assert!(err.contains("Unsupported private key type"));
1881 }
1882
1883 #[test]
1884 fn invalid_private_key_error() {
1885 let mut params = BTreeMap::new();
1886 params.insert("foo".into(), Value::String("bar".into()));
1887
1888 let signature_gen =
1889 SignatureGenerator::new(None, Some(PrivateKey::Raw(b"not a key".to_vec())), None);
1890 let err = signature_gen
1891 .get_signature(¶ms, None)
1892 .unwrap_err()
1893 .to_string();
1894 assert!(err.contains("Failed to parse private key"));
1895 }
1896
1897 #[test]
1898 fn missing_credentials_error() {
1899 let mut params = BTreeMap::new();
1900 params.insert("a".into(), Value::Number(1.into()));
1901
1902 let signature_gen = SignatureGenerator::new(None, None, None);
1903 let err = signature_gen
1904 .get_signature(¶ms, None)
1905 .unwrap_err()
1906 .to_string();
1907 assert!(err.contains("Either 'api_secret' or 'private_key' must be provided"));
1908 }
1909 }
1910
1911 mod should_retry_request {
1912 use crate::common::utils::should_retry_request;
1913
1914 use reqwest::{Error, Response};
1915
1916 fn mk_http_error(code: u16) -> Error {
1917 let resp = Response::from(
1918 http::response::Response::builder()
1919 .status(code)
1920 .body("")
1921 .unwrap(),
1922 );
1923 resp.error_for_status().unwrap_err()
1924 }
1925
1926 fn mk_network_error() -> Error {
1927 reqwest::blocking::get("http://256.256.256.256").unwrap_err()
1928 }
1929
1930 #[test]
1931 fn retry_on_retriable_status_and_method() {
1932 let err = mk_http_error(500);
1933 assert!(should_retry_request(&err, Some("GET"), Some(1)));
1934 assert!(should_retry_request(&err, Some("delete"), Some(2)));
1935 }
1936
1937 #[test]
1938 fn retry_when_status_none_and_retriable_method() {
1939 let retriable_methods = ["GET", "DELETE"];
1940
1941 for &method in &retriable_methods {
1942 let err = mk_network_error();
1943 assert!(
1944 should_retry_request(&err, Some(method), Some(1)),
1945 "Should retry when no status and method {method}"
1946 );
1947 }
1948 }
1949
1950 #[test]
1951 fn no_retry_when_no_retries_left() {
1952 let err = mk_http_error(503);
1953 assert!(!should_retry_request(&err, Some("GET"), Some(0)));
1954 }
1955
1956 #[test]
1957 fn no_retry_on_non_retriable_status() {
1958 let non_retriable_statuses = [400, 401, 404, 422];
1959
1960 for &status in &non_retriable_statuses {
1961 let err = mk_http_error(status);
1962 assert!(
1963 !should_retry_request(&err, Some("GET"), Some(2)),
1964 "Should not retry for non-retriable status {status}"
1965 );
1966 }
1967 }
1968
1969 #[test]
1970 fn no_retry_on_non_retriable_method() {
1971 let non_retriable_methods = ["POST", "PUT", "PATCH"];
1972
1973 for &method in &non_retriable_methods {
1974 let err = mk_http_error(500);
1975 assert!(
1976 !should_retry_request(&err, Some(method), Some(2)),
1977 "Should not retry for non-retriable method {method}"
1978 );
1979 }
1980 }
1981
1982 #[test]
1983 fn no_retry_when_status_none_and_non_retriable_method() {
1984 let non_retriable_methods = ["POST", "PUT"];
1985
1986 for &method in &non_retriable_methods {
1987 let err = mk_network_error();
1988 assert!(
1989 !should_retry_request(&err, Some(method), Some(1)),
1990 "Should not retry when no status and method {method}"
1991 );
1992 }
1993 }
1994 }
1995
1996 mod parse_rate_limit_headers_tests {
1997 use crate::common::{
1998 models::{Interval, RateLimitType},
1999 utils::parse_rate_limit_headers,
2000 };
2001 use std::collections::HashMap;
2002
2003 fn mk_headers(pairs: Vec<(&str, &str)>) -> HashMap<String, String> {
2004 let mut m = HashMap::new();
2005 for (k, v) in pairs {
2006 m.insert(k.to_string(), v.to_string());
2007 }
2008 m
2009 }
2010
2011 #[test]
2012 fn single_weight_header() {
2013 let headers = mk_headers(vec![("x-mbx-used-weight-1s", "123")]);
2014 let limits = parse_rate_limit_headers(&headers);
2015 assert_eq!(limits.len(), 1);
2016 let rl = &limits[0];
2017 assert_eq!(rl.rate_limit_type, RateLimitType::RequestWeight);
2018 assert_eq!(rl.interval, Interval::Second);
2019 assert_eq!(rl.interval_num, 1);
2020 assert_eq!(rl.count, 123);
2021 assert_eq!(rl.retry_after, None);
2022 }
2023
2024 #[test]
2025 fn single_order_count_with_retry_after() {
2026 let headers = mk_headers(vec![("x-mbx-order-count-5m", "42"), ("retry-after", "7")]);
2027 let limits = parse_rate_limit_headers(&headers);
2028 assert_eq!(limits.len(), 1);
2029 let rl = &limits[0];
2030 assert_eq!(rl.rate_limit_type, RateLimitType::Orders);
2031 assert_eq!(rl.interval, Interval::Minute);
2032 assert_eq!(rl.interval_num, 5);
2033 assert_eq!(rl.count, 42);
2034 assert_eq!(rl.retry_after, Some(7));
2035 }
2036
2037 #[test]
2038 fn multiple_headers() {
2039 let headers = mk_headers(vec![
2040 ("X-MBX-USED-WEIGHT-1h", "10"),
2041 ("x-mbx-order-count-2d", "20"),
2042 ]);
2043 let mut limits = parse_rate_limit_headers(&headers);
2044 limits.sort_by_key(|r| (r.interval_num, format!("{:?}", r.rate_limit_type)));
2045 assert_eq!(limits.len(), 2);
2046 let w = &limits[0];
2047 assert_eq!(w.rate_limit_type, RateLimitType::RequestWeight);
2048 assert_eq!(w.interval, Interval::Hour);
2049 assert_eq!(w.interval_num, 1);
2050 assert_eq!(w.count, 10);
2051 let o = &limits[1];
2052 assert_eq!(o.rate_limit_type, RateLimitType::Orders);
2053 assert_eq!(o.interval, Interval::Day);
2054 assert_eq!(o.interval_num, 2);
2055 assert_eq!(o.count, 20);
2056 }
2057
2058 #[test]
2059 fn ignores_unknown_and_malformed() {
2060 let headers = mk_headers(vec![
2061 ("x-mbx-used-weight-3x", "5"),
2062 ("random-header", "100"),
2063 ]);
2064 let limits = parse_rate_limit_headers(&headers);
2065 assert!(limits.is_empty());
2066 }
2067 }
2068
2069 mod http_request {
2070 use std::io::Write;
2071
2072 use flate2::{Compression, write::GzEncoder};
2073 use httpmock::MockServer;
2074 use reqwest::{Client, Method, Request};
2075 use serde::Deserialize;
2076
2077 use crate::{
2078 common::utils::http_request, config::ConfigurationRestApi, errors::ConnectorError,
2079 models::RestApiResponse,
2080 };
2081
2082 use super::TOKIO_SHARED_RT;
2083
2084 #[derive(Deserialize, Debug, PartialEq)]
2085 struct Dummy {
2086 foo: String,
2087 }
2088
2089 fn make_config(server_url: &str) -> ConfigurationRestApi {
2090 ConfigurationRestApi::builder()
2091 .api_key("key")
2092 .api_secret("secret")
2093 .base_path(server_url)
2094 .build()
2095 .expect("Failed to build configuration")
2096 }
2097
2098 #[test]
2099 fn http_request_success_plain_text() {
2100 TOKIO_SHARED_RT.block_on(async {
2101 let server = MockServer::start();
2102 let mock = server.mock(|when, then| {
2103 when.method(httpmock::Method::GET).path("/test");
2104 then.status(200)
2105 .header("Content-Type", "application/json")
2106 .body(r#"{"foo":"bar"}"#);
2107 });
2108
2109 let client = Client::new();
2110 let req: Request = client
2111 .request(Method::GET, format!("{}{}", server.url(""), "/test"))
2112 .build()
2113 .unwrap();
2114
2115 let cfg = make_config(&server.url(""));
2116 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
2117 assert_eq!(resp.status, 200);
2118 let data = resp.data().await.unwrap();
2119 assert_eq!(data, Dummy { foo: "bar".into() });
2120 mock.assert();
2121 });
2122 }
2123
2124 #[test]
2125 fn http_request_success_gzip() {
2126 TOKIO_SHARED_RT.block_on(async {
2127 let server = MockServer::start();
2128 let body = r#"{"foo":"baz"}"#;
2129 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
2130 encoder.write_all(body.as_bytes()).unwrap();
2131 let gz = encoder.finish().unwrap();
2132
2133 let mock = server.mock(|when, then| {
2134 when.method(httpmock::Method::GET).path("/gz");
2135 then.status(200)
2136 .header("Content-Type", "application/json")
2137 .header("Content-Encoding", "gzip")
2138 .body(gz);
2139 });
2140
2141 let client = Client::new();
2142 let req: Request = client
2143 .request(Method::GET, format!("{}{}", server.url(""), "/gz"))
2144 .build()
2145 .unwrap();
2146 let mut cfg = make_config(&server.url(""));
2147 cfg.compression = true;
2148
2149 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
2150 assert_eq!(resp.status, 200);
2151 let data = resp.data().await.unwrap();
2152 assert_eq!(data, Dummy { foo: "baz".into() });
2153 mock.assert();
2154 });
2155 }
2156
2157 #[test]
2158 fn http_request_client_error_bad_request() {
2159 TOKIO_SHARED_RT.block_on(async {
2160 let server = MockServer::start();
2161 let mock = server.mock(|when, then| {
2162 when.method(httpmock::Method::GET).path("/400");
2163 then.status(400)
2164 .header("Content-Type", "application/json")
2165 .body(r#"{"code":-1121,"msg":"bad request"}"#);
2166 });
2167
2168 let client = Client::new();
2169 let req: Request = client
2170 .request(Method::GET, format!("{}{}", server.url(""), "/400"))
2171 .build()
2172 .unwrap();
2173 let cfg = make_config(&server.url(""));
2174
2175 let result = http_request::<Dummy>(req, &cfg).await;
2176
2177 assert!(matches!(
2178 result,
2179 Err(ConnectorError::BadRequestError { .. })
2180 ));
2181
2182 if let Err(ConnectorError::BadRequestError { msg, code }) = result {
2183 assert_eq!(msg, "bad request");
2184 assert_eq!(code, Some(-1121));
2185 }
2186
2187 mock.assert();
2188 });
2189 }
2190
2191 #[test]
2192 fn http_request_client_error_unauthorized() {
2193 TOKIO_SHARED_RT.block_on(async {
2194 let server = MockServer::start();
2195 let mock = server.mock(|when, then| {
2196 when.method(httpmock::Method::GET).path("/401");
2197 then.status(401)
2198 .header("Content-Type", "application/json")
2199 .body(r#"{"code":-2015,"msg":"unauthorized"}"#);
2200 });
2201
2202 let client = Client::new();
2203 let req: Request = client
2204 .request(Method::GET, format!("{}{}", server.url(""), "/401"))
2205 .build()
2206 .unwrap();
2207 let cfg = make_config(&server.url(""));
2208
2209 let result = http_request::<Dummy>(req, &cfg).await;
2210
2211 assert!(matches!(
2212 result,
2213 Err(ConnectorError::UnauthorizedError { .. })
2214 ));
2215
2216 if let Err(ConnectorError::UnauthorizedError { msg, code }) = result {
2217 assert_eq!(msg, "unauthorized");
2218 assert_eq!(code, Some(-2015));
2219 }
2220
2221 mock.assert();
2222 });
2223 }
2224
2225 #[test]
2226 fn http_request_client_error_forbidden() {
2227 TOKIO_SHARED_RT.block_on(async {
2228 let server = MockServer::start();
2229 let mock = server.mock(|when, then| {
2230 when.method(httpmock::Method::GET).path("/403");
2231 then.status(403)
2232 .header("Content-Type", "application/json")
2233 .body(r#"{"code":-2010,"msg":"forbidden"}"#);
2234 });
2235
2236 let client = Client::new();
2237 let req: Request = client
2238 .request(Method::GET, format!("{}{}", server.url(""), "/403"))
2239 .build()
2240 .unwrap();
2241 let cfg = make_config(&server.url(""));
2242
2243 let result = http_request::<Dummy>(req, &cfg).await;
2244
2245 assert!(matches!(result, Err(ConnectorError::ForbiddenError { .. })));
2246
2247 if let Err(ConnectorError::ForbiddenError { msg, code }) = result {
2248 assert_eq!(msg, "forbidden");
2249 assert_eq!(code, Some(-2010));
2250 }
2251
2252 mock.assert();
2253 });
2254 }
2255
2256 #[test]
2257 fn http_request_client_error_not_found() {
2258 TOKIO_SHARED_RT.block_on(async {
2259 let server = MockServer::start();
2260 let mock = server.mock(|when, then| {
2261 when.method(httpmock::Method::GET).path("/404");
2262 then.status(404)
2263 .header("Content-Type", "application/json")
2264 .body(r#"{"code":-1003,"msg":"not found"}"#);
2265 });
2266
2267 let client = Client::new();
2268 let req: Request = client
2269 .request(Method::GET, format!("{}{}", server.url(""), "/404"))
2270 .build()
2271 .unwrap();
2272 let cfg = make_config(&server.url(""));
2273
2274 let result = http_request::<Dummy>(req, &cfg).await;
2275
2276 assert!(matches!(result, Err(ConnectorError::NotFoundError { .. })));
2277
2278 if let Err(ConnectorError::NotFoundError { msg, code }) = result {
2279 assert_eq!(msg, "not found");
2280 assert_eq!(code, Some(-1003));
2281 }
2282
2283 mock.assert();
2284 });
2285 }
2286
2287 #[test]
2288 fn http_request_client_error_rate_limit_exceeded() {
2289 TOKIO_SHARED_RT.block_on(async {
2290 let server = MockServer::start();
2291 let mock = server.mock(|when, then| {
2292 when.method(httpmock::Method::GET).path("/418");
2293 then.status(418)
2294 .header("Content-Type", "application/json")
2295 .body(r#"{"code":-1003,"msg":"rate limit exceeded"}"#);
2296 });
2297
2298 let client = Client::new();
2299 let req: Request = client
2300 .request(Method::GET, format!("{}{}", server.url(""), "/418"))
2301 .build()
2302 .unwrap();
2303 let cfg = make_config(&server.url(""));
2304
2305 let result = http_request::<Dummy>(req, &cfg).await;
2306
2307 assert!(matches!(
2308 result,
2309 Err(ConnectorError::RateLimitBanError { .. })
2310 ));
2311
2312 if let Err(ConnectorError::RateLimitBanError { msg, code }) = result {
2313 assert_eq!(msg, "rate limit exceeded");
2314 assert_eq!(code, Some(-1003));
2315 }
2316
2317 mock.assert();
2318 });
2319 }
2320
2321 #[test]
2322 fn http_request_client_error_too_many_requests() {
2323 TOKIO_SHARED_RT.block_on(async {
2324 let server = MockServer::start();
2325 let mock = server.mock(|when, then| {
2326 when.method(httpmock::Method::GET).path("/429");
2327 then.status(429)
2328 .header("Content-Type", "application/json")
2329 .body(r#"{"code":-1003,"msg":"too many requests"}"#);
2330 });
2331
2332 let client = Client::new();
2333 let req: Request = client
2334 .request(Method::GET, format!("{}{}", server.url(""), "/429"))
2335 .build()
2336 .unwrap();
2337 let cfg = make_config(&server.url(""));
2338
2339 let result = http_request::<Dummy>(req, &cfg).await;
2340
2341 assert!(matches!(
2342 result,
2343 Err(ConnectorError::TooManyRequestsError { .. })
2344 ));
2345
2346 if let Err(ConnectorError::TooManyRequestsError { msg, code }) = result {
2347 assert_eq!(msg, "too many requests");
2348 assert_eq!(code, Some(-1003));
2349 }
2350
2351 mock.assert();
2352 });
2353 }
2354
2355 #[test]
2356 fn http_request_client_error_server_error() {
2357 TOKIO_SHARED_RT.block_on(async {
2358 let server = MockServer::start();
2359 let mock = server.mock(|when, then| {
2360 when.method(httpmock::Method::GET).path("/500");
2361 then.status(500)
2362 .header("Content-Type", "application/json")
2363 .body(r#"{"code":-1000,"msg":"internal server error"}"#);
2364 });
2365
2366 let client = Client::new();
2367 let req: Request = client
2368 .request(Method::GET, format!("{}{}", server.url(""), "/500"))
2369 .build()
2370 .unwrap();
2371 let cfg = make_config(&server.url(""));
2372
2373 let result = http_request::<Dummy>(req, &cfg).await;
2374
2375 assert!(matches!(result, Err(ConnectorError::ServerError { .. })));
2376
2377 if let Err(ConnectorError::ServerError {
2378 msg,
2379 status_code: Some(500),
2380 }) = result
2381 {
2382 assert_eq!(msg, "Server error: 500".to_string());
2383 }
2384
2385 mock.assert();
2386 });
2387 }
2388
2389 #[test]
2390 fn http_request_unexpected_status_maps_generic() {
2391 TOKIO_SHARED_RT.block_on(async {
2392 let server = MockServer::start();
2393 let code_http = 402;
2394 let mock = server.mock(|when, then| {
2395 when.method(httpmock::Method::GET).path("/402");
2396 then.status(code_http)
2397 .header("Content-Type", "application/json")
2398 .body(r#"{"code":-12345,"msg":"payment required"}"#);
2399 });
2400
2401 let client = Client::new();
2402 let req: Request = client
2403 .request(Method::GET, format!("{}{}", server.url(""), "/402"))
2404 .build()
2405 .unwrap();
2406 let cfg = make_config(&server.url(""));
2407
2408 let result = http_request::<Dummy>(req, &cfg).await;
2409
2410 assert!(matches!(
2411 result,
2412 Err(ConnectorError::ConnectorClientError { .. })
2413 ));
2414
2415 if let Err(ConnectorError::ConnectorClientError { msg, code }) = result {
2416 assert_eq!(msg, "payment required");
2417 assert_eq!(code, Some(-12345));
2418 }
2419
2420 mock.assert();
2421 });
2422 }
2423
2424 #[test]
2425 fn http_request_malformed_json_maps_generic() {
2426 TOKIO_SHARED_RT.block_on(async {
2427 let server = MockServer::start();
2428 let mock = server.mock(|when, then| {
2429 when.method(httpmock::Method::GET).path("/malformed");
2430 then.status(200)
2431 .header("Content-Type", "application/json")
2432 .body("not json");
2433 });
2434
2435 let client = Client::new();
2436 let req: Request = client
2437 .request(Method::GET, format!("{}{}", server.url(""), "/malformed"))
2438 .build()
2439 .unwrap();
2440 let cfg = make_config(&server.url(""));
2441
2442 let resp = http_request::<Dummy>(req, &cfg)
2443 .await
2444 .expect("http_request should succeed even if JSON is bad");
2445
2446 let err = resp
2447 .data()
2448 .await
2449 .expect_err("malformed JSON should turn into ConnectorClientError");
2450
2451 assert!(matches!(err, ConnectorError::ConnectorClientError { .. }));
2452
2453 if let ConnectorError::ConnectorClientError { msg: _, code } = err {
2454 assert_eq!(code, None);
2455 }
2456
2457 mock.assert();
2458 });
2459 }
2460 }
2461
2462 mod send_request {
2463 use anyhow::Result;
2464 use httpmock::prelude::*;
2465 use reqwest::Method;
2466 use serde::Deserialize;
2467 use serde_json::json;
2468 use std::collections::{BTreeMap, HashMap};
2469
2470 use crate::{
2471 common::{models::TimeUnit, utils::send_request},
2472 config::ConfigurationRestApi,
2473 };
2474
2475 use super::TOKIO_SHARED_RT;
2476
2477 #[derive(Deserialize, Debug, PartialEq)]
2478 struct TestResponse {
2479 message: String,
2480 }
2481
2482 #[test]
2483 fn basic_get_request() -> Result<()> {
2484 TOKIO_SHARED_RT.block_on(async {
2485 let server = MockServer::start();
2486
2487 server.mock(|when, then| {
2488 when.method(GET).path("/api/v1/test");
2489 then.status(200)
2490 .header("content-type", "application/json")
2491 .body(r#"{"message": "success"}"#);
2492 });
2493
2494 let configuration = ConfigurationRestApi::builder()
2495 .api_key("key")
2496 .api_secret("secret")
2497 .base_path(server.base_url())
2498 .compression(false)
2499 .build()
2500 .expect("Failed to build configuration");
2501
2502 let params = BTreeMap::new();
2503
2504 let result = send_request::<TestResponse>(
2505 &configuration,
2506 "/api/v1/test",
2507 Method::GET,
2508 params,
2509 BTreeMap::new(),
2510 None,
2511 false,
2512 )
2513 .await?;
2514
2515 let data = result.data().await.unwrap();
2516 assert_eq!(data.message, "success");
2517
2518 Ok(())
2519 })
2520 }
2521
2522 #[test]
2523 fn signed_post_request() -> Result<()> {
2524 TOKIO_SHARED_RT.block_on(async {
2525 let server = MockServer::start();
2526
2527 server.mock(|when, then| {
2528 when.method(POST).path("/api/v3/order");
2529 then.status(200)
2530 .header("content-type", "application/json")
2531 .body(r#"{"message": "order placed"}"#);
2532 });
2533
2534 let configuration = ConfigurationRestApi::builder()
2535 .api_key("key")
2536 .api_secret("secret")
2537 .base_path(server.base_url())
2538 .compression(false)
2539 .build()
2540 .expect("Failed to build configuration");
2541
2542 let mut params = BTreeMap::new();
2543 params.insert("symbol".to_string(), json!("ETHUSDT"));
2544 params.insert("side".to_string(), json!("BUY"));
2545 params.insert("type".to_string(), json!("MARKET"));
2546 params.insert("quantity".to_string(), json!("1"));
2547
2548 let result = send_request::<TestResponse>(
2549 &configuration,
2550 "/api/v3/order",
2551 Method::POST,
2552 params,
2553 BTreeMap::new(),
2554 None,
2555 true,
2556 )
2557 .await?;
2558
2559 let data = result.data().await.unwrap();
2560 assert_eq!(data.message, "order placed");
2561
2562 Ok(())
2563 })
2564 }
2565
2566 #[test]
2567 fn signed_post_request_with_body() -> Result<()> {
2568 TOKIO_SHARED_RT.block_on(async {
2569 let server = MockServer::start();
2570
2571 server.mock(|when, then| {
2572 when.method(POST).path("/api/v3/order");
2573 then.status(200)
2574 .header("content-type", "application/json")
2575 .body(r#"{"message": "order placed"}"#);
2576 });
2577
2578 let configuration = ConfigurationRestApi::builder()
2579 .api_key("key")
2580 .api_secret("secret")
2581 .base_path(server.base_url())
2582 .compression(false)
2583 .build()
2584 .expect("Failed to build configuration");
2585
2586 let mut query_params = BTreeMap::new();
2587 query_params.insert("symbol".to_string(), json!("ETHUSDT"));
2588
2589 let mut body_params = BTreeMap::new();
2590 body_params.insert("side".to_string(), json!("BUY"));
2591 body_params.insert("type".to_string(), json!("MARKET"));
2592 body_params.insert("quantity".to_string(), json!("1"));
2593
2594 let result = send_request::<TestResponse>(
2595 &configuration,
2596 "/api/v3/order",
2597 Method::POST,
2598 query_params,
2599 body_params,
2600 None,
2601 true,
2602 )
2603 .await?;
2604
2605 let data = result.data().await.unwrap();
2606 assert_eq!(data.message, "order placed");
2607
2608 Ok(())
2609 })
2610 }
2611
2612 #[test]
2613 fn get_request_with_params() -> Result<()> {
2614 TOKIO_SHARED_RT.block_on(async {
2615 let server = MockServer::start();
2616
2617 server.mock(|when, then| {
2618 when.method(GET)
2619 .path("/api/v1/data")
2620 .query_param("symbol", "BTCUSDT")
2621 .query_param("limit", "10");
2622 then.status(200)
2623 .header("content-type", "application/json")
2624 .body(r#"{"message": "data retrieved"}"#);
2625 });
2626
2627 let configuration = ConfigurationRestApi::builder()
2628 .api_key("key")
2629 .api_secret("secret")
2630 .base_path(server.base_url())
2631 .compression(false)
2632 .build()
2633 .expect("Failed to build configuration");
2634
2635 let mut params = BTreeMap::new();
2636 params.insert("symbol".to_string(), json!("BTCUSDT"));
2637 params.insert("limit".to_string(), json!(10));
2638
2639 let result = send_request::<TestResponse>(
2640 &configuration,
2641 "/api/v1/data",
2642 Method::GET,
2643 params,
2644 BTreeMap::new(),
2645 None,
2646 false,
2647 )
2648 .await?;
2649
2650 let data = result.data().await.unwrap();
2651 assert_eq!(data.message, "data retrieved");
2652
2653 Ok(())
2654 })
2655 }
2656
2657 #[test]
2658 fn invalid_endpoint() {
2659 TOKIO_SHARED_RT.block_on(async {
2660 let server = MockServer::start();
2661
2662 let configuration = ConfigurationRestApi::builder()
2663 .api_key("key")
2664 .api_secret("secret")
2665 .base_path(server.base_url())
2666 .compression(false)
2667 .build()
2668 .expect("Failed to build configuration");
2669
2670 let params = BTreeMap::new();
2671
2672 let result = send_request::<TestResponse>(
2673 &configuration,
2674 "http://invalid",
2675 Method::GET,
2676 params,
2677 BTreeMap::new(),
2678 None,
2679 false,
2680 )
2681 .await;
2682
2683 assert!(result.is_err());
2684 });
2685 }
2686
2687 #[test]
2688 fn missing_signature_on_signed_request() {
2689 TOKIO_SHARED_RT.block_on(async {
2690 let server = MockServer::start();
2691
2692 let configuration = ConfigurationRestApi::builder()
2693 .api_key("key")
2694 .api_secret("secret")
2695 .base_path(server.base_url())
2696 .compression(false)
2697 .build()
2698 .expect("Failed to build configuration");
2699
2700 let mut params = BTreeMap::new();
2701 params.insert("symbol".to_string(), json!("BTCUSDT"));
2702 params.insert("side".to_string(), json!("BUY"));
2703
2704 let result = send_request::<TestResponse>(
2705 &configuration,
2706 "/api/v3/order",
2707 Method::POST,
2708 params,
2709 BTreeMap::new(),
2710 None,
2711 true,
2712 )
2713 .await;
2714
2715 assert!(result.is_err());
2716 });
2717 }
2718
2719 #[test]
2720 fn compression_enabled() -> Result<()> {
2721 TOKIO_SHARED_RT.block_on(async {
2722 let server = MockServer::start();
2723
2724 server.mock(|when, then| {
2725 when.method(GET).path("/api/v1/test");
2726 then.status(200)
2727 .header("content-type", "application/json")
2728 .header("accept-encoding", "gzip, deflate, br")
2729 .body(r#"{"message": "compression enabled"}"#);
2730 });
2731
2732 let configuration = ConfigurationRestApi::builder()
2733 .api_key("key")
2734 .api_secret("secret")
2735 .base_path(server.base_url())
2736 .compression(true)
2737 .build()
2738 .expect("Failed to build configuration");
2739
2740 let params = BTreeMap::new();
2741
2742 let result = send_request::<TestResponse>(
2743 &configuration,
2744 "/api/v1/test",
2745 Method::GET,
2746 params,
2747 BTreeMap::new(),
2748 None,
2749 false,
2750 )
2751 .await?;
2752
2753 let data = result.data().await.unwrap();
2754 assert_eq!(data.message, "compression enabled");
2755
2756 Ok(())
2757 })
2758 }
2759
2760 #[test]
2761 fn get_request_with_time_unit_header() -> Result<()> {
2762 TOKIO_SHARED_RT.block_on(async {
2763 let server = MockServer::start();
2764
2765 server.mock(|when, then| {
2766 when.method(GET)
2767 .path("/api/v1/test")
2768 .header("X-MBX-TIME-UNIT", "MILLISECOND");
2769 then.status(200)
2770 .header("content-type", "application/json")
2771 .body(r#"{"message": "time unit applied"}"#);
2772 });
2773
2774 let configuration = ConfigurationRestApi::builder()
2775 .api_key("key")
2776 .api_secret("secret")
2777 .base_path(server.base_url())
2778 .compression(false)
2779 .time_unit(TimeUnit::Millisecond)
2780 .build()
2781 .expect("Failed to build configuration");
2782
2783 let params = BTreeMap::new();
2784
2785 let result = send_request::<TestResponse>(
2786 &configuration,
2787 "/api/v1/test",
2788 Method::GET,
2789 params,
2790 BTreeMap::new(),
2791 Some(TimeUnit::Millisecond),
2792 false,
2793 )
2794 .await?;
2795
2796 let data = result.data().await.unwrap();
2797 assert_eq!(data.message, "time unit applied");
2798
2799 Ok(())
2800 })
2801 }
2802
2803 #[test]
2804 fn custom_headers_are_sent() -> Result<()> {
2805 TOKIO_SHARED_RT.block_on(async {
2806 let server = MockServer::start();
2807
2808 server.mock(|when, then| {
2809 when.method(GET)
2810 .path("/api/v1/test")
2811 .header("X-My-Test", "all-clear");
2812 then.status(200)
2813 .header("content-type", "application/json")
2814 .body(r#"{"message":"ok"}"#);
2815 });
2816
2817 let mut custom = HashMap::new();
2818 custom.insert("X-My-Test".to_string(), "all-clear".to_string());
2819
2820 let configuration = ConfigurationRestApi::builder()
2821 .api_key("key")
2822 .api_secret("secret")
2823 .base_path(server.base_url())
2824 .compression(false)
2825 .custom_headers(custom)
2826 .build()
2827 .expect("Failed to build configuration");
2828
2829 let params = BTreeMap::new();
2830 let res = send_request::<TestResponse>(
2831 &configuration,
2832 "/api/v1/test",
2833 Method::GET,
2834 params,
2835 BTreeMap::new(),
2836 None,
2837 false,
2838 )
2839 .await?;
2840
2841 let data = res.data().await.unwrap();
2842 assert_eq!(data.message, "ok");
2843
2844 Ok(())
2845 })
2846 }
2847
2848 #[test]
2849 fn custom_header_override_prevention() -> Result<()> {
2850 TOKIO_SHARED_RT.block_on(async {
2851 let server = MockServer::start();
2852
2853 server.mock(|when, then| {
2854 when.method(GET)
2855 .path("/api/v1/test")
2856 .header("content-type", "application/json")
2857 .header("x-mbx-apikey", "key")
2858 .header("X-My-Test", "ok");
2859 then.status(200)
2860 .header("content-type", "application/json")
2861 .body(r#"{"message":"defaults intact"}"#);
2862 });
2863
2864 let mut custom = HashMap::new();
2865 custom.insert("Content-Type".to_string(), "text/plain".to_string());
2866 custom.insert("X-MBX-APIKEY".to_string(), "BAD".to_string());
2867 custom.insert("X-My-Test".to_string(), "ok".to_string());
2868
2869 let configuration = ConfigurationRestApi::builder()
2870 .api_key("key")
2871 .api_secret("secret")
2872 .base_path(server.base_url())
2873 .compression(false)
2874 .custom_headers(custom)
2875 .build()
2876 .expect("Failed to build configuration");
2877
2878 let params = BTreeMap::new();
2879 let res = send_request::<TestResponse>(
2880 &configuration,
2881 "/api/v1/test",
2882 Method::GET,
2883 params,
2884 BTreeMap::new(),
2885 None,
2886 false,
2887 )
2888 .await?;
2889
2890 let data = res.data().await.unwrap();
2891 assert_eq!(data.message, "defaults intact");
2892
2893 Ok(())
2894 })
2895 }
2896
2897 #[test]
2898 fn crlf_in_header_values_are_dropped() -> Result<()> {
2899 TOKIO_SHARED_RT.block_on(async {
2900 let server = MockServer::start();
2901
2902 server.mock(|when, then| {
2903 when.method(GET)
2904 .path("/api/v1/test")
2905 .header("X-Good", "safe");
2906 then.status(200)
2907 .header("content-type", "application/json")
2908 .body(r#"{"message":"clean only"}"#);
2909 });
2910
2911 let mut custom = HashMap::new();
2912 custom.insert("X-Bad".to_string(), "evil\r\ninject".to_string());
2913 custom.insert("X-Good".to_string(), "safe".to_string());
2914
2915 let configuration = ConfigurationRestApi::builder()
2916 .api_key("key")
2917 .api_secret("secret")
2918 .base_path(server.base_url())
2919 .compression(false)
2920 .custom_headers(custom)
2921 .build()
2922 .expect("Failed to build configuration");
2923
2924 let params = BTreeMap::new();
2925 let res = send_request::<TestResponse>(
2926 &configuration,
2927 "/api/v1/test",
2928 Method::GET,
2929 params,
2930 BTreeMap::new(),
2931 None,
2932 false,
2933 )
2934 .await?;
2935
2936 let data = res.data().await.unwrap();
2937 assert_eq!(data.message, "clean only");
2938
2939 Ok(())
2940 })
2941 }
2942 }
2943
2944 mod random_string {
2945 use crate::common::utils::random_string;
2946 use hex;
2947
2948 #[test]
2949 fn length_is_32() {
2950 let s = random_string();
2951 assert_eq!(
2952 s.len(),
2953 32,
2954 "random_string() should be 32 chars, got {}",
2955 s.len()
2956 );
2957 }
2958
2959 #[test]
2960 fn is_valid_lowercase_hex() {
2961 let s = random_string();
2962 assert!(
2963 s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')),
2964 "random_string() contains invalid hex characters: {s}"
2965 );
2966 }
2967
2968 #[test]
2969 fn decodes_to_16_bytes() {
2970 let s = random_string();
2971 let bytes = hex::decode(&s).expect("random_string() output must be valid hex");
2972 assert_eq!(
2973 bytes.len(),
2974 16,
2975 "hex::decode returned {} bytes",
2976 bytes.len()
2977 );
2978 }
2979
2980 #[test]
2981 fn two_calls_are_different() {
2982 let a = random_string();
2983 let b = random_string();
2984 assert_ne!(
2985 a, b,
2986 "Two calls to random_string() returned the same value: {a}"
2987 );
2988 }
2989 }
2990
2991 mod random_integer {
2992 use crate::common::utils::random_integer;
2993
2994 #[test]
2995 fn is_within_u32_range() {
2996 let n = random_integer();
2997 assert!(
2998 n <= u32::MAX,
2999 "random_integer() should be <= u32::MAX, got {n}"
3000 );
3001 }
3002
3003 #[test]
3004 fn two_calls_can_differ() {
3005 let a = random_integer();
3006 let b = random_integer();
3007 assert_ne!(
3008 a, b,
3009 "Two calls to random_integer() returned the same value: {a}"
3010 );
3011 }
3012 }
3013
3014 mod normalize_stream_id {
3015 use crate::common::utils::{StreamId, normalize_stream_id};
3016 use serde_json::Value;
3017
3018 fn is_lower_hex32(s: &str) -> bool {
3019 s.len() == 32 && s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f'))
3020 }
3021
3022 #[test]
3023 fn valid_hex_string_is_kept() {
3024 let id = "0123456789abcdef0123456789abcdef".to_string();
3025 let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
3026
3027 match out {
3028 Value::String(s) => assert_eq!(s, id, "Expected to keep the valid hex id"),
3029 other => panic!("Expected Value::String, got {other:?}"),
3030 }
3031 }
3032
3033 #[test]
3034 fn invalid_hex_string_generates_random_hex() {
3035 let id = "not-hex".to_string();
3036 let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
3037
3038 match out {
3039 Value::String(s) => {
3040 assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
3041 assert_ne!(s, id, "Expected generated id to differ from input");
3042 assert!(
3043 is_lower_hex32(&s),
3044 "Generated id contains invalid hex characters: {s}"
3045 );
3046 }
3047 other => panic!("Expected Value::String, got {other:?}"),
3048 }
3049 }
3050
3051 #[test]
3052 fn none_generates_random_hex() {
3053 let out = normalize_stream_id(None, false);
3054
3055 match out {
3056 Value::String(s) => {
3057 assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
3058 assert!(
3059 is_lower_hex32(&s),
3060 "Generated id contains invalid hex characters: {s}"
3061 );
3062 }
3063 other => panic!("Expected Value::String, got {other:?}"),
3064 }
3065 }
3066
3067 #[test]
3068 fn number_is_kept_when_not_strict() {
3069 let out = normalize_stream_id(Some(StreamId::Number(42)), false);
3070
3071 match out {
3072 Value::Number(n) => {
3073 assert_eq!(n.as_u64(), Some(42), "Expected to keep the numeric id");
3074 }
3075 other => panic!("Expected Value::Number, got {other:?}"),
3076 }
3077 }
3078
3079 #[test]
3080 fn strict_number_forces_number_even_for_valid_hex_string() {
3081 let id = "0123456789abcdef0123456789abcdef".to_string();
3082 let out = normalize_stream_id(Some(StreamId::Str(id)), true);
3083
3084 match out {
3085 Value::Number(n) => {
3086 assert!(
3087 n.as_u64().is_some(),
3088 "Expected unsigned integer JSON number, got {n}"
3089 );
3090 }
3091 other => panic!("Expected Value::Number, got {other:?}"),
3092 }
3093 }
3094
3095 #[test]
3096 fn strict_number_keeps_number_if_provided() {
3097 let out = normalize_stream_id(Some(StreamId::Number(7)), true);
3098
3099 match out {
3100 Value::Number(n) => {
3101 assert_eq!(n.as_u64(), Some(7), "Expected to keep the numeric id");
3102 }
3103 other => panic!("Expected Value::Number, got {other:?}"),
3104 }
3105 }
3106
3107 #[test]
3108 fn strict_number_generates_number_when_none() {
3109 let out = normalize_stream_id(None, true);
3110
3111 match out {
3112 Value::Number(n) => {
3113 assert!(
3114 n.as_u64().is_some(),
3115 "Expected unsigned integer JSON number, got {n}"
3116 );
3117 }
3118 other => panic!("Expected Value::Number, got {other:?}"),
3119 }
3120 }
3121
3122 #[test]
3123 fn strict_number_generates_number_for_invalid_hex_string() {
3124 let out = normalize_stream_id(Some(StreamId::Str("nope".to_string())), true);
3125
3126 match out {
3127 Value::Number(n) => {
3128 assert!(
3129 n.as_u64().is_some(),
3130 "Expected unsigned integer JSON number, got {n}"
3131 );
3132 }
3133 other => panic!("Expected Value::Number, got {other:?}"),
3134 }
3135 }
3136 }
3137 mod remove_empty_value {
3138 use crate::common::utils::remove_empty_value;
3139 use serde_json::{Map, Value};
3140
3141 #[test]
3142 fn filters_out_null_and_empty_strings() {
3143 let entries = vec![
3144 ("key1".to_string(), Value::String("value1".to_string())),
3145 ("key2".to_string(), Value::Null),
3146 ("key3".to_string(), Value::String(String::new())),
3147 ];
3148 let result = remove_empty_value(entries);
3149 assert_eq!(
3150 result.len(),
3151 1,
3152 "expected only one entry, got {}",
3153 result.len()
3154 );
3155 assert_eq!(
3156 result.get("key1"),
3157 Some(&Value::String("value1".to_string()))
3158 );
3159 assert!(!result.contains_key("key2"));
3160 assert!(!result.contains_key("key3"));
3161 }
3162
3163 #[test]
3164 fn retains_other_value_types() {
3165 let entries = vec![
3166 ("bool".to_string(), Value::Bool(true)),
3167 ("num".to_string(), Value::Number(42.into())),
3168 ("arr".to_string(), Value::Array(vec![])),
3169 ("obj".to_string(), Value::Object(Map::default())),
3170 ("nil".to_string(), Value::Null),
3171 ("empty_str".to_string(), Value::String(String::new())),
3172 ];
3173 let result = remove_empty_value(entries);
3174 let keys: Vec<&String> = result.keys().collect();
3175 assert_eq!(keys.len(), 4, "expected 4 entries, got {}", keys.len());
3176 assert!(result.get("bool") == Some(&Value::Bool(true)));
3177 assert!(result.get("num") == Some(&Value::Number(42.into())));
3178 assert!(result.get("arr") == Some(&Value::Array(vec![])));
3179 assert!(result.get("obj") == Some(&Value::Object(Map::default())));
3180 assert!(!result.contains_key("nil"));
3181 assert!(!result.contains_key("empty_str"));
3182 }
3183
3184 #[test]
3185 fn empty_iterator_returns_empty_map() {
3186 let entries: Vec<(String, Value)> = vec![];
3187 let result = remove_empty_value(entries);
3188 assert!(result.is_empty(), "expected an empty map");
3189 }
3190
3191 #[test]
3192 fn keys_are_sorted() {
3193 let entries = vec![
3194 ("c".to_string(), Value::String("foo".to_string())),
3195 ("a".to_string(), Value::String("bar".to_string())),
3196 ("b".to_string(), Value::String("baz".to_string())),
3197 ];
3198 let result = remove_empty_value(entries);
3199 let sorted_keys: Vec<&String> = result.keys().collect();
3200 assert_eq!(
3201 sorted_keys,
3202 [&"a".to_string(), &"b".to_string(), &"c".to_string()]
3203 );
3204 }
3205 }
3206
3207 mod sort_object_params {
3208 use crate::common::utils::sort_object_params;
3209 use serde_json::Value;
3210 use std::collections::BTreeMap;
3211
3212 #[test]
3213 fn sorts_keys() {
3214 let mut params = BTreeMap::new();
3215 params.insert("z".to_string(), Value::String("last".to_string()));
3216 params.insert("a".to_string(), Value::String("first".to_string()));
3217 params.insert("m".to_string(), Value::String("middle".to_string()));
3218
3219 let sorted = sort_object_params(¶ms);
3220 let keys: Vec<&String> = sorted.keys().collect();
3221 assert_eq!(
3222 keys,
3223 [&"a".to_string(), &"m".to_string(), &"z".to_string()],
3224 "Keys should be sorted alphabetically"
3225 );
3226 }
3227
3228 #[test]
3229 fn preserves_values() {
3230 let mut params = BTreeMap::new();
3231 params.insert("one".to_string(), Value::Number(1.into()));
3232 params.insert("two".to_string(), Value::Bool(true));
3233
3234 let sorted = sort_object_params(¶ms);
3235 assert_eq!(sorted.get("one"), Some(&Value::Number(1.into())));
3236 assert_eq!(sorted.get("two"), Some(&Value::Bool(true)));
3237 }
3238
3239 #[test]
3240 fn empty_map_returns_empty() {
3241 let params: BTreeMap<String, Value> = BTreeMap::new();
3242 let sorted = sort_object_params(¶ms);
3243 assert!(sorted.is_empty(), "Expected empty map");
3244 }
3245
3246 #[test]
3247 fn independent_clone() {
3248 let mut params = BTreeMap::new();
3249 params.insert("key".to_string(), Value::String("val".to_string()));
3250
3251 let mut sorted = sort_object_params(¶ms);
3252 sorted.insert("new".to_string(), Value::String("x".to_string()));
3253
3254 assert!(
3255 !params.contains_key("new"),
3256 "Original should not be modified when changing sorted"
3257 );
3258 assert!(
3259 sorted.contains_key("new"),
3260 "Sorted map should reflect its own insertions"
3261 );
3262 }
3263 }
3264
3265 mod normalize_ws_streams_key {
3266 use crate::common::utils::normalize_ws_streams_key;
3267
3268 #[test]
3269 fn returns_empty_for_empty() {
3270 assert_eq!(normalize_ws_streams_key(""), "");
3271 }
3272
3273 #[test]
3274 fn already_normalized_stays_same() {
3275 assert_eq!(normalize_ws_streams_key("streamname"), "streamname");
3276 }
3277
3278 #[test]
3279 fn uppercases_are_lowercased() {
3280 assert_eq!(normalize_ws_streams_key("MyStream"), "mystream");
3281 }
3282
3283 #[test]
3284 fn underscores_are_removed() {
3285 assert_eq!(normalize_ws_streams_key("my_stream_name"), "mystreamname");
3286 }
3287
3288 #[test]
3289 fn hyphens_are_removed() {
3290 assert_eq!(normalize_ws_streams_key("my-stream-name"), "mystreamname");
3291 }
3292
3293 #[test]
3294 fn mixed_underscores_and_hyphens_and_case() {
3295 let input = "Mixed_Case-Stream_Name";
3296 let expected = "mixedcasestreamname";
3297 assert_eq!(normalize_ws_streams_key(input), expected);
3298 }
3299
3300 #[test]
3301 fn retains_other_punctuation() {
3302 assert_eq!(normalize_ws_streams_key("stream.name!"), "stream.name!");
3303 }
3304 }
3305
3306 mod replace_websocket_streams_placeholders {
3307 use crate::common::utils::replace_websocket_streams_placeholders;
3308 use std::collections::HashMap;
3309
3310 #[test]
3311 fn empty_string_unchanged() {
3312 let vars: HashMap<&str, &str> = HashMap::new();
3313 assert_eq!(replace_websocket_streams_placeholders("", &vars), "");
3314 }
3315
3316 #[test]
3317 fn unknown_placeholder_becomes_empty() {
3318 let vars: HashMap<&str, &str> = HashMap::new();
3319 assert_eq!(replace_websocket_streams_placeholders("<foo>", &vars), "");
3320 }
3321
3322 #[test]
3323 fn leading_slash_symbol_lowercases_head() {
3324 let mut vars = HashMap::new();
3325 vars.insert("symbol", "BTC");
3326 assert_eq!(
3327 replace_websocket_streams_placeholders("/<symbol>", &vars),
3328 "btc"
3329 );
3330 }
3331
3332 #[test]
3333 fn no_lowercase_without_slash() {
3334 let mut vars = HashMap::new();
3335 vars.insert("symbol", "BTC");
3336 assert_eq!(
3337 replace_websocket_streams_placeholders("<symbol>", &vars),
3338 "BTC"
3339 );
3340 }
3341
3342 #[test]
3343 fn multiple_placeholders_mid_preserve_ats() {
3344 let mut vars = HashMap::new();
3345 vars.insert("symbol", "BNBUSDT");
3346 vars.insert("levels", "10");
3347 vars.insert("updateSpeed", "1000ms");
3348 let out = replace_websocket_streams_placeholders(
3349 "/<symbol>@depth<levels>@<updateSpeed>",
3350 &vars,
3351 );
3352 assert_eq!(out, "bnbusdt@depth10@1000ms");
3353 }
3354
3355 #[test]
3356 fn trailing_at_removed_when_missing_var() {
3357 let mut vars = HashMap::new();
3358 vars.insert("symbol", "BNBUSDT");
3359 vars.insert("levels", "10");
3360 let out = replace_websocket_streams_placeholders(
3361 "/<symbol>@depth<levels>@<updateSpeed>",
3362 &vars,
3363 );
3364 assert_eq!(out, "bnbusdt@depth10");
3365 }
3366
3367 #[test]
3368 fn custom_key_normalization_and_value() {
3369 let mut vars = HashMap::new();
3370 vars.insert("my-stream_key", "Value");
3371 assert_eq!(
3372 replace_websocket_streams_placeholders("<My_Stream-Key>", &vars),
3373 "Value"
3374 );
3375 }
3376
3377 #[test]
3378 fn text_surrounding_placeholders_intact() {
3379 let mut vars = HashMap::new();
3380 vars.insert("symbol", "ABC");
3381 let input = "pre-<symbol>-post";
3382 assert_eq!(
3383 replace_websocket_streams_placeholders(input, &vars),
3384 "pre-ABC-post"
3385 );
3386 }
3387 }
3388
3389 mod build_websocket_api_message {
3390 use serde_json::{Value, json};
3391 use std::collections::BTreeMap;
3392
3393 use crate::{
3394 common::{
3395 utils::{ID_REGEX, build_websocket_api_message, remove_empty_value},
3396 websocket::WebsocketMessageSendOptions,
3397 },
3398 config::ConfigurationWebsocketApi,
3399 };
3400
3401 fn make_config() -> ConfigurationWebsocketApi {
3402 ConfigurationWebsocketApi::builder()
3403 .api_key("api-key".to_string())
3404 .api_secret("api-secret".to_string())
3405 .build()
3406 .unwrap()
3407 }
3408
3409 #[test]
3410 fn no_auth_or_sign_with_skip_auth() {
3411 let mut payload = BTreeMap::new();
3412 payload.insert("foo".into(), Value::String("bar".into()));
3413 let cfg = make_config();
3414
3415 let (id, req) = build_websocket_api_message(
3416 &cfg,
3417 "method",
3418 payload.clone(),
3419 &WebsocketMessageSendOptions {
3420 with_api_key: true,
3421 is_signed: true,
3422 ..Default::default()
3423 },
3424 true,
3425 );
3426
3427 assert!(ID_REGEX.is_match(&id));
3428 assert_eq!(req["method"], "method");
3429 assert_eq!(req["params"]["foo"], "bar");
3430 assert!(req["params"].get("apiKey").is_none());
3431 assert!(req["params"].get("signature").is_none());
3432 assert!(req["params"]["timestamp"].is_number());
3433 }
3434
3435 #[test]
3436 fn only_api_key_when_not_signed() {
3437 let cfg = make_config();
3438
3439 let (id, req) = build_websocket_api_message(
3440 &cfg,
3441 "method",
3442 BTreeMap::new(),
3443 &WebsocketMessageSendOptions {
3444 with_api_key: true,
3445 is_signed: false,
3446 ..Default::default()
3447 },
3448 false,
3449 );
3450
3451 assert!(ID_REGEX.is_match(&id));
3452 assert_eq!(req["method"], "method");
3453 assert_eq!(req["params"]["apiKey"], "api-key");
3454 assert!(req["params"].get("timestamp").is_none());
3455 assert!(req["params"].get("signature").is_none());
3456 }
3457
3458 #[test]
3459 fn signed_includes_timestamp_and_signature() {
3460 let mut payload = BTreeMap::new();
3461 payload.insert("foo".into(), Value::String("bar".into()));
3462 let cfg = make_config();
3463
3464 let (id, req) = build_websocket_api_message(
3465 &cfg,
3466 "method",
3467 payload.clone(),
3468 &WebsocketMessageSendOptions {
3469 with_api_key: true,
3470 is_signed: true,
3471 ..Default::default()
3472 },
3473 false,
3474 );
3475
3476 assert!(ID_REGEX.is_match(&id));
3477 assert_eq!(req["method"], "method");
3478
3479 let params = &req["params"];
3480 assert_eq!(params["apiKey"], "api-key");
3481
3482 let timestamp = params["timestamp"].as_i64().unwrap();
3483 assert!(timestamp > 0, "timestamp should not be empty");
3484
3485 let sig = params["signature"].as_str().unwrap();
3486 assert!(!sig.is_empty(), "signature should not be empty");
3487 }
3488
3489 #[test]
3490 fn respects_provided_valid_id_and_removes_from_params() {
3491 let mut payload = BTreeMap::new();
3492 let custom = "0123456789abcdef0123456789abcdef".to_string();
3493 payload.insert("id".into(), Value::String(custom.clone()));
3494 payload.insert("foo".into(), Value::Number(123.into()));
3495
3496 let cfg = make_config();
3497 let (id, req) = build_websocket_api_message(
3498 &cfg,
3499 "method",
3500 payload.clone(),
3501 &WebsocketMessageSendOptions::default(),
3502 true,
3503 );
3504
3505 assert_eq!(id, custom);
3506 assert!(req["params"].get("id").is_none());
3507 assert_eq!(req["params"]["foo"], 123);
3508 }
3509
3510 #[test]
3511 fn skip_auth_blocks_api_and_signature_but_keeps_timestamp() {
3512 let mut payload = BTreeMap::new();
3513 payload.insert("foo".into(), Value::String("bar".into()));
3514 let cfg = make_config();
3515
3516 let (_id, req) = build_websocket_api_message(
3517 &cfg,
3518 "method",
3519 payload.clone(),
3520 &WebsocketMessageSendOptions {
3521 with_api_key: true,
3522 is_signed: true,
3523 ..Default::default()
3524 },
3525 true,
3526 );
3527
3528 let p = &req["params"];
3529 assert_eq!(p["foo"], "bar");
3530 assert!(p.get("apiKey").is_none());
3531 assert!(p.get("signature").is_none());
3532 assert!(p["timestamp"].is_number());
3533 }
3534
3535 #[test]
3536 fn random_id_changes_each_call() {
3537 let cfg = make_config();
3538 let (id1, _) = build_websocket_api_message(
3539 &cfg,
3540 "method",
3541 BTreeMap::new(),
3542 &WebsocketMessageSendOptions::default(),
3543 true,
3544 );
3545 let (id2, _) = build_websocket_api_message(
3546 &cfg,
3547 "method",
3548 BTreeMap::new(),
3549 &WebsocketMessageSendOptions::default(),
3550 true,
3551 );
3552 assert!(ID_REGEX.is_match(&id1));
3553 assert!(ID_REGEX.is_match(&id2));
3554 assert_ne!(id1, id2, "IDs should be random and not equal");
3555 }
3556
3557 #[test]
3558 fn null_and_empty_values_are_stripped() {
3559 let mut payload = BTreeMap::new();
3560 payload.insert("a".into(), Value::Null);
3561 payload.insert("b".into(), Value::String(String::new()));
3562 payload.insert("c".into(), Value::String("ok".into()));
3563
3564 let cleaned = remove_empty_value(payload.clone());
3565 assert!(!cleaned.contains_key("a"), "Null should be stripped");
3566 assert!(
3567 !cleaned.contains_key("b"),
3568 "Empty string should be stripped"
3569 );
3570 assert!(cleaned.contains_key("c"), "Non-empty string should be kept");
3571
3572 let cfg = make_config();
3573 let (_id, req) = build_websocket_api_message(
3574 &cfg,
3575 "method",
3576 payload,
3577 &WebsocketMessageSendOptions::default(),
3578 true,
3579 );
3580 let params = &req["params"];
3581 assert!(params.get("a").is_none(), "`a` should not appear");
3582 assert!(params.get("b").is_none(), "`b` should not appear");
3583 assert_eq!(params["c"], "ok", "`c` should be present with value \"ok\"");
3584 }
3585
3586 #[test]
3587 fn provided_invalid_id_gets_replaced() {
3588 let mut payload = BTreeMap::new();
3589 payload.insert("id".into(), Value::String("not-hex-32-chars".into()));
3590 let cfg = make_config();
3591 let (id, _req) = build_websocket_api_message(
3592 &cfg,
3593 "method",
3594 payload,
3595 &WebsocketMessageSendOptions::default(),
3596 true,
3597 );
3598
3599 assert!(ID_REGEX.is_match(&id));
3600 assert_ne!(id, "not-hex-32-chars");
3601 }
3602
3603 #[test]
3604 fn sign_only_includes_api_key_even_when_with_api_key_false() {
3605 let mut payload = BTreeMap::new();
3606 payload.insert("x".into(), json!(1));
3607
3608 let cfg = make_config();
3609 let (_id, req) = build_websocket_api_message(
3610 &cfg,
3611 "method",
3612 payload,
3613 &WebsocketMessageSendOptions {
3614 with_api_key: false,
3615 is_signed: true,
3616 ..Default::default()
3617 },
3618 false,
3619 );
3620 let params = &req["params"];
3621
3622 assert_eq!(params["apiKey"], "api-key");
3623 assert!(params["timestamp"].is_number());
3624 assert!(params["signature"].is_string());
3625 }
3626
3627 #[test]
3628 fn skip_auth_false_without_any_auth_flags() {
3629 let cfg = make_config();
3630 let (_id, req) = build_websocket_api_message(
3631 &cfg,
3632 "method",
3633 BTreeMap::new(),
3634 &WebsocketMessageSendOptions {
3635 with_api_key: false,
3636 is_signed: false,
3637 ..Default::default()
3638 },
3639 false,
3640 );
3641 let params = &req["params"];
3642 assert!(params.as_object().unwrap().is_empty());
3643 }
3644 }
3645}