1use anyhow::{Context, Result, bail};
2use base64::{Engine as _, engine::general_purpose};
3use ed25519_dalek::Signer as Ed25519Signer;
4use ed25519_dalek::SigningKey;
5use ed25519_dalek::pkcs8::DecodePrivateKey;
6use flate2::read::GzDecoder;
7use hex;
8use hmac::{Hmac, Mac};
9use http::HeaderMap;
10use http::header::ACCEPT_ENCODING;
11use once_cell::sync::OnceCell;
12use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer as OpenSslSigner};
13use rand::RngCore;
14use regex::Captures;
15use regex::Regex;
16use reqwest::Client;
17use reqwest::Proxy;
18use reqwest::{Method, Request};
19use serde::de::DeserializeOwned;
20use serde_json::{Value, json};
21use sha2::Sha256;
22use std::fmt::Display;
23use std::hash::BuildHasher;
24use std::sync::LazyLock;
25use std::{
26 collections::BTreeMap,
27 collections::HashMap,
28 fs,
29 io::Read,
30 path::Path,
31 time::Duration,
32 time::{SystemTime, UNIX_EPOCH},
33};
34use tokio::time::sleep;
35use tracing::info;
36use url::{Url, form_urlencoded::Serializer};
37
38use super::config::HttpAgent;
39use super::config::ProxyConfig;
40use super::config::{ConfigurationRestApi, PrivateKey};
41use super::errors::ConnectorError;
42use super::models::TimeUnit;
43use super::models::{Interval, RateLimitType, RestApiRateLimit, RestApiResponse};
44
45static PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(@)?<([^>]+)>").unwrap());
46
47#[derive(Debug, Default, Clone)]
62pub struct SignatureGenerator {
63 api_secret: Option<String>,
64 private_key: Option<PrivateKey>,
65 private_key_passphrase: Option<String>,
66 raw_key_data: OnceCell<String>,
67 key_object: OnceCell<PKey<openssl::pkey::Private>>,
68 ed25519_signing_key: OnceCell<SigningKey>,
69}
70
71impl SignatureGenerator {
72 #[must_use]
73 pub fn new(
74 api_secret: Option<String>,
75 private_key: Option<PrivateKey>,
76 private_key_passphrase: Option<String>,
77 ) -> Self {
78 SignatureGenerator {
79 api_secret,
80 private_key,
81 private_key_passphrase,
82 raw_key_data: OnceCell::new(),
83 key_object: OnceCell::new(),
84 ed25519_signing_key: OnceCell::new(),
85 }
86 }
87
88 fn get_raw_key_data(&self) -> Result<&String> {
103 self.raw_key_data.get_or_try_init(|| {
104 let pk = self
105 .private_key
106 .as_ref()
107 .ok_or_else(|| anyhow::anyhow!("No private_key provided"))?;
108 match pk {
109 PrivateKey::File(path) => {
110 if Path::new(path).exists() {
111 fs::read_to_string(path)
112 .with_context(|| format!("Failed to read private key file: {path}"))
113 } else {
114 Err(anyhow::anyhow!("Private key file does not exist: {}", path))
115 }
116 }
117 PrivateKey::Raw(bytes) => Ok(String::from_utf8_lossy(bytes).to_string()),
118 }
119 })
120 }
121
122 fn get_key_object(&self) -> Result<&PKey<openssl::pkey::Private>> {
137 self.key_object.get_or_try_init(|| {
138 let key_data = self.get_raw_key_data()?;
139 if let Some(pass) = self.private_key_passphrase.as_ref() {
140 PKey::private_key_from_pem_passphrase(key_data.as_bytes(), pass.as_bytes())
141 .context("Failed to parse private key with passphrase")
142 } else {
143 PKey::private_key_from_pem(key_data.as_bytes())
144 .context("Failed to parse private key")
145 }
146 })
147 }
148
149 fn get_ed25519_signing_key(&self) -> Result<&SigningKey> {
162 self.ed25519_signing_key.get_or_try_init(|| {
163 let key_data = self.get_raw_key_data()?;
164 let b64 = key_data
165 .lines()
166 .filter(|l| !l.starts_with("-----"))
167 .collect::<String>();
168 let der = general_purpose::STANDARD
169 .decode(b64)
170 .context("Failed to base64 decode Ed25519 PEM")?;
171 SigningKey::from_pkcs8_der(&der)
172 .map_err(|e| anyhow::anyhow!("Failed to parse Ed25519 key: {}", e))
173 })
174 }
175
176 pub fn get_signature(&self, query_params: &BTreeMap<String, Value>) -> Result<String> {
199 let params = build_query_string(query_params)?;
200
201 if let Some(secret) = self.api_secret.as_ref() {
202 if self.private_key.is_none() {
203 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
204 .context("HMAC key initialization failed")?;
205 mac.update(params.as_bytes());
206 let result = mac.finalize().into_bytes();
207 return Ok(hex::encode(result));
208 }
209 }
210
211 if self.private_key.is_some() {
212 let key_obj = self.get_key_object()?;
213 match key_obj.id() {
214 openssl::pkey::Id::RSA => {
215 let mut signer = OpenSslSigner::new(MessageDigest::sha256(), key_obj)
216 .context("Failed to create RSA signer")?;
217 signer
218 .update(params.as_bytes())
219 .context("Failed to update RSA signer")?;
220 let sig = signer.sign_to_vec().context("RSA signing failed")?;
221 return Ok(general_purpose::STANDARD.encode(sig));
222 }
223 openssl::pkey::Id::ED25519 => {
224 let signing_key = self.get_ed25519_signing_key()?;
225 let signature = signing_key.sign(params.as_bytes());
226 return Ok(general_purpose::STANDARD.encode(signature.to_bytes()));
227 }
228 other => {
229 return Err(anyhow::anyhow!(
230 "Unsupported private key type: {:?}. Must be RSA or ED25519.",
231 other
232 ));
233 }
234 }
235 }
236
237 Err(anyhow::anyhow!(
238 "Either 'api_secret' or 'private_key' must be provided for signed requests."
239 ))
240 }
241}
242
243#[must_use]
266pub fn build_client(
267 timeout: u64,
268 keep_alive: bool,
269 proxy: Option<&ProxyConfig>,
270 agent: Option<HttpAgent>,
271) -> Client {
272 let builder = Client::builder().timeout(Duration::from_millis(timeout));
273
274 let mut builder = if keep_alive {
275 builder
276 } else {
277 builder.pool_idle_timeout(Some(Duration::from_secs(0)))
278 };
279
280 if let Some(proxy_conf) = proxy {
281 let protocol = proxy_conf
282 .protocol
283 .clone()
284 .unwrap_or_else(|| "http".to_string());
285 let proxy_url = format!("{}://{}:{}", protocol, proxy_conf.host, proxy_conf.port);
286 let mut proxy_builder = Proxy::all(&proxy_url).expect("Failed to create proxy from URL");
287 if let Some(auth) = &proxy_conf.auth {
288 proxy_builder = proxy_builder.basic_auth(&auth.username, &auth.password);
289 }
290 builder = builder.proxy(proxy_builder);
291 }
292
293 if let Some(HttpAgent(agent_fn)) = agent {
294 builder = (agent_fn)(builder);
295 }
296
297 info!("Client builder {:?}", builder);
298
299 builder.build().expect("Failed to build reqwest client")
300}
301
302#[must_use]
325pub fn build_user_agent(product: &str) -> String {
326 format!(
327 "{}/{}/{} (Rust/{}; {}; {})",
328 env!("CARGO_PKG_NAME"),
329 product,
330 env!("CARGO_PKG_VERSION"),
331 env!("RUSTC_VERSION"),
332 std::env::consts::OS,
333 std::env::consts::ARCH,
334 )
335}
336
337pub fn validate_time_unit(time_unit: &str) -> Result<Option<&str>, anyhow::Error> {
365 match time_unit {
366 "" => Ok(None),
367 "MILLISECOND" | "MICROSECOND" | "millisecond" | "microsecond" => Ok(Some(time_unit)),
368 _ => Err(anyhow::anyhow!(
369 "time_unit must be either 'MILLISECOND' or 'MICROSECOND'"
370 )),
371 }
372}
373
374#[must_use]
391pub fn get_timestamp() -> u128 {
392 SystemTime::now()
393 .duration_since(UNIX_EPOCH)
394 .expect("Time went backwards")
395 .as_millis()
396}
397
398pub async fn delay(ms: u64) {
410 sleep(Duration::from_millis(ms)).await;
411}
412
413pub fn build_query_string(params: &BTreeMap<String, Value>) -> Result<String, anyhow::Error> {
432 let mut segments = Vec::with_capacity(params.len());
433
434 for (key, value) in params {
435 match value {
436 Value::Null => {}
437 Value::String(s) => {
438 let mut ser = Serializer::new(String::new());
439 ser.append_pair(key, s);
440 segments.push(ser.finish());
441 }
442 Value::Bool(b) => {
443 let val = b.to_string();
444 let mut ser = Serializer::new(String::new());
445 ser.append_pair(key, &val);
446 segments.push(ser.finish());
447 }
448 Value::Number(n) => {
449 let val = n.to_string();
450 let mut ser = Serializer::new(String::new());
451 ser.append_pair(key, &val);
452 segments.push(ser.finish());
453 }
454 Value::Array(arr)
455 if arr
456 .iter()
457 .all(|v| matches!(v, Value::String(_) | Value::Bool(_) | Value::Number(_))) =>
458 {
459 let mut parts = Vec::with_capacity(arr.len());
460 for v in arr {
461 match v {
462 Value::String(s) => parts.push(s.clone()),
463 Value::Bool(b) => parts.push(b.to_string()),
464 Value::Number(n) => parts.push(n.to_string()),
465 _ => unreachable!(),
466 }
467 }
468 segments.push(format!("{}={}", key, parts.join(",")));
469 }
470 Value::Array(arr) => {
471 let json =
472 serde_json::to_string(arr).context("Failed to JSON-serialize nested array")?;
473 let mut ser = Serializer::new(String::new());
474 ser.append_pair(key, &json);
475 segments.push(ser.finish());
476 }
477 Value::Object(_) => {
478 bail!("Cannot serialize object for key `{}` in query params", key);
479 }
480 }
481 }
482
483 Ok(segments.join("&"))
484}
485
486#[must_use]
494pub fn should_retry_request(
495 error: &reqwest::Error,
496 method: Option<&str>,
497 retries_left: Option<usize>,
498) -> bool {
499 let method = method.unwrap_or("");
500 let is_retriable_method =
501 method.eq_ignore_ascii_case("GET") || method.eq_ignore_ascii_case("DELETE");
502
503 let status = error.status().map_or(0, |s| s.as_u16());
504 let is_retriable_status = [500, 502, 503, 504].contains(&status);
505
506 let retries_left = retries_left.unwrap_or(0);
507 retries_left > 0 && is_retriable_method && (is_retriable_status || error.status().is_none())
508}
509
510#[must_use]
535pub fn parse_rate_limit_headers<S>(headers: &HashMap<String, String, S>) -> Vec<RestApiRateLimit>
536where
537 S: BuildHasher,
538{
539 let mut rate_limits = Vec::new();
540 let re = Regex::new(r"x-mbx-(used-weight|order-count)-(\d+)([smhd])").unwrap();
541 for (key, value) in headers {
542 let normalized_key = key.to_lowercase();
543 if normalized_key.starts_with("x-mbx-used-weight-")
544 || normalized_key.starts_with("x-mbx-order-count-")
545 {
546 if let Some(caps) = re.captures(&normalized_key) {
547 let interval_num: u32 = caps.get(2).unwrap().as_str().parse().unwrap_or(0);
548 let interval_letter = caps.get(3).unwrap().as_str().to_uppercase();
549 let interval = match interval_letter.as_str() {
550 "S" => Interval::Second,
551 "M" => Interval::Minute,
552 "H" => Interval::Hour,
553 "D" => Interval::Day,
554 _ => continue,
555 };
556 let count: u32 = value.parse().unwrap_or(0);
557 let rate_limit_type = if normalized_key.starts_with("x-mbx-used-weight-") {
558 RateLimitType::RequestWeight
559 } else {
560 RateLimitType::Orders
561 };
562 rate_limits.push(RestApiRateLimit {
563 rate_limit_type,
564 interval,
565 interval_num,
566 count,
567 retry_after: headers.get("retry-after").and_then(|v| v.parse().ok()),
568 });
569 }
570 }
571 }
572 rate_limits
573}
574
575pub async fn http_request<T: DeserializeOwned + Send + 'static>(
605 req: Request,
606 configuration: &ConfigurationRestApi,
607) -> Result<RestApiResponse<T>, ConnectorError> {
608 let client = &configuration.client;
609 let retries = configuration.retries as usize;
610 let backoff = configuration.backoff;
611 let mut attempt = 0;
612
613 loop {
614 let req_clone = req
615 .try_clone()
616 .context("Failed to clone request")
617 .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
618 match client.execute(req_clone).await {
619 Ok(response) => {
620 let status = response.status();
621 let headers_map: HashMap<String, String> = response
622 .headers()
623 .iter()
624 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
625 .collect();
626
627 let raw_bytes = match response.bytes().await {
628 Ok(b) => b,
629 Err(e) => {
630 attempt += 1;
631 if attempt <= retries {
632 continue;
633 }
634 return Err(ConnectorError::ConnectorClientError(format!(
635 "Failed to get response bytes: {e}"
636 )));
637 }
638 };
639
640 let content = if headers_map
641 .get("content-encoding")
642 .is_some_and(|enc| enc.to_lowercase().contains("gzip"))
643 {
644 let mut decoder = GzDecoder::new(&raw_bytes[..]);
645 let mut decompressed = String::new();
646 decoder
647 .read_to_string(&mut decompressed)
648 .context("Failed to decompress gzip response")
649 .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
650 decompressed
651 } else {
652 String::from_utf8(raw_bytes.to_vec())
653 .context("Failed to convert response to UTF-8")
654 .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?
655 };
656
657 let rate_limits = parse_rate_limit_headers(&headers_map);
658
659 if status.is_client_error() || status.is_server_error() {
660 let error_msg = serde_json::from_str::<serde_json::Value>(&content)
661 .ok()
662 .and_then(|v| {
663 v.get("msg")
664 .and_then(|m| m.as_str())
665 .map(std::string::ToString::to_string)
666 })
667 .unwrap_or_else(|| content.clone());
668
669 match status.as_u16() {
670 400 => return Err(ConnectorError::BadRequestError(error_msg)),
671 401 => return Err(ConnectorError::UnauthorizedError(error_msg)),
672 403 => return Err(ConnectorError::ForbiddenError(error_msg)),
673 404 => return Err(ConnectorError::NotFoundError(error_msg)),
674 418 => return Err(ConnectorError::RateLimitBanError(error_msg)),
675 429 => return Err(ConnectorError::TooManyRequestsError(error_msg)),
676 s if (500..600).contains(&s) => {
677 return Err(ConnectorError::ServerError {
678 msg: format!("Server error: {s}"),
679 status_code: Some(s),
680 });
681 }
682 _ => return Err(ConnectorError::ConnectorClientError(error_msg)),
683 }
684 }
685
686 let raw = content.clone();
687 return Ok(RestApiResponse {
688 data_fn: Box::new(move || {
689 Box::pin(async move {
690 let parsed: T = serde_json::from_str(&raw)
691 .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
692 Ok(parsed)
693 })
694 }),
695 status: status.as_u16(),
696 headers: headers_map,
697 rate_limits: if rate_limits.is_empty() {
698 None
699 } else {
700 Some(rate_limits)
701 },
702 });
703 }
704 Err(e) => {
705 attempt += 1;
706 if should_retry_request(&e, Some(req.method().as_str()), Some(retries - attempt)) {
707 delay(backoff * attempt as u64).await;
708 continue;
709 }
710 return Err(ConnectorError::ConnectorClientError(format!(
711 "HTTP request failed: {e}"
712 )));
713 }
714 }
715 }
716}
717
718pub async fn send_request<T: DeserializeOwned + Send + 'static>(
744 configuration: &ConfigurationRestApi,
745 endpoint: &str,
746 method: Method,
747 mut params: BTreeMap<String, Value>,
748 time_unit: Option<TimeUnit>,
749 is_signed: bool,
750) -> anyhow::Result<RestApiResponse<T>> {
751 let base = configuration.base_path.as_deref().unwrap_or("");
752 let full_url = reqwest::Url::parse(base)
753 .and_then(|u| u.join(endpoint))
754 .context("Failed to join base URL and endpoint")?
755 .to_string();
756
757 if is_signed {
758 let timestamp = get_timestamp();
759 params.insert("timestamp".to_string(), json!(timestamp));
760 let signature = configuration.signature_gen.get_signature(¶ms)?;
761 params.insert("signature".to_string(), Value::String(signature));
762 }
763
764 let mut url = Url::parse(&full_url)?;
765 {
766 let mut pairs = url.query_pairs_mut();
767 for (key, value) in ¶ms {
768 let val_str = match value {
769 Value::String(s) => s.clone(),
770 _ => value.to_string(),
771 };
772 pairs.append_pair(key, &val_str);
773 }
774 }
775
776 let mut headers = HeaderMap::new();
777 headers.insert("Content-Type", "application/json".parse().unwrap());
778 headers.insert("User-Agent", configuration.user_agent.parse().unwrap());
779 if let Some(api_key) = &configuration.api_key {
780 headers.insert("X-MBX-APIKEY", api_key.parse().unwrap());
781 }
782
783 if configuration.compression {
784 headers.insert(ACCEPT_ENCODING, "gzip, deflate, br".parse().unwrap());
785 }
786
787 let time_unit_to_apply = time_unit.or(configuration.time_unit);
788 if let Some(time_unit) = time_unit_to_apply {
789 headers.insert("X-MBX-TIME-UNIT", time_unit.as_upper_str().parse()?);
790 }
791
792 let req_builder = configuration.client.request(method, url).headers(headers);
793 let req = req_builder.build()?;
794
795 Ok(http_request::<T>(req, configuration).await?)
796}
797
798#[must_use]
807pub fn random_string() -> String {
808 let mut buf = [0u8; 16];
809 rand::thread_rng().fill_bytes(&mut buf);
810 hex::encode(buf)
811}
812
813pub fn remove_empty_value<I>(entries: I) -> BTreeMap<String, Value>
835where
836 I: IntoIterator<Item = (String, Value)>,
837{
838 entries
839 .into_iter()
840 .filter(|(_, value)| match value {
841 Value::Null => false,
842 Value::String(s) if s.is_empty() => false,
843 _ => true,
844 })
845 .collect()
846}
847
848#[must_use]
869pub fn sort_object_params(params: &BTreeMap<String, Value>) -> BTreeMap<String, Value> {
870 let mut sorted = BTreeMap::new();
871 for (k, v) in params {
872 sorted.insert(k.clone(), v.clone());
873 }
874 sorted
875}
876
877fn normalize_ws_streams_key(key: &str) -> String {
887 key.to_lowercase().replace(&['_', '-'][..], "")
888}
889
890pub fn replace_websocket_streams_placeholders<V, S>(
915 input: &str,
916 variables: &HashMap<&str, V, S>,
917) -> String
918where
919 V: Display,
920 S: BuildHasher,
921{
922 let original = input;
923
924 let body = original.strip_prefix('/').unwrap_or(original);
926
927 let normalized: HashMap<String, String> = variables
929 .iter()
930 .map(|(k, v)| (normalize_ws_streams_key(k), v.to_string()))
931 .collect();
932
933 let replaced = PLACEHOLDER_RE
935 .replace_all(body, |caps: &Captures| {
936 let prefix = caps.get(1).map_or("", |m| m.as_str());
937 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
938 let val = normalized.get(&key).cloned().unwrap_or_default();
939 format!("{prefix}{val}")
940 })
941 .into_owned();
942
943 let stripped = replaced.trim_end_matches('@').to_string();
945
946 let should_lower_head =
949 original.starts_with('/') && PLACEHOLDER_RE.find(body).is_some_and(|m| m.start() == 0);
950
951 let result = if should_lower_head {
953 if let Some(caps) = PLACEHOLDER_RE.captures(body) {
954 let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
955 let first_val = normalized.get(&key).cloned().unwrap_or_default();
956 if stripped.starts_with(&first_val) {
957 let tail = &stripped[first_val.len()..];
958 format!("{}{}", first_val.to_lowercase(), tail)
959 } else {
960 stripped.clone()
961 }
962 } else {
963 stripped.clone()
964 }
965 } else {
966 stripped.clone()
967 };
968
969 result
970}
971
972#[cfg(test)]
973mod tests {
974 use crate::TOKIO_SHARED_RT;
975
976 mod build_client {
977 use std::{
978 sync::{Arc, Mutex},
979 time::{Duration, Instant},
980 };
981
982 use reqwest::ClientBuilder;
983
984 use crate::{
985 common::utils::build_client,
986 config::{HttpAgent, ProxyAuth, ProxyConfig},
987 };
988
989 use super::TOKIO_SHARED_RT;
990
991 #[test]
992 fn enforces_timeout() {
993 TOKIO_SHARED_RT.block_on(async {
994 let client = build_client(100, true, None, None);
995 let start = Instant::now();
996 let res = client.get("http://10.255.255.1").send().await;
997 assert!(
998 res.is_err(),
999 "expected an error (timeout or connect) but got {res:?}"
1000 );
1001 let elapsed = start.elapsed();
1002 assert!(
1003 elapsed < Duration::from_millis(500),
1004 "timed out too slowly: {elapsed:?}"
1005 );
1006 });
1007 }
1008
1009 #[test]
1010 fn builds_with_keep_alive_disabled() {
1011 let client = build_client(200, false, None, None);
1012 let _: reqwest::Client = client;
1013 }
1014
1015 #[test]
1016 #[should_panic(expected = "Failed to create proxy from URL")]
1017 fn invalid_proxy_url_panics() {
1018 let bad_proxy = ProxyConfig {
1019 protocol: Some("http".to_string()),
1020 host: String::new(),
1021 port: 8080,
1022 auth: None,
1023 };
1024 let _ = build_client(1_000, true, Some(&bad_proxy), None);
1025 }
1026
1027 #[test]
1028 fn builds_with_proxy_and_auth() {
1029 let proxy = ProxyConfig {
1030 protocol: Some("https".to_string()),
1031 host: "127.0.0.1".to_string(),
1032 port: 3128,
1033 auth: Some(ProxyAuth {
1034 username: "alice".to_string(),
1035 password: "secret".to_string(),
1036 }),
1037 };
1038 let client = build_client(2_000, true, Some(&proxy), None);
1039 let _: reqwest::Client = client;
1040 }
1041
1042 #[test]
1043 fn custom_agent_invoked() {
1044 let called = Arc::new(Mutex::new(false));
1045 let called_clone = Arc::clone(&called);
1046
1047 let agent = HttpAgent(Arc::new(move |builder: ClientBuilder| {
1048 *called_clone.lock().unwrap() = true;
1049 builder
1050 }));
1051
1052 let client = build_client(1_000, true, None, Some(agent));
1053 assert!(*called.lock().unwrap(), "agent closure wasn’t invoked");
1054 let _: reqwest::Client = client;
1055 }
1056 }
1057
1058 mod build_user_agent {
1059 use crate::common::utils::build_user_agent;
1060
1061 #[test]
1062 fn build_user_agent_contains_crate_product_and_rust_info() {
1063 let product = "product";
1064 let user_agent = build_user_agent(product);
1065
1066 let name = env!("CARGO_PKG_NAME");
1067 let version = env!("CARGO_PKG_VERSION");
1068 let rustc = env!("RUSTC_VERSION");
1069 let os = std::env::consts::OS;
1070 let arch = std::env::consts::ARCH;
1071
1072 let expected_prefix = format!("{name}/{product}/{version} (Rust/");
1073 assert!(
1074 user_agent.starts_with(&expected_prefix),
1075 "prefix mismatch: {user_agent}"
1076 );
1077
1078 assert!(
1079 user_agent.contains(rustc),
1080 "user agent missing RUSTC_VERSION: {user_agent}"
1081 );
1082
1083 assert!(
1084 user_agent.contains(&format!("; {os}")),
1085 "user agent missing OS: {user_agent}"
1086 );
1087 assert!(
1088 user_agent.contains(&format!("; {arch}")),
1089 "user agent missing ARCH: {user_agent}"
1090 );
1091 }
1092
1093 #[test]
1094 fn build_user_agent_is_deterministic() {
1095 let product = "product";
1096 let user_agent1 = build_user_agent(product);
1097 let user_agent2 = build_user_agent(product);
1098 assert_eq!(
1099 user_agent1, user_agent2,
1100 "user agent should be the same on repeated calls"
1101 );
1102 }
1103 }
1104
1105 mod validate_time_unit {
1106 use crate::common::utils::validate_time_unit;
1107
1108 #[test]
1109 fn empty_string_returns_none() {
1110 let res = validate_time_unit("").expect("Should not error on empty string");
1111 assert_eq!(res, None);
1112 }
1113
1114 #[test]
1115 fn uppercase_millisecond() {
1116 let res = validate_time_unit("MILLISECOND").expect("Should accept MILLISECOND");
1117 assert_eq!(res, Some("MILLISECOND"));
1118 }
1119
1120 #[test]
1121 fn uppercase_microsecond() {
1122 let res = validate_time_unit("MICROSECOND").expect("Should accept MICROSECOND");
1123 assert_eq!(res, Some("MICROSECOND"));
1124 }
1125
1126 #[test]
1127 fn lowercase_millisecond() {
1128 let res = validate_time_unit("millisecond").expect("Should accept millisecond");
1129 assert_eq!(res, Some("millisecond"));
1130 }
1131
1132 #[test]
1133 fn lowercase_microsecond() {
1134 let res = validate_time_unit("microsecond").expect("Should accept microsecond");
1135 assert_eq!(res, Some("microsecond"));
1136 }
1137
1138 #[test]
1139 fn invalid_value_returns_err() {
1140 let err = validate_time_unit("SECOND").unwrap_err();
1141 let msg = format!("{err}");
1142 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1143 }
1144
1145 #[test]
1146 fn partial_match_returns_err() {
1147 let err = validate_time_unit("MILLI").unwrap_err();
1148 let msg = format!("{err}");
1149 assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1150 }
1151 }
1152
1153 mod get_timestamp {
1154 use crate::common::utils::get_timestamp;
1155 use std::{
1156 thread::sleep,
1157 time::{Duration, SystemTime, UNIX_EPOCH},
1158 };
1159
1160 #[test]
1161 fn timestamp_is_within_system_time_bounds() {
1162 let before = SystemTime::now()
1163 .duration_since(UNIX_EPOCH)
1164 .expect("SystemTime before UNIX_EPOCH")
1165 .as_millis();
1166 let ts = get_timestamp();
1167 let after = SystemTime::now()
1168 .duration_since(UNIX_EPOCH)
1169 .expect("SystemTime before UNIX_EPOCH")
1170 .as_millis();
1171
1172 assert!(
1173 ts >= before,
1174 "timestamp {ts} is before captured before time {before}"
1175 );
1176 assert!(
1177 ts <= after,
1178 "timestamp {ts} is after captured after time {after}"
1179 );
1180 }
1181
1182 #[test]
1183 fn timestamps_are_monotonic() {
1184 let t1 = get_timestamp();
1185 sleep(Duration::from_millis(1));
1186 let t2 = get_timestamp();
1187 assert!(
1188 t2 >= t1,
1189 "second timestamp {t2} is not >= first timestamp {t1}"
1190 );
1191 }
1192 }
1193
1194 mod build_query_string {
1195 use std::collections::BTreeMap;
1196
1197 use anyhow::Result;
1198 use serde_json::{Value, json};
1199 use url::form_urlencoded::Serializer;
1200
1201 use crate::common::utils::build_query_string;
1202
1203 fn mk_map(pairs: Vec<(&str, Value)>) -> BTreeMap<String, Value> {
1204 let mut m = BTreeMap::new();
1205 for (k, v) in pairs {
1206 m.insert(k.to_string(), v);
1207 }
1208 m
1209 }
1210
1211 #[test]
1212 fn empty_map_returns_empty_string() -> Result<()> {
1213 let params = BTreeMap::new();
1214 let qs = build_query_string(¶ms)?;
1215 assert_eq!(qs, "");
1216 Ok(())
1217 }
1218
1219 #[test]
1220 fn string_and_number() -> Result<()> {
1221 let params = mk_map(vec![("foo", json!("bar")), ("num", json!(42))]);
1222 let qs = build_query_string(¶ms)?;
1223 assert_eq!(qs, "foo=bar&num=42");
1224 Ok(())
1225 }
1226
1227 #[test]
1228 fn bool_and_null_skipped() -> Result<()> {
1229 let params = mk_map(vec![("a", json!(true)), ("b", Value::Null)]);
1230 let qs = build_query_string(¶ms)?;
1231 assert_eq!(qs, "a=true");
1232 Ok(())
1233 }
1234
1235 #[test]
1236 fn flat_array() -> Result<()> {
1237 let params = mk_map(vec![("list", json!(vec!["x", "y", "z"]))]);
1238 let qs = build_query_string(¶ms)?;
1239 assert_eq!(qs, "list=x,y,z");
1240 Ok(())
1241 }
1242
1243 #[test]
1244 fn nested_array_json_encoded() -> Result<()> {
1245 let params = mk_map(vec![("nested", json!([[1, 2], [3, 4]]))]);
1246 let qs = build_query_string(¶ms)?;
1247
1248 let nested_json = serde_json::to_string(&json!([[1, 2], [3, 4]]))?;
1249 let mut ser = Serializer::new(String::new());
1250 ser.append_pair("nested", &nested_json);
1251 let expected = ser.finish();
1252
1253 assert_eq!(qs, expected);
1254 Ok(())
1255 }
1256
1257 #[test]
1258 fn object_not_supported() {
1259 let params = mk_map(vec![("obj", json!({"k":1}))]);
1260 let err = build_query_string(¶ms).unwrap_err();
1261 let msg = format!("{err}");
1262 assert!(msg.contains("Cannot serialize object for key `obj`"));
1263 }
1264 }
1265
1266 mod signature_generator {
1267 use base64::{Engine, engine::general_purpose};
1268 use ed25519_dalek::{SigningKey, ed25519::signature::SignerMut, pkcs8::DecodePrivateKey};
1269 use hex;
1270 use hmac::{Hmac, Mac};
1271 use openssl::{hash::MessageDigest, pkey::PKey, rsa::Rsa, sign::Verifier};
1272 use serde_json::Value;
1273 use sha2::Sha256;
1274 use std::collections::BTreeMap;
1275 use std::io::Write;
1276 use tempfile::NamedTempFile;
1277
1278 use crate::{common::utils::SignatureGenerator, config::PrivateKey};
1279
1280 #[test]
1281 fn hmac_sha256_signature() {
1282 let mut params = BTreeMap::new();
1283 params.insert("b".into(), Value::Number(2.into()));
1284 params.insert("a".into(), Value::Number(1.into()));
1285
1286 let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1287 let sig = signature_gen
1288 .get_signature(¶ms)
1289 .expect("HMAC signing failed");
1290
1291 let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1292 let qs = "a=1&b=2";
1293 mac.update(qs.as_bytes());
1294 let expected = hex::encode(mac.finalize().into_bytes());
1295
1296 assert_eq!(sig, expected);
1297 }
1298
1299 #[test]
1300 fn repeated_hmac_signature() {
1301 let mut params = BTreeMap::new();
1302 params.insert("x".into(), Value::String("y".into()));
1303 let signature_gen = SignatureGenerator::new(Some("abc".into()), None, None);
1304 let s1 = signature_gen.get_signature(¶ms).unwrap();
1305 let s2 = signature_gen.get_signature(¶ms).unwrap();
1306 assert_eq!(s1, s2);
1307 }
1308
1309 #[test]
1310 fn rsa_signature_verification() {
1311 let mut params = BTreeMap::new();
1312 params.insert("a".into(), Value::Number(1.into()));
1313 params.insert("b".into(), Value::Number(2.into()));
1314
1315 let rsa = Rsa::generate(2048).unwrap();
1316 let priv_pem = rsa.private_key_to_pem().unwrap();
1317 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1318
1319 let signature_gen =
1320 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1321 let sig = signature_gen
1322 .get_signature(¶ms)
1323 .expect("RSA signing failed");
1324
1325 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1326 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1327 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1328 verifier.update(b"a=1&b=2").unwrap();
1329 assert!(verifier.verify(&sig_bytes).unwrap());
1330 }
1331
1332 #[test]
1333 fn repeated_rsa_signature() {
1334 let mut params = BTreeMap::new();
1335 params.insert("k".into(), Value::Number(5.into()));
1336 let rsa = Rsa::generate(2048).unwrap();
1337 let priv_pem = rsa.private_key_to_pem().unwrap();
1338 let signature_gen =
1339 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem)), None);
1340 let s1 = signature_gen.get_signature(¶ms).unwrap();
1341 let s2 = signature_gen.get_signature(¶ms).unwrap();
1342 assert_eq!(s1, s2);
1343 }
1344
1345 #[test]
1346 fn ed25519_signature_verification() {
1347 let mut params = BTreeMap::new();
1348 params.insert("a".into(), Value::Number(1.into()));
1349 params.insert("b".into(), Value::Number(2.into()));
1350 let qs = "a=1&b=2";
1351
1352 let ed = PKey::generate_ed25519().unwrap();
1353 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1354
1355 let signature_gen =
1356 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1357 let sig = signature_gen
1358 .get_signature(¶ms)
1359 .expect("Ed25519 signing failed");
1360
1361 let pem_str = String::from_utf8(priv_pem).unwrap();
1362 let b64 = pem_str
1363 .lines()
1364 .filter(|l| !l.starts_with("-----"))
1365 .collect::<String>();
1366 let der = general_purpose::STANDARD.decode(b64).unwrap();
1367 let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1368 let expected_bytes = sk.sign(qs.as_bytes()).to_bytes();
1369 let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1370 assert_eq!(sig, expected_sig);
1371 }
1372
1373 #[test]
1374 fn repeated_ed25519_signature() {
1375 let mut params = BTreeMap::new();
1376 params.insert("m".into(), Value::String("n".into()));
1377 let ed = PKey::generate_ed25519().unwrap();
1378 let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1379 let signature_gen =
1380 SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1381 let s1 = signature_gen.get_signature(¶ms).unwrap();
1382 let s2 = signature_gen.get_signature(¶ms).unwrap();
1383 assert_eq!(s1, s2);
1384 }
1385
1386 #[test]
1387 fn file_based_key() {
1388 let rsa = Rsa::generate(1024).unwrap();
1389 let priv_pem = rsa.private_key_to_pem().unwrap();
1390 let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1391
1392 let mut file = NamedTempFile::new().unwrap();
1393 file.write_all(&priv_pem).unwrap();
1394 let path = file.path().to_str().unwrap().to_string();
1395
1396 let mut params = BTreeMap::new();
1397 params.insert("z".into(), Value::Number(9.into()));
1398
1399 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::File(path)), None);
1400 let sig = signature_gen.get_signature(¶ms).unwrap();
1401
1402 let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1403 let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1404 let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1405 verifier.update(b"z=9").unwrap();
1406 assert!(verifier.verify(&sig_bytes).unwrap());
1407 }
1408
1409 #[test]
1410 fn unsupported_key_type_error() {
1411 let mut params = BTreeMap::new();
1412 params.insert("x".into(), Value::String("y".into()));
1413
1414 let group =
1415 openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
1416 let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
1417 let pkey_ec = PKey::from_ec_key(ec_key).unwrap();
1418 let raw = pkey_ec.private_key_to_pem_pkcs8().unwrap();
1419
1420 let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::Raw(raw)), None);
1421 let err = signature_gen
1422 .get_signature(¶ms)
1423 .unwrap_err()
1424 .to_string();
1425 assert!(err.contains("Unsupported private key type"));
1426 }
1427
1428 #[test]
1429 fn invalid_private_key_error() {
1430 let mut params = BTreeMap::new();
1431 params.insert("foo".into(), Value::String("bar".into()));
1432
1433 let signature_gen =
1434 SignatureGenerator::new(None, Some(PrivateKey::Raw(b"not a key".to_vec())), None);
1435 let err = signature_gen
1436 .get_signature(¶ms)
1437 .unwrap_err()
1438 .to_string();
1439 assert!(err.contains("Failed to parse private key"));
1440 }
1441
1442 #[test]
1443 fn missing_credentials_error() {
1444 let mut params = BTreeMap::new();
1445 params.insert("a".into(), Value::Number(1.into()));
1446
1447 let signature_gen = SignatureGenerator::new(None, None, None);
1448 let err = signature_gen
1449 .get_signature(¶ms)
1450 .unwrap_err()
1451 .to_string();
1452 assert!(err.contains("Either 'api_secret' or 'private_key' must be provided"));
1453 }
1454 }
1455
1456 mod should_retry_request {
1457 use crate::common::utils::should_retry_request;
1458
1459 use reqwest::{Error, Response};
1460
1461 fn mk_http_error(code: u16) -> Error {
1462 let resp = Response::from(
1463 http::response::Response::builder()
1464 .status(code)
1465 .body("")
1466 .unwrap(),
1467 );
1468 resp.error_for_status().unwrap_err()
1469 }
1470
1471 fn mk_network_error() -> Error {
1472 reqwest::blocking::get("http://256.256.256.256").unwrap_err()
1473 }
1474
1475 #[test]
1476 fn retry_on_retriable_status_and_method() {
1477 let err = mk_http_error(500);
1478 assert!(should_retry_request(&err, Some("GET"), Some(1)));
1479 assert!(should_retry_request(&err, Some("delete"), Some(2)));
1480 }
1481
1482 #[test]
1483 fn retry_when_status_none_and_retriable_method() {
1484 let retriable_methods = ["GET", "DELETE"];
1485
1486 for &method in &retriable_methods {
1487 let err = mk_network_error();
1488 assert!(
1489 should_retry_request(&err, Some(method), Some(1)),
1490 "Should retry when no status and method {method}"
1491 );
1492 }
1493 }
1494
1495 #[test]
1496 fn no_retry_when_no_retries_left() {
1497 let err = mk_http_error(503);
1498 assert!(!should_retry_request(&err, Some("GET"), Some(0)));
1499 }
1500
1501 #[test]
1502 fn no_retry_on_non_retriable_status() {
1503 let non_retriable_statuses = [400, 401, 404, 422];
1504
1505 for &status in &non_retriable_statuses {
1506 let err = mk_http_error(status);
1507 assert!(
1508 !should_retry_request(&err, Some("GET"), Some(2)),
1509 "Should not retry for non-retriable status {status}"
1510 );
1511 }
1512 }
1513
1514 #[test]
1515 fn no_retry_on_non_retriable_method() {
1516 let non_retriable_methods = ["POST", "PUT", "PATCH"];
1517
1518 for &method in &non_retriable_methods {
1519 let err = mk_http_error(500);
1520 assert!(
1521 !should_retry_request(&err, Some(method), Some(2)),
1522 "Should not retry for non-retriable method {method}"
1523 );
1524 }
1525 }
1526
1527 #[test]
1528 fn no_retry_when_status_none_and_non_retriable_method() {
1529 let non_retriable_methods = ["POST", "PUT"];
1530
1531 for &method in &non_retriable_methods {
1532 let err = mk_network_error();
1533 assert!(
1534 !should_retry_request(&err, Some(method), Some(1)),
1535 "Should not retry when no status and method {method}"
1536 );
1537 }
1538 }
1539 }
1540
1541 mod parse_rate_limit_headers_tests {
1542 use crate::common::{
1543 models::{Interval, RateLimitType},
1544 utils::parse_rate_limit_headers,
1545 };
1546 use std::collections::HashMap;
1547
1548 fn mk_headers(pairs: Vec<(&str, &str)>) -> HashMap<String, String> {
1549 let mut m = HashMap::new();
1550 for (k, v) in pairs {
1551 m.insert(k.to_string(), v.to_string());
1552 }
1553 m
1554 }
1555
1556 #[test]
1557 fn single_weight_header() {
1558 let headers = mk_headers(vec![("x-mbx-used-weight-1s", "123")]);
1559 let limits = parse_rate_limit_headers(&headers);
1560 assert_eq!(limits.len(), 1);
1561 let rl = &limits[0];
1562 assert_eq!(rl.rate_limit_type, RateLimitType::RequestWeight);
1563 assert_eq!(rl.interval, Interval::Second);
1564 assert_eq!(rl.interval_num, 1);
1565 assert_eq!(rl.count, 123);
1566 assert_eq!(rl.retry_after, None);
1567 }
1568
1569 #[test]
1570 fn single_order_count_with_retry_after() {
1571 let headers = mk_headers(vec![("x-mbx-order-count-5m", "42"), ("retry-after", "7")]);
1572 let limits = parse_rate_limit_headers(&headers);
1573 assert_eq!(limits.len(), 1);
1574 let rl = &limits[0];
1575 assert_eq!(rl.rate_limit_type, RateLimitType::Orders);
1576 assert_eq!(rl.interval, Interval::Minute);
1577 assert_eq!(rl.interval_num, 5);
1578 assert_eq!(rl.count, 42);
1579 assert_eq!(rl.retry_after, Some(7));
1580 }
1581
1582 #[test]
1583 fn multiple_headers() {
1584 let headers = mk_headers(vec![
1585 ("X-MBX-USED-WEIGHT-1h", "10"),
1586 ("x-mbx-order-count-2d", "20"),
1587 ]);
1588 let mut limits = parse_rate_limit_headers(&headers);
1589 limits.sort_by_key(|r| (r.interval_num, format!("{:?}", r.rate_limit_type)));
1590 assert_eq!(limits.len(), 2);
1591 let w = &limits[0];
1592 assert_eq!(w.rate_limit_type, RateLimitType::RequestWeight);
1593 assert_eq!(w.interval, Interval::Hour);
1594 assert_eq!(w.interval_num, 1);
1595 assert_eq!(w.count, 10);
1596 let o = &limits[1];
1597 assert_eq!(o.rate_limit_type, RateLimitType::Orders);
1598 assert_eq!(o.interval, Interval::Day);
1599 assert_eq!(o.interval_num, 2);
1600 assert_eq!(o.count, 20);
1601 }
1602
1603 #[test]
1604 fn ignores_unknown_and_malformed() {
1605 let headers = mk_headers(vec![
1606 ("x-mbx-used-weight-3x", "5"),
1607 ("random-header", "100"),
1608 ]);
1609 let limits = parse_rate_limit_headers(&headers);
1610 assert!(limits.is_empty());
1611 }
1612 }
1613
1614 mod http_request {
1615 use std::io::Write;
1616
1617 use flate2::{Compression, write::GzEncoder};
1618 use httpmock::MockServer;
1619 use reqwest::{Client, Method, Request};
1620 use serde::Deserialize;
1621
1622 use crate::{
1623 common::utils::http_request, config::ConfigurationRestApi, errors::ConnectorError,
1624 models::RestApiResponse,
1625 };
1626
1627 use super::TOKIO_SHARED_RT;
1628
1629 #[derive(Deserialize, Debug, PartialEq)]
1630 struct Dummy {
1631 foo: String,
1632 }
1633
1634 fn make_config(server_url: &str) -> ConfigurationRestApi {
1635 ConfigurationRestApi::builder()
1636 .api_key("key")
1637 .api_secret("secret")
1638 .base_path(server_url)
1639 .build()
1640 .expect("Failed to build configuration")
1641 }
1642
1643 #[test]
1644 fn http_request_success_plain_text() {
1645 TOKIO_SHARED_RT.block_on(async {
1646 let server = MockServer::start();
1647 let mock = server.mock(|when, then| {
1648 when.method(httpmock::Method::GET).path("/test");
1649 then.status(200)
1650 .header("Content-Type", "application/json")
1651 .body(r#"{"foo":"bar"}"#);
1652 });
1653
1654 let client = Client::new();
1655 let req: Request = client
1656 .request(Method::GET, format!("{}{}", server.url(""), "/test"))
1657 .build()
1658 .unwrap();
1659
1660 let cfg = make_config(&server.url(""));
1661 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
1662 assert_eq!(resp.status, 200);
1663 let data = resp.data().await.unwrap();
1664 assert_eq!(data, Dummy { foo: "bar".into() });
1665 mock.assert();
1666 });
1667 }
1668
1669 #[test]
1670 fn http_request_success_gzip() {
1671 TOKIO_SHARED_RT.block_on(async {
1672 let server = MockServer::start();
1673 let body = r#"{"foo":"baz"}"#;
1674 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1675 encoder.write_all(body.as_bytes()).unwrap();
1676 let gz = encoder.finish().unwrap();
1677
1678 let mock = server.mock(|when, then| {
1679 when.method(httpmock::Method::GET).path("/gz");
1680 then.status(200)
1681 .header("Content-Type", "application/json")
1682 .header("Content-Encoding", "gzip")
1683 .body(gz);
1684 });
1685
1686 let client = Client::new();
1687 let req: Request = client
1688 .request(Method::GET, format!("{}{}", server.url(""), "/gz"))
1689 .build()
1690 .unwrap();
1691 let mut cfg = make_config(&server.url(""));
1692 cfg.compression = true;
1693
1694 let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
1695 assert_eq!(resp.status, 200);
1696 let data = resp.data().await.unwrap();
1697 assert_eq!(data, Dummy { foo: "baz".into() });
1698 mock.assert();
1699 });
1700 }
1701
1702 #[test]
1703 fn http_request_client_error_bad_request() {
1704 TOKIO_SHARED_RT.block_on(async {
1705 let server = MockServer::start();
1706 let mock = server.mock(|when, then| {
1707 when.method(httpmock::Method::GET).path("/400");
1708 then.status(400)
1709 .header("Content-Type", "application/json")
1710 .body(r#"{"msg":"bad request"}"#);
1711 });
1712
1713 let client = Client::new();
1714 let req: Request = client
1715 .request(Method::GET, format!("{}{}", server.url(""), "/400"))
1716 .build()
1717 .unwrap();
1718 let cfg = make_config(&server.url(""));
1719
1720 let result = http_request::<Dummy>(req, &cfg).await;
1721 assert!(matches!(result, Err(ConnectorError::BadRequestError(_))));
1722 if let Err(ConnectorError::BadRequestError(msg)) = result {
1723 assert_eq!(msg, "bad request");
1724 }
1725 mock.assert();
1726 });
1727 }
1728
1729 #[test]
1730 fn http_request_client_error_unauthorized() {
1731 TOKIO_SHARED_RT.block_on(async {
1732 let server = MockServer::start();
1733 let mock = server.mock(|when, then| {
1734 when.method(httpmock::Method::GET).path("/401");
1735 then.status(401)
1736 .header("Content-Type", "application/json")
1737 .body(r#"{"msg":"unauthorized"}"#);
1738 });
1739
1740 let client = Client::new();
1741 let req: Request = client
1742 .request(Method::GET, format!("{}{}", server.url(""), "/401"))
1743 .build()
1744 .unwrap();
1745 let cfg = make_config(&server.url(""));
1746
1747 let result = http_request::<Dummy>(req, &cfg).await;
1748 assert!(matches!(result, Err(ConnectorError::UnauthorizedError(_))));
1749 if let Err(ConnectorError::UnauthorizedError(msg)) = result {
1750 assert_eq!(msg, "unauthorized");
1751 }
1752 mock.assert();
1753 });
1754 }
1755
1756 #[test]
1757 fn http_request_client_error_forbidden() {
1758 TOKIO_SHARED_RT.block_on(async {
1759 let server = MockServer::start();
1760 let mock = server.mock(|when, then| {
1761 when.method(httpmock::Method::GET).path("/403");
1762 then.status(403)
1763 .header("Content-Type", "application/json")
1764 .body(r#"{"msg":"forbidden"}"#);
1765 });
1766
1767 let client = Client::new();
1768 let req: Request = client
1769 .request(Method::GET, format!("{}{}", server.url(""), "/403"))
1770 .build()
1771 .unwrap();
1772 let cfg = make_config(&server.url(""));
1773
1774 let result = http_request::<Dummy>(req, &cfg).await;
1775 assert!(matches!(result, Err(ConnectorError::ForbiddenError(_))));
1776 if let Err(ConnectorError::ForbiddenError(msg)) = result {
1777 assert_eq!(msg, "forbidden");
1778 }
1779 mock.assert();
1780 });
1781 }
1782
1783 #[test]
1784 fn http_request_client_error_not_found() {
1785 TOKIO_SHARED_RT.block_on(async {
1786 let server = MockServer::start();
1787 let mock = server.mock(|when, then| {
1788 when.method(httpmock::Method::GET).path("/404");
1789 then.status(404)
1790 .header("Content-Type", "application/json")
1791 .body(r#"{"msg":"not found"}"#);
1792 });
1793
1794 let client = Client::new();
1795 let req: Request = client
1796 .request(Method::GET, format!("{}{}", server.url(""), "/404"))
1797 .build()
1798 .unwrap();
1799 let cfg = make_config(&server.url(""));
1800
1801 let result = http_request::<Dummy>(req, &cfg).await;
1802 assert!(matches!(result, Err(ConnectorError::NotFoundError(_))));
1803 if let Err(ConnectorError::NotFoundError(msg)) = result {
1804 assert_eq!(msg, "not found");
1805 }
1806 mock.assert();
1807 });
1808 }
1809
1810 #[test]
1811 fn http_request_client_error_rate_limit_exceeded() {
1812 TOKIO_SHARED_RT.block_on(async {
1813 let server = MockServer::start();
1814 let mock = server.mock(|when, then| {
1815 when.method(httpmock::Method::GET).path("/418");
1816 then.status(418)
1817 .header("Content-Type", "application/json")
1818 .body(r#"{"msg":"rate limit exceeded"}"#);
1819 });
1820
1821 let client = Client::new();
1822 let req: Request = client
1823 .request(Method::GET, format!("{}{}", server.url(""), "/418"))
1824 .build()
1825 .unwrap();
1826 let cfg = make_config(&server.url(""));
1827
1828 let result = http_request::<Dummy>(req, &cfg).await;
1829 assert!(matches!(result, Err(ConnectorError::RateLimitBanError(_))));
1830 if let Err(ConnectorError::RateLimitBanError(msg)) = result {
1831 assert_eq!(msg, "rate limit exceeded");
1832 }
1833 mock.assert();
1834 });
1835 }
1836
1837 #[test]
1838 fn http_request_client_error_too_many_requests() {
1839 TOKIO_SHARED_RT.block_on(async {
1840 let server = MockServer::start();
1841 let mock = server.mock(|when, then| {
1842 when.method(httpmock::Method::GET).path("/429");
1843 then.status(429)
1844 .header("Content-Type", "application/json")
1845 .body(r#"{"msg":"too many requests"}"#);
1846 });
1847
1848 let client = Client::new();
1849 let req: Request = client
1850 .request(Method::GET, format!("{}{}", server.url(""), "/429"))
1851 .build()
1852 .unwrap();
1853 let cfg = make_config(&server.url(""));
1854
1855 let result = http_request::<Dummy>(req, &cfg).await;
1856 assert!(matches!(
1857 result,
1858 Err(ConnectorError::TooManyRequestsError(_))
1859 ));
1860 if let Err(ConnectorError::TooManyRequestsError(msg)) = result {
1861 assert_eq!(msg, "too many requests");
1862 }
1863 mock.assert();
1864 });
1865 }
1866
1867 #[test]
1868 fn http_request_client_error_server_error() {
1869 TOKIO_SHARED_RT.block_on(async {
1870 let server = MockServer::start();
1871 let mock = server.mock(|when, then| {
1872 when.method(httpmock::Method::GET).path("/500");
1873 then.status(500)
1874 .header("Content-Type", "application/json")
1875 .body(r#"{"msg":"internal server error"}"#);
1876 });
1877
1878 let client = Client::new();
1879 let req: Request = client
1880 .request(Method::GET, format!("{}{}", server.url(""), "/500"))
1881 .build()
1882 .unwrap();
1883 let cfg = make_config(&server.url(""));
1884
1885 let result = http_request::<Dummy>(req, &cfg).await;
1886 assert!(matches!(result, Err(ConnectorError::ServerError { .. })));
1887 if let Err(ConnectorError::ServerError {
1888 msg,
1889 status_code: Some(500),
1890 }) = result
1891 {
1892 assert_eq!(msg, "Server error: 500".to_string());
1893 }
1894 mock.assert();
1895 });
1896 }
1897
1898 #[test]
1899 fn http_request_unexpected_status_maps_generic() {
1900 TOKIO_SHARED_RT.block_on(async {
1901 let server = MockServer::start();
1902 let code = 402;
1903 let mock = server.mock(|when, then| {
1904 when.method(httpmock::Method::GET).path("/402");
1905 then.status(code).body("error text");
1906 });
1907
1908 let client = Client::new();
1909 let req: Request = client
1910 .request(Method::GET, format!("{}{}", server.url(""), "/402"))
1911 .build()
1912 .unwrap();
1913 let cfg = make_config(&server.url(""));
1914
1915 let result = http_request::<Dummy>(req, &cfg).await;
1916 assert!(matches!(
1917 result,
1918 Err(ConnectorError::ConnectorClientError(_))
1919 ));
1920 mock.assert();
1921 });
1922 }
1923
1924 #[test]
1925 fn http_request_malformed_json_maps_generic() {
1926 TOKIO_SHARED_RT.block_on(async {
1927 let server = MockServer::start();
1928 let mock = server.mock(|when, then| {
1929 when.method(httpmock::Method::GET).path("/malformed");
1930 then.status(200)
1931 .header("Content-Type", "application/json")
1932 .body("not json");
1933 });
1934
1935 let client = Client::new();
1936 let req: Request = client
1937 .request(Method::GET, format!("{}{}", server.url(""), "/malformed"))
1938 .build()
1939 .unwrap();
1940 let cfg = make_config(&server.url(""));
1941
1942 let resp = http_request::<Dummy>(req, &cfg)
1944 .await
1945 .expect("http_request should succeed even if JSON is bad");
1946
1947 let err = resp
1949 .data() .await
1951 .expect_err("malformed JSON should turn into ConnectorClientError");
1952
1953 assert!(matches!(err, ConnectorError::ConnectorClientError(_)));
1954
1955 mock.assert();
1956 });
1957 }
1958 }
1959
1960 mod send_request {
1961 use anyhow::Result;
1962 use httpmock::prelude::*;
1963 use reqwest::Method;
1964 use serde::Deserialize;
1965 use serde_json::json;
1966 use std::collections::BTreeMap;
1967
1968 use crate::{
1969 common::{models::TimeUnit, utils::send_request},
1970 config::ConfigurationRestApi,
1971 };
1972
1973 use super::TOKIO_SHARED_RT;
1974
1975 #[derive(Deserialize, Debug, PartialEq)]
1976 struct TestResponse {
1977 message: String,
1978 }
1979
1980 #[test]
1981 fn basic_get_request() -> Result<()> {
1982 TOKIO_SHARED_RT.block_on(async {
1983 let server = MockServer::start();
1984
1985 server.mock(|when, then| {
1986 when.method(GET).path("/api/v1/test");
1987 then.status(200)
1988 .header("content-type", "application/json")
1989 .body(r#"{"message": "success"}"#);
1990 });
1991
1992 let configuration = ConfigurationRestApi::builder()
1993 .api_key("key")
1994 .api_secret("secret")
1995 .base_path(server.base_url())
1996 .compression(false)
1997 .build()
1998 .expect("Failed to build configuration");
1999
2000 let params = BTreeMap::new();
2001
2002 let result = send_request::<TestResponse>(
2003 &configuration,
2004 "/api/v1/test",
2005 Method::GET,
2006 params,
2007 None,
2008 false,
2009 )
2010 .await?;
2011
2012 let data = result.data().await.unwrap();
2013 assert_eq!(data.message, "success");
2014
2015 Ok(())
2016 })
2017 }
2018
2019 #[test]
2020 fn signed_post_request() -> Result<()> {
2021 TOKIO_SHARED_RT.block_on(async {
2022 let server = MockServer::start();
2023
2024 server.mock(|when, then| {
2025 when.method(POST).path("/api/v3/order");
2026 then.status(200)
2027 .header("content-type", "application/json")
2028 .body(r#"{"message": "order placed"}"#);
2029 });
2030
2031 let configuration = ConfigurationRestApi::builder()
2032 .api_key("key")
2033 .api_secret("secret")
2034 .base_path(server.base_url())
2035 .compression(false)
2036 .build()
2037 .expect("Failed to build configuration");
2038
2039 let mut params = BTreeMap::new();
2040 params.insert("symbol".to_string(), json!("ETHUSDT"));
2041 params.insert("side".to_string(), json!("BUY"));
2042 params.insert("type".to_string(), json!("MARKET"));
2043 params.insert("quantity".to_string(), json!("1"));
2044
2045 let result = send_request::<TestResponse>(
2046 &configuration,
2047 "/api/v3/order",
2048 Method::POST,
2049 params,
2050 None,
2051 true,
2052 )
2053 .await?;
2054
2055 let data = result.data().await.unwrap();
2056 assert_eq!(data.message, "order placed");
2057
2058 Ok(())
2059 })
2060 }
2061
2062 #[test]
2063 fn get_request_with_params() -> Result<()> {
2064 TOKIO_SHARED_RT.block_on(async {
2065 let server = MockServer::start();
2066
2067 server.mock(|when, then| {
2068 when.method(GET)
2069 .path("/api/v1/data")
2070 .query_param("symbol", "BTCUSDT")
2071 .query_param("limit", "10");
2072 then.status(200)
2073 .header("content-type", "application/json")
2074 .body(r#"{"message": "data retrieved"}"#);
2075 });
2076
2077 let configuration = ConfigurationRestApi::builder()
2078 .api_key("key")
2079 .api_secret("secret")
2080 .base_path(server.base_url())
2081 .compression(false)
2082 .build()
2083 .expect("Failed to build configuration");
2084
2085 let mut params = BTreeMap::new();
2086 params.insert("symbol".to_string(), json!("BTCUSDT"));
2087 params.insert("limit".to_string(), json!(10));
2088
2089 let result = send_request::<TestResponse>(
2090 &configuration,
2091 "/api/v1/data",
2092 Method::GET,
2093 params,
2094 None,
2095 false,
2096 )
2097 .await?;
2098
2099 let data = result.data().await.unwrap();
2100 assert_eq!(data.message, "data retrieved");
2101
2102 Ok(())
2103 })
2104 }
2105
2106 #[test]
2107 fn invalid_endpoint() {
2108 TOKIO_SHARED_RT.block_on(async {
2109 let server = MockServer::start();
2110
2111 let configuration = ConfigurationRestApi::builder()
2112 .api_key("key")
2113 .api_secret("secret")
2114 .base_path(server.base_url())
2115 .compression(false)
2116 .build()
2117 .expect("Failed to build configuration");
2118
2119 let params = BTreeMap::new();
2120
2121 let result = send_request::<TestResponse>(
2122 &configuration,
2123 "http://invalid",
2124 Method::GET,
2125 params,
2126 None,
2127 false,
2128 )
2129 .await;
2130
2131 assert!(result.is_err());
2132 });
2133 }
2134
2135 #[test]
2136 fn missing_signature_on_signed_request() {
2137 TOKIO_SHARED_RT.block_on(async {
2138 let server = MockServer::start();
2139
2140 let configuration = ConfigurationRestApi::builder()
2141 .api_key("key")
2142 .api_secret("secret")
2143 .base_path(server.base_url())
2144 .compression(false)
2145 .build()
2146 .expect("Failed to build configuration");
2147
2148 let mut params = BTreeMap::new();
2149 params.insert("symbol".to_string(), json!("BTCUSDT"));
2150 params.insert("side".to_string(), json!("BUY"));
2151
2152 let result = send_request::<TestResponse>(
2153 &configuration,
2154 "/api/v3/order",
2155 Method::POST,
2156 params,
2157 None,
2158 true,
2159 )
2160 .await;
2161
2162 assert!(result.is_err());
2163 });
2164 }
2165
2166 #[test]
2167 fn compression_enabled() -> Result<()> {
2168 TOKIO_SHARED_RT.block_on(async {
2169 let server = MockServer::start();
2170
2171 server.mock(|when, then| {
2172 when.method(GET).path("/api/v1/test");
2173 then.status(200)
2174 .header("content-type", "application/json")
2175 .header("accept-encoding", "gzip, deflate, br")
2176 .body(r#"{"message": "compression enabled"}"#);
2177 });
2178
2179 let configuration = ConfigurationRestApi::builder()
2180 .api_key("key")
2181 .api_secret("secret")
2182 .base_path(server.base_url())
2183 .compression(true)
2184 .build()
2185 .expect("Failed to build configuration");
2186
2187 let params = BTreeMap::new();
2188
2189 let result = send_request::<TestResponse>(
2190 &configuration,
2191 "/api/v1/test",
2192 Method::GET,
2193 params,
2194 None,
2195 false,
2196 )
2197 .await?;
2198
2199 let data = result.data().await.unwrap();
2200 assert_eq!(data.message, "compression enabled");
2201
2202 Ok(())
2203 })
2204 }
2205
2206 #[test]
2207 fn get_request_with_time_unit_header() -> Result<()> {
2208 TOKIO_SHARED_RT.block_on(async {
2209 let server = MockServer::start();
2210
2211 server.mock(|when, then| {
2212 when.method(GET)
2213 .path("/api/v1/test")
2214 .header("X-MBX-TIME-UNIT", "MILLISECOND");
2215 then.status(200)
2216 .header("content-type", "application/json")
2217 .body(r#"{"message": "time unit applied"}"#);
2218 });
2219
2220 let configuration = ConfigurationRestApi::builder()
2221 .api_key("key")
2222 .api_secret("secret")
2223 .base_path(server.base_url())
2224 .compression(false)
2225 .time_unit(TimeUnit::Millisecond)
2226 .build()
2227 .expect("Failed to build configuration");
2228
2229 let params = BTreeMap::new();
2230
2231 let result = send_request::<TestResponse>(
2232 &configuration,
2233 "/api/v1/test",
2234 Method::GET,
2235 params,
2236 Some(TimeUnit::Millisecond),
2237 false,
2238 )
2239 .await?;
2240
2241 let data = result.data().await.unwrap();
2242 assert_eq!(data.message, "time unit applied");
2243
2244 Ok(())
2245 })
2246 }
2247 }
2248
2249 mod random_string {
2250 use crate::common::utils::random_string;
2251 use hex;
2252
2253 #[test]
2254 fn length_is_32() {
2255 let s = random_string();
2256 assert_eq!(
2257 s.len(),
2258 32,
2259 "random_string() should be 32 chars, got {}",
2260 s.len()
2261 );
2262 }
2263
2264 #[test]
2265 fn is_valid_lowercase_hex() {
2266 let s = random_string();
2267 assert!(
2268 s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')),
2269 "random_string() contains invalid hex characters: {s}"
2270 );
2271 }
2272
2273 #[test]
2274 fn decodes_to_16_bytes() {
2275 let s = random_string();
2276 let bytes = hex::decode(&s).expect("random_string() output must be valid hex");
2277 assert_eq!(
2278 bytes.len(),
2279 16,
2280 "hex::decode returned {} bytes",
2281 bytes.len()
2282 );
2283 }
2284
2285 #[test]
2286 fn two_calls_are_different() {
2287 let a = random_string();
2288 let b = random_string();
2289 assert_ne!(
2290 a, b,
2291 "Two calls to random_string() returned the same value: {a}"
2292 );
2293 }
2294 }
2295
2296 mod remove_empty_value {
2297 use crate::common::utils::remove_empty_value;
2298 use serde_json::{Map, Value};
2299
2300 #[test]
2301 fn filters_out_null_and_empty_strings() {
2302 let entries = vec![
2303 ("key1".to_string(), Value::String("value1".to_string())),
2304 ("key2".to_string(), Value::Null),
2305 ("key3".to_string(), Value::String(String::new())),
2306 ];
2307 let result = remove_empty_value(entries);
2308 assert_eq!(
2309 result.len(),
2310 1,
2311 "expected only one entry, got {}",
2312 result.len()
2313 );
2314 assert_eq!(
2315 result.get("key1"),
2316 Some(&Value::String("value1".to_string()))
2317 );
2318 assert!(!result.contains_key("key2"));
2319 assert!(!result.contains_key("key3"));
2320 }
2321
2322 #[test]
2323 fn retains_other_value_types() {
2324 let entries = vec![
2325 ("bool".to_string(), Value::Bool(true)),
2326 ("num".to_string(), Value::Number(42.into())),
2327 ("arr".to_string(), Value::Array(vec![])),
2328 ("obj".to_string(), Value::Object(Map::default())),
2329 ("nil".to_string(), Value::Null),
2330 ("empty_str".to_string(), Value::String(String::new())),
2331 ];
2332 let result = remove_empty_value(entries);
2333 let keys: Vec<&String> = result.keys().collect();
2334 assert_eq!(keys.len(), 4, "expected 4 entries, got {}", keys.len());
2335 assert!(result.get("bool") == Some(&Value::Bool(true)));
2336 assert!(result.get("num") == Some(&Value::Number(42.into())));
2337 assert!(result.get("arr") == Some(&Value::Array(vec![])));
2338 assert!(result.get("obj") == Some(&Value::Object(Map::default())));
2339 assert!(!result.contains_key("nil"));
2340 assert!(!result.contains_key("empty_str"));
2341 }
2342
2343 #[test]
2344 fn empty_iterator_returns_empty_map() {
2345 let entries: Vec<(String, Value)> = vec![];
2346 let result = remove_empty_value(entries);
2347 assert!(result.is_empty(), "expected an empty map");
2348 }
2349
2350 #[test]
2351 fn keys_are_sorted() {
2352 let entries = vec![
2353 ("c".to_string(), Value::String("foo".to_string())),
2354 ("a".to_string(), Value::String("bar".to_string())),
2355 ("b".to_string(), Value::String("baz".to_string())),
2356 ];
2357 let result = remove_empty_value(entries);
2358 let sorted_keys: Vec<&String> = result.keys().collect();
2359 assert_eq!(
2360 sorted_keys,
2361 [&"a".to_string(), &"b".to_string(), &"c".to_string()]
2362 );
2363 }
2364 }
2365
2366 mod sort_object_params {
2367 use crate::common::utils::sort_object_params;
2368 use serde_json::Value;
2369 use std::collections::BTreeMap;
2370
2371 #[test]
2372 fn sorts_keys() {
2373 let mut params = BTreeMap::new();
2374 params.insert("z".to_string(), Value::String("last".to_string()));
2375 params.insert("a".to_string(), Value::String("first".to_string()));
2376 params.insert("m".to_string(), Value::String("middle".to_string()));
2377
2378 let sorted = sort_object_params(¶ms);
2379 let keys: Vec<&String> = sorted.keys().collect();
2380 assert_eq!(
2381 keys,
2382 [&"a".to_string(), &"m".to_string(), &"z".to_string()],
2383 "Keys should be sorted alphabetically"
2384 );
2385 }
2386
2387 #[test]
2388 fn preserves_values() {
2389 let mut params = BTreeMap::new();
2390 params.insert("one".to_string(), Value::Number(1.into()));
2391 params.insert("two".to_string(), Value::Bool(true));
2392
2393 let sorted = sort_object_params(¶ms);
2394 assert_eq!(sorted.get("one"), Some(&Value::Number(1.into())));
2395 assert_eq!(sorted.get("two"), Some(&Value::Bool(true)));
2396 }
2397
2398 #[test]
2399 fn empty_map_returns_empty() {
2400 let params: BTreeMap<String, Value> = BTreeMap::new();
2401 let sorted = sort_object_params(¶ms);
2402 assert!(sorted.is_empty(), "Expected empty map");
2403 }
2404
2405 #[test]
2406 fn independent_clone() {
2407 let mut params = BTreeMap::new();
2408 params.insert("key".to_string(), Value::String("val".to_string()));
2409
2410 let mut sorted = sort_object_params(¶ms);
2411 sorted.insert("new".to_string(), Value::String("x".to_string()));
2412
2413 assert!(
2414 !params.contains_key("new"),
2415 "Original should not be modified when changing sorted"
2416 );
2417 assert!(
2418 sorted.contains_key("new"),
2419 "Sorted map should reflect its own insertions"
2420 );
2421 }
2422 }
2423
2424 mod normalize_ws_streams_key {
2425 use crate::common::utils::normalize_ws_streams_key;
2426
2427 #[test]
2428 fn returns_empty_for_empty() {
2429 assert_eq!(normalize_ws_streams_key(""), "");
2430 }
2431
2432 #[test]
2433 fn already_normalized_stays_same() {
2434 assert_eq!(normalize_ws_streams_key("streamname"), "streamname");
2435 }
2436
2437 #[test]
2438 fn uppercases_are_lowercased() {
2439 assert_eq!(normalize_ws_streams_key("MyStream"), "mystream");
2440 }
2441
2442 #[test]
2443 fn underscores_are_removed() {
2444 assert_eq!(normalize_ws_streams_key("my_stream_name"), "mystreamname");
2445 }
2446
2447 #[test]
2448 fn hyphens_are_removed() {
2449 assert_eq!(normalize_ws_streams_key("my-stream-name"), "mystreamname");
2450 }
2451
2452 #[test]
2453 fn mixed_underscores_and_hyphens_and_case() {
2454 let input = "Mixed_Case-Stream_Name";
2455 let expected = "mixedcasestreamname";
2456 assert_eq!(normalize_ws_streams_key(input), expected);
2457 }
2458
2459 #[test]
2460 fn retains_other_punctuation() {
2461 assert_eq!(normalize_ws_streams_key("stream.name!"), "stream.name!");
2462 }
2463 }
2464
2465 mod replace_websocket_streams_placeholders {
2466 use crate::common::utils::replace_websocket_streams_placeholders;
2467 use std::collections::HashMap;
2468
2469 #[test]
2470 fn empty_string_unchanged() {
2471 let vars: HashMap<&str, &str> = HashMap::new();
2472 assert_eq!(replace_websocket_streams_placeholders("", &vars), "");
2473 }
2474
2475 #[test]
2476 fn unknown_placeholder_becomes_empty() {
2477 let vars: HashMap<&str, &str> = HashMap::new();
2478 assert_eq!(replace_websocket_streams_placeholders("<foo>", &vars), "");
2479 }
2480
2481 #[test]
2482 fn leading_slash_symbol_lowercases_head() {
2483 let mut vars = HashMap::new();
2484 vars.insert("symbol", "BTC");
2485 assert_eq!(
2486 replace_websocket_streams_placeholders("/<symbol>", &vars),
2487 "btc"
2488 );
2489 }
2490
2491 #[test]
2492 fn no_lowercase_without_slash() {
2493 let mut vars = HashMap::new();
2494 vars.insert("symbol", "BTC");
2495 assert_eq!(
2496 replace_websocket_streams_placeholders("<symbol>", &vars),
2497 "BTC"
2498 );
2499 }
2500
2501 #[test]
2502 fn multiple_placeholders_mid_preserve_ats() {
2503 let mut vars = HashMap::new();
2504 vars.insert("symbol", "BNBUSDT");
2505 vars.insert("levels", "10");
2506 vars.insert("updateSpeed", "1000ms");
2507 let out = replace_websocket_streams_placeholders(
2508 "/<symbol>@depth<levels>@<updateSpeed>",
2509 &vars,
2510 );
2511 assert_eq!(out, "bnbusdt@depth10@1000ms");
2512 }
2513
2514 #[test]
2515 fn trailing_at_removed_when_missing_var() {
2516 let mut vars = HashMap::new();
2517 vars.insert("symbol", "BNBUSDT");
2518 vars.insert("levels", "10");
2519 let out = replace_websocket_streams_placeholders(
2520 "/<symbol>@depth<levels>@<updateSpeed>",
2521 &vars,
2522 );
2523 assert_eq!(out, "bnbusdt@depth10");
2524 }
2525
2526 #[test]
2527 fn custom_key_normalization_and_value() {
2528 let mut vars = HashMap::new();
2529 vars.insert("my-stream_key", "Value");
2530 assert_eq!(
2531 replace_websocket_streams_placeholders("<My_Stream-Key>", &vars),
2532 "Value"
2533 );
2534 }
2535
2536 #[test]
2537 fn text_surrounding_placeholders_intact() {
2538 let mut vars = HashMap::new();
2539 vars.insert("symbol", "ABC");
2540 let input = "pre-<symbol>-post";
2541 assert_eq!(
2542 replace_websocket_streams_placeholders(input, &vars),
2543 "pre-ABC-post"
2544 );
2545 }
2546 }
2547}