1use anyhow::{Context, Result};
2#[cfg(feature = "openssl-tls")]
3use base64::{Engine as _, engine::general_purpose};
4#[cfg(feature = "openssl-tls")]
5use ed25519_dalek::Signer as Ed25519Signer;
6#[cfg(feature = "openssl-tls")]
7use ed25519_dalek::SigningKey;
8#[cfg(feature = "openssl-tls")]
9use ed25519_dalek::pkcs8::DecodePrivateKey;
10use flate2::read::GzDecoder;
11use hex;
12use hmac::{Hmac, Mac};
13use http::HeaderMap;
14use http::HeaderValue;
15use http::header::ACCEPT_ENCODING;
16use once_cell::sync::OnceCell;
17#[cfg(feature = "openssl-tls")]
18use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer as OpenSslSigner};
19use rand::{RngCore, rngs::OsRng};
20use regex::Captures;
21use regex::Regex;
22use reqwest::Client;
23use reqwest::Proxy;
24use reqwest::{Method, Request};
25use serde::de::DeserializeOwned;
26use serde_json::Number;
27use serde_json::{Value, json};
28use sha2::Sha256;
29use std::fmt;
30use std::fmt::Display;
31use std::hash::BuildHasher;
32use std::sync::LazyLock;
33use std::{
34 collections::BTreeMap,
35 collections::HashMap,
36 io::Read,
37 time::Duration,
38 time::{SystemTime, UNIX_EPOCH},
39};
40#[cfg(feature = "openssl-tls")]
41use std::{fs, path::Path};
42use tokio::time::sleep;
43use url::form_urlencoded;
44use url::{Url, form_urlencoded::Serializer};
45
46use super::config::{
47 ConfigurationRestApi, ConfigurationWebsocketApi, HttpAgent, PrivateKey, ProxyConfig,
48};
49use super::errors::ConnectorError;
50use super::models::{
51 Interval, RateLimitType, RestApiRateLimit, RestApiResponse, StreamId, TimeUnit,
52};
53use super::websocket::WebsocketMessageSendOptions;
54
55pub(crate) static ID_REGEX: LazyLock<Regex> =
56 LazyLock::new(|| Regex::new(r"^[0-9a-f]{32}$").unwrap());
57static PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(@)?<([^>]+)>").unwrap());
58
59#[derive(Default, Clone)]
74#[allow(dead_code)]
75pub struct SignatureGenerator {
76 api_secret: Option<String>,
77 private_key: Option<PrivateKey>,
78 private_key_passphrase: Option<String>,
79 raw_key_data: OnceCell<String>,
80 #[cfg(feature = "openssl-tls")]
81 key_object: OnceCell<PKey<openssl::pkey::Private>>,
82 #[cfg(feature = "openssl-tls")]
83 ed25519_signing_key: OnceCell<SigningKey>,
84}
85
86impl fmt::Debug for SignatureGenerator {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("SignatureGenerator")
89 .field(
90 "api_secret",
91 &self.api_secret.as_ref().map(|_| "[REDACTED]"),
92 )
93 .field(
94 "private_key",
95 &self.private_key.as_ref().map(|_| "[REDACTED]"),
96 )
97 .field(
98 "private_key_passphrase",
99 &self.private_key_passphrase.as_ref().map(|_| "[REDACTED]"),
100 )
101 .field(
102 "raw_key_data",
103 &self.raw_key_data.get().map(|_| "[REDACTED]"),
104 )
105 .field("key_object", &"[REDACTED]")
106 .field("ed25519_signing_key", &"[REDACTED]")
107 .finish()
108 }
109}
110
111impl SignatureGenerator {
112 #[must_use]
113 pub fn new(
114 api_secret: Option<String>,
115 private_key: Option<PrivateKey>,
116 private_key_passphrase: Option<String>,
117 ) -> Self {
118 SignatureGenerator {
119 api_secret,
120 private_key,
121 private_key_passphrase,
122 raw_key_data: OnceCell::new(),
123 #[cfg(feature = "openssl-tls")]
124 key_object: OnceCell::new(),
125 #[cfg(feature = "openssl-tls")]
126 ed25519_signing_key: OnceCell::new(),
127 }
128 }
129
130 #[cfg(feature = "openssl-tls")]
145 fn get_raw_key_data(&self) -> Result<&String> {
146 self.raw_key_data.get_or_try_init(|| {
147 let pk = self
148 .private_key
149 .as_ref()
150 .ok_or_else(|| anyhow::anyhow!("No private_key provided"))?;
151 match pk {
152 PrivateKey::File(path) => {
153 if Path::new(path).exists() {
154 fs::read_to_string(path)
155 .with_context(|| format!("Failed to read private key file: {path}"))
156 } else {
157 Err(anyhow::anyhow!("Private key file does not exist: {}", path))
158 }
159 }
160 PrivateKey::Raw(bytes) => Ok(String::from_utf8_lossy(bytes).to_string()),
161 }
162 })
163 }
164
165 #[cfg(feature = "openssl-tls")]
180 fn get_key_object(&self) -> Result<&PKey<openssl::pkey::Private>> {
181 self.key_object.get_or_try_init(|| {
182 let key_data = self.get_raw_key_data()?;
183 if let Some(pass) = self.private_key_passphrase.as_ref() {
184 PKey::private_key_from_pem_passphrase(key_data.as_bytes(), pass.as_bytes())
185 .context("Failed to parse private key with passphrase")
186 } else {
187 PKey::private_key_from_pem(key_data.as_bytes())
188 .context("Failed to parse private key")
189 }
190 })
191 }
192
193 #[cfg(feature = "openssl-tls")]
206 fn get_ed25519_signing_key(
207 &self,
208 key_obj: &PKey<openssl::pkey::Private>,
209 ) -> Result<&SigningKey> {
210 self.ed25519_signing_key.get_or_try_init(|| {
211 let der = key_obj
212 .private_key_to_der()
213 .context("Failed to export Ed25519 key to DER")?;
214 SigningKey::from_pkcs8_der(&der)
215 .map_err(|e| anyhow::anyhow!("Failed to parse Ed25519 key: {}", e))
216 })
217 }
218
219 pub fn get_signature(
242 &self,
243 query_params: &BTreeMap<String, Value>,
244 body_params: Option<&BTreeMap<String, Value>>,
245 ) -> Result<String> {
246 let query_str = build_query_string(query_params)?;
247 let params = if let Some(body) = body_params {
248 if body.is_empty() {
249 query_str
250 } else {
251 let body_str = build_query_string(body)?;
252 format!("{query_str}{body_str}")
253 }
254 } else {
255 query_str
256 };
257
258 if self.private_key.is_none() {
259 if let Some(secret) = self.api_secret.as_ref() {
260 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
261 .context("HMAC key initialization failed")?;
262 mac.update(params.as_bytes());
263 let result = mac.finalize().into_bytes();
264 return Ok(hex::encode(result));
265 }
266 }
267
268 if self.private_key.is_some() {
269 #[cfg(feature = "openssl-tls")]
270 {
271 let key_obj = self.get_key_object()?;
272 match key_obj.id() {
273 openssl::pkey::Id::RSA => {
274 let mut signer = OpenSslSigner::new(MessageDigest::sha256(), key_obj)
275 .context("Failed to create RSA signer")?;
276 signer
277 .update(params.as_bytes())
278 .context("Failed to update RSA signer")?;
279 let sig = signer.sign_to_vec().context("RSA signing failed")?;
280 return Ok(general_purpose::STANDARD.encode(sig));
281 }
282 openssl::pkey::Id::ED25519 => {
283 let signing_key = self.get_ed25519_signing_key(key_obj)?;
284 let signature = signing_key.sign(params.as_bytes());
285 return Ok(general_purpose::STANDARD.encode(signature.to_bytes()));
286 }
287 other => {
288 return Err(anyhow::anyhow!(
289 "Unsupported private key type: {:?}. Must be RSA or ED25519.",
290 other
291 ));
292 }
293 }
294 }
295
296 #[cfg(not(feature = "openssl-tls"))]
297 {
298 return Err(anyhow::anyhow!(
299 "Private key signing requires the 'openssl-tls' feature to be enabled."
300 ));
301 }
302 }
303
304 Err(anyhow::anyhow!(
305 "Either 'api_secret' or 'private_key' must be provided for signed requests."
306 ))
307 }
308}
309
310#[must_use]
333pub fn build_client(
334 timeout: u64,
335 keep_alive: bool,
336 proxy: Option<&ProxyConfig>,
337 agent: Option<HttpAgent>,
338) -> Client {
339 let builder = Client::builder().timeout(Duration::from_millis(timeout));
340
341 let mut builder = if keep_alive {
342 builder
343 } else {
344 builder.pool_idle_timeout(Some(Duration::from_secs(0)))
345 };
346
347 if let Some(proxy_conf) = proxy {
348 let protocol = proxy_conf
349 .protocol
350 .clone()
351 .unwrap_or_else(|| "http".to_string());
352 let proxy_url = format!("{}://{}:{}", protocol, proxy_conf.host, proxy_conf.port);
353 let mut proxy_builder = Proxy::all(&proxy_url).expect("Failed to create proxy from URL");
354 if let Some(auth) = &proxy_conf.auth {
355 proxy_builder = proxy_builder.basic_auth(&auth.username, &auth.password);
356 }
357 builder = builder.proxy(proxy_builder);
358 }
359
360 if let Some(HttpAgent(agent_fn)) = agent {
361 builder = (agent_fn)(builder);
362 }
363
364 builder.build().expect("Failed to build reqwest client")
365}
366
367#[must_use]
390pub fn build_user_agent(product: &str) -> String {
391 format!(
392 "{}/{}/{} (Rust/{}; {}; {})",
393 env!("CARGO_PKG_NAME"),
394 product,
395 env!("CARGO_PKG_VERSION"),
396 env!("RUSTC_VERSION"),
397 std::env::consts::OS,
398 std::env::consts::ARCH,
399 )
400}
401
402pub fn validate_time_unit(time_unit: &str) -> Result<Option<&str>, anyhow::Error> {
430 match time_unit {
431 "" => Ok(None),
432 "MILLISECOND" | "MICROSECOND" | "millisecond" | "microsecond" => Ok(Some(time_unit)),
433 _ => Err(anyhow::anyhow!(
434 "time_unit must be either 'MILLISECOND' or 'MICROSECOND'"
435 )),
436 }
437}
438
439#[must_use]
456pub fn get_timestamp() -> u128 {
457 SystemTime::now()
458 .duration_since(UNIX_EPOCH)
459 .expect("Time went backwards")
460 .as_millis()
461}
462
463pub async fn delay(ms: u64) {
475 sleep(Duration::from_millis(ms)).await;
476}
477
478pub fn build_query_string(params: &BTreeMap<String, Value>) -> Result<String, anyhow::Error> {
497 let mut segments = Vec::with_capacity(params.len());
498
499 for (key, value) in params {
500 if value.is_null() {
501 continue;
502 }
503
504 let value_str = match value {
505 Value::String(s) => s.clone(),
506 Value::Bool(b) => b.to_string(),
507 Value::Number(n) => n.to_string(),
508 Value::Array(_) | Value::Object(_) => serde_json::to_string(value)
509 .with_context(|| format!("failed to JSON-serialize `{}`", key))?,
510 Value::Null => unreachable!(),
511 };
512
513 let mut ser = Serializer::new(String::new());
514 ser.append_pair(key, &value_str);
515 segments.push(ser.finish());
516 }
517
518 Ok(segments.join("&"))
519}
520
521#[must_use]
529pub fn should_retry_request(
530 error: &reqwest::Error,
531 method: Option<&str>,
532 retries_left: Option<usize>,
533) -> bool {
534 let method = method.unwrap_or("");
535 let is_retriable_method =
536 method.eq_ignore_ascii_case("GET") || method.eq_ignore_ascii_case("DELETE");
537
538 let status = error.status().map_or(0, |s| s.as_u16());
539 let is_retriable_status = [500, 502, 503, 504].contains(&status);
540
541 let retries_left = retries_left.unwrap_or(0);
542 retries_left > 0 && is_retriable_method && (is_retriable_status || error.status().is_none())
543}
544
545#[must_use]
570pub fn parse_rate_limit_headers<S>(headers: &HashMap<String, String, S>) -> Vec<RestApiRateLimit>
571where
572 S: BuildHasher,
573{
574 let mut rate_limits = Vec::new();
575 let re = Regex::new(r"x-mbx-(used-weight|order-count)-(\d+)([smhd])").unwrap();
576 for (key, value) in headers {
577 let normalized_key = key.to_lowercase();
578 if normalized_key.starts_with("x-mbx-used-weight-")
579 || normalized_key.starts_with("x-mbx-order-count-")
580 {
581 if let Some(caps) = re.captures(&normalized_key) {
582 let interval_num: u32 = caps.get(2).unwrap().as_str().parse().unwrap_or(0);
583 let interval_letter = caps.get(3).unwrap().as_str().to_uppercase();
584 let interval = match interval_letter.as_str() {
585 "S" => Interval::Second,
586 "M" => Interval::Minute,
587 "H" => Interval::Hour,
588 "D" => Interval::Day,
589 _ => continue,
590 };
591 let count: u32 = value.parse().unwrap_or(0);
592 let rate_limit_type = if normalized_key.starts_with("x-mbx-used-weight-") {
593 RateLimitType::RequestWeight
594 } else {
595 RateLimitType::Orders
596 };
597
598 rate_limits.push(RestApiRateLimit {
599 rate_limit_type,
600 interval,
601 interval_num,
602 count,
603 retry_after: headers.get("retry-after").and_then(|v| v.parse().ok()),
604 });
605 }
606 }
607 }
608 rate_limits
609}
610
611pub async fn http_request<T: DeserializeOwned + Send + 'static>(
641 req: Request,
642 configuration: &ConfigurationRestApi,
643) -> Result<RestApiResponse<T>, ConnectorError> {
644 let client = &configuration.client;
645 let retries = configuration.retries as usize;
646 let backoff = configuration.backoff;
647 let mut attempt = 0;
648
649 loop {
650 let req_clone = req
651 .try_clone()
652 .context("Failed to clone request")
653 .map_err(|e| ConnectorError::ConnectorClientError {
654 msg: e.to_string(),
655 code: None,
656 })?;
657 match client.execute(req_clone).await {
658 Ok(response) => {
659 let status = response.status();
660 let headers_map: HashMap<String, String> = response
661 .headers()
662 .iter()
663 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
664 .collect();
665
666 let raw_bytes = match response.bytes().await {
667 Ok(b) => b,
668 Err(e) => {
669 attempt += 1;
670 if attempt <= retries {
671 continue;
672 }
673 return Err(ConnectorError::ConnectorClientError {
674 msg: format!("Failed to get response bytes: {e}"),
675 code: None,
676 });
677 }
678 };
679
680 let content = if headers_map
681 .get("content-encoding")
682 .is_some_and(|enc| enc.to_lowercase().contains("gzip"))
683 {
684 let mut decoder = GzDecoder::new(&raw_bytes[..]);
685 let mut decompressed = String::new();
686 decoder
687 .read_to_string(&mut decompressed)
688 .context("Failed to decompress gzip response")
689 .map_err(|e: anyhow::Error| ConnectorError::ConnectorClientError {
690 msg: e.to_string(),
691 code: None,
692 })?;
693 decompressed
694 } else {
695 String::from_utf8(raw_bytes.to_vec())
696 .context("Failed to convert response to UTF-8")
697 .map_err(|e| ConnectorError::ConnectorClientError {
698 msg: e.to_string(),
699 code: None,
700 })?
701 };
702
703 let rate_limits = parse_rate_limit_headers(&headers_map);
704
705 if status.is_client_error() || status.is_server_error() {
706 let mut err_msg = content.clone();
707 let mut err_code: Option<i64> = None;
708
709 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&content) {
710 if let Some(m) = v.get("msg").and_then(|m| m.as_str()) {
711 err_msg = m.to_string();
712 }
713 err_code = v.get("code").and_then(serde_json::Value::as_i64);
714 }
715
716 match status.as_u16() {
717 400 => {
718 return Err(ConnectorError::BadRequestError {
719 msg: err_msg,
720 code: err_code,
721 });
722 }
723 401 => {
724 return Err(ConnectorError::UnauthorizedError {
725 msg: err_msg,
726 code: err_code,
727 });
728 }
729 403 => {
730 return Err(ConnectorError::ForbiddenError {
731 msg: err_msg,
732 code: err_code,
733 });
734 }
735 404 => {
736 return Err(ConnectorError::NotFoundError {
737 msg: err_msg,
738 code: err_code,
739 });
740 }
741 418 => {
742 return Err(ConnectorError::RateLimitBanError {
743 msg: err_msg,
744 code: err_code,
745 });
746 }
747 429 => {
748 return Err(ConnectorError::TooManyRequestsError {
749 msg: err_msg,
750 code: err_code,
751 });
752 }
753 s if (500..600).contains(&s) => {
754 return Err(ConnectorError::ServerError {
755 msg: format!("Server error: {s}"),
756 status_code: Some(s),
757 });
758 }
759 _ => {
760 return Err(ConnectorError::ConnectorClientError {
761 msg: err_msg,
762 code: err_code,
763 });
764 }
765 }
766 }
767
768 let raw = content.clone();
769 return Ok(RestApiResponse {
770 data_fn: Box::new(move || {
771 Box::pin(async move {
772 let parsed: T = serde_json::from_str(&raw).map_err(|e| {
773 ConnectorError::ConnectorClientError {
774 msg: e.to_string(),
775 code: None,
776 }
777 })?;
778 Ok(parsed)
779 })
780 }),
781 status: status.as_u16(),
782 headers: headers_map,
783 rate_limits: if rate_limits.is_empty() {
784 None
785 } else {
786 Some(rate_limits)
787 },
788 });
789 }
790 Err(e) => {
791 attempt += 1;
792 if should_retry_request(&e, Some(req.method().as_str()), Some(retries - attempt)) {
793 delay(backoff * attempt as u64).await;
794 continue;
795 }
796 return Err(ConnectorError::ConnectorClientError {
797 msg: format!("HTTP request failed: {e}"),
798 code: None,
799 });
800 }
801 }
802 }
803}
804
805pub async fn send_request<T: DeserializeOwned + Send + 'static>(
832 configuration: &ConfigurationRestApi,
833 endpoint: &str,
834 method: Method,
835 mut query_params: BTreeMap<String, Value>,
836 body_params: BTreeMap<String, Value>,
837 time_unit: Option<TimeUnit>,
838 is_signed: bool,
839) -> anyhow::Result<RestApiResponse<T>> {
840 let base = configuration.base_path.as_deref().unwrap_or("");
841 let full_url = reqwest::Url::parse(base)
842 .and_then(|u| u.join(endpoint))
843 .context("Failed to join base URL and endpoint")?
844 .to_string();
845
846 if is_signed {
847 let timestamp = get_timestamp();
848 query_params.insert("timestamp".to_string(), json!(timestamp));
849 }
850
851 let signature = if is_signed {
852 let body_ref = if body_params.is_empty() {
853 None
854 } else {
855 Some(&body_params)
856 };
857 Some(
858 configuration
859 .signature_gen
860 .get_signature(&query_params, body_ref)?,
861 )
862 } else {
863 None
864 };
865
866 let mut url = Url::parse(&full_url)?;
867 {
868 let mut pairs = url.query_pairs_mut();
869 for (key, value) in &query_params {
870 let val_str = match value {
871 Value::String(s) => s.clone(),
872 _ => value.to_string(),
873 };
874 pairs.append_pair(key, &val_str);
875 }
876 if let Some(signature) = &signature {
877 pairs.append_pair("signature", signature);
878 }
879 }
880
881 let mut headers = HeaderMap::new();
882
883 let forbidden = ["host", "authorization", "cookie", ":method", ":path"]
884 .into_iter()
885 .map(str::to_ascii_lowercase)
886 .collect::<std::collections::HashSet<_>>();
887
888 if let Some(custom) = &configuration.custom_headers {
889 for (raw_name, raw_val) in custom {
890 let name = raw_name.trim();
891 if forbidden.contains(&name.to_ascii_lowercase()) {
892 continue;
893 }
894 if let (Ok(header_name), Ok(header_val)) = (
895 name.parse::<reqwest::header::HeaderName>(),
896 HeaderValue::from_str(raw_val),
897 ) {
898 headers.append(header_name, header_val);
899 }
900 }
901 }
902
903 if body_params.is_empty() {
904 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
905 } else {
906 headers.insert(
907 "Content-Type",
908 HeaderValue::from_static("application/x-www-form-urlencoded"),
909 );
910 }
911
912 headers.insert("User-Agent", configuration.user_agent.parse().unwrap());
913 if let Some(api_key) = &configuration.api_key {
914 headers.insert("X-MBX-APIKEY", HeaderValue::from_str(api_key)?);
915 }
916
917 if configuration.compression {
918 headers.insert(ACCEPT_ENCODING, "gzip, deflate, br".parse().unwrap());
919 }
920
921 let time_unit_to_apply = time_unit.or(configuration.time_unit);
922 if let Some(time_unit) = time_unit_to_apply {
923 headers.insert("X-MBX-TIME-UNIT", time_unit.as_upper_str().parse()?);
924 }
925
926 let mut req_builder = configuration.client.request(method, url).headers(headers);
927
928 if !body_params.is_empty() {
929 let mut serializer = form_urlencoded::Serializer::new(String::new());
930 for (key, value) in body_params {
931 let val_str = match value {
932 Value::String(s) => s,
933 _ => value.to_string(),
934 };
935 serializer.append_pair(&key, &val_str);
936 }
937 let body_str = serializer.finish();
938 req_builder = req_builder.body(body_str);
939 }
940
941 let req = req_builder.build()?;
942
943 Ok(http_request::<T>(req, configuration).await?)
944}
945
946#[must_use]
955pub fn random_string() -> String {
956 let mut buf = [0u8; 16];
957 rand::thread_rng().fill_bytes(&mut buf);
958 hex::encode(buf)
959}
960
961#[must_use]
970pub fn random_integer() -> u32 {
971 let mut buf = [0u8; 4];
972 OsRng.fill_bytes(&mut buf);
973 u32::from_ne_bytes(buf)
974}
975
976#[must_use]
987pub fn normalize_stream_id(id: Option<StreamId>, stream_id_is_strictly_number: bool) -> Value {
988 if stream_id_is_strictly_number {
989 let n = match id {
990 Some(StreamId::Number(n)) => n,
991 _ => random_integer(),
992 };
993 return Value::Number(Number::from(n));
994 }
995
996 match id {
997 Some(StreamId::Number(n)) => Value::Number(Number::from(n)),
998 Some(StreamId::Str(s)) => {
999 let out = if ID_REGEX.is_match(&s) {
1000 s
1001 } else {
1002 random_string()
1003 };
1004 Value::String(out)
1005 }
1006 None => Value::String(random_string()),
1007 }
1008}
1009
1010pub fn remove_empty_value<I>(entries: I) -> BTreeMap<String, Value>
1032where
1033 I: IntoIterator<Item = (String, Value)>,
1034{
1035 entries
1036 .into_iter()
1037 .filter(|(_, value)| match value {
1038 Value::Null => false,
1039 Value::String(s) if s.is_empty() => false,
1040 _ => true,
1041 })
1042 .collect()
1043}
1044
1045#[must_use]
1066pub fn sort_object_params(params: &BTreeMap<String, Value>) -> BTreeMap<String, Value> {
1067 let mut sorted = BTreeMap::new();
1068 for (k, v) in params {
1069 sorted.insert(k.clone(), v.clone());
1070 }
1071 sorted
1072}
1073
1074fn normalize_ws_streams_key(key: &str) -> String {
1084 key.to_lowercase().replace(&['_', '-'][..], "")
1085}
1086
1087pub fn replace_websocket_streams_placeholders<V, S>(
1112 input: &str,
1113 variables: &HashMap<&str, V, S>,
1114) -> String
1115where
1116 V: Display,
1117 S: BuildHasher,
1118{
1119 let original = input;
1120
1121 let body = original.strip_prefix('/').unwrap_or(original);
1123
1124 let normalized: HashMap<String, String> = variables
1126 .iter()
1127 .map(|(k, v)| (normalize_ws_streams_key(k), v.to_string()))
1128 .collect();
1129
1130 let replaced = PLACEHOLDER_RE
1132 .replace_all(body, |caps: &Captures| {
1133 let prefix = caps.get(1).map_or("", |m| m.as_str());
1134 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
1135 let val = normalized.get(&key).cloned().unwrap_or_default();
1136 format!("{prefix}{val}")
1137 })
1138 .into_owned();
1139
1140 let stripped = replaced.trim_end_matches('@').to_string();
1142
1143 let should_lower_head =
1146 original.starts_with('/') && PLACEHOLDER_RE.find(body).is_some_and(|m| m.start() == 0);
1147
1148 if should_lower_head {
1150 if let Some(caps) = PLACEHOLDER_RE.captures(body) {
1151 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
1152 let first_val = normalized.get(&key).cloned().unwrap_or_default();
1153 if stripped.starts_with(&first_val) {
1154 let tail = &stripped[first_val.len()..];
1155 format!("{}{}", first_val.to_lowercase(), tail)
1156 } else {
1157 stripped.clone()
1158 }
1159 } else {
1160 stripped.clone()
1161 }
1162 } else {
1163 stripped.clone()
1164 }
1165}
1166
1167pub fn build_websocket_api_message(
1185 configuration: &ConfigurationWebsocketApi,
1186 method: &str,
1187 mut payload: BTreeMap<String, Value>,
1188 options: &WebsocketMessageSendOptions,
1189 skip_auth: bool,
1190) -> (String, serde_json::Value) {
1191 let id = payload
1192 .get("id")
1193 .and_then(Value::as_str)
1194 .filter(|s| ID_REGEX.is_match(s))
1195 .map_or_else(random_string, String::from);
1196
1197 payload.remove("id");
1198
1199 let mut params = remove_empty_value(payload);
1200
1201 if (options.with_api_key || options.is_signed) && !skip_auth {
1202 params.insert(
1203 "apiKey".into(),
1204 Value::String(configuration.api_key.clone().expect("API key must be set")),
1205 );
1206 }
1207
1208 if options.is_signed {
1209 let ts = get_timestamp();
1210 let ts_i64 = i64::try_from(ts).expect("timestamp fits in i64");
1211 params.insert("timestamp".into(), Value::Number(ts_i64.into()));
1212
1213 let mut sorted = sort_object_params(¶ms);
1214 if !skip_auth {
1215 let sig = configuration
1216 .signature_gen
1217 .get_signature(&sorted, None)
1218 .expect("signature generation");
1219 sorted.insert("signature".into(), Value::String(sig));
1220 }
1221 params = sorted.into_iter().collect();
1222 }
1223
1224 let request = json!({
1225 "id": id,
1226 "method": method,
1227 "params": params,
1228 });
1229
1230 (id, request)
1231}
1232
1233#[cfg(test)]
1234mod tests {
1235 use crate::TOKIO_SHARED_RT;
1236
1237 mod build_client {
1238 use std::{
1239 sync::{Arc, Mutex},
1240 time::{Duration, Instant},
1241 };
1242
1243 use reqwest::ClientBuilder;
1244
1245 use crate::{
1246 common::utils::build_client,
1247 config::{HttpAgent, ProxyAuth, ProxyConfig},
1248 };
1249
1250 use super::TOKIO_SHARED_RT;
1251
1252 #[test]
1253 fn enforces_timeout() {
1254 TOKIO_SHARED_RT.block_on(async {
1255 let client = build_client(100, true, None, None);
1256 let start = Instant::now();
1257 let res = client.get("http://10.255.255.1").send().await;
1258 assert!(
1259 res.is_err(),
1260 "expected an error (timeout or connect) but got {res:?}"
1261 );
1262 let elapsed = start.elapsed();
1263 assert!(
1264 elapsed < Duration::from_millis(500),
1265 "timed out too slowly: {elapsed:?}"
1266 );
1267 });
1268 }
1269
1270 #[test]
1271 fn builds_with_keep_alive_disabled() {
1272 let client = build_client(200, false, None, None);
1273 let _: reqwest::Client = client;
1274 }
1275
1276 #[test]
1277 #[should_panic(expected = "Failed to create proxy from URL")]
1278 fn invalid_proxy_url_panics() {
1279 let bad_proxy = ProxyConfig {
1280 protocol: Some("http".to_string()),
1281 host: String::new(),
1282 port: 8080,
1283 auth: None,
1284 };
1285 let _ = build_client(1_000, true, Some(&bad_proxy), None);
1286 }
1287
1288 #[test]
1289 fn builds_with_proxy_and_auth() {
1290 let proxy = ProxyConfig {
1291 protocol: Some("https".to_string()),
1292 host: "127.0.0.1".to_string(),
1293 port: 3128,
1294 auth: Some(ProxyAuth {
1295 username: "alice".to_string(),
1296 password: "secret".to_string(),
1297 }),
1298 };
1299 let client = build_client(2_000, true, Some(&proxy), None);
1300 let _: reqwest::Client = client;
1301 }
1302
1303 #[test]
1304 fn custom_agent_invoked() {
1305 let called = Arc::new(Mutex::new(false));
1306 let called_clone = Arc::clone(&called);
1307
1308 let agent = HttpAgent(Arc::new(move |builder: ClientBuilder| {
1309 *called_clone.lock().unwrap() = true;
1310 builder
1311 }));
1312
1313 let client = build_client(1_000, true, None, Some(agent));
1314 assert!(*called.lock().unwrap(), "agent closure wasn’t invoked");
1315 let _: reqwest::Client = client;
1316 }
1317 }
1318
1319 mod build_user_agent {
1320 use crate::common::utils::build_user_agent;
1321
1322 #[test]
1323 fn build_user_agent_contains_crate_product_and_rust_info() {
1324 let product = "product";
1325 let user_agent = build_user_agent(product);
1326
1327 let name = env!("CARGO_PKG_NAME");
1328 let version = env!("CARGO_PKG_VERSION");
1329 let rustc = env!("RUSTC_VERSION");
1330 let os = std::env::consts::OS;
1331 let arch = std::env::consts::ARCH;
1332
1333 let expected_prefix = format!("{name}/{product}/{version} (Rust/");
1334 assert!(
1335 user_agent.starts_with(&expected_prefix),
1336 "prefix mismatch: {user_agent}"
1337 );
1338
1339 assert!(
1340 user_agent.contains(rustc),
1341 "user agent missing RUSTC_VERSION: {user_agent}"
1342 );
1343
1344 assert!(
1345 user_agent.contains(&format!("; {os}")),
1346 "user agent missing OS: {user_agent}"
1347 );
1348 assert!(
1349 user_agent.contains(&format!("; {arch}")),
1350 "user agent missing ARCH: {user_agent}"
1351 );
1352 }
1353
1354 #[test]
1355 fn build_user_agent_is_deterministic() {
1356 let product = "product";
1357 let user_agent1 = build_user_agent(product);
1358 let user_agent2 = build_user_agent(product);
1359 assert_eq!(
1360 user_agent1, user_agent2,
1361 "user agent should be the same on repeated calls"
1362 );
1363 }
1364 }
1365
1366 mod validate_time_unit {
1367 use crate::common::utils::validate_time_unit;
1368
1369 #[test]
1370 fn empty_string_returns_none() {
1371 let res = validate_time_unit("").expect("Should not error on empty string");
1372 assert_eq!(res, None);
1373 }
1374
1375 #[test]
1376 fn uppercase_millisecond() {
1377 let res = validate_time_unit("MILLISECOND").expect("Should accept MILLISECOND");
1378 assert_eq!(res, Some("MILLISECOND"));
1379 }
1380
1381 #[test]
1382 fn uppercase_microsecond() {
1383 let res = validate_time_unit("MICROSECOND").expect("Should accept MICROSECOND");
1384 assert_eq!(res, Some("MICROSECOND"));
1385 }
1386
1387 #[test]
1388 fn lowercase_millisecond() {
1389 let res = validate_time_unit("millisecond").expect("Should accept millisecond");
1390 assert_eq!(res, Some("millisecond"));
1391 }
1392
1393 #[test]
1394 fn lowercase_microsecond() {
1395 let res = validate_time_unit("microsecond").expect("Should accept microsecond");
1396 assert_eq!(res, Some("microsecond"));
1397 }
1398
1399 #[test]
1400 fn invalid_value_returns_err() {
1401 let err = validate_time_unit("SECOND").unwrap_err();
1402 let msg = format!("{err}");
1403 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1404 }
1405
1406 #[test]
1407 fn partial_match_returns_err() {
1408 let err = validate_time_unit("MILLI").unwrap_err();
1409 let msg = format!("{err}");
1410 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1411 }
1412 }
1413
1414 mod get_timestamp {
1415 use crate::common::utils::get_timestamp;
1416 use std::{
1417 thread::sleep,
1418 time::{Duration, SystemTime, UNIX_EPOCH},
1419 };
1420
1421 #[test]
1422 fn timestamp_is_within_system_time_bounds() {
1423 let before = SystemTime::now()
1424 .duration_since(UNIX_EPOCH)
1425 .expect("SystemTime before UNIX_EPOCH")
1426 .as_millis();
1427 let ts = get_timestamp();
1428 let after = SystemTime::now()
1429 .duration_since(UNIX_EPOCH)
1430 .expect("SystemTime before UNIX_EPOCH")
1431 .as_millis();
1432
1433 assert!(
1434 ts >= before,
1435 "timestamp {ts} is before captured before time {before}"
1436 );
1437 assert!(
1438 ts <= after,
1439 "timestamp {ts} is after captured after time {after}"
1440 );
1441 }
1442
1443 #[test]
1444 fn timestamps_are_monotonic() {
1445 let t1 = get_timestamp();
1446 sleep(Duration::from_millis(1));
1447 let t2 = get_timestamp();
1448 assert!(
1449 t2 >= t1,
1450 "second timestamp {t2} is not >= first timestamp {t1}"
1451 );
1452 }
1453 }
1454
1455 mod build_query_string {
1456 use std::collections::BTreeMap;
1457
1458 use anyhow::Result;
1459 use serde_json::{Value, json};
1460 use url::form_urlencoded::Serializer;
1461
1462 use crate::common::utils::build_query_string;
1463
1464 fn mk_map(pairs: Vec<(&str, Value)>) -> BTreeMap<String, Value> {
1465 let mut m = BTreeMap::new();
1466 for (k, v) in pairs {
1467 m.insert(k.to_string(), v);
1468 }
1469 m
1470 }
1471
1472 #[test]
1473 fn empty_map_returns_empty_string() -> Result<()> {
1474 let params = BTreeMap::new();
1475 let qs = build_query_string(¶ms)?;
1476 assert_eq!(qs, "");
1477 Ok(())
1478 }
1479
1480 #[test]
1481 fn string_and_number_and_bool() -> Result<()> {
1482 let params = mk_map(vec![
1483 ("foo", json!("bar")),
1484 ("num", json!(42)),
1485 ("flag", json!(true)),
1486 ]);
1487 let qs = build_query_string(¶ms)?;
1488 assert_eq!(qs, "flag=true&foo=bar&num=42");
1489 Ok(())
1490 }
1491
1492 #[test]
1493 fn null_is_skipped() -> Result<()> {
1494 let params = mk_map(vec![("a", json!(true)), ("b", Value::Null)]);
1495 let qs = build_query_string(¶ms)?;
1496 assert_eq!(qs, "a=true");
1497 Ok(())
1498 }
1499
1500 #[test]
1501 fn percent_encode_special_chars() -> Result<()> {
1502 let params = mk_map(vec![
1503 ("space", json!("hello world")),
1504 ("symbols", json!("a/b?c")),
1505 ]);
1506 let qs = build_query_string(¶ms)?;
1507 let mut parts = vec![];
1508 let mut ser = Serializer::new(String::new());
1509 ser.append_pair("space", "hello world");
1510 parts.push(ser.finish());
1511 let mut ser = Serializer::new(String::new());
1512 ser.append_pair("symbols", "a/b?c");
1513 parts.push(ser.finish());
1514 let expected = parts.join("&");
1515 assert_eq!(qs, expected);
1516 Ok(())
1517 }
1518
1519 #[test]
1520 fn primitive_array_json_encoded() -> Result<()> {
1521 let params = mk_map(vec![
1522 ("strs", json!(["a", "b", "c"])),
1523 ("nums", json!([1, 2, 3])),
1524 ("bools", json!([true, false])),
1525 ]);
1526 let qs = build_query_string(¶ms)?;
1527
1528 let mut parts = Vec::new();
1529 for (k, v) in ¶ms {
1530 let json = serde_json::to_string(v)?;
1531 let mut ser = Serializer::new(String::new());
1532 ser.append_pair(k, &json);
1533 parts.push(ser.finish());
1534 }
1535 let expected = parts.join("&");
1536 assert_eq!(qs, expected);
1537 Ok(())
1538 }
1539
1540 #[test]
1541 fn nested_array_json_encoded() -> Result<()> {
1542 let params = mk_map(vec![("nested", json!([[1, 2], [3, 4]]))]);
1543 let qs = build_query_string(¶ms)?;
1544
1545 let nested_json = serde_json::to_string(&json!([[1, 2], [3, 4]]))?;
1546 let mut ser = Serializer::new(String::new());
1547 ser.append_pair("nested", &nested_json);
1548 let expected = ser.finish();
1549
1550 assert_eq!(qs, expected);
1551 Ok(())
1552 }
1553
1554 #[test]
1555 fn object_json_encoded() -> Result<()> {
1556 let params = mk_map(vec![("obj", json!({"k":1, "v":"two"}))]);
1557 let qs = build_query_string(¶ms)?;
1558
1559 let obj_json = serde_json::to_string(&json!({"k":1, "v":"two"}))?;
1560 let mut ser = Serializer::new(String::new());
1561 ser.append_pair("obj", &obj_json);
1562 let expected = ser.finish();
1563
1564 assert_eq!(qs, expected);
1565 Ok(())
1566 }
1567
1568 #[test]
1569 fn empty_array() {
1570 let params = mk_map(vec![("foo", json!([]))]);
1571 let qs = build_query_string(¶ms).unwrap();
1572
1573 let json = serde_json::to_string(&json!([])).unwrap();
1574 let expected = Serializer::new(String::new())
1575 .append_pair("foo", &json)
1576 .finish();
1577 assert_eq!(qs, expected);
1578 }
1579
1580 #[test]
1581 fn mixed_array() {
1582 let params = mk_map(vec![("mix", json!([1, "x", false]))]);
1583 let qs = build_query_string(¶ms).unwrap();
1584
1585 let json = serde_json::to_string(&json!([1, "x", false])).unwrap();
1586 let expected = Serializer::new(String::new())
1587 .append_pair("mix", &json)
1588 .finish();
1589 assert_eq!(qs, expected);
1590 }
1591
1592 #[test]
1593 fn array_of_objects() {
1594 let params = mk_map(vec![("objs", json!([{"a":1}, {"b":2}]))]);
1595 let qs = build_query_string(¶ms).unwrap();
1596
1597 let json = serde_json::to_string(&json!([{"a":1}, {"b":2}])).unwrap();
1598 let expected = Serializer::new(String::new())
1599 .append_pair("objs", &json)
1600 .finish();
1601 assert_eq!(qs, expected);
1602 }
1603
1604 #[test]
1605 fn empty_object() {
1606 let params = mk_map(vec![("emp", json!({}))]);
1607 let qs = build_query_string(¶ms).unwrap();
1608
1609 let json = serde_json::to_string(&json!({})).unwrap();
1610 let expected = Serializer::new(String::new())
1611 .append_pair("emp", &json)
1612 .finish();
1613 assert_eq!(qs, expected);
1614 }
1615
1616 #[test]
1617 fn floats_and_negatives() {
1618 let params = mk_map(vec![("fl", json!(1.23456)), ("neg", json!(-0.001))]);
1619 let qs = build_query_string(¶ms).unwrap();
1620 assert_eq!(qs, "fl=1.23456&neg=-0.001");
1621 }
1622
1623 #[test]
1624 fn unicode_and_special_key() {
1625 let params = mk_map(vec![
1626 ("こんにちは", json!("世界")),
1627 ("weird key/?=", json!("val")),
1628 ]);
1629 let qs = build_query_string(¶ms).unwrap();
1630
1631 let mut parts = Vec::new();
1632 for (k, v) in ¶ms {
1633 let mut ser = Serializer::new(String::new());
1634 ser.append_pair(k, v.as_str().unwrap());
1635 parts.push(ser.finish());
1636 }
1637 let expected = parts.join("&");
1638 assert_eq!(qs, expected);
1639 }
1640
1641 #[test]
1642 fn empty_string_value() {
1643 let params = mk_map(vec![("empty", json!(""))]);
1644 let qs = build_query_string(¶ms).unwrap();
1645 assert_eq!(qs, "empty=");
1646 }
1647
1648 #[test]
1649 fn nulls_in_array() {
1650 let params = mk_map(vec![("a", json!([null, 1, "x"]))]);
1651 let qs = build_query_string(¶ms).unwrap();
1652
1653 let json = serde_json::to_string(&json!([null, 1, "x"])).unwrap();
1654 let expected = Serializer::new(String::new())
1655 .append_pair("a", &json)
1656 .finish();
1657 assert_eq!(qs, expected);
1658 }
1659
1660 #[test]
1661 fn special_chars_in_key() {
1662 let params = mk_map(vec![("a=b&c%", json!("val"))]);
1663 let qs = build_query_string(¶ms).unwrap();
1664
1665 let expected = Serializer::new(String::new())
1666 .append_pair("a=b&c%", "val")
1667 .finish();
1668 assert_eq!(qs, expected);
1669 }
1670
1671 #[test]
1672 fn empty_key() {
1673 let params = mk_map(vec![("", json!("v"))]);
1674 let qs = build_query_string(¶ms).unwrap();
1675 assert_eq!(qs, "=v");
1676 }
1677 }
1678
1679 #[cfg(feature = "openssl-tls")]
1680 mod signature_generator {
1681 use base64::{Engine, engine::general_purpose};
1682 use ed25519_dalek::{SigningKey, ed25519::signature::SignerMut, pkcs8::DecodePrivateKey};
1683 use hex;
1684 use hmac::{Hmac, Mac};
1685 #[cfg(feature = "openssl-tls")]
1686 use openssl::{hash::MessageDigest, pkey::PKey, rsa::Rsa, sign::Verifier};
1687 use serde_json::Value;
1688 use sha2::Sha256;
1689 use std::collections::BTreeMap;
1690 use std::io::Write;
1691 use tempfile::NamedTempFile;
1692
1693 use crate::{common::utils::SignatureGenerator, config::PrivateKey};
1694
1695 #[test]
1696 fn hmac_sha256_signature() {
1697 let mut params = BTreeMap::new();
1698 params.insert("b".into(), Value::Number(2.into()));
1699 params.insert("a".into(), Value::Number(1.into()));
1700
1701 let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1702 let sig = signature_gen
1703 .get_signature(¶ms, None)
1704 .expect("HMAC signing failed");
1705
1706 let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1707 let qs = "a=1&b=2";
1708 mac.update(qs.as_bytes());
1709 let expected = hex::encode(mac.finalize().into_bytes());
1710
1711 assert_eq!(sig, expected);
1712 }
1713
1714 #[test]
1715 fn hmac_sha256_signature_with_body() {
1716 let mut query_params = BTreeMap::new();
1717 query_params.insert("b".into(), Value::Number(2.into()));
1718 query_params.insert("a".into(), Value::Number(1.into()));
1719
1720 let mut body_params = BTreeMap::new();
1721 body_params.insert("d".into(), Value::Number(4.into()));
1722 body_params.insert("c".into(), Value::Number(3.into()));
1723
1724 let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1725 let sig = signature_gen
1726 .get_signature(&query_params, Some(&body_params))
1727 .expect("HMAC signing with body failed");
1728
1729 let query_str = "a=1&b=2";
1730 let body_str = "c=3&d=4";
1731
1732 let payload = format!("{query_str}{body_str}");
1733
1734 let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1735 mac.update(payload.as_bytes());
1736 let expected = hex::encode(mac.finalize().into_bytes());
1737
1738 assert_eq!(sig, expected);
1739 }
1740
1741 #[test]
1742 fn repeated_hmac_signature() {
1743 let mut params = BTreeMap::new();
1744 params.insert("x".into(), Value::String("y".into()));
1745 let signature_gen = SignatureGenerator::new(Some("abc".into()), None, None);
1746 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1747 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1748 assert_eq!(s1, s2);
1749 }
1750
1751 #[test]
1752 fn rsa_signature_verification() {
1753 let mut params = BTreeMap::new();
1754 params.insert("a".into(), Value::Number(1.into()));
1755 params.insert("b".into(), Value::Number(2.into()));
1756
1757 let rsa = Rsa::generate(2048).unwrap();
1758 let priv_pem = rsa.private_key_to_pem().unwrap();
1759 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1760
1761 let signature_gen =
1762 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1763 let sig = signature_gen
1764 .get_signature(¶ms, None)
1765 .expect("RSA signing failed");
1766
1767 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1768 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1769 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1770 verifier.update(b"a=1&b=2").unwrap();
1771 assert!(verifier.verify(&sig_bytes).unwrap());
1772 }
1773
1774 #[test]
1775 fn rsa_signature_verification_with_body() {
1776 let mut query_params = BTreeMap::new();
1777 query_params.insert("a".into(), Value::Number(1.into()));
1778 query_params.insert("b".into(), Value::Number(2.into()));
1779
1780 let mut body_params = BTreeMap::new();
1781 body_params.insert("c".into(), Value::Number(3.into()));
1782 body_params.insert("d".into(), Value::Number(4.into()));
1783
1784 let rsa = Rsa::generate(2048).unwrap();
1785 let priv_pem = rsa.private_key_to_pem().unwrap();
1786 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1787
1788 let signature_gen =
1789 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1790 let sig = signature_gen
1791 .get_signature(&query_params, Some(&body_params))
1792 .expect("RSA signing with body failed");
1793
1794 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1795 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1796 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1797 verifier.update(b"a=1&b=2c=3&d=4").unwrap();
1798 assert!(verifier.verify(&sig_bytes).unwrap());
1799 }
1800
1801 #[test]
1802 fn repeated_rsa_signature() {
1803 let mut params = BTreeMap::new();
1804 params.insert("k".into(), Value::Number(5.into()));
1805 let rsa = Rsa::generate(2048).unwrap();
1806 let priv_pem = rsa.private_key_to_pem().unwrap();
1807 let signature_gen =
1808 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem)), None);
1809 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1810 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1811 assert_eq!(s1, s2);
1812 }
1813
1814 #[test]
1815 fn ed25519_signature_verification() {
1816 let mut params = BTreeMap::new();
1817 params.insert("a".into(), Value::Number(1.into()));
1818 params.insert("b".into(), Value::Number(2.into()));
1819 let qs = "a=1&b=2";
1820
1821 let ed = PKey::generate_ed25519().unwrap();
1822 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1823
1824 let signature_gen =
1825 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1826 let sig = signature_gen
1827 .get_signature(¶ms, None)
1828 .expect("Ed25519 signing failed");
1829
1830 let pem_str = String::from_utf8(priv_pem).unwrap();
1831 let b64 = pem_str
1832 .lines()
1833 .filter(|l| !l.starts_with("-----"))
1834 .collect::<String>();
1835 let der = general_purpose::STANDARD.decode(b64).unwrap();
1836 let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1837 let expected_bytes = sk.sign(qs.as_bytes()).to_bytes();
1838 let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1839 assert_eq!(sig, expected_sig);
1840 }
1841
1842 #[test]
1843 fn ed25519_signature_verification_with_body() {
1844 let mut query_params = BTreeMap::new();
1845 query_params.insert("a".into(), Value::Number(1.into()));
1846 query_params.insert("b".into(), Value::Number(2.into()));
1847 let qs = "a=1&b=2";
1848
1849 let mut body_params = BTreeMap::new();
1850 body_params.insert("c".into(), Value::Number(3.into()));
1851 body_params.insert("d".into(), Value::Number(4.into()));
1852 let body_qs = "c=3&d=4";
1853
1854 let ed = PKey::generate_ed25519().unwrap();
1855 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1856
1857 let signature_gen =
1858 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1859 let sig = signature_gen
1860 .get_signature(&query_params, Some(&body_params))
1861 .expect("Ed25519 signing with body failed");
1862
1863 let pem_str = String::from_utf8(priv_pem).unwrap();
1864 let b64 = pem_str
1865 .lines()
1866 .filter(|l| !l.starts_with("-----"))
1867 .collect::<String>();
1868 let der = general_purpose::STANDARD.decode(b64).unwrap();
1869 let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1870 let payload = format!("{qs}{body_qs}");
1871 let expected_bytes = sk.sign(payload.as_bytes()).to_bytes();
1872 let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1873 assert_eq!(sig, expected_sig);
1874 }
1875
1876 #[test]
1877 fn repeated_ed25519_signature() {
1878 let mut params = BTreeMap::new();
1879 params.insert("m".into(), Value::String("n".into()));
1880 let ed = PKey::generate_ed25519().unwrap();
1881 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1882 let signature_gen =
1883 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1884 let s1 = signature_gen.get_signature(¶ms, None).unwrap();
1885 let s2 = signature_gen.get_signature(¶ms, None).unwrap();
1886 assert_eq!(s1, s2);
1887 }
1888
1889 #[test]
1890 fn file_based_key() {
1891 let rsa = Rsa::generate(1024).unwrap();
1892 let priv_pem = rsa.private_key_to_pem().unwrap();
1893 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1894
1895 let mut file = NamedTempFile::new().unwrap();
1896 file.write_all(&priv_pem).unwrap();
1897 let path = file.path().to_str().unwrap().to_string();
1898
1899 let mut params = BTreeMap::new();
1900 params.insert("z".into(), Value::Number(9.into()));
1901
1902 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::File(path)), None);
1903 let sig = signature_gen.get_signature(¶ms, None).unwrap();
1904
1905 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1906 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1907 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1908 verifier.update(b"z=9").unwrap();
1909 assert!(verifier.verify(&sig_bytes).unwrap());
1910 }
1911
1912 #[cfg(feature = "openssl-tls")]
1913 #[test]
1914 fn unsupported_key_type_error() {
1915 let mut params = BTreeMap::new();
1916 params.insert("x".into(), Value::String("y".into()));
1917
1918 let group =
1919 openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
1920 let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
1921 let pkey_ec = PKey::from_ec_key(ec_key).unwrap();
1922 let raw = pkey_ec.private_key_to_pem_pkcs8().unwrap();
1923
1924 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::Raw(raw)), None);
1925 let err = signature_gen
1926 .get_signature(¶ms, None)
1927 .unwrap_err()
1928 .to_string();
1929 assert!(err.contains("Unsupported private key type"));
1930 }
1931
1932 #[test]
1933 fn invalid_private_key_error() {
1934 let mut params = BTreeMap::new();
1935 params.insert("foo".into(), Value::String("bar".into()));
1936
1937 let signature_gen =
1938 SignatureGenerator::new(None, Some(PrivateKey::Raw(b"not a key".to_vec())), None);
1939 let err = signature_gen
1940 .get_signature(¶ms, None)
1941 .unwrap_err()
1942 .to_string();
1943 assert!(err.contains("Failed to parse private key"));
1944 }
1945
1946 #[test]
1947 fn missing_credentials_error() {
1948 let mut params = BTreeMap::new();
1949 params.insert("a".into(), Value::Number(1.into()));
1950
1951 let signature_gen = SignatureGenerator::new(None, None, None);
1952 let err = signature_gen
1953 .get_signature(¶ms, None)
1954 .unwrap_err()
1955 .to_string();
1956 assert!(err.contains("Either 'api_secret' or 'private_key' must be provided"));
1957 }
1958 }
1959
1960 mod should_retry_request {
1961 use crate::common::utils::should_retry_request;
1962
1963 use reqwest::{Error, Response};
1964
1965 fn mk_http_error(code: u16) -> Error {
1966 let resp = Response::from(
1967 http::response::Response::builder()
1968 .status(code)
1969 .body("")
1970 .unwrap(),
1971 );
1972 resp.error_for_status().unwrap_err()
1973 }
1974
1975 fn mk_network_error() -> Error {
1976 reqwest::blocking::get("http://256.256.256.256").unwrap_err()
1977 }
1978
1979 #[test]
1980 fn retry_on_retriable_status_and_method() {
1981 let err = mk_http_error(500);
1982 assert!(should_retry_request(&err, Some("GET"), Some(1)));
1983 assert!(should_retry_request(&err, Some("delete"), Some(2)));
1984 }
1985
1986 #[test]
1987 fn retry_when_status_none_and_retriable_method() {
1988 let retriable_methods = ["GET", "DELETE"];
1989
1990 for &method in &retriable_methods {
1991 let err = mk_network_error();
1992 assert!(
1993 should_retry_request(&err, Some(method), Some(1)),
1994 "Should retry when no status and method {method}"
1995 );
1996 }
1997 }
1998
1999 #[test]
2000 fn no_retry_when_no_retries_left() {
2001 let err = mk_http_error(503);
2002 assert!(!should_retry_request(&err, Some("GET"), Some(0)));
2003 }
2004
2005 #[test]
2006 fn no_retry_on_non_retriable_status() {
2007 let non_retriable_statuses = [400, 401, 404, 422];
2008
2009 for &status in &non_retriable_statuses {
2010 let err = mk_http_error(status);
2011 assert!(
2012 !should_retry_request(&err, Some("GET"), Some(2)),
2013 "Should not retry for non-retriable status {status}"
2014 );
2015 }
2016 }
2017
2018 #[test]
2019 fn no_retry_on_non_retriable_method() {
2020 let non_retriable_methods = ["POST", "PUT", "PATCH"];
2021
2022 for &method in &non_retriable_methods {
2023 let err = mk_http_error(500);
2024 assert!(
2025 !should_retry_request(&err, Some(method), Some(2)),
2026 "Should not retry for non-retriable method {method}"
2027 );
2028 }
2029 }
2030
2031 #[test]
2032 fn no_retry_when_status_none_and_non_retriable_method() {
2033 let non_retriable_methods = ["POST", "PUT"];
2034
2035 for &method in &non_retriable_methods {
2036 let err = mk_network_error();
2037 assert!(
2038 !should_retry_request(&err, Some(method), Some(1)),
2039 "Should not retry when no status and method {method}"
2040 );
2041 }
2042 }
2043 }
2044
2045 mod parse_rate_limit_headers_tests {
2046 use crate::common::{
2047 models::{Interval, RateLimitType},
2048 utils::parse_rate_limit_headers,
2049 };
2050 use std::collections::HashMap;
2051
2052 fn mk_headers(pairs: Vec<(&str, &str)>) -> HashMap<String, String> {
2053 let mut m = HashMap::new();
2054 for (k, v) in pairs {
2055 m.insert(k.to_string(), v.to_string());
2056 }
2057 m
2058 }
2059
2060 #[test]
2061 fn single_weight_header() {
2062 let headers = mk_headers(vec![("x-mbx-used-weight-1s", "123")]);
2063 let limits = parse_rate_limit_headers(&headers);
2064 assert_eq!(limits.len(), 1);
2065 let rl = &limits[0];
2066 assert_eq!(rl.rate_limit_type, RateLimitType::RequestWeight);
2067 assert_eq!(rl.interval, Interval::Second);
2068 assert_eq!(rl.interval_num, 1);
2069 assert_eq!(rl.count, 123);
2070 assert_eq!(rl.retry_after, None);
2071 }
2072
2073 #[test]
2074 fn single_order_count_with_retry_after() {
2075 let headers = mk_headers(vec![("x-mbx-order-count-5m", "42"), ("retry-after", "7")]);
2076 let limits = parse_rate_limit_headers(&headers);
2077 assert_eq!(limits.len(), 1);
2078 let rl = &limits[0];
2079 assert_eq!(rl.rate_limit_type, RateLimitType::Orders);
2080 assert_eq!(rl.interval, Interval::Minute);
2081 assert_eq!(rl.interval_num, 5);
2082 assert_eq!(rl.count, 42);
2083 assert_eq!(rl.retry_after, Some(7));
2084 }
2085
2086 #[test]
2087 fn multiple_headers() {
2088 let headers = mk_headers(vec![
2089 ("X-MBX-USED-WEIGHT-1h", "10"),
2090 ("x-mbx-order-count-2d", "20"),
2091 ]);
2092 let mut limits = parse_rate_limit_headers(&headers);
2093 limits.sort_by_key(|r| (r.interval_num, format!("{:?}", r.rate_limit_type)));
2094 assert_eq!(limits.len(), 2);
2095 let w = &limits[0];
2096 assert_eq!(w.rate_limit_type, RateLimitType::RequestWeight);
2097 assert_eq!(w.interval, Interval::Hour);
2098 assert_eq!(w.interval_num, 1);
2099 assert_eq!(w.count, 10);
2100 let o = &limits[1];
2101 assert_eq!(o.rate_limit_type, RateLimitType::Orders);
2102 assert_eq!(o.interval, Interval::Day);
2103 assert_eq!(o.interval_num, 2);
2104 assert_eq!(o.count, 20);
2105 }
2106
2107 #[test]
2108 fn ignores_unknown_and_malformed() {
2109 let headers = mk_headers(vec![
2110 ("x-mbx-used-weight-3x", "5"),
2111 ("random-header", "100"),
2112 ]);
2113 let limits = parse_rate_limit_headers(&headers);
2114 assert!(limits.is_empty());
2115 }
2116 }
2117
2118 mod http_request {
2119 use std::io::Write;
2120
2121 use flate2::{Compression, write::GzEncoder};
2122 use httpmock::MockServer;
2123 use reqwest::{Client, Method, Request};
2124 use serde::Deserialize;
2125
2126 use crate::{
2127 common::utils::http_request, config::ConfigurationRestApi, errors::ConnectorError,
2128 models::RestApiResponse,
2129 };
2130
2131 use super::TOKIO_SHARED_RT;
2132
2133 #[derive(Deserialize, Debug, PartialEq)]
2134 struct Dummy {
2135 foo: String,
2136 }
2137
2138 fn make_config(server_url: &str) -> ConfigurationRestApi {
2139 ConfigurationRestApi::builder()
2140 .api_key("key")
2141 .api_secret("secret")
2142 .base_path(server_url)
2143 .build()
2144 .expect("Failed to build configuration")
2145 }
2146
2147 #[test]
2148 fn http_request_success_plain_text() {
2149 TOKIO_SHARED_RT.block_on(async {
2150 let server = MockServer::start();
2151 let mock = server.mock(|when, then| {
2152 when.method(httpmock::Method::GET).path("/test");
2153 then.status(200)
2154 .header("Content-Type", "application/json")
2155 .body(r#"{"foo":"bar"}"#);
2156 });
2157
2158 let client = Client::new();
2159 let req: Request = client
2160 .request(Method::GET, format!("{}{}", server.url(""), "/test"))
2161 .build()
2162 .unwrap();
2163
2164 let cfg = make_config(&server.url(""));
2165 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
2166 assert_eq!(resp.status, 200);
2167 let data = resp.data().await.unwrap();
2168 assert_eq!(data, Dummy { foo: "bar".into() });
2169 mock.assert();
2170 });
2171 }
2172
2173 #[test]
2174 fn http_request_success_gzip() {
2175 TOKIO_SHARED_RT.block_on(async {
2176 let server = MockServer::start();
2177 let body = r#"{"foo":"baz"}"#;
2178 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
2179 encoder.write_all(body.as_bytes()).unwrap();
2180 let gz = encoder.finish().unwrap();
2181
2182 let mock = server.mock(|when, then| {
2183 when.method(httpmock::Method::GET).path("/gz");
2184 then.status(200)
2185 .header("Content-Type", "application/json")
2186 .header("Content-Encoding", "gzip")
2187 .body(gz);
2188 });
2189
2190 let client = Client::new();
2191 let req: Request = client
2192 .request(Method::GET, format!("{}{}", server.url(""), "/gz"))
2193 .build()
2194 .unwrap();
2195 let mut cfg = make_config(&server.url(""));
2196 cfg.compression = true;
2197
2198 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
2199 assert_eq!(resp.status, 200);
2200 let data = resp.data().await.unwrap();
2201 assert_eq!(data, Dummy { foo: "baz".into() });
2202 mock.assert();
2203 });
2204 }
2205
2206 #[test]
2207 fn http_request_client_error_bad_request() {
2208 TOKIO_SHARED_RT.block_on(async {
2209 let server = MockServer::start();
2210 let mock = server.mock(|when, then| {
2211 when.method(httpmock::Method::GET).path("/400");
2212 then.status(400)
2213 .header("Content-Type", "application/json")
2214 .body(r#"{"code":-1121,"msg":"bad request"}"#);
2215 });
2216
2217 let client = Client::new();
2218 let req: Request = client
2219 .request(Method::GET, format!("{}{}", server.url(""), "/400"))
2220 .build()
2221 .unwrap();
2222 let cfg = make_config(&server.url(""));
2223
2224 let result = http_request::<Dummy>(req, &cfg).await;
2225
2226 assert!(matches!(
2227 result,
2228 Err(ConnectorError::BadRequestError { .. })
2229 ));
2230
2231 if let Err(ConnectorError::BadRequestError { msg, code }) = result {
2232 assert_eq!(msg, "bad request");
2233 assert_eq!(code, Some(-1121));
2234 }
2235
2236 mock.assert();
2237 });
2238 }
2239
2240 #[test]
2241 fn http_request_client_error_unauthorized() {
2242 TOKIO_SHARED_RT.block_on(async {
2243 let server = MockServer::start();
2244 let mock = server.mock(|when, then| {
2245 when.method(httpmock::Method::GET).path("/401");
2246 then.status(401)
2247 .header("Content-Type", "application/json")
2248 .body(r#"{"code":-2015,"msg":"unauthorized"}"#);
2249 });
2250
2251 let client = Client::new();
2252 let req: Request = client
2253 .request(Method::GET, format!("{}{}", server.url(""), "/401"))
2254 .build()
2255 .unwrap();
2256 let cfg = make_config(&server.url(""));
2257
2258 let result = http_request::<Dummy>(req, &cfg).await;
2259
2260 assert!(matches!(
2261 result,
2262 Err(ConnectorError::UnauthorizedError { .. })
2263 ));
2264
2265 if let Err(ConnectorError::UnauthorizedError { msg, code }) = result {
2266 assert_eq!(msg, "unauthorized");
2267 assert_eq!(code, Some(-2015));
2268 }
2269
2270 mock.assert();
2271 });
2272 }
2273
2274 #[test]
2275 fn http_request_client_error_forbidden() {
2276 TOKIO_SHARED_RT.block_on(async {
2277 let server = MockServer::start();
2278 let mock = server.mock(|when, then| {
2279 when.method(httpmock::Method::GET).path("/403");
2280 then.status(403)
2281 .header("Content-Type", "application/json")
2282 .body(r#"{"code":-2010,"msg":"forbidden"}"#);
2283 });
2284
2285 let client = Client::new();
2286 let req: Request = client
2287 .request(Method::GET, format!("{}{}", server.url(""), "/403"))
2288 .build()
2289 .unwrap();
2290 let cfg = make_config(&server.url(""));
2291
2292 let result = http_request::<Dummy>(req, &cfg).await;
2293
2294 assert!(matches!(result, Err(ConnectorError::ForbiddenError { .. })));
2295
2296 if let Err(ConnectorError::ForbiddenError { msg, code }) = result {
2297 assert_eq!(msg, "forbidden");
2298 assert_eq!(code, Some(-2010));
2299 }
2300
2301 mock.assert();
2302 });
2303 }
2304
2305 #[test]
2306 fn http_request_client_error_not_found() {
2307 TOKIO_SHARED_RT.block_on(async {
2308 let server = MockServer::start();
2309 let mock = server.mock(|when, then| {
2310 when.method(httpmock::Method::GET).path("/404");
2311 then.status(404)
2312 .header("Content-Type", "application/json")
2313 .body(r#"{"code":-1003,"msg":"not found"}"#);
2314 });
2315
2316 let client = Client::new();
2317 let req: Request = client
2318 .request(Method::GET, format!("{}{}", server.url(""), "/404"))
2319 .build()
2320 .unwrap();
2321 let cfg = make_config(&server.url(""));
2322
2323 let result = http_request::<Dummy>(req, &cfg).await;
2324
2325 assert!(matches!(result, Err(ConnectorError::NotFoundError { .. })));
2326
2327 if let Err(ConnectorError::NotFoundError { msg, code }) = result {
2328 assert_eq!(msg, "not found");
2329 assert_eq!(code, Some(-1003));
2330 }
2331
2332 mock.assert();
2333 });
2334 }
2335
2336 #[test]
2337 fn http_request_client_error_rate_limit_exceeded() {
2338 TOKIO_SHARED_RT.block_on(async {
2339 let server = MockServer::start();
2340 let mock = server.mock(|when, then| {
2341 when.method(httpmock::Method::GET).path("/418");
2342 then.status(418)
2343 .header("Content-Type", "application/json")
2344 .body(r#"{"code":-1003,"msg":"rate limit exceeded"}"#);
2345 });
2346
2347 let client = Client::new();
2348 let req: Request = client
2349 .request(Method::GET, format!("{}{}", server.url(""), "/418"))
2350 .build()
2351 .unwrap();
2352 let cfg = make_config(&server.url(""));
2353
2354 let result = http_request::<Dummy>(req, &cfg).await;
2355
2356 assert!(matches!(
2357 result,
2358 Err(ConnectorError::RateLimitBanError { .. })
2359 ));
2360
2361 if let Err(ConnectorError::RateLimitBanError { msg, code }) = result {
2362 assert_eq!(msg, "rate limit exceeded");
2363 assert_eq!(code, Some(-1003));
2364 }
2365
2366 mock.assert();
2367 });
2368 }
2369
2370 #[test]
2371 fn http_request_client_error_too_many_requests() {
2372 TOKIO_SHARED_RT.block_on(async {
2373 let server = MockServer::start();
2374 let mock = server.mock(|when, then| {
2375 when.method(httpmock::Method::GET).path("/429");
2376 then.status(429)
2377 .header("Content-Type", "application/json")
2378 .body(r#"{"code":-1003,"msg":"too many requests"}"#);
2379 });
2380
2381 let client = Client::new();
2382 let req: Request = client
2383 .request(Method::GET, format!("{}{}", server.url(""), "/429"))
2384 .build()
2385 .unwrap();
2386 let cfg = make_config(&server.url(""));
2387
2388 let result = http_request::<Dummy>(req, &cfg).await;
2389
2390 assert!(matches!(
2391 result,
2392 Err(ConnectorError::TooManyRequestsError { .. })
2393 ));
2394
2395 if let Err(ConnectorError::TooManyRequestsError { msg, code }) = result {
2396 assert_eq!(msg, "too many requests");
2397 assert_eq!(code, Some(-1003));
2398 }
2399
2400 mock.assert();
2401 });
2402 }
2403
2404 #[test]
2405 fn http_request_client_error_server_error() {
2406 TOKIO_SHARED_RT.block_on(async {
2407 let server = MockServer::start();
2408 let mock = server.mock(|when, then| {
2409 when.method(httpmock::Method::GET).path("/500");
2410 then.status(500)
2411 .header("Content-Type", "application/json")
2412 .body(r#"{"code":-1000,"msg":"internal server error"}"#);
2413 });
2414
2415 let client = Client::new();
2416 let req: Request = client
2417 .request(Method::GET, format!("{}{}", server.url(""), "/500"))
2418 .build()
2419 .unwrap();
2420 let cfg = make_config(&server.url(""));
2421
2422 let result = http_request::<Dummy>(req, &cfg).await;
2423
2424 assert!(matches!(result, Err(ConnectorError::ServerError { .. })));
2425
2426 if let Err(ConnectorError::ServerError {
2427 msg,
2428 status_code: Some(500),
2429 }) = result
2430 {
2431 assert_eq!(msg, "Server error: 500".to_string());
2432 }
2433
2434 mock.assert();
2435 });
2436 }
2437
2438 #[test]
2439 fn http_request_unexpected_status_maps_generic() {
2440 TOKIO_SHARED_RT.block_on(async {
2441 let server = MockServer::start();
2442 let code_http = 402;
2443 let mock = server.mock(|when, then| {
2444 when.method(httpmock::Method::GET).path("/402");
2445 then.status(code_http)
2446 .header("Content-Type", "application/json")
2447 .body(r#"{"code":-12345,"msg":"payment required"}"#);
2448 });
2449
2450 let client = Client::new();
2451 let req: Request = client
2452 .request(Method::GET, format!("{}{}", server.url(""), "/402"))
2453 .build()
2454 .unwrap();
2455 let cfg = make_config(&server.url(""));
2456
2457 let result = http_request::<Dummy>(req, &cfg).await;
2458
2459 assert!(matches!(
2460 result,
2461 Err(ConnectorError::ConnectorClientError { .. })
2462 ));
2463
2464 if let Err(ConnectorError::ConnectorClientError { msg, code }) = result {
2465 assert_eq!(msg, "payment required");
2466 assert_eq!(code, Some(-12345));
2467 }
2468
2469 mock.assert();
2470 });
2471 }
2472
2473 #[test]
2474 fn http_request_malformed_json_maps_generic() {
2475 TOKIO_SHARED_RT.block_on(async {
2476 let server = MockServer::start();
2477 let mock = server.mock(|when, then| {
2478 when.method(httpmock::Method::GET).path("/malformed");
2479 then.status(200)
2480 .header("Content-Type", "application/json")
2481 .body("not json");
2482 });
2483
2484 let client = Client::new();
2485 let req: Request = client
2486 .request(Method::GET, format!("{}{}", server.url(""), "/malformed"))
2487 .build()
2488 .unwrap();
2489 let cfg = make_config(&server.url(""));
2490
2491 let resp = http_request::<Dummy>(req, &cfg)
2492 .await
2493 .expect("http_request should succeed even if JSON is bad");
2494
2495 let err = resp
2496 .data()
2497 .await
2498 .expect_err("malformed JSON should turn into ConnectorClientError");
2499
2500 assert!(matches!(err, ConnectorError::ConnectorClientError { .. }));
2501
2502 if let ConnectorError::ConnectorClientError { msg: _, code } = err {
2503 assert_eq!(code, None);
2504 }
2505
2506 mock.assert();
2507 });
2508 }
2509 }
2510
2511 mod send_request {
2512 use anyhow::Result;
2513 use httpmock::prelude::*;
2514 use reqwest::Method;
2515 use serde::Deserialize;
2516 use serde_json::json;
2517 use std::collections::{BTreeMap, HashMap};
2518
2519 use crate::{
2520 common::{models::TimeUnit, utils::send_request},
2521 config::ConfigurationRestApi,
2522 };
2523
2524 use super::TOKIO_SHARED_RT;
2525
2526 #[derive(Deserialize, Debug, PartialEq)]
2527 struct TestResponse {
2528 message: String,
2529 }
2530
2531 #[test]
2532 fn basic_get_request() -> Result<()> {
2533 TOKIO_SHARED_RT.block_on(async {
2534 let server = MockServer::start();
2535
2536 server.mock(|when, then| {
2537 when.method(GET).path("/api/v1/test");
2538 then.status(200)
2539 .header("content-type", "application/json")
2540 .body(r#"{"message": "success"}"#);
2541 });
2542
2543 let configuration = ConfigurationRestApi::builder()
2544 .api_key("key")
2545 .api_secret("secret")
2546 .base_path(server.base_url())
2547 .compression(false)
2548 .build()
2549 .expect("Failed to build configuration");
2550
2551 let params = BTreeMap::new();
2552
2553 let result = send_request::<TestResponse>(
2554 &configuration,
2555 "/api/v1/test",
2556 Method::GET,
2557 params,
2558 BTreeMap::new(),
2559 None,
2560 false,
2561 )
2562 .await?;
2563
2564 let data = result.data().await.unwrap();
2565 assert_eq!(data.message, "success");
2566
2567 Ok(())
2568 })
2569 }
2570
2571 #[test]
2572 fn signed_post_request() -> Result<()> {
2573 TOKIO_SHARED_RT.block_on(async {
2574 let server = MockServer::start();
2575
2576 server.mock(|when, then| {
2577 when.method(POST).path("/api/v3/order");
2578 then.status(200)
2579 .header("content-type", "application/json")
2580 .body(r#"{"message": "order placed"}"#);
2581 });
2582
2583 let configuration = ConfigurationRestApi::builder()
2584 .api_key("key")
2585 .api_secret("secret")
2586 .base_path(server.base_url())
2587 .compression(false)
2588 .build()
2589 .expect("Failed to build configuration");
2590
2591 let mut params = BTreeMap::new();
2592 params.insert("symbol".to_string(), json!("ETHUSDT"));
2593 params.insert("side".to_string(), json!("BUY"));
2594 params.insert("type".to_string(), json!("MARKET"));
2595 params.insert("quantity".to_string(), json!("1"));
2596
2597 let result = send_request::<TestResponse>(
2598 &configuration,
2599 "/api/v3/order",
2600 Method::POST,
2601 params,
2602 BTreeMap::new(),
2603 None,
2604 true,
2605 )
2606 .await?;
2607
2608 let data = result.data().await.unwrap();
2609 assert_eq!(data.message, "order placed");
2610
2611 Ok(())
2612 })
2613 }
2614
2615 #[test]
2616 fn signed_post_request_with_body() -> Result<()> {
2617 TOKIO_SHARED_RT.block_on(async {
2618 let server = MockServer::start();
2619
2620 server.mock(|when, then| {
2621 when.method(POST).path("/api/v3/order");
2622 then.status(200)
2623 .header("content-type", "application/json")
2624 .body(r#"{"message": "order placed"}"#);
2625 });
2626
2627 let configuration = ConfigurationRestApi::builder()
2628 .api_key("key")
2629 .api_secret("secret")
2630 .base_path(server.base_url())
2631 .compression(false)
2632 .build()
2633 .expect("Failed to build configuration");
2634
2635 let mut query_params = BTreeMap::new();
2636 query_params.insert("symbol".to_string(), json!("ETHUSDT"));
2637
2638 let mut body_params = BTreeMap::new();
2639 body_params.insert("side".to_string(), json!("BUY"));
2640 body_params.insert("type".to_string(), json!("MARKET"));
2641 body_params.insert("quantity".to_string(), json!("1"));
2642
2643 let result = send_request::<TestResponse>(
2644 &configuration,
2645 "/api/v3/order",
2646 Method::POST,
2647 query_params,
2648 body_params,
2649 None,
2650 true,
2651 )
2652 .await?;
2653
2654 let data = result.data().await.unwrap();
2655 assert_eq!(data.message, "order placed");
2656
2657 Ok(())
2658 })
2659 }
2660
2661 #[test]
2662 fn get_request_with_params() -> Result<()> {
2663 TOKIO_SHARED_RT.block_on(async {
2664 let server = MockServer::start();
2665
2666 server.mock(|when, then| {
2667 when.method(GET)
2668 .path("/api/v1/data")
2669 .query_param("symbol", "BTCUSDT")
2670 .query_param("limit", "10");
2671 then.status(200)
2672 .header("content-type", "application/json")
2673 .body(r#"{"message": "data retrieved"}"#);
2674 });
2675
2676 let configuration = ConfigurationRestApi::builder()
2677 .api_key("key")
2678 .api_secret("secret")
2679 .base_path(server.base_url())
2680 .compression(false)
2681 .build()
2682 .expect("Failed to build configuration");
2683
2684 let mut params = BTreeMap::new();
2685 params.insert("symbol".to_string(), json!("BTCUSDT"));
2686 params.insert("limit".to_string(), json!(10));
2687
2688 let result = send_request::<TestResponse>(
2689 &configuration,
2690 "/api/v1/data",
2691 Method::GET,
2692 params,
2693 BTreeMap::new(),
2694 None,
2695 false,
2696 )
2697 .await?;
2698
2699 let data = result.data().await.unwrap();
2700 assert_eq!(data.message, "data retrieved");
2701
2702 Ok(())
2703 })
2704 }
2705
2706 #[test]
2707 fn invalid_endpoint() {
2708 TOKIO_SHARED_RT.block_on(async {
2709 let server = MockServer::start();
2710
2711 let configuration = ConfigurationRestApi::builder()
2712 .api_key("key")
2713 .api_secret("secret")
2714 .base_path(server.base_url())
2715 .compression(false)
2716 .build()
2717 .expect("Failed to build configuration");
2718
2719 let params = BTreeMap::new();
2720
2721 let result = send_request::<TestResponse>(
2722 &configuration,
2723 "http://invalid",
2724 Method::GET,
2725 params,
2726 BTreeMap::new(),
2727 None,
2728 false,
2729 )
2730 .await;
2731
2732 assert!(result.is_err());
2733 });
2734 }
2735
2736 #[test]
2737 fn missing_signature_on_signed_request() {
2738 TOKIO_SHARED_RT.block_on(async {
2739 let server = MockServer::start();
2740
2741 let configuration = ConfigurationRestApi::builder()
2742 .api_key("key")
2743 .api_secret("secret")
2744 .base_path(server.base_url())
2745 .compression(false)
2746 .build()
2747 .expect("Failed to build configuration");
2748
2749 let mut params = BTreeMap::new();
2750 params.insert("symbol".to_string(), json!("BTCUSDT"));
2751 params.insert("side".to_string(), json!("BUY"));
2752
2753 let result = send_request::<TestResponse>(
2754 &configuration,
2755 "/api/v3/order",
2756 Method::POST,
2757 params,
2758 BTreeMap::new(),
2759 None,
2760 true,
2761 )
2762 .await;
2763
2764 assert!(result.is_err());
2765 });
2766 }
2767
2768 #[test]
2769 fn compression_enabled() -> Result<()> {
2770 TOKIO_SHARED_RT.block_on(async {
2771 let server = MockServer::start();
2772
2773 server.mock(|when, then| {
2774 when.method(GET).path("/api/v1/test");
2775 then.status(200)
2776 .header("content-type", "application/json")
2777 .header("accept-encoding", "gzip, deflate, br")
2778 .body(r#"{"message": "compression enabled"}"#);
2779 });
2780
2781 let configuration = ConfigurationRestApi::builder()
2782 .api_key("key")
2783 .api_secret("secret")
2784 .base_path(server.base_url())
2785 .compression(true)
2786 .build()
2787 .expect("Failed to build configuration");
2788
2789 let params = BTreeMap::new();
2790
2791 let result = send_request::<TestResponse>(
2792 &configuration,
2793 "/api/v1/test",
2794 Method::GET,
2795 params,
2796 BTreeMap::new(),
2797 None,
2798 false,
2799 )
2800 .await?;
2801
2802 let data = result.data().await.unwrap();
2803 assert_eq!(data.message, "compression enabled");
2804
2805 Ok(())
2806 })
2807 }
2808
2809 #[test]
2810 fn get_request_with_time_unit_header() -> Result<()> {
2811 TOKIO_SHARED_RT.block_on(async {
2812 let server = MockServer::start();
2813
2814 server.mock(|when, then| {
2815 when.method(GET)
2816 .path("/api/v1/test")
2817 .header("X-MBX-TIME-UNIT", "MILLISECOND");
2818 then.status(200)
2819 .header("content-type", "application/json")
2820 .body(r#"{"message": "time unit applied"}"#);
2821 });
2822
2823 let configuration = ConfigurationRestApi::builder()
2824 .api_key("key")
2825 .api_secret("secret")
2826 .base_path(server.base_url())
2827 .compression(false)
2828 .time_unit(TimeUnit::Millisecond)
2829 .build()
2830 .expect("Failed to build configuration");
2831
2832 let params = BTreeMap::new();
2833
2834 let result = send_request::<TestResponse>(
2835 &configuration,
2836 "/api/v1/test",
2837 Method::GET,
2838 params,
2839 BTreeMap::new(),
2840 Some(TimeUnit::Millisecond),
2841 false,
2842 )
2843 .await?;
2844
2845 let data = result.data().await.unwrap();
2846 assert_eq!(data.message, "time unit applied");
2847
2848 Ok(())
2849 })
2850 }
2851
2852 #[test]
2853 fn custom_headers_are_sent() -> Result<()> {
2854 TOKIO_SHARED_RT.block_on(async {
2855 let server = MockServer::start();
2856
2857 server.mock(|when, then| {
2858 when.method(GET)
2859 .path("/api/v1/test")
2860 .header("X-My-Test", "all-clear");
2861 then.status(200)
2862 .header("content-type", "application/json")
2863 .body(r#"{"message":"ok"}"#);
2864 });
2865
2866 let mut custom = HashMap::new();
2867 custom.insert("X-My-Test".to_string(), "all-clear".to_string());
2868
2869 let configuration = ConfigurationRestApi::builder()
2870 .api_key("key")
2871 .api_secret("secret")
2872 .base_path(server.base_url())
2873 .compression(false)
2874 .custom_headers(custom)
2875 .build()
2876 .expect("Failed to build configuration");
2877
2878 let params = BTreeMap::new();
2879 let res = send_request::<TestResponse>(
2880 &configuration,
2881 "/api/v1/test",
2882 Method::GET,
2883 params,
2884 BTreeMap::new(),
2885 None,
2886 false,
2887 )
2888 .await?;
2889
2890 let data = res.data().await.unwrap();
2891 assert_eq!(data.message, "ok");
2892
2893 Ok(())
2894 })
2895 }
2896
2897 #[test]
2898 fn custom_header_override_prevention() -> Result<()> {
2899 TOKIO_SHARED_RT.block_on(async {
2900 let server = MockServer::start();
2901
2902 server.mock(|when, then| {
2903 when.method(GET)
2904 .path("/api/v1/test")
2905 .header("content-type", "application/json")
2906 .header("x-mbx-apikey", "key")
2907 .header("X-My-Test", "ok");
2908 then.status(200)
2909 .header("content-type", "application/json")
2910 .body(r#"{"message":"defaults intact"}"#);
2911 });
2912
2913 let mut custom = HashMap::new();
2914 custom.insert("Content-Type".to_string(), "text/plain".to_string());
2915 custom.insert("X-MBX-APIKEY".to_string(), "BAD".to_string());
2916 custom.insert("X-My-Test".to_string(), "ok".to_string());
2917
2918 let configuration = ConfigurationRestApi::builder()
2919 .api_key("key")
2920 .api_secret("secret")
2921 .base_path(server.base_url())
2922 .compression(false)
2923 .custom_headers(custom)
2924 .build()
2925 .expect("Failed to build configuration");
2926
2927 let params = BTreeMap::new();
2928 let res = send_request::<TestResponse>(
2929 &configuration,
2930 "/api/v1/test",
2931 Method::GET,
2932 params,
2933 BTreeMap::new(),
2934 None,
2935 false,
2936 )
2937 .await?;
2938
2939 let data = res.data().await.unwrap();
2940 assert_eq!(data.message, "defaults intact");
2941
2942 Ok(())
2943 })
2944 }
2945
2946 #[test]
2947 fn crlf_in_header_values_are_dropped() -> Result<()> {
2948 TOKIO_SHARED_RT.block_on(async {
2949 let server = MockServer::start();
2950
2951 server.mock(|when, then| {
2952 when.method(GET)
2953 .path("/api/v1/test")
2954 .header("X-Good", "safe");
2955 then.status(200)
2956 .header("content-type", "application/json")
2957 .body(r#"{"message":"clean only"}"#);
2958 });
2959
2960 let mut custom = HashMap::new();
2961 custom.insert("X-Bad".to_string(), "evil\r\ninject".to_string());
2962 custom.insert("X-Good".to_string(), "safe".to_string());
2963
2964 let configuration = ConfigurationRestApi::builder()
2965 .api_key("key")
2966 .api_secret("secret")
2967 .base_path(server.base_url())
2968 .compression(false)
2969 .custom_headers(custom)
2970 .build()
2971 .expect("Failed to build configuration");
2972
2973 let params = BTreeMap::new();
2974 let res = send_request::<TestResponse>(
2975 &configuration,
2976 "/api/v1/test",
2977 Method::GET,
2978 params,
2979 BTreeMap::new(),
2980 None,
2981 false,
2982 )
2983 .await?;
2984
2985 let data = res.data().await.unwrap();
2986 assert_eq!(data.message, "clean only");
2987
2988 Ok(())
2989 })
2990 }
2991 }
2992
2993 mod random_string {
2994 use crate::common::utils::random_string;
2995 use hex;
2996
2997 #[test]
2998 fn length_is_32() {
2999 let s = random_string();
3000 assert_eq!(
3001 s.len(),
3002 32,
3003 "random_string() should be 32 chars, got {}",
3004 s.len()
3005 );
3006 }
3007
3008 #[test]
3009 fn is_valid_lowercase_hex() {
3010 let s = random_string();
3011 assert!(
3012 s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')),
3013 "random_string() contains invalid hex characters: {s}"
3014 );
3015 }
3016
3017 #[test]
3018 fn decodes_to_16_bytes() {
3019 let s = random_string();
3020 let bytes = hex::decode(&s).expect("random_string() output must be valid hex");
3021 assert_eq!(
3022 bytes.len(),
3023 16,
3024 "hex::decode returned {} bytes",
3025 bytes.len()
3026 );
3027 }
3028
3029 #[test]
3030 fn two_calls_are_different() {
3031 let a = random_string();
3032 let b = random_string();
3033 assert_ne!(
3034 a, b,
3035 "Two calls to random_string() returned the same value: {a}"
3036 );
3037 }
3038 }
3039
3040 mod random_integer {
3041 use crate::common::utils::random_integer;
3042
3043 #[test]
3044 fn is_within_u32_range() {
3045 let n = random_integer();
3046 assert!(
3047 n <= u32::MAX,
3048 "random_integer() should be <= u32::MAX, got {n}"
3049 );
3050 }
3051
3052 #[test]
3053 fn two_calls_can_differ() {
3054 let a = random_integer();
3055 let b = random_integer();
3056 assert_ne!(
3057 a, b,
3058 "Two calls to random_integer() returned the same value: {a}"
3059 );
3060 }
3061 }
3062
3063 mod normalize_stream_id {
3064 use crate::common::utils::{StreamId, normalize_stream_id};
3065 use serde_json::Value;
3066
3067 fn is_lower_hex32(s: &str) -> bool {
3068 s.len() == 32 && s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f'))
3069 }
3070
3071 #[test]
3072 fn valid_hex_string_is_kept() {
3073 let id = "0123456789abcdef0123456789abcdef".to_string();
3074 let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
3075
3076 match out {
3077 Value::String(s) => assert_eq!(s, id, "Expected to keep the valid hex id"),
3078 other => panic!("Expected Value::String, got {other:?}"),
3079 }
3080 }
3081
3082 #[test]
3083 fn invalid_hex_string_generates_random_hex() {
3084 let id = "not-hex".to_string();
3085 let out = normalize_stream_id(Some(StreamId::Str(id.clone())), false);
3086
3087 match out {
3088 Value::String(s) => {
3089 assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
3090 assert_ne!(s, id, "Expected generated id to differ from input");
3091 assert!(
3092 is_lower_hex32(&s),
3093 "Generated id contains invalid hex characters: {s}"
3094 );
3095 }
3096 other => panic!("Expected Value::String, got {other:?}"),
3097 }
3098 }
3099
3100 #[test]
3101 fn none_generates_random_hex() {
3102 let out = normalize_stream_id(None, false);
3103
3104 match out {
3105 Value::String(s) => {
3106 assert_eq!(s.len(), 32, "Expected 32-char hex, got {}", s.len());
3107 assert!(
3108 is_lower_hex32(&s),
3109 "Generated id contains invalid hex characters: {s}"
3110 );
3111 }
3112 other => panic!("Expected Value::String, got {other:?}"),
3113 }
3114 }
3115
3116 #[test]
3117 fn number_is_kept_when_not_strict() {
3118 let out = normalize_stream_id(Some(StreamId::Number(42)), false);
3119
3120 match out {
3121 Value::Number(n) => {
3122 assert_eq!(n.as_u64(), Some(42), "Expected to keep the numeric id");
3123 }
3124 other => panic!("Expected Value::Number, got {other:?}"),
3125 }
3126 }
3127
3128 #[test]
3129 fn strict_number_forces_number_even_for_valid_hex_string() {
3130 let id = "0123456789abcdef0123456789abcdef".to_string();
3131 let out = normalize_stream_id(Some(StreamId::Str(id)), true);
3132
3133 match out {
3134 Value::Number(n) => {
3135 assert!(
3136 n.as_u64().is_some(),
3137 "Expected unsigned integer JSON number, got {n}"
3138 );
3139 }
3140 other => panic!("Expected Value::Number, got {other:?}"),
3141 }
3142 }
3143
3144 #[test]
3145 fn strict_number_keeps_number_if_provided() {
3146 let out = normalize_stream_id(Some(StreamId::Number(7)), true);
3147
3148 match out {
3149 Value::Number(n) => {
3150 assert_eq!(n.as_u64(), Some(7), "Expected to keep the numeric id");
3151 }
3152 other => panic!("Expected Value::Number, got {other:?}"),
3153 }
3154 }
3155
3156 #[test]
3157 fn strict_number_generates_number_when_none() {
3158 let out = normalize_stream_id(None, true);
3159
3160 match out {
3161 Value::Number(n) => {
3162 assert!(
3163 n.as_u64().is_some(),
3164 "Expected unsigned integer JSON number, got {n}"
3165 );
3166 }
3167 other => panic!("Expected Value::Number, got {other:?}"),
3168 }
3169 }
3170
3171 #[test]
3172 fn strict_number_generates_number_for_invalid_hex_string() {
3173 let out = normalize_stream_id(Some(StreamId::Str("nope".to_string())), true);
3174
3175 match out {
3176 Value::Number(n) => {
3177 assert!(
3178 n.as_u64().is_some(),
3179 "Expected unsigned integer JSON number, got {n}"
3180 );
3181 }
3182 other => panic!("Expected Value::Number, got {other:?}"),
3183 }
3184 }
3185 }
3186 mod remove_empty_value {
3187 use crate::common::utils::remove_empty_value;
3188 use serde_json::{Map, Value};
3189
3190 #[test]
3191 fn filters_out_null_and_empty_strings() {
3192 let entries = vec![
3193 ("key1".to_string(), Value::String("value1".to_string())),
3194 ("key2".to_string(), Value::Null),
3195 ("key3".to_string(), Value::String(String::new())),
3196 ];
3197 let result = remove_empty_value(entries);
3198 assert_eq!(
3199 result.len(),
3200 1,
3201 "expected only one entry, got {}",
3202 result.len()
3203 );
3204 assert_eq!(
3205 result.get("key1"),
3206 Some(&Value::String("value1".to_string()))
3207 );
3208 assert!(!result.contains_key("key2"));
3209 assert!(!result.contains_key("key3"));
3210 }
3211
3212 #[test]
3213 fn retains_other_value_types() {
3214 let entries = vec![
3215 ("bool".to_string(), Value::Bool(true)),
3216 ("num".to_string(), Value::Number(42.into())),
3217 ("arr".to_string(), Value::Array(vec![])),
3218 ("obj".to_string(), Value::Object(Map::default())),
3219 ("nil".to_string(), Value::Null),
3220 ("empty_str".to_string(), Value::String(String::new())),
3221 ];
3222 let result = remove_empty_value(entries);
3223 let keys: Vec<&String> = result.keys().collect();
3224 assert_eq!(keys.len(), 4, "expected 4 entries, got {}", keys.len());
3225 assert!(result.get("bool") == Some(&Value::Bool(true)));
3226 assert!(result.get("num") == Some(&Value::Number(42.into())));
3227 assert!(result.get("arr") == Some(&Value::Array(vec![])));
3228 assert!(result.get("obj") == Some(&Value::Object(Map::default())));
3229 assert!(!result.contains_key("nil"));
3230 assert!(!result.contains_key("empty_str"));
3231 }
3232
3233 #[test]
3234 fn empty_iterator_returns_empty_map() {
3235 let entries: Vec<(String, Value)> = vec![];
3236 let result = remove_empty_value(entries);
3237 assert!(result.is_empty(), "expected an empty map");
3238 }
3239
3240 #[test]
3241 fn keys_are_sorted() {
3242 let entries = vec![
3243 ("c".to_string(), Value::String("foo".to_string())),
3244 ("a".to_string(), Value::String("bar".to_string())),
3245 ("b".to_string(), Value::String("baz".to_string())),
3246 ];
3247 let result = remove_empty_value(entries);
3248 let sorted_keys: Vec<&String> = result.keys().collect();
3249 assert_eq!(
3250 sorted_keys,
3251 [&"a".to_string(), &"b".to_string(), &"c".to_string()]
3252 );
3253 }
3254 }
3255
3256 mod sort_object_params {
3257 use crate::common::utils::sort_object_params;
3258 use serde_json::Value;
3259 use std::collections::BTreeMap;
3260
3261 #[test]
3262 fn sorts_keys() {
3263 let mut params = BTreeMap::new();
3264 params.insert("z".to_string(), Value::String("last".to_string()));
3265 params.insert("a".to_string(), Value::String("first".to_string()));
3266 params.insert("m".to_string(), Value::String("middle".to_string()));
3267
3268 let sorted = sort_object_params(¶ms);
3269 let keys: Vec<&String> = sorted.keys().collect();
3270 assert_eq!(
3271 keys,
3272 [&"a".to_string(), &"m".to_string(), &"z".to_string()],
3273 "Keys should be sorted alphabetically"
3274 );
3275 }
3276
3277 #[test]
3278 fn preserves_values() {
3279 let mut params = BTreeMap::new();
3280 params.insert("one".to_string(), Value::Number(1.into()));
3281 params.insert("two".to_string(), Value::Bool(true));
3282
3283 let sorted = sort_object_params(¶ms);
3284 assert_eq!(sorted.get("one"), Some(&Value::Number(1.into())));
3285 assert_eq!(sorted.get("two"), Some(&Value::Bool(true)));
3286 }
3287
3288 #[test]
3289 fn empty_map_returns_empty() {
3290 let params: BTreeMap<String, Value> = BTreeMap::new();
3291 let sorted = sort_object_params(¶ms);
3292 assert!(sorted.is_empty(), "Expected empty map");
3293 }
3294
3295 #[test]
3296 fn independent_clone() {
3297 let mut params = BTreeMap::new();
3298 params.insert("key".to_string(), Value::String("val".to_string()));
3299
3300 let mut sorted = sort_object_params(¶ms);
3301 sorted.insert("new".to_string(), Value::String("x".to_string()));
3302
3303 assert!(
3304 !params.contains_key("new"),
3305 "Original should not be modified when changing sorted"
3306 );
3307 assert!(
3308 sorted.contains_key("new"),
3309 "Sorted map should reflect its own insertions"
3310 );
3311 }
3312 }
3313
3314 mod normalize_ws_streams_key {
3315 use crate::common::utils::normalize_ws_streams_key;
3316
3317 #[test]
3318 fn returns_empty_for_empty() {
3319 assert_eq!(normalize_ws_streams_key(""), "");
3320 }
3321
3322 #[test]
3323 fn already_normalized_stays_same() {
3324 assert_eq!(normalize_ws_streams_key("streamname"), "streamname");
3325 }
3326
3327 #[test]
3328 fn uppercases_are_lowercased() {
3329 assert_eq!(normalize_ws_streams_key("MyStream"), "mystream");
3330 }
3331
3332 #[test]
3333 fn underscores_are_removed() {
3334 assert_eq!(normalize_ws_streams_key("my_stream_name"), "mystreamname");
3335 }
3336
3337 #[test]
3338 fn hyphens_are_removed() {
3339 assert_eq!(normalize_ws_streams_key("my-stream-name"), "mystreamname");
3340 }
3341
3342 #[test]
3343 fn mixed_underscores_and_hyphens_and_case() {
3344 let input = "Mixed_Case-Stream_Name";
3345 let expected = "mixedcasestreamname";
3346 assert_eq!(normalize_ws_streams_key(input), expected);
3347 }
3348
3349 #[test]
3350 fn retains_other_punctuation() {
3351 assert_eq!(normalize_ws_streams_key("stream.name!"), "stream.name!");
3352 }
3353 }
3354
3355 mod replace_websocket_streams_placeholders {
3356 use crate::common::utils::replace_websocket_streams_placeholders;
3357 use std::collections::HashMap;
3358
3359 #[test]
3360 fn empty_string_unchanged() {
3361 let vars: HashMap<&str, &str> = HashMap::new();
3362 assert_eq!(replace_websocket_streams_placeholders("", &vars), "");
3363 }
3364
3365 #[test]
3366 fn unknown_placeholder_becomes_empty() {
3367 let vars: HashMap<&str, &str> = HashMap::new();
3368 assert_eq!(replace_websocket_streams_placeholders("<foo>", &vars), "");
3369 }
3370
3371 #[test]
3372 fn leading_slash_symbol_lowercases_head() {
3373 let mut vars = HashMap::new();
3374 vars.insert("symbol", "BTC");
3375 assert_eq!(
3376 replace_websocket_streams_placeholders("/<symbol>", &vars),
3377 "btc"
3378 );
3379 }
3380
3381 #[test]
3382 fn no_lowercase_without_slash() {
3383 let mut vars = HashMap::new();
3384 vars.insert("symbol", "BTC");
3385 assert_eq!(
3386 replace_websocket_streams_placeholders("<symbol>", &vars),
3387 "BTC"
3388 );
3389 }
3390
3391 #[test]
3392 fn multiple_placeholders_mid_preserve_ats() {
3393 let mut vars = HashMap::new();
3394 vars.insert("symbol", "BNBUSDT");
3395 vars.insert("levels", "10");
3396 vars.insert("updateSpeed", "1000ms");
3397 let out = replace_websocket_streams_placeholders(
3398 "/<symbol>@depth<levels>@<updateSpeed>",
3399 &vars,
3400 );
3401 assert_eq!(out, "bnbusdt@depth10@1000ms");
3402 }
3403
3404 #[test]
3405 fn trailing_at_removed_when_missing_var() {
3406 let mut vars = HashMap::new();
3407 vars.insert("symbol", "BNBUSDT");
3408 vars.insert("levels", "10");
3409 let out = replace_websocket_streams_placeholders(
3410 "/<symbol>@depth<levels>@<updateSpeed>",
3411 &vars,
3412 );
3413 assert_eq!(out, "bnbusdt@depth10");
3414 }
3415
3416 #[test]
3417 fn custom_key_normalization_and_value() {
3418 let mut vars = HashMap::new();
3419 vars.insert("my-stream_key", "Value");
3420 assert_eq!(
3421 replace_websocket_streams_placeholders("<My_Stream-Key>", &vars),
3422 "Value"
3423 );
3424 }
3425
3426 #[test]
3427 fn text_surrounding_placeholders_intact() {
3428 let mut vars = HashMap::new();
3429 vars.insert("symbol", "ABC");
3430 let input = "pre-<symbol>-post";
3431 assert_eq!(
3432 replace_websocket_streams_placeholders(input, &vars),
3433 "pre-ABC-post"
3434 );
3435 }
3436 }
3437
3438 mod build_websocket_api_message {
3439 use serde_json::{Value, json};
3440 use std::collections::BTreeMap;
3441
3442 use crate::{
3443 common::{
3444 utils::{ID_REGEX, build_websocket_api_message, remove_empty_value},
3445 websocket::WebsocketMessageSendOptions,
3446 },
3447 config::ConfigurationWebsocketApi,
3448 };
3449
3450 fn make_config() -> ConfigurationWebsocketApi {
3451 ConfigurationWebsocketApi::builder()
3452 .api_key("api-key".to_string())
3453 .api_secret("api-secret".to_string())
3454 .build()
3455 .unwrap()
3456 }
3457
3458 #[test]
3459 fn no_auth_or_sign_with_skip_auth() {
3460 let mut payload = BTreeMap::new();
3461 payload.insert("foo".into(), Value::String("bar".into()));
3462 let cfg = make_config();
3463
3464 let (id, req) = build_websocket_api_message(
3465 &cfg,
3466 "method",
3467 payload.clone(),
3468 &WebsocketMessageSendOptions {
3469 with_api_key: true,
3470 is_signed: true,
3471 ..Default::default()
3472 },
3473 true,
3474 );
3475
3476 assert!(ID_REGEX.is_match(&id));
3477 assert_eq!(req["method"], "method");
3478 assert_eq!(req["params"]["foo"], "bar");
3479 assert!(req["params"].get("apiKey").is_none());
3480 assert!(req["params"].get("signature").is_none());
3481 assert!(req["params"]["timestamp"].is_number());
3482 }
3483
3484 #[test]
3485 fn only_api_key_when_not_signed() {
3486 let cfg = make_config();
3487
3488 let (id, req) = build_websocket_api_message(
3489 &cfg,
3490 "method",
3491 BTreeMap::new(),
3492 &WebsocketMessageSendOptions {
3493 with_api_key: true,
3494 is_signed: false,
3495 ..Default::default()
3496 },
3497 false,
3498 );
3499
3500 assert!(ID_REGEX.is_match(&id));
3501 assert_eq!(req["method"], "method");
3502 assert_eq!(req["params"]["apiKey"], "api-key");
3503 assert!(req["params"].get("timestamp").is_none());
3504 assert!(req["params"].get("signature").is_none());
3505 }
3506
3507 #[test]
3508 fn signed_includes_timestamp_and_signature() {
3509 let mut payload = BTreeMap::new();
3510 payload.insert("foo".into(), Value::String("bar".into()));
3511 let cfg = make_config();
3512
3513 let (id, req) = build_websocket_api_message(
3514 &cfg,
3515 "method",
3516 payload.clone(),
3517 &WebsocketMessageSendOptions {
3518 with_api_key: true,
3519 is_signed: true,
3520 ..Default::default()
3521 },
3522 false,
3523 );
3524
3525 assert!(ID_REGEX.is_match(&id));
3526 assert_eq!(req["method"], "method");
3527
3528 let params = &req["params"];
3529 assert_eq!(params["apiKey"], "api-key");
3530
3531 let timestamp = params["timestamp"].as_i64().unwrap();
3532 assert!(timestamp > 0, "timestamp should not be empty");
3533
3534 let sig = params["signature"].as_str().unwrap();
3535 assert!(!sig.is_empty(), "signature should not be empty");
3536 }
3537
3538 #[test]
3539 fn respects_provided_valid_id_and_removes_from_params() {
3540 let mut payload = BTreeMap::new();
3541 let custom = "0123456789abcdef0123456789abcdef".to_string();
3542 payload.insert("id".into(), Value::String(custom.clone()));
3543 payload.insert("foo".into(), Value::Number(123.into()));
3544
3545 let cfg = make_config();
3546 let (id, req) = build_websocket_api_message(
3547 &cfg,
3548 "method",
3549 payload.clone(),
3550 &WebsocketMessageSendOptions::default(),
3551 true,
3552 );
3553
3554 assert_eq!(id, custom);
3555 assert!(req["params"].get("id").is_none());
3556 assert_eq!(req["params"]["foo"], 123);
3557 }
3558
3559 #[test]
3560 fn skip_auth_blocks_api_and_signature_but_keeps_timestamp() {
3561 let mut payload = BTreeMap::new();
3562 payload.insert("foo".into(), Value::String("bar".into()));
3563 let cfg = make_config();
3564
3565 let (_id, req) = build_websocket_api_message(
3566 &cfg,
3567 "method",
3568 payload.clone(),
3569 &WebsocketMessageSendOptions {
3570 with_api_key: true,
3571 is_signed: true,
3572 ..Default::default()
3573 },
3574 true,
3575 );
3576
3577 let p = &req["params"];
3578 assert_eq!(p["foo"], "bar");
3579 assert!(p.get("apiKey").is_none());
3580 assert!(p.get("signature").is_none());
3581 assert!(p["timestamp"].is_number());
3582 }
3583
3584 #[test]
3585 fn random_id_changes_each_call() {
3586 let cfg = make_config();
3587 let (id1, _) = build_websocket_api_message(
3588 &cfg,
3589 "method",
3590 BTreeMap::new(),
3591 &WebsocketMessageSendOptions::default(),
3592 true,
3593 );
3594 let (id2, _) = build_websocket_api_message(
3595 &cfg,
3596 "method",
3597 BTreeMap::new(),
3598 &WebsocketMessageSendOptions::default(),
3599 true,
3600 );
3601 assert!(ID_REGEX.is_match(&id1));
3602 assert!(ID_REGEX.is_match(&id2));
3603 assert_ne!(id1, id2, "IDs should be random and not equal");
3604 }
3605
3606 #[test]
3607 fn null_and_empty_values_are_stripped() {
3608 let mut payload = BTreeMap::new();
3609 payload.insert("a".into(), Value::Null);
3610 payload.insert("b".into(), Value::String(String::new()));
3611 payload.insert("c".into(), Value::String("ok".into()));
3612
3613 let cleaned = remove_empty_value(payload.clone());
3614 assert!(!cleaned.contains_key("a"), "Null should be stripped");
3615 assert!(
3616 !cleaned.contains_key("b"),
3617 "Empty string should be stripped"
3618 );
3619 assert!(cleaned.contains_key("c"), "Non-empty string should be kept");
3620
3621 let cfg = make_config();
3622 let (_id, req) = build_websocket_api_message(
3623 &cfg,
3624 "method",
3625 payload,
3626 &WebsocketMessageSendOptions::default(),
3627 true,
3628 );
3629 let params = &req["params"];
3630 assert!(params.get("a").is_none(), "`a` should not appear");
3631 assert!(params.get("b").is_none(), "`b` should not appear");
3632 assert_eq!(params["c"], "ok", "`c` should be present with value \"ok\"");
3633 }
3634
3635 #[test]
3636 fn provided_invalid_id_gets_replaced() {
3637 let mut payload = BTreeMap::new();
3638 payload.insert("id".into(), Value::String("not-hex-32-chars".into()));
3639 let cfg = make_config();
3640 let (id, _req) = build_websocket_api_message(
3641 &cfg,
3642 "method",
3643 payload,
3644 &WebsocketMessageSendOptions::default(),
3645 true,
3646 );
3647
3648 assert!(ID_REGEX.is_match(&id));
3649 assert_ne!(id, "not-hex-32-chars");
3650 }
3651
3652 #[test]
3653 fn sign_only_includes_api_key_even_when_with_api_key_false() {
3654 let mut payload = BTreeMap::new();
3655 payload.insert("x".into(), json!(1));
3656
3657 let cfg = make_config();
3658 let (_id, req) = build_websocket_api_message(
3659 &cfg,
3660 "method",
3661 payload,
3662 &WebsocketMessageSendOptions {
3663 with_api_key: false,
3664 is_signed: true,
3665 ..Default::default()
3666 },
3667 false,
3668 );
3669 let params = &req["params"];
3670
3671 assert_eq!(params["apiKey"], "api-key");
3672 assert!(params["timestamp"].is_number());
3673 assert!(params["signature"].is_string());
3674 }
3675
3676 #[test]
3677 fn skip_auth_false_without_any_auth_flags() {
3678 let cfg = make_config();
3679 let (_id, req) = build_websocket_api_message(
3680 &cfg,
3681 "method",
3682 BTreeMap::new(),
3683 &WebsocketMessageSendOptions {
3684 with_api_key: false,
3685 is_signed: false,
3686 ..Default::default()
3687 },
3688 false,
3689 );
3690 let params = &req["params"];
3691 assert!(params.as_object().unwrap().is_empty());
3692 }
3693 }
3694}