binance_sdk/common/
utils.rs

1use anyhow::{Context, Result, bail};
2use base64::{Engine as _, engine::general_purpose};
3use ed25519_dalek::Signer as Ed25519Signer;
4use ed25519_dalek::SigningKey;
5use ed25519_dalek::pkcs8::DecodePrivateKey;
6use flate2::read::GzDecoder;
7use hex;
8use hmac::{Hmac, Mac};
9use http::HeaderMap;
10use http::header::ACCEPT_ENCODING;
11use once_cell::sync::OnceCell;
12use openssl::{hash::MessageDigest, pkey::PKey, sign::Signer as OpenSslSigner};
13use rand::RngCore;
14use regex::Captures;
15use regex::Regex;
16use reqwest::Client;
17use reqwest::Proxy;
18use reqwest::{Method, Request};
19use serde::de::DeserializeOwned;
20use serde_json::{Value, json};
21use sha2::Sha256;
22use std::fmt::Display;
23use std::hash::BuildHasher;
24use std::sync::LazyLock;
25use std::{
26    collections::BTreeMap,
27    collections::HashMap,
28    fs,
29    io::Read,
30    path::Path,
31    time::Duration,
32    time::{SystemTime, UNIX_EPOCH},
33};
34use tokio::time::sleep;
35use tracing::info;
36use url::{Url, form_urlencoded::Serializer};
37
38use super::config::HttpAgent;
39use super::config::ProxyConfig;
40use super::config::{ConfigurationRestApi, PrivateKey};
41use super::errors::ConnectorError;
42use super::models::TimeUnit;
43use super::models::{Interval, RateLimitType, RestApiRateLimit, RestApiResponse};
44
45static PLACEHOLDER_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(@)?<([^>]+)>").unwrap());
46
47/// A generator for creating cryptographic signatures with support for various key types and configurations.
48///
49/// This struct manages different authentication mechanisms including API secrets, private keys,
50/// and supports multiple key formats (file-based or raw bytes). It uses lazy initialization
51/// for key loading and supports different cryptographic key types like OpenSSL private keys
52/// and Ed25519 signing keys.
53///
54/// # Fields
55/// * `api_secret`: Optional API secret for signature generation
56/// * `private_key`: Optional private key source (file or raw bytes)
57/// * `private_key_passphrase`: Optional passphrase for decrypting private keys
58/// * `raw_key_data`: Lazily initialized raw key data as a string
59/// * `key_object`: Lazily initialized OpenSSL private key
60/// * `ed25519_signing_key`: Lazily initialized Ed25519 signing key
61#[derive(Debug, Default, Clone)]
62pub struct SignatureGenerator {
63    api_secret: Option<String>,
64    private_key: Option<PrivateKey>,
65    private_key_passphrase: Option<String>,
66    raw_key_data: OnceCell<String>,
67    key_object: OnceCell<PKey<openssl::pkey::Private>>,
68    ed25519_signing_key: OnceCell<SigningKey>,
69}
70
71impl SignatureGenerator {
72    #[must_use]
73    pub fn new(
74        api_secret: Option<String>,
75        private_key: Option<PrivateKey>,
76        private_key_passphrase: Option<String>,
77    ) -> Self {
78        SignatureGenerator {
79            api_secret,
80            private_key,
81            private_key_passphrase,
82            raw_key_data: OnceCell::new(),
83            key_object: OnceCell::new(),
84            ed25519_signing_key: OnceCell::new(),
85        }
86    }
87
88    /// Retrieves the raw key data from a private key source.
89    ///
90    /// This method lazily initializes the raw key data by reading it from either a file path
91    /// or a raw byte array. If the key is from a file, it checks for file existence before reading.
92    /// If the key is provided as raw bytes, it converts them to a UTF-8 string.
93    ///
94    /// # Returns
95    /// A reference to the raw key data as a `String`.
96    ///
97    /// # Errors
98    /// Returns an error if:
99    /// - No private key is provided
100    /// - The private key file does not exist
101    /// - The private key file cannot be read
102    fn get_raw_key_data(&self) -> Result<&String> {
103        self.raw_key_data.get_or_try_init(|| {
104            let pk = self
105                .private_key
106                .as_ref()
107                .ok_or_else(|| anyhow::anyhow!("No private_key provided"))?;
108            match pk {
109                PrivateKey::File(path) => {
110                    if Path::new(path).exists() {
111                        fs::read_to_string(path)
112                            .with_context(|| format!("Failed to read private key file: {path}"))
113                    } else {
114                        Err(anyhow::anyhow!("Private key file does not exist: {}", path))
115                    }
116                }
117                PrivateKey::Raw(bytes) => Ok(String::from_utf8_lossy(bytes).to_string()),
118            }
119        })
120    }
121
122    /// Retrieves the private key object, lazily initializing it from raw key data.
123    ///
124    /// This method attempts to parse the private key from PEM format, supporting both
125    /// passphrase-protected and unprotected keys. It uses the raw key data obtained
126    /// from `get_raw_key_data()` and attempts to create an OpenSSL private key object.
127    ///
128    /// # Returns
129    /// A reference to the parsed private key as a `PKey<openssl::pkey::Private>`.
130    ///
131    /// # Errors
132    /// Returns an error if:
133    /// - The key cannot be parsed from PEM format
134    /// - A passphrase is required but incorrect
135    /// - The key data is invalid
136    fn get_key_object(&self) -> Result<&PKey<openssl::pkey::Private>> {
137        self.key_object.get_or_try_init(|| {
138            let key_data = self.get_raw_key_data()?;
139            if let Some(pass) = self.private_key_passphrase.as_ref() {
140                PKey::private_key_from_pem_passphrase(key_data.as_bytes(), pass.as_bytes())
141                    .context("Failed to parse private key with passphrase")
142            } else {
143                PKey::private_key_from_pem(key_data.as_bytes())
144                    .context("Failed to parse private key")
145            }
146        })
147    }
148
149    /// Retrieves the Ed25519 signing key, lazily initializing it from raw key data.
150    ///
151    /// This method attempts to parse an Ed25519 private key from a PEM-formatted input,
152    /// extracting the base64-encoded key material and converting it to a `SigningKey`.
153    ///
154    /// # Returns
155    /// A reference to the parsed Ed25519 signing key.
156    ///
157    /// # Errors
158    /// Returns an error if:
159    /// - The key cannot be base64 decoded
160    /// - The key cannot be parsed from PKCS8 DER format
161    fn get_ed25519_signing_key(&self) -> Result<&SigningKey> {
162        self.ed25519_signing_key.get_or_try_init(|| {
163            let key_data = self.get_raw_key_data()?;
164            let b64 = key_data
165                .lines()
166                .filter(|l| !l.starts_with("-----"))
167                .collect::<String>();
168            let der = general_purpose::STANDARD
169                .decode(b64)
170                .context("Failed to base64 decode Ed25519 PEM")?;
171            SigningKey::from_pkcs8_der(&der)
172                .map_err(|e| anyhow::anyhow!("Failed to parse Ed25519 key: {}", e))
173        })
174    }
175
176    /// Generates a signature for the given query parameters using either HMAC-SHA256 or asymmetric key signing.
177    ///
178    /// # Arguments
179    ///
180    /// * `query_params` - A map of query parameters to be signed
181    ///
182    /// # Returns
183    ///
184    /// A base64-encoded signature string
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if:
189    /// - No API secret or private key is provided
190    /// - Key initialization fails
191    /// - Signing process encounters an error
192    /// - An unsupported key type is used
193    ///
194    /// # Supported Key Types
195    /// - HMAC with API secret
196    /// - RSA private key
197    /// - ED25519 private key
198    pub fn get_signature(&self, query_params: &BTreeMap<String, Value>) -> Result<String> {
199        let params = build_query_string(query_params)?;
200
201        if let Some(secret) = self.api_secret.as_ref() {
202            if self.private_key.is_none() {
203                let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes())
204                    .context("HMAC key initialization failed")?;
205                mac.update(params.as_bytes());
206                let result = mac.finalize().into_bytes();
207                return Ok(hex::encode(result));
208            }
209        }
210
211        if self.private_key.is_some() {
212            let key_obj = self.get_key_object()?;
213            match key_obj.id() {
214                openssl::pkey::Id::RSA => {
215                    let mut signer = OpenSslSigner::new(MessageDigest::sha256(), key_obj)
216                        .context("Failed to create RSA signer")?;
217                    signer
218                        .update(params.as_bytes())
219                        .context("Failed to update RSA signer")?;
220                    let sig = signer.sign_to_vec().context("RSA signing failed")?;
221                    return Ok(general_purpose::STANDARD.encode(sig));
222                }
223                openssl::pkey::Id::ED25519 => {
224                    let signing_key = self.get_ed25519_signing_key()?;
225                    let signature = signing_key.sign(params.as_bytes());
226                    return Ok(general_purpose::STANDARD.encode(signature.to_bytes()));
227                }
228                other => {
229                    return Err(anyhow::anyhow!(
230                        "Unsupported private key type: {:?}. Must be RSA or ED25519.",
231                        other
232                    ));
233                }
234            }
235        }
236
237        Err(anyhow::anyhow!(
238            "Either 'api_secret' or 'private_key' must be provided for signed requests."
239        ))
240    }
241}
242
243/// Builds a reqwest HTTP client with configurable timeout, keep-alive, proxy, and custom agent settings.
244///
245/// # Arguments
246///
247/// * `timeout` - Timeout duration in milliseconds for HTTP requests
248/// * `keep_alive` - Whether to enable HTTP keep-alive connections
249/// * `proxy` - Optional proxy configuration for routing requests
250/// * `agent` - Optional custom HTTP agent configuration function
251///
252/// # Returns
253///
254/// A configured `reqwest::Client` instance
255///
256/// # Panics
257///
258/// Panics if the client cannot be built with the provided configuration
259///
260/// # Examples
261///
262///
263/// let client = `build_client(5000`, true, None, None);
264///
265#[must_use]
266pub fn build_client(
267    timeout: u64,
268    keep_alive: bool,
269    proxy: Option<&ProxyConfig>,
270    agent: Option<HttpAgent>,
271) -> Client {
272    let builder = Client::builder().timeout(Duration::from_millis(timeout));
273
274    let mut builder = if keep_alive {
275        builder
276    } else {
277        builder.pool_idle_timeout(Some(Duration::from_secs(0)))
278    };
279
280    if let Some(proxy_conf) = proxy {
281        let protocol = proxy_conf
282            .protocol
283            .clone()
284            .unwrap_or_else(|| "http".to_string());
285        let proxy_url = format!("{}://{}:{}", protocol, proxy_conf.host, proxy_conf.port);
286        let mut proxy_builder = Proxy::all(&proxy_url).expect("Failed to create proxy from URL");
287        if let Some(auth) = &proxy_conf.auth {
288            proxy_builder = proxy_builder.basic_auth(&auth.username, &auth.password);
289        }
290        builder = builder.proxy(proxy_builder);
291    }
292
293    if let Some(HttpAgent(agent_fn)) = agent {
294        builder = (agent_fn)(builder);
295    }
296
297    info!("Client builder {:?}", builder);
298
299    builder.build().expect("Failed to build reqwest client")
300}
301
302/// Generates a user agent string for the current module.
303///
304/// # Arguments
305///
306/// * `product` - A string slice representing the product.
307///
308/// # Returns
309///
310/// A formatted user agent string containing:
311/// - Package name
312/// - Product
313/// - Package version
314/// - Rust compiler version
315/// - Operating system
316/// - Architecture
317///
318/// # Examples
319///
320///
321/// let `user_agent` = `build_user_agent("spot`");
322/// // Might return something like: "`binance_sdk/spot/1.0.0` (Rust/rustc 1.87.0; linux; `x86_64`;)"
323///
324#[must_use]
325pub fn build_user_agent(product: &str) -> String {
326    format!(
327        "{}/{}/{} (Rust/{}; {}; {})",
328        env!("CARGO_PKG_NAME"),
329        product,
330        env!("CARGO_PKG_VERSION"),
331        env!("RUSTC_VERSION"),
332        std::env::consts::OS,
333        std::env::consts::ARCH,
334    )
335}
336
337/// Validates the time unit string and returns an optional normalized time unit.
338///
339/// # Arguments
340///
341/// * `time_unit` - A string representing the time unit to validate.
342///
343/// # Returns
344///
345/// * `Ok(None)` if an empty string is provided
346/// * `Ok(Some(time_unit))` if the time unit is 'MILLISECOND', 'MICROSECOND', 'millisecond', or 'microsecond'
347/// * `Err` with an error message if an invalid time unit is provided
348///
349/// # Errors
350///
351/// Returns `Err(anyhow::Error)` if `time_unit` is non-empty and not one of the allowed values.
352///
353/// # Examples
354///
355/// let result = `validate_time_unit("MILLISECOND`");
356/// `assert!(result.is_ok())`;
357///
358/// let result = `validate_time_unit`("");
359/// `assert!(result.is_ok()` && `result.unwrap().is_none()`);
360///
361/// let result = `validate_time_unit("SECOND`");
362/// `assert!(result.is_err())`;
363///
364pub fn validate_time_unit(time_unit: &str) -> Result<Option<&str>, anyhow::Error> {
365    match time_unit {
366        "" => Ok(None),
367        "MILLISECOND" | "MICROSECOND" | "millisecond" | "microsecond" => Ok(Some(time_unit)),
368        _ => Err(anyhow::anyhow!(
369            "time_unit must be either 'MILLISECOND' or 'MICROSECOND'"
370        )),
371    }
372}
373
374/// Returns the current timestamp in milliseconds since the Unix epoch.
375///
376/// # Returns
377///
378/// * A `u128` representing the current timestamp in milliseconds.
379///
380/// # Panics
381///
382/// Panics if the system time is set to a time before the Unix epoch.
383///
384/// # Examples
385///
386///
387/// let timestamp = `get_timestamp()`;
388/// println!("Current timestamp: {}", timestamp);
389///
390#[must_use]
391pub fn get_timestamp() -> u128 {
392    SystemTime::now()
393        .duration_since(UNIX_EPOCH)
394        .expect("Time went backwards")
395        .as_millis()
396}
397
398/// Asynchronously pauses the current task for a specified number of milliseconds.
399///
400/// # Arguments
401///
402/// * `ms` - The number of milliseconds to pause the task.
403///
404/// # Examples
405///
406///
407/// let _ = delay(100).await; // Pause for 100 milliseconds
408///
409pub async fn delay(ms: u64) {
410    sleep(Duration::from_millis(ms)).await;
411}
412
413/// Builds a query string from a map of key-value parameters.
414///
415/// Converts various JSON `Value` types into URL query string segments, handling:
416/// - Strings, booleans, and numbers as direct key-value pairs
417/// - Arrays of strings, booleans, or numbers as comma-separated values
418/// - Nested arrays serialized as JSON strings
419///
420/// # Arguments
421///
422/// * `params` - A map of parameter names to their corresponding JSON values
423///
424/// # Returns
425///
426/// * `Result<String, anyhow::Error>` - A query string with URL-encoded parameters, or an error
427///
428/// # Errors
429///
430/// Returns an error if an object value is encountered or JSON serialization fails
431pub fn build_query_string(params: &BTreeMap<String, Value>) -> Result<String, anyhow::Error> {
432    let mut segments = Vec::with_capacity(params.len());
433
434    for (key, value) in params {
435        match value {
436            Value::Null => {}
437            Value::String(s) => {
438                let mut ser = Serializer::new(String::new());
439                ser.append_pair(key, s);
440                segments.push(ser.finish());
441            }
442            Value::Bool(b) => {
443                let val = b.to_string();
444                let mut ser = Serializer::new(String::new());
445                ser.append_pair(key, &val);
446                segments.push(ser.finish());
447            }
448            Value::Number(n) => {
449                let val = n.to_string();
450                let mut ser = Serializer::new(String::new());
451                ser.append_pair(key, &val);
452                segments.push(ser.finish());
453            }
454            Value::Array(arr)
455                if arr
456                    .iter()
457                    .all(|v| matches!(v, Value::String(_) | Value::Bool(_) | Value::Number(_))) =>
458            {
459                let mut parts = Vec::with_capacity(arr.len());
460                for v in arr {
461                    match v {
462                        Value::String(s) => parts.push(s.clone()),
463                        Value::Bool(b) => parts.push(b.to_string()),
464                        Value::Number(n) => parts.push(n.to_string()),
465                        _ => unreachable!(),
466                    }
467                }
468                segments.push(format!("{}={}", key, parts.join(",")));
469            }
470            Value::Array(arr) => {
471                let json =
472                    serde_json::to_string(arr).context("Failed to JSON-serialize nested array")?;
473                let mut ser = Serializer::new(String::new());
474                ser.append_pair(key, &json);
475                segments.push(ser.finish());
476            }
477            Value::Object(_) => {
478                bail!("Cannot serialize object for key `{}` in query params", key);
479            }
480        }
481    }
482
483    Ok(segments.join("&"))
484}
485
486/// Determines whether a request should be retried based on:
487/// - HTTP method (only GET or DELETE are retriable)
488/// - HTTP status (500, 502, 503, 504)
489/// - Number of retries left.
490///
491/// `error` is the reqwest error, `method` is the HTTP method (e.g. "GET"),
492/// and `retries_left` is the number of remaining retries.
493#[must_use]
494pub fn should_retry_request(
495    error: &reqwest::Error,
496    method: Option<&str>,
497    retries_left: Option<usize>,
498) -> bool {
499    let method = method.unwrap_or("");
500    let is_retriable_method =
501        method.eq_ignore_ascii_case("GET") || method.eq_ignore_ascii_case("DELETE");
502
503    let status = error.status().map_or(0, |s| s.as_u16());
504    let is_retriable_status = [500, 502, 503, 504].contains(&status);
505
506    let retries_left = retries_left.unwrap_or(0);
507    retries_left > 0 && is_retriable_method && (is_retriable_status || error.status().is_none())
508}
509
510/// Parses rate limit headers from a `HashMap` of headers and returns a vector of `RestApiRateLimit`.
511///
512/// This function extracts rate limit information from headers with specific patterns (x-mbx-used-weight or x-mbx-order-count)
513/// and converts them into `RestApiRateLimit` structures. It handles different intervals (seconds, minutes, hours, days)
514/// and distinguishes between request weight and order rate limits.
515///
516/// # Arguments
517///
518/// * `headers` - A reference to a `HashMap` containing HTTP headers
519///
520/// # Returns
521///
522/// A `Vec<RestApiRateLimit>` containing parsed rate limit information
523///
524/// # Panics
525///
526/// * If the static regex fails to compile (via `Regex::new(...).unwrap()`), which can only happen if the literal pattern is invalid.  
527/// * If a matching header’s key doesn’t actually contain both capture groups (so `caps.get(2).unwrap()` or `caps.get(3).unwrap()` fails).
528///
529/// # Examples
530///
531/// let headers: `HashMap`<String, String> = // ... headers with rate limit information
532/// let `rate_limits` = `parse_rate_limit_headers(&headers)`;
533///
534#[must_use]
535pub fn parse_rate_limit_headers<S>(headers: &HashMap<String, String, S>) -> Vec<RestApiRateLimit>
536where
537    S: BuildHasher,
538{
539    let mut rate_limits = Vec::new();
540    let re = Regex::new(r"x-mbx-(used-weight|order-count)-(\d+)([smhd])").unwrap();
541    for (key, value) in headers {
542        let normalized_key = key.to_lowercase();
543        if normalized_key.starts_with("x-mbx-used-weight-")
544            || normalized_key.starts_with("x-mbx-order-count-")
545        {
546            if let Some(caps) = re.captures(&normalized_key) {
547                let interval_num: u32 = caps.get(2).unwrap().as_str().parse().unwrap_or(0);
548                let interval_letter = caps.get(3).unwrap().as_str().to_uppercase();
549                let interval = match interval_letter.as_str() {
550                    "S" => Interval::Second,
551                    "M" => Interval::Minute,
552                    "H" => Interval::Hour,
553                    "D" => Interval::Day,
554                    _ => continue,
555                };
556                let count: u32 = value.parse().unwrap_or(0);
557                let rate_limit_type = if normalized_key.starts_with("x-mbx-used-weight-") {
558                    RateLimitType::RequestWeight
559                } else {
560                    RateLimitType::Orders
561                };
562                rate_limits.push(RestApiRateLimit {
563                    rate_limit_type,
564                    interval,
565                    interval_num,
566                    count,
567                    retry_after: headers.get("retry-after").and_then(|v| v.parse().ok()),
568                });
569            }
570        }
571    }
572    rate_limits
573}
574
575/// Sends an HTTP request with retry and error handling capabilities.
576///
577/// # Parameters
578///
579/// - `req`: The HTTP request to be sent
580/// - `configuration`: REST API configuration containing client, retry settings, and other parameters
581///
582/// # Returns
583///
584/// A `Result` containing a `RestApiResponse` with deserialized data, or a `ConnectorError` if the request fails
585///
586/// # Errors
587///
588/// Returns various `ConnectorError` types based on HTTP response status, such as:
589/// - `BadRequestError`
590/// - `UnauthorizedError`
591/// - `ForbiddenError`
592/// - `NotFoundError`
593/// - `RateLimitBanError`
594/// - `TooManyRequestsError`
595/// - `ServerError`
596/// - `ConnectorClientError`
597///
598/// # Behavior
599///
600/// - Supports request retries with configurable backoff
601/// - Handles gzip-encoded responses
602/// - Parses rate limit headers
603/// - Provides detailed error handling for different HTTP status codes
604pub async fn http_request<T: DeserializeOwned + Send + 'static>(
605    req: Request,
606    configuration: &ConfigurationRestApi,
607) -> Result<RestApiResponse<T>, ConnectorError> {
608    let client = &configuration.client;
609    let retries = configuration.retries as usize;
610    let backoff = configuration.backoff;
611    let mut attempt = 0;
612
613    loop {
614        let req_clone = req
615            .try_clone()
616            .context("Failed to clone request")
617            .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
618        match client.execute(req_clone).await {
619            Ok(response) => {
620                let status = response.status();
621                let headers_map: HashMap<String, String> = response
622                    .headers()
623                    .iter()
624                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
625                    .collect();
626
627                let raw_bytes = match response.bytes().await {
628                    Ok(b) => b,
629                    Err(e) => {
630                        attempt += 1;
631                        if attempt <= retries {
632                            continue;
633                        }
634                        return Err(ConnectorError::ConnectorClientError(format!(
635                            "Failed to get response bytes: {e}"
636                        )));
637                    }
638                };
639
640                let content = if headers_map
641                    .get("content-encoding")
642                    .is_some_and(|enc| enc.to_lowercase().contains("gzip"))
643                {
644                    let mut decoder = GzDecoder::new(&raw_bytes[..]);
645                    let mut decompressed = String::new();
646                    decoder
647                        .read_to_string(&mut decompressed)
648                        .context("Failed to decompress gzip response")
649                        .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
650                    decompressed
651                } else {
652                    String::from_utf8(raw_bytes.to_vec())
653                        .context("Failed to convert response to UTF-8")
654                        .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?
655                };
656
657                let rate_limits = parse_rate_limit_headers(&headers_map);
658
659                if status.is_client_error() || status.is_server_error() {
660                    let error_msg = serde_json::from_str::<serde_json::Value>(&content)
661                        .ok()
662                        .and_then(|v| {
663                            v.get("msg")
664                                .and_then(|m| m.as_str())
665                                .map(std::string::ToString::to_string)
666                        })
667                        .unwrap_or_else(|| content.clone());
668
669                    match status.as_u16() {
670                        400 => return Err(ConnectorError::BadRequestError(error_msg)),
671                        401 => return Err(ConnectorError::UnauthorizedError(error_msg)),
672                        403 => return Err(ConnectorError::ForbiddenError(error_msg)),
673                        404 => return Err(ConnectorError::NotFoundError(error_msg)),
674                        418 => return Err(ConnectorError::RateLimitBanError(error_msg)),
675                        429 => return Err(ConnectorError::TooManyRequestsError(error_msg)),
676                        s if (500..600).contains(&s) => {
677                            return Err(ConnectorError::ServerError {
678                                msg: format!("Server error: {s}"),
679                                status_code: Some(s),
680                            });
681                        }
682                        _ => return Err(ConnectorError::ConnectorClientError(error_msg)),
683                    }
684                }
685
686                let raw = content.clone();
687                return Ok(RestApiResponse {
688                    data_fn: Box::new(move || {
689                        Box::pin(async move {
690                            let parsed: T = serde_json::from_str(&raw)
691                                .map_err(|e| ConnectorError::ConnectorClientError(e.to_string()))?;
692                            Ok(parsed)
693                        })
694                    }),
695                    status: status.as_u16(),
696                    headers: headers_map,
697                    rate_limits: if rate_limits.is_empty() {
698                        None
699                    } else {
700                        Some(rate_limits)
701                    },
702                });
703            }
704            Err(e) => {
705                attempt += 1;
706                if should_retry_request(&e, Some(req.method().as_str()), Some(retries - attempt)) {
707                    delay(backoff * attempt as u64).await;
708                    continue;
709                }
710                return Err(ConnectorError::ConnectorClientError(format!(
711                    "HTTP request failed: {e}"
712                )));
713            }
714        }
715    }
716}
717
718/// Sends an HTTP request to a REST API endpoint with optional authentication and configuration.
719///
720/// # Parameters
721///
722/// - `configuration`: REST API configuration containing client, base path, and authentication details
723/// - `endpoint`: The specific API endpoint path to send the request to
724/// - `method`: HTTP method for the request (GET, POST, etc.)
725/// - `params`: Parameters to be sent with the request, as a key-value map
726/// - `time_unit`: Optional time unit for the request header
727/// - `is_signed`: Optional flag to indicate whether the request requires authentication
728///
729/// # Returns
730///
731/// A `RestApiResponse` containing the deserialized response data, or an error if the request fails
732///
733/// # Panics
734///
735/// This function will panic if any of the following `.unwrap()` calls fail:
736/// - Parsing the literal `"application/json"` into a header value (should never fail)  
737/// - Parsing `configuration.user_agent` or `configuration.api_key` into header values  
738/// - Parsing the literal `"gzip, deflate, br"` into a header value when `compression` is enabled  
739///
740/// # Errors
741///
742/// Returns an `anyhow::Result` which can contain various connector-related errors during request processing
743pub async fn send_request<T: DeserializeOwned + Send + 'static>(
744    configuration: &ConfigurationRestApi,
745    endpoint: &str,
746    method: Method,
747    mut params: BTreeMap<String, Value>,
748    time_unit: Option<TimeUnit>,
749    is_signed: bool,
750) -> anyhow::Result<RestApiResponse<T>> {
751    let base = configuration.base_path.as_deref().unwrap_or("");
752    let full_url = reqwest::Url::parse(base)
753        .and_then(|u| u.join(endpoint))
754        .context("Failed to join base URL and endpoint")?
755        .to_string();
756
757    if is_signed {
758        let timestamp = get_timestamp();
759        params.insert("timestamp".to_string(), json!(timestamp));
760        let signature = configuration.signature_gen.get_signature(&params)?;
761        params.insert("signature".to_string(), Value::String(signature));
762    }
763
764    let mut url = Url::parse(&full_url)?;
765    {
766        let mut pairs = url.query_pairs_mut();
767        for (key, value) in &params {
768            let val_str = match value {
769                Value::String(s) => s.clone(),
770                _ => value.to_string(),
771            };
772            pairs.append_pair(key, &val_str);
773        }
774    }
775
776    let mut headers = HeaderMap::new();
777    headers.insert("Content-Type", "application/json".parse().unwrap());
778    headers.insert("User-Agent", configuration.user_agent.parse().unwrap());
779    if let Some(api_key) = &configuration.api_key {
780        headers.insert("X-MBX-APIKEY", api_key.parse().unwrap());
781    }
782
783    if configuration.compression {
784        headers.insert(ACCEPT_ENCODING, "gzip, deflate, br".parse().unwrap());
785    }
786
787    let time_unit_to_apply = time_unit.or(configuration.time_unit);
788    if let Some(time_unit) = time_unit_to_apply {
789        headers.insert("X-MBX-TIME-UNIT", time_unit.as_upper_str().parse()?);
790    }
791
792    let req_builder = configuration.client.request(method, url).headers(headers);
793    let req = req_builder.build()?;
794
795    Ok(http_request::<T>(req, configuration).await?)
796}
797
798/// Generates a random hexadecimal string of 32 characters.
799///
800/// Uses the thread-local random number generator to fill a 16-byte buffer,
801/// which is then encoded into a hexadecimal string.
802///
803/// # Returns
804///
805/// A randomly generated 32-character hexadecimal string.
806#[must_use]
807pub fn random_string() -> String {
808    let mut buf = [0u8; 16];
809    rand::thread_rng().fill_bytes(&mut buf);
810    hex::encode(buf)
811}
812
813/// Removes entries with empty or null values from an iterator of key-value pairs.
814///
815/// # Arguments
816///
817/// * `entries` - An iterator of key-value pairs where keys are strings and values are of type `Value`.
818///
819/// # Returns
820///
821/// A `BTreeMap` containing only the key-value pairs where the value is neither `null` nor an empty string.
822///
823/// # Examples
824///
825///
826/// let entries = vec![
827///     ("`key1".to_string()`, `Value::String("value1".to_string())`),
828///     ("`key2".to_string()`, `Value::Null`),
829///     ("`key3".to_string()`, `Value::String("".to_string())`),
830/// ];
831/// let filtered = `remove_empty_value(entries)`;
832/// // filtered will only contain the first key-value pair
833///
834pub fn remove_empty_value<I>(entries: I) -> BTreeMap<String, Value>
835where
836    I: IntoIterator<Item = (String, Value)>,
837{
838    entries
839        .into_iter()
840        .filter(|(_, value)| match value {
841            Value::Null => false,
842            Value::String(s) if s.is_empty() => false,
843            _ => true,
844        })
845        .collect()
846}
847
848/// Creates a sorted copy of a `BTreeMap` of parameters.
849///
850/// # Arguments
851///
852/// * `params` - A reference to a `BTreeMap` containing string keys and Value values.
853///
854/// # Returns
855///
856/// A new `BTreeMap` with the same key-value pairs as the input, sorted by keys.
857///
858/// # Examples
859///
860///
861/// let params = `BTreeMap::from`([
862///     ("`z".to_string()`, `Value::String("value1".to_string())`),
863///     ("`a".to_string()`, `Value::String("value2".to_string())`),
864/// ]);
865/// let `sorted_params` = `sort_object_params(&params)`;
866/// // `sorted_params` will have keys sorted in ascending order
867///
868#[must_use]
869pub fn sort_object_params(params: &BTreeMap<String, Value>) -> BTreeMap<String, Value> {
870    let mut sorted = BTreeMap::new();
871    for (k, v) in params {
872        sorted.insert(k.clone(), v.clone());
873    }
874    sorted
875}
876
877/// Normalizes a WebSocket streams key by converting it to lowercase and removing underscores and hyphens.
878///
879/// # Arguments
880///
881/// * `key` - The input key to be normalized
882///
883/// # Returns
884///
885/// A normalized string with lowercase characters and no underscores or hyphens
886fn normalize_ws_streams_key(key: &str) -> String {
887    key.to_lowercase().replace(&['_', '-'][..], "")
888}
889
890/// Replaces placeholders in a WebSocket stream key with corresponding values from a variables map.
891///
892/// # Arguments
893///
894/// * `input` - The input string containing placeholders to be replaced
895/// * `variables` - A `HashMap` of key-value pairs used for placeholder substitution
896///
897/// # Returns
898///
899/// A modified string with placeholders replaced by their corresponding values,
900/// with special handling for normalization, lowercasing, and '@' symbol stripping.
901///
902/// # Panics
903///
904/// Panics if the input string contains an invalid placeholder format.
905///
906/// # Examples
907///
908///
909/// let input = "/<symbol>@ticker";
910/// let variables = `HashMap::from`([("symbol", "BTCUSDT")]);
911/// let result = `replace_websocket_streams_placeholders(input`, &variables);
912/// // Possible result: "btcusdt@ticker"
913///
914pub fn replace_websocket_streams_placeholders<V, S>(
915    input: &str,
916    variables: &HashMap<&str, V, S>,
917) -> String
918where
919    V: Display,
920    S: BuildHasher,
921{
922    let original = input;
923
924    // Drop a leading slash for processing
925    let body = original.strip_prefix('/').unwrap_or(original);
926
927    // Normalize variables into String→String map
928    let normalized: HashMap<String, String> = variables
929        .iter()
930        .map(|(k, v)| (normalize_ws_streams_key(k), v.to_string()))
931        .collect();
932
933    // Replace all placeholders, preserving any '@' prefix captured by the regex
934    let replaced = PLACEHOLDER_RE
935        .replace_all(body, |caps: &Captures| {
936            let prefix = caps.get(1).map_or("", |m| m.as_str());
937            let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
938            let val = normalized.get(&key).cloned().unwrap_or_default();
939            format!("{prefix}{val}")
940        })
941        .into_owned();
942
943    // Strip any trailing '@'
944    let stripped = replaced.trim_end_matches('@').to_string();
945
946    // Only lowercase head if original started with '/' and first placeholder at start
947    // (cases where `symbol` or `pair` are used and they are not lower-cased)
948    let should_lower_head =
949        original.starts_with('/') && PLACEHOLDER_RE.find(body).is_some_and(|m| m.start() == 0);
950
951    // Lowercase only that first placeholder's value
952    let result = if should_lower_head {
953        if let Some(caps) = PLACEHOLDER_RE.captures(body) {
954            let key = normalize_ws_streams_key(caps.get(2).unwrap().as_str());
955            let first_val = normalized.get(&key).cloned().unwrap_or_default();
956            if stripped.starts_with(&first_val) {
957                let tail = &stripped[first_val.len()..];
958                format!("{}{}", first_val.to_lowercase(), tail)
959            } else {
960                stripped.clone()
961            }
962        } else {
963            stripped.clone()
964        }
965    } else {
966        stripped.clone()
967    };
968
969    result
970}
971
972#[cfg(test)]
973mod tests {
974    use crate::TOKIO_SHARED_RT;
975
976    mod build_client {
977        use std::{
978            sync::{Arc, Mutex},
979            time::{Duration, Instant},
980        };
981
982        use reqwest::ClientBuilder;
983
984        use crate::{
985            common::utils::build_client,
986            config::{HttpAgent, ProxyAuth, ProxyConfig},
987        };
988
989        use super::TOKIO_SHARED_RT;
990
991        #[test]
992        fn enforces_timeout() {
993            TOKIO_SHARED_RT.block_on(async {
994                let client = build_client(100, true, None, None);
995                let start = Instant::now();
996                let res = client.get("http://10.255.255.1").send().await;
997                assert!(
998                    res.is_err(),
999                    "expected an error (timeout or connect) but got {res:?}"
1000                );
1001                let elapsed = start.elapsed();
1002                assert!(
1003                    elapsed < Duration::from_millis(500),
1004                    "timed out too slowly: {elapsed:?}"
1005                );
1006            });
1007        }
1008
1009        #[test]
1010        fn builds_with_keep_alive_disabled() {
1011            let client = build_client(200, false, None, None);
1012            let _: reqwest::Client = client;
1013        }
1014
1015        #[test]
1016        #[should_panic(expected = "Failed to create proxy from URL")]
1017        fn invalid_proxy_url_panics() {
1018            let bad_proxy = ProxyConfig {
1019                protocol: Some("http".to_string()),
1020                host: String::new(),
1021                port: 8080,
1022                auth: None,
1023            };
1024            let _ = build_client(1_000, true, Some(&bad_proxy), None);
1025        }
1026
1027        #[test]
1028        fn builds_with_proxy_and_auth() {
1029            let proxy = ProxyConfig {
1030                protocol: Some("https".to_string()),
1031                host: "127.0.0.1".to_string(),
1032                port: 3128,
1033                auth: Some(ProxyAuth {
1034                    username: "alice".to_string(),
1035                    password: "secret".to_string(),
1036                }),
1037            };
1038            let client = build_client(2_000, true, Some(&proxy), None);
1039            let _: reqwest::Client = client;
1040        }
1041
1042        #[test]
1043        fn custom_agent_invoked() {
1044            let called = Arc::new(Mutex::new(false));
1045            let called_clone = Arc::clone(&called);
1046
1047            let agent = HttpAgent(Arc::new(move |builder: ClientBuilder| {
1048                *called_clone.lock().unwrap() = true;
1049                builder
1050            }));
1051
1052            let client = build_client(1_000, true, None, Some(agent));
1053            assert!(*called.lock().unwrap(), "agent closure wasn’t invoked");
1054            let _: reqwest::Client = client;
1055        }
1056    }
1057
1058    mod build_user_agent {
1059        use crate::common::utils::build_user_agent;
1060
1061        #[test]
1062        fn build_user_agent_contains_crate_product_and_rust_info() {
1063            let product = "product";
1064            let user_agent = build_user_agent(product);
1065
1066            let name = env!("CARGO_PKG_NAME");
1067            let version = env!("CARGO_PKG_VERSION");
1068            let rustc = env!("RUSTC_VERSION");
1069            let os = std::env::consts::OS;
1070            let arch = std::env::consts::ARCH;
1071
1072            let expected_prefix = format!("{name}/{product}/{version} (Rust/");
1073            assert!(
1074                user_agent.starts_with(&expected_prefix),
1075                "prefix mismatch: {user_agent}"
1076            );
1077
1078            assert!(
1079                user_agent.contains(rustc),
1080                "user agent missing RUSTC_VERSION: {user_agent}"
1081            );
1082
1083            assert!(
1084                user_agent.contains(&format!("; {os}")),
1085                "user agent missing OS: {user_agent}"
1086            );
1087            assert!(
1088                user_agent.contains(&format!("; {arch}")),
1089                "user agent missing ARCH: {user_agent}"
1090            );
1091        }
1092
1093        #[test]
1094        fn build_user_agent_is_deterministic() {
1095            let product = "product";
1096            let user_agent1 = build_user_agent(product);
1097            let user_agent2 = build_user_agent(product);
1098            assert_eq!(
1099                user_agent1, user_agent2,
1100                "user agent should be the same on repeated calls"
1101            );
1102        }
1103    }
1104
1105    mod validate_time_unit {
1106        use crate::common::utils::validate_time_unit;
1107
1108        #[test]
1109        fn empty_string_returns_none() {
1110            let res = validate_time_unit("").expect("Should not error on empty string");
1111            assert_eq!(res, None);
1112        }
1113
1114        #[test]
1115        fn uppercase_millisecond() {
1116            let res = validate_time_unit("MILLISECOND").expect("Should accept MILLISECOND");
1117            assert_eq!(res, Some("MILLISECOND"));
1118        }
1119
1120        #[test]
1121        fn uppercase_microsecond() {
1122            let res = validate_time_unit("MICROSECOND").expect("Should accept MICROSECOND");
1123            assert_eq!(res, Some("MICROSECOND"));
1124        }
1125
1126        #[test]
1127        fn lowercase_millisecond() {
1128            let res = validate_time_unit("millisecond").expect("Should accept millisecond");
1129            assert_eq!(res, Some("millisecond"));
1130        }
1131
1132        #[test]
1133        fn lowercase_microsecond() {
1134            let res = validate_time_unit("microsecond").expect("Should accept microsecond");
1135            assert_eq!(res, Some("microsecond"));
1136        }
1137
1138        #[test]
1139        fn invalid_value_returns_err() {
1140            let err = validate_time_unit("SECOND").unwrap_err();
1141            let msg = format!("{err}");
1142            assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1143        }
1144
1145        #[test]
1146        fn partial_match_returns_err() {
1147            let err = validate_time_unit("MILLI").unwrap_err();
1148            let msg = format!("{err}");
1149            assert!(msg.contains("time_unit must be either 'MILLISECOND' or 'MICROSECOND'"));
1150        }
1151    }
1152
1153    mod get_timestamp {
1154        use crate::common::utils::get_timestamp;
1155        use std::{
1156            thread::sleep,
1157            time::{Duration, SystemTime, UNIX_EPOCH},
1158        };
1159
1160        #[test]
1161        fn timestamp_is_within_system_time_bounds() {
1162            let before = SystemTime::now()
1163                .duration_since(UNIX_EPOCH)
1164                .expect("SystemTime before UNIX_EPOCH")
1165                .as_millis();
1166            let ts = get_timestamp();
1167            let after = SystemTime::now()
1168                .duration_since(UNIX_EPOCH)
1169                .expect("SystemTime before UNIX_EPOCH")
1170                .as_millis();
1171
1172            assert!(
1173                ts >= before,
1174                "timestamp {ts} is before captured before time {before}"
1175            );
1176            assert!(
1177                ts <= after,
1178                "timestamp {ts} is after captured after time {after}"
1179            );
1180        }
1181
1182        #[test]
1183        fn timestamps_are_monotonic() {
1184            let t1 = get_timestamp();
1185            sleep(Duration::from_millis(1));
1186            let t2 = get_timestamp();
1187            assert!(
1188                t2 >= t1,
1189                "second timestamp {t2} is not >= first timestamp {t1}"
1190            );
1191        }
1192    }
1193
1194    mod build_query_string {
1195        use std::collections::BTreeMap;
1196
1197        use anyhow::Result;
1198        use serde_json::{Value, json};
1199        use url::form_urlencoded::Serializer;
1200
1201        use crate::common::utils::build_query_string;
1202
1203        fn mk_map(pairs: Vec<(&str, Value)>) -> BTreeMap<String, Value> {
1204            let mut m = BTreeMap::new();
1205            for (k, v) in pairs {
1206                m.insert(k.to_string(), v);
1207            }
1208            m
1209        }
1210
1211        #[test]
1212        fn empty_map_returns_empty_string() -> Result<()> {
1213            let params = BTreeMap::new();
1214            let qs = build_query_string(&params)?;
1215            assert_eq!(qs, "");
1216            Ok(())
1217        }
1218
1219        #[test]
1220        fn string_and_number() -> Result<()> {
1221            let params = mk_map(vec![("foo", json!("bar")), ("num", json!(42))]);
1222            let qs = build_query_string(&params)?;
1223            assert_eq!(qs, "foo=bar&num=42");
1224            Ok(())
1225        }
1226
1227        #[test]
1228        fn bool_and_null_skipped() -> Result<()> {
1229            let params = mk_map(vec![("a", json!(true)), ("b", Value::Null)]);
1230            let qs = build_query_string(&params)?;
1231            assert_eq!(qs, "a=true");
1232            Ok(())
1233        }
1234
1235        #[test]
1236        fn flat_array() -> Result<()> {
1237            let params = mk_map(vec![("list", json!(vec!["x", "y", "z"]))]);
1238            let qs = build_query_string(&params)?;
1239            assert_eq!(qs, "list=x,y,z");
1240            Ok(())
1241        }
1242
1243        #[test]
1244        fn nested_array_json_encoded() -> Result<()> {
1245            let params = mk_map(vec![("nested", json!([[1, 2], [3, 4]]))]);
1246            let qs = build_query_string(&params)?;
1247
1248            let nested_json = serde_json::to_string(&json!([[1, 2], [3, 4]]))?;
1249            let mut ser = Serializer::new(String::new());
1250            ser.append_pair("nested", &nested_json);
1251            let expected = ser.finish();
1252
1253            assert_eq!(qs, expected);
1254            Ok(())
1255        }
1256
1257        #[test]
1258        fn object_not_supported() {
1259            let params = mk_map(vec![("obj", json!({"k":1}))]);
1260            let err = build_query_string(&params).unwrap_err();
1261            let msg = format!("{err}");
1262            assert!(msg.contains("Cannot serialize object for key `obj`"));
1263        }
1264    }
1265
1266    mod signature_generator {
1267        use base64::{Engine, engine::general_purpose};
1268        use ed25519_dalek::{SigningKey, ed25519::signature::SignerMut, pkcs8::DecodePrivateKey};
1269        use hex;
1270        use hmac::{Hmac, Mac};
1271        use openssl::{hash::MessageDigest, pkey::PKey, rsa::Rsa, sign::Verifier};
1272        use serde_json::Value;
1273        use sha2::Sha256;
1274        use std::collections::BTreeMap;
1275        use std::io::Write;
1276        use tempfile::NamedTempFile;
1277
1278        use crate::{common::utils::SignatureGenerator, config::PrivateKey};
1279
1280        #[test]
1281        fn hmac_sha256_signature() {
1282            let mut params = BTreeMap::new();
1283            params.insert("b".into(), Value::Number(2.into()));
1284            params.insert("a".into(), Value::Number(1.into()));
1285
1286            let signature_gen = SignatureGenerator::new(Some("test-secret".into()), None, None);
1287            let sig = signature_gen
1288                .get_signature(&params)
1289                .expect("HMAC signing failed");
1290
1291            let mut mac = Hmac::<Sha256>::new_from_slice(b"test-secret").unwrap();
1292            let qs = "a=1&b=2";
1293            mac.update(qs.as_bytes());
1294            let expected = hex::encode(mac.finalize().into_bytes());
1295
1296            assert_eq!(sig, expected);
1297        }
1298
1299        #[test]
1300        fn repeated_hmac_signature() {
1301            let mut params = BTreeMap::new();
1302            params.insert("x".into(), Value::String("y".into()));
1303            let signature_gen = SignatureGenerator::new(Some("abc".into()), None, None);
1304            let s1 = signature_gen.get_signature(&params).unwrap();
1305            let s2 = signature_gen.get_signature(&params).unwrap();
1306            assert_eq!(s1, s2);
1307        }
1308
1309        #[test]
1310        fn rsa_signature_verification() {
1311            let mut params = BTreeMap::new();
1312            params.insert("a".into(), Value::Number(1.into()));
1313            params.insert("b".into(), Value::Number(2.into()));
1314
1315            let rsa = Rsa::generate(2048).unwrap();
1316            let priv_pem = rsa.private_key_to_pem().unwrap();
1317            let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1318
1319            let signature_gen =
1320                SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1321            let sig = signature_gen
1322                .get_signature(&params)
1323                .expect("RSA signing failed");
1324
1325            let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1326            let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1327            let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1328            verifier.update(b"a=1&b=2").unwrap();
1329            assert!(verifier.verify(&sig_bytes).unwrap());
1330        }
1331
1332        #[test]
1333        fn repeated_rsa_signature() {
1334            let mut params = BTreeMap::new();
1335            params.insert("k".into(), Value::Number(5.into()));
1336            let rsa = Rsa::generate(2048).unwrap();
1337            let priv_pem = rsa.private_key_to_pem().unwrap();
1338            let signature_gen =
1339                SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem)), None);
1340            let s1 = signature_gen.get_signature(&params).unwrap();
1341            let s2 = signature_gen.get_signature(&params).unwrap();
1342            assert_eq!(s1, s2);
1343        }
1344
1345        #[test]
1346        fn ed25519_signature_verification() {
1347            let mut params = BTreeMap::new();
1348            params.insert("a".into(), Value::Number(1.into()));
1349            params.insert("b".into(), Value::Number(2.into()));
1350            let qs = "a=1&b=2";
1351
1352            let ed = PKey::generate_ed25519().unwrap();
1353            let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1354
1355            let signature_gen =
1356                SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1357            let sig = signature_gen
1358                .get_signature(&params)
1359                .expect("Ed25519 signing failed");
1360
1361            let pem_str = String::from_utf8(priv_pem).unwrap();
1362            let b64 = pem_str
1363                .lines()
1364                .filter(|l| !l.starts_with("-----"))
1365                .collect::<String>();
1366            let der = general_purpose::STANDARD.decode(b64).unwrap();
1367            let mut sk = SigningKey::from_pkcs8_der(&der).unwrap();
1368            let expected_bytes = sk.sign(qs.as_bytes()).to_bytes();
1369            let expected_sig = general_purpose::STANDARD.encode(expected_bytes);
1370            assert_eq!(sig, expected_sig);
1371        }
1372
1373        #[test]
1374        fn repeated_ed25519_signature() {
1375            let mut params = BTreeMap::new();
1376            params.insert("m".into(), Value::String("n".into()));
1377            let ed = PKey::generate_ed25519().unwrap();
1378            let priv_pem = ed.private_key_to_pem_pkcs8().unwrap();
1379            let signature_gen =
1380                SignatureGenerator::new(None, Some(PrivateKey::Raw(priv_pem.clone())), None);
1381            let s1 = signature_gen.get_signature(&params).unwrap();
1382            let s2 = signature_gen.get_signature(&params).unwrap();
1383            assert_eq!(s1, s2);
1384        }
1385
1386        #[test]
1387        fn file_based_key() {
1388            let rsa = Rsa::generate(1024).unwrap();
1389            let priv_pem = rsa.private_key_to_pem().unwrap();
1390            let pub_pem = rsa.public_key_to_pem_pkcs1().unwrap();
1391
1392            let mut file = NamedTempFile::new().unwrap();
1393            file.write_all(&priv_pem).unwrap();
1394            let path = file.path().to_str().unwrap().to_string();
1395
1396            let mut params = BTreeMap::new();
1397            params.insert("z".into(), Value::Number(9.into()));
1398
1399            let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::File(path)), None);
1400            let sig = signature_gen.get_signature(&params).unwrap();
1401
1402            let sig_bytes = general_purpose::STANDARD.decode(&sig).unwrap();
1403            let pubkey = PKey::public_key_from_pem(&pub_pem).unwrap();
1404            let mut verifier = Verifier::new(MessageDigest::sha256(), &pubkey).unwrap();
1405            verifier.update(b"z=9").unwrap();
1406            assert!(verifier.verify(&sig_bytes).unwrap());
1407        }
1408
1409        #[test]
1410        fn unsupported_key_type_error() {
1411            let mut params = BTreeMap::new();
1412            params.insert("x".into(), Value::String("y".into()));
1413
1414            let group =
1415                openssl::ec::EcGroup::from_curve_name(openssl::nid::Nid::X9_62_PRIME256V1).unwrap();
1416            let ec_key = openssl::ec::EcKey::generate(&group).unwrap();
1417            let pkey_ec = PKey::from_ec_key(ec_key).unwrap();
1418            let raw = pkey_ec.private_key_to_pem_pkcs8().unwrap();
1419
1420            let signature_gen = SignatureGenerator::new(None, Some(PrivateKey::Raw(raw)), None);
1421            let err = signature_gen
1422                .get_signature(&params)
1423                .unwrap_err()
1424                .to_string();
1425            assert!(err.contains("Unsupported private key type"));
1426        }
1427
1428        #[test]
1429        fn invalid_private_key_error() {
1430            let mut params = BTreeMap::new();
1431            params.insert("foo".into(), Value::String("bar".into()));
1432
1433            let signature_gen =
1434                SignatureGenerator::new(None, Some(PrivateKey::Raw(b"not a key".to_vec())), None);
1435            let err = signature_gen
1436                .get_signature(&params)
1437                .unwrap_err()
1438                .to_string();
1439            assert!(err.contains("Failed to parse private key"));
1440        }
1441
1442        #[test]
1443        fn missing_credentials_error() {
1444            let mut params = BTreeMap::new();
1445            params.insert("a".into(), Value::Number(1.into()));
1446
1447            let signature_gen = SignatureGenerator::new(None, None, None);
1448            let err = signature_gen
1449                .get_signature(&params)
1450                .unwrap_err()
1451                .to_string();
1452            assert!(err.contains("Either 'api_secret' or 'private_key' must be provided"));
1453        }
1454    }
1455
1456    mod should_retry_request {
1457        use crate::common::utils::should_retry_request;
1458
1459        use reqwest::{Error, Response};
1460
1461        fn mk_http_error(code: u16) -> Error {
1462            let resp = Response::from(
1463                http::response::Response::builder()
1464                    .status(code)
1465                    .body("")
1466                    .unwrap(),
1467            );
1468            resp.error_for_status().unwrap_err()
1469        }
1470
1471        fn mk_network_error() -> Error {
1472            reqwest::blocking::get("http://256.256.256.256").unwrap_err()
1473        }
1474
1475        #[test]
1476        fn retry_on_retriable_status_and_method() {
1477            let err = mk_http_error(500);
1478            assert!(should_retry_request(&err, Some("GET"), Some(1)));
1479            assert!(should_retry_request(&err, Some("delete"), Some(2)));
1480        }
1481
1482        #[test]
1483        fn retry_when_status_none_and_retriable_method() {
1484            let retriable_methods = ["GET", "DELETE"];
1485
1486            for &method in &retriable_methods {
1487                let err = mk_network_error();
1488                assert!(
1489                    should_retry_request(&err, Some(method), Some(1)),
1490                    "Should retry when no status and method {method}"
1491                );
1492            }
1493        }
1494
1495        #[test]
1496        fn no_retry_when_no_retries_left() {
1497            let err = mk_http_error(503);
1498            assert!(!should_retry_request(&err, Some("GET"), Some(0)));
1499        }
1500
1501        #[test]
1502        fn no_retry_on_non_retriable_status() {
1503            let non_retriable_statuses = [400, 401, 404, 422];
1504
1505            for &status in &non_retriable_statuses {
1506                let err = mk_http_error(status);
1507                assert!(
1508                    !should_retry_request(&err, Some("GET"), Some(2)),
1509                    "Should not retry for non-retriable status {status}"
1510                );
1511            }
1512        }
1513
1514        #[test]
1515        fn no_retry_on_non_retriable_method() {
1516            let non_retriable_methods = ["POST", "PUT", "PATCH"];
1517
1518            for &method in &non_retriable_methods {
1519                let err = mk_http_error(500);
1520                assert!(
1521                    !should_retry_request(&err, Some(method), Some(2)),
1522                    "Should not retry for non-retriable method {method}"
1523                );
1524            }
1525        }
1526
1527        #[test]
1528        fn no_retry_when_status_none_and_non_retriable_method() {
1529            let non_retriable_methods = ["POST", "PUT"];
1530
1531            for &method in &non_retriable_methods {
1532                let err = mk_network_error();
1533                assert!(
1534                    !should_retry_request(&err, Some(method), Some(1)),
1535                    "Should not retry when no status and method {method}"
1536                );
1537            }
1538        }
1539    }
1540
1541    mod parse_rate_limit_headers_tests {
1542        use crate::common::{
1543            models::{Interval, RateLimitType},
1544            utils::parse_rate_limit_headers,
1545        };
1546        use std::collections::HashMap;
1547
1548        fn mk_headers(pairs: Vec<(&str, &str)>) -> HashMap<String, String> {
1549            let mut m = HashMap::new();
1550            for (k, v) in pairs {
1551                m.insert(k.to_string(), v.to_string());
1552            }
1553            m
1554        }
1555
1556        #[test]
1557        fn single_weight_header() {
1558            let headers = mk_headers(vec![("x-mbx-used-weight-1s", "123")]);
1559            let limits = parse_rate_limit_headers(&headers);
1560            assert_eq!(limits.len(), 1);
1561            let rl = &limits[0];
1562            assert_eq!(rl.rate_limit_type, RateLimitType::RequestWeight);
1563            assert_eq!(rl.interval, Interval::Second);
1564            assert_eq!(rl.interval_num, 1);
1565            assert_eq!(rl.count, 123);
1566            assert_eq!(rl.retry_after, None);
1567        }
1568
1569        #[test]
1570        fn single_order_count_with_retry_after() {
1571            let headers = mk_headers(vec![("x-mbx-order-count-5m", "42"), ("retry-after", "7")]);
1572            let limits = parse_rate_limit_headers(&headers);
1573            assert_eq!(limits.len(), 1);
1574            let rl = &limits[0];
1575            assert_eq!(rl.rate_limit_type, RateLimitType::Orders);
1576            assert_eq!(rl.interval, Interval::Minute);
1577            assert_eq!(rl.interval_num, 5);
1578            assert_eq!(rl.count, 42);
1579            assert_eq!(rl.retry_after, Some(7));
1580        }
1581
1582        #[test]
1583        fn multiple_headers() {
1584            let headers = mk_headers(vec![
1585                ("X-MBX-USED-WEIGHT-1h", "10"),
1586                ("x-mbx-order-count-2d", "20"),
1587            ]);
1588            let mut limits = parse_rate_limit_headers(&headers);
1589            limits.sort_by_key(|r| (r.interval_num, format!("{:?}", r.rate_limit_type)));
1590            assert_eq!(limits.len(), 2);
1591            let w = &limits[0];
1592            assert_eq!(w.rate_limit_type, RateLimitType::RequestWeight);
1593            assert_eq!(w.interval, Interval::Hour);
1594            assert_eq!(w.interval_num, 1);
1595            assert_eq!(w.count, 10);
1596            let o = &limits[1];
1597            assert_eq!(o.rate_limit_type, RateLimitType::Orders);
1598            assert_eq!(o.interval, Interval::Day);
1599            assert_eq!(o.interval_num, 2);
1600            assert_eq!(o.count, 20);
1601        }
1602
1603        #[test]
1604        fn ignores_unknown_and_malformed() {
1605            let headers = mk_headers(vec![
1606                ("x-mbx-used-weight-3x", "5"),
1607                ("random-header", "100"),
1608            ]);
1609            let limits = parse_rate_limit_headers(&headers);
1610            assert!(limits.is_empty());
1611        }
1612    }
1613
1614    mod http_request {
1615        use std::io::Write;
1616
1617        use flate2::{Compression, write::GzEncoder};
1618        use httpmock::MockServer;
1619        use reqwest::{Client, Method, Request};
1620        use serde::Deserialize;
1621
1622        use crate::{
1623            common::utils::http_request, config::ConfigurationRestApi, errors::ConnectorError,
1624            models::RestApiResponse,
1625        };
1626
1627        use super::TOKIO_SHARED_RT;
1628
1629        #[derive(Deserialize, Debug, PartialEq)]
1630        struct Dummy {
1631            foo: String,
1632        }
1633
1634        fn make_config(server_url: &str) -> ConfigurationRestApi {
1635            ConfigurationRestApi::builder()
1636                .api_key("key")
1637                .api_secret("secret")
1638                .base_path(server_url)
1639                .build()
1640                .expect("Failed to build configuration")
1641        }
1642
1643        #[test]
1644        fn http_request_success_plain_text() {
1645            TOKIO_SHARED_RT.block_on(async {
1646                let server = MockServer::start();
1647                let mock = server.mock(|when, then| {
1648                    when.method(httpmock::Method::GET).path("/test");
1649                    then.status(200)
1650                        .header("Content-Type", "application/json")
1651                        .body(r#"{"foo":"bar"}"#);
1652                });
1653
1654                let client = Client::new();
1655                let req: Request = client
1656                    .request(Method::GET, format!("{}{}", server.url(""), "/test"))
1657                    .build()
1658                    .unwrap();
1659
1660                let cfg = make_config(&server.url(""));
1661                let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
1662                assert_eq!(resp.status, 200);
1663                let data = resp.data().await.unwrap();
1664                assert_eq!(data, Dummy { foo: "bar".into() });
1665                mock.assert();
1666            });
1667        }
1668
1669        #[test]
1670        fn http_request_success_gzip() {
1671            TOKIO_SHARED_RT.block_on(async {
1672                let server = MockServer::start();
1673                let body = r#"{"foo":"baz"}"#;
1674                let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
1675                encoder.write_all(body.as_bytes()).unwrap();
1676                let gz = encoder.finish().unwrap();
1677
1678                let mock = server.mock(|when, then| {
1679                    when.method(httpmock::Method::GET).path("/gz");
1680                    then.status(200)
1681                        .header("Content-Type", "application/json")
1682                        .header("Content-Encoding", "gzip")
1683                        .body(gz);
1684                });
1685
1686                let client = Client::new();
1687                let req: Request = client
1688                    .request(Method::GET, format!("{}{}", server.url(""), "/gz"))
1689                    .build()
1690                    .unwrap();
1691                let mut cfg = make_config(&server.url(""));
1692                cfg.compression = true;
1693
1694                let resp: RestApiResponse<Dummy> = http_request(req, &cfg).await.unwrap();
1695                assert_eq!(resp.status, 200);
1696                let data = resp.data().await.unwrap();
1697                assert_eq!(data, Dummy { foo: "baz".into() });
1698                mock.assert();
1699            });
1700        }
1701
1702        #[test]
1703        fn http_request_client_error_bad_request() {
1704            TOKIO_SHARED_RT.block_on(async {
1705                let server = MockServer::start();
1706                let mock = server.mock(|when, then| {
1707                    when.method(httpmock::Method::GET).path("/400");
1708                    then.status(400)
1709                        .header("Content-Type", "application/json")
1710                        .body(r#"{"msg":"bad request"}"#);
1711                });
1712
1713                let client = Client::new();
1714                let req: Request = client
1715                    .request(Method::GET, format!("{}{}", server.url(""), "/400"))
1716                    .build()
1717                    .unwrap();
1718                let cfg = make_config(&server.url(""));
1719
1720                let result = http_request::<Dummy>(req, &cfg).await;
1721                assert!(matches!(result, Err(ConnectorError::BadRequestError(_))));
1722                if let Err(ConnectorError::BadRequestError(msg)) = result {
1723                    assert_eq!(msg, "bad request");
1724                }
1725                mock.assert();
1726            });
1727        }
1728
1729        #[test]
1730        fn http_request_client_error_unauthorized() {
1731            TOKIO_SHARED_RT.block_on(async {
1732                let server = MockServer::start();
1733                let mock = server.mock(|when, then| {
1734                    when.method(httpmock::Method::GET).path("/401");
1735                    then.status(401)
1736                        .header("Content-Type", "application/json")
1737                        .body(r#"{"msg":"unauthorized"}"#);
1738                });
1739
1740                let client = Client::new();
1741                let req: Request = client
1742                    .request(Method::GET, format!("{}{}", server.url(""), "/401"))
1743                    .build()
1744                    .unwrap();
1745                let cfg = make_config(&server.url(""));
1746
1747                let result = http_request::<Dummy>(req, &cfg).await;
1748                assert!(matches!(result, Err(ConnectorError::UnauthorizedError(_))));
1749                if let Err(ConnectorError::UnauthorizedError(msg)) = result {
1750                    assert_eq!(msg, "unauthorized");
1751                }
1752                mock.assert();
1753            });
1754        }
1755
1756        #[test]
1757        fn http_request_client_error_forbidden() {
1758            TOKIO_SHARED_RT.block_on(async {
1759                let server = MockServer::start();
1760                let mock = server.mock(|when, then| {
1761                    when.method(httpmock::Method::GET).path("/403");
1762                    then.status(403)
1763                        .header("Content-Type", "application/json")
1764                        .body(r#"{"msg":"forbidden"}"#);
1765                });
1766
1767                let client = Client::new();
1768                let req: Request = client
1769                    .request(Method::GET, format!("{}{}", server.url(""), "/403"))
1770                    .build()
1771                    .unwrap();
1772                let cfg = make_config(&server.url(""));
1773
1774                let result = http_request::<Dummy>(req, &cfg).await;
1775                assert!(matches!(result, Err(ConnectorError::ForbiddenError(_))));
1776                if let Err(ConnectorError::ForbiddenError(msg)) = result {
1777                    assert_eq!(msg, "forbidden");
1778                }
1779                mock.assert();
1780            });
1781        }
1782
1783        #[test]
1784        fn http_request_client_error_not_found() {
1785            TOKIO_SHARED_RT.block_on(async {
1786                let server = MockServer::start();
1787                let mock = server.mock(|when, then| {
1788                    when.method(httpmock::Method::GET).path("/404");
1789                    then.status(404)
1790                        .header("Content-Type", "application/json")
1791                        .body(r#"{"msg":"not found"}"#);
1792                });
1793
1794                let client = Client::new();
1795                let req: Request = client
1796                    .request(Method::GET, format!("{}{}", server.url(""), "/404"))
1797                    .build()
1798                    .unwrap();
1799                let cfg = make_config(&server.url(""));
1800
1801                let result = http_request::<Dummy>(req, &cfg).await;
1802                assert!(matches!(result, Err(ConnectorError::NotFoundError(_))));
1803                if let Err(ConnectorError::NotFoundError(msg)) = result {
1804                    assert_eq!(msg, "not found");
1805                }
1806                mock.assert();
1807            });
1808        }
1809
1810        #[test]
1811        fn http_request_client_error_rate_limit_exceeded() {
1812            TOKIO_SHARED_RT.block_on(async {
1813                let server = MockServer::start();
1814                let mock = server.mock(|when, then| {
1815                    when.method(httpmock::Method::GET).path("/418");
1816                    then.status(418)
1817                        .header("Content-Type", "application/json")
1818                        .body(r#"{"msg":"rate limit exceeded"}"#);
1819                });
1820
1821                let client = Client::new();
1822                let req: Request = client
1823                    .request(Method::GET, format!("{}{}", server.url(""), "/418"))
1824                    .build()
1825                    .unwrap();
1826                let cfg = make_config(&server.url(""));
1827
1828                let result = http_request::<Dummy>(req, &cfg).await;
1829                assert!(matches!(result, Err(ConnectorError::RateLimitBanError(_))));
1830                if let Err(ConnectorError::RateLimitBanError(msg)) = result {
1831                    assert_eq!(msg, "rate limit exceeded");
1832                }
1833                mock.assert();
1834            });
1835        }
1836
1837        #[test]
1838        fn http_request_client_error_too_many_requests() {
1839            TOKIO_SHARED_RT.block_on(async {
1840                let server = MockServer::start();
1841                let mock = server.mock(|when, then| {
1842                    when.method(httpmock::Method::GET).path("/429");
1843                    then.status(429)
1844                        .header("Content-Type", "application/json")
1845                        .body(r#"{"msg":"too many requests"}"#);
1846                });
1847
1848                let client = Client::new();
1849                let req: Request = client
1850                    .request(Method::GET, format!("{}{}", server.url(""), "/429"))
1851                    .build()
1852                    .unwrap();
1853                let cfg = make_config(&server.url(""));
1854
1855                let result = http_request::<Dummy>(req, &cfg).await;
1856                assert!(matches!(
1857                    result,
1858                    Err(ConnectorError::TooManyRequestsError(_))
1859                ));
1860                if let Err(ConnectorError::TooManyRequestsError(msg)) = result {
1861                    assert_eq!(msg, "too many requests");
1862                }
1863                mock.assert();
1864            });
1865        }
1866
1867        #[test]
1868        fn http_request_client_error_server_error() {
1869            TOKIO_SHARED_RT.block_on(async {
1870                let server = MockServer::start();
1871                let mock = server.mock(|when, then| {
1872                    when.method(httpmock::Method::GET).path("/500");
1873                    then.status(500)
1874                        .header("Content-Type", "application/json")
1875                        .body(r#"{"msg":"internal server error"}"#);
1876                });
1877
1878                let client = Client::new();
1879                let req: Request = client
1880                    .request(Method::GET, format!("{}{}", server.url(""), "/500"))
1881                    .build()
1882                    .unwrap();
1883                let cfg = make_config(&server.url(""));
1884
1885                let result = http_request::<Dummy>(req, &cfg).await;
1886                assert!(matches!(result, Err(ConnectorError::ServerError { .. })));
1887                if let Err(ConnectorError::ServerError {
1888                    msg,
1889                    status_code: Some(500),
1890                }) = result
1891                {
1892                    assert_eq!(msg, "Server error: 500".to_string());
1893                }
1894                mock.assert();
1895            });
1896        }
1897
1898        #[test]
1899        fn http_request_unexpected_status_maps_generic() {
1900            TOKIO_SHARED_RT.block_on(async {
1901                let server = MockServer::start();
1902                let code = 402;
1903                let mock = server.mock(|when, then| {
1904                    when.method(httpmock::Method::GET).path("/402");
1905                    then.status(code).body("error text");
1906                });
1907
1908                let client = Client::new();
1909                let req: Request = client
1910                    .request(Method::GET, format!("{}{}", server.url(""), "/402"))
1911                    .build()
1912                    .unwrap();
1913                let cfg = make_config(&server.url(""));
1914
1915                let result = http_request::<Dummy>(req, &cfg).await;
1916                assert!(matches!(
1917                    result,
1918                    Err(ConnectorError::ConnectorClientError(_))
1919                ));
1920                mock.assert();
1921            });
1922        }
1923
1924        #[test]
1925        fn http_request_malformed_json_maps_generic() {
1926            TOKIO_SHARED_RT.block_on(async {
1927                let server = MockServer::start();
1928                let mock = server.mock(|when, then| {
1929                    when.method(httpmock::Method::GET).path("/malformed");
1930                    then.status(200)
1931                        .header("Content-Type", "application/json")
1932                        .body("not json");
1933                });
1934
1935                let client = Client::new();
1936                let req: Request = client
1937                    .request(Method::GET, format!("{}{}", server.url(""), "/malformed"))
1938                    .build()
1939                    .unwrap();
1940                let cfg = make_config(&server.url(""));
1941
1942                // 1) HTTP layer still “succeeds”:
1943                let resp = http_request::<Dummy>(req, &cfg)
1944                    .await
1945                    .expect("http_request should succeed even if JSON is bad");
1946
1947                // 2) only when we call `.data().await` do we hit the parse‐error:
1948                let err = resp
1949                    .data() // or however you invoke that boxed future
1950                    .await
1951                    .expect_err("malformed JSON should turn into ConnectorClientError");
1952
1953                assert!(matches!(err, ConnectorError::ConnectorClientError(_)));
1954
1955                mock.assert();
1956            });
1957        }
1958    }
1959
1960    mod send_request {
1961        use anyhow::Result;
1962        use httpmock::prelude::*;
1963        use reqwest::Method;
1964        use serde::Deserialize;
1965        use serde_json::json;
1966        use std::collections::BTreeMap;
1967
1968        use crate::{
1969            common::{models::TimeUnit, utils::send_request},
1970            config::ConfigurationRestApi,
1971        };
1972
1973        use super::TOKIO_SHARED_RT;
1974
1975        #[derive(Deserialize, Debug, PartialEq)]
1976        struct TestResponse {
1977            message: String,
1978        }
1979
1980        #[test]
1981        fn basic_get_request() -> Result<()> {
1982            TOKIO_SHARED_RT.block_on(async {
1983                let server = MockServer::start();
1984
1985                server.mock(|when, then| {
1986                    when.method(GET).path("/api/v1/test");
1987                    then.status(200)
1988                        .header("content-type", "application/json")
1989                        .body(r#"{"message": "success"}"#);
1990                });
1991
1992                let configuration = ConfigurationRestApi::builder()
1993                    .api_key("key")
1994                    .api_secret("secret")
1995                    .base_path(server.base_url())
1996                    .compression(false)
1997                    .build()
1998                    .expect("Failed to build configuration");
1999
2000                let params = BTreeMap::new();
2001
2002                let result = send_request::<TestResponse>(
2003                    &configuration,
2004                    "/api/v1/test",
2005                    Method::GET,
2006                    params,
2007                    None,
2008                    false,
2009                )
2010                .await?;
2011
2012                let data = result.data().await.unwrap();
2013                assert_eq!(data.message, "success");
2014
2015                Ok(())
2016            })
2017        }
2018
2019        #[test]
2020        fn signed_post_request() -> Result<()> {
2021            TOKIO_SHARED_RT.block_on(async {
2022                let server = MockServer::start();
2023
2024                server.mock(|when, then| {
2025                    when.method(POST).path("/api/v3/order");
2026                    then.status(200)
2027                        .header("content-type", "application/json")
2028                        .body(r#"{"message": "order placed"}"#);
2029                });
2030
2031                let configuration = ConfigurationRestApi::builder()
2032                    .api_key("key")
2033                    .api_secret("secret")
2034                    .base_path(server.base_url())
2035                    .compression(false)
2036                    .build()
2037                    .expect("Failed to build configuration");
2038
2039                let mut params = BTreeMap::new();
2040                params.insert("symbol".to_string(), json!("ETHUSDT"));
2041                params.insert("side".to_string(), json!("BUY"));
2042                params.insert("type".to_string(), json!("MARKET"));
2043                params.insert("quantity".to_string(), json!("1"));
2044
2045                let result = send_request::<TestResponse>(
2046                    &configuration,
2047                    "/api/v3/order",
2048                    Method::POST,
2049                    params,
2050                    None,
2051                    true,
2052                )
2053                .await?;
2054
2055                let data = result.data().await.unwrap();
2056                assert_eq!(data.message, "order placed");
2057
2058                Ok(())
2059            })
2060        }
2061
2062        #[test]
2063        fn get_request_with_params() -> Result<()> {
2064            TOKIO_SHARED_RT.block_on(async {
2065                let server = MockServer::start();
2066
2067                server.mock(|when, then| {
2068                    when.method(GET)
2069                        .path("/api/v1/data")
2070                        .query_param("symbol", "BTCUSDT")
2071                        .query_param("limit", "10");
2072                    then.status(200)
2073                        .header("content-type", "application/json")
2074                        .body(r#"{"message": "data retrieved"}"#);
2075                });
2076
2077                let configuration = ConfigurationRestApi::builder()
2078                    .api_key("key")
2079                    .api_secret("secret")
2080                    .base_path(server.base_url())
2081                    .compression(false)
2082                    .build()
2083                    .expect("Failed to build configuration");
2084
2085                let mut params = BTreeMap::new();
2086                params.insert("symbol".to_string(), json!("BTCUSDT"));
2087                params.insert("limit".to_string(), json!(10));
2088
2089                let result = send_request::<TestResponse>(
2090                    &configuration,
2091                    "/api/v1/data",
2092                    Method::GET,
2093                    params,
2094                    None,
2095                    false,
2096                )
2097                .await?;
2098
2099                let data = result.data().await.unwrap();
2100                assert_eq!(data.message, "data retrieved");
2101
2102                Ok(())
2103            })
2104        }
2105
2106        #[test]
2107        fn invalid_endpoint() {
2108            TOKIO_SHARED_RT.block_on(async {
2109                let server = MockServer::start();
2110
2111                let configuration = ConfigurationRestApi::builder()
2112                    .api_key("key")
2113                    .api_secret("secret")
2114                    .base_path(server.base_url())
2115                    .compression(false)
2116                    .build()
2117                    .expect("Failed to build configuration");
2118
2119                let params = BTreeMap::new();
2120
2121                let result = send_request::<TestResponse>(
2122                    &configuration,
2123                    "http://invalid",
2124                    Method::GET,
2125                    params,
2126                    None,
2127                    false,
2128                )
2129                .await;
2130
2131                assert!(result.is_err());
2132            });
2133        }
2134
2135        #[test]
2136        fn missing_signature_on_signed_request() {
2137            TOKIO_SHARED_RT.block_on(async {
2138                let server = MockServer::start();
2139
2140                let configuration = ConfigurationRestApi::builder()
2141                    .api_key("key")
2142                    .api_secret("secret")
2143                    .base_path(server.base_url())
2144                    .compression(false)
2145                    .build()
2146                    .expect("Failed to build configuration");
2147
2148                let mut params = BTreeMap::new();
2149                params.insert("symbol".to_string(), json!("BTCUSDT"));
2150                params.insert("side".to_string(), json!("BUY"));
2151
2152                let result = send_request::<TestResponse>(
2153                    &configuration,
2154                    "/api/v3/order",
2155                    Method::POST,
2156                    params,
2157                    None,
2158                    true,
2159                )
2160                .await;
2161
2162                assert!(result.is_err());
2163            });
2164        }
2165
2166        #[test]
2167        fn compression_enabled() -> Result<()> {
2168            TOKIO_SHARED_RT.block_on(async {
2169                let server = MockServer::start();
2170
2171                server.mock(|when, then| {
2172                    when.method(GET).path("/api/v1/test");
2173                    then.status(200)
2174                        .header("content-type", "application/json")
2175                        .header("accept-encoding", "gzip, deflate, br")
2176                        .body(r#"{"message": "compression enabled"}"#);
2177                });
2178
2179                let configuration = ConfigurationRestApi::builder()
2180                    .api_key("key")
2181                    .api_secret("secret")
2182                    .base_path(server.base_url())
2183                    .compression(true)
2184                    .build()
2185                    .expect("Failed to build configuration");
2186
2187                let params = BTreeMap::new();
2188
2189                let result = send_request::<TestResponse>(
2190                    &configuration,
2191                    "/api/v1/test",
2192                    Method::GET,
2193                    params,
2194                    None,
2195                    false,
2196                )
2197                .await?;
2198
2199                let data = result.data().await.unwrap();
2200                assert_eq!(data.message, "compression enabled");
2201
2202                Ok(())
2203            })
2204        }
2205
2206        #[test]
2207        fn get_request_with_time_unit_header() -> Result<()> {
2208            TOKIO_SHARED_RT.block_on(async {
2209                let server = MockServer::start();
2210
2211                server.mock(|when, then| {
2212                    when.method(GET)
2213                        .path("/api/v1/test")
2214                        .header("X-MBX-TIME-UNIT", "MILLISECOND");
2215                    then.status(200)
2216                        .header("content-type", "application/json")
2217                        .body(r#"{"message": "time unit applied"}"#);
2218                });
2219
2220                let configuration = ConfigurationRestApi::builder()
2221                    .api_key("key")
2222                    .api_secret("secret")
2223                    .base_path(server.base_url())
2224                    .compression(false)
2225                    .time_unit(TimeUnit::Millisecond)
2226                    .build()
2227                    .expect("Failed to build configuration");
2228
2229                let params = BTreeMap::new();
2230
2231                let result = send_request::<TestResponse>(
2232                    &configuration,
2233                    "/api/v1/test",
2234                    Method::GET,
2235                    params,
2236                    Some(TimeUnit::Millisecond),
2237                    false,
2238                )
2239                .await?;
2240
2241                let data = result.data().await.unwrap();
2242                assert_eq!(data.message, "time unit applied");
2243
2244                Ok(())
2245            })
2246        }
2247    }
2248
2249    mod random_string {
2250        use crate::common::utils::random_string;
2251        use hex;
2252
2253        #[test]
2254        fn length_is_32() {
2255            let s = random_string();
2256            assert_eq!(
2257                s.len(),
2258                32,
2259                "random_string() should be 32 chars, got {}",
2260                s.len()
2261            );
2262        }
2263
2264        #[test]
2265        fn is_valid_lowercase_hex() {
2266            let s = random_string();
2267            assert!(
2268                s.chars().all(|c| matches!(c, '0'..='9' | 'a'..='f')),
2269                "random_string() contains invalid hex characters: {s}"
2270            );
2271        }
2272
2273        #[test]
2274        fn decodes_to_16_bytes() {
2275            let s = random_string();
2276            let bytes = hex::decode(&s).expect("random_string() output must be valid hex");
2277            assert_eq!(
2278                bytes.len(),
2279                16,
2280                "hex::decode returned {} bytes",
2281                bytes.len()
2282            );
2283        }
2284
2285        #[test]
2286        fn two_calls_are_different() {
2287            let a = random_string();
2288            let b = random_string();
2289            assert_ne!(
2290                a, b,
2291                "Two calls to random_string() returned the same value: {a}"
2292            );
2293        }
2294    }
2295
2296    mod remove_empty_value {
2297        use crate::common::utils::remove_empty_value;
2298        use serde_json::{Map, Value};
2299
2300        #[test]
2301        fn filters_out_null_and_empty_strings() {
2302            let entries = vec![
2303                ("key1".to_string(), Value::String("value1".to_string())),
2304                ("key2".to_string(), Value::Null),
2305                ("key3".to_string(), Value::String(String::new())),
2306            ];
2307            let result = remove_empty_value(entries);
2308            assert_eq!(
2309                result.len(),
2310                1,
2311                "expected only one entry, got {}",
2312                result.len()
2313            );
2314            assert_eq!(
2315                result.get("key1"),
2316                Some(&Value::String("value1".to_string()))
2317            );
2318            assert!(!result.contains_key("key2"));
2319            assert!(!result.contains_key("key3"));
2320        }
2321
2322        #[test]
2323        fn retains_other_value_types() {
2324            let entries = vec![
2325                ("bool".to_string(), Value::Bool(true)),
2326                ("num".to_string(), Value::Number(42.into())),
2327                ("arr".to_string(), Value::Array(vec![])),
2328                ("obj".to_string(), Value::Object(Map::default())),
2329                ("nil".to_string(), Value::Null),
2330                ("empty_str".to_string(), Value::String(String::new())),
2331            ];
2332            let result = remove_empty_value(entries);
2333            let keys: Vec<&String> = result.keys().collect();
2334            assert_eq!(keys.len(), 4, "expected 4 entries, got {}", keys.len());
2335            assert!(result.get("bool") == Some(&Value::Bool(true)));
2336            assert!(result.get("num") == Some(&Value::Number(42.into())));
2337            assert!(result.get("arr") == Some(&Value::Array(vec![])));
2338            assert!(result.get("obj") == Some(&Value::Object(Map::default())));
2339            assert!(!result.contains_key("nil"));
2340            assert!(!result.contains_key("empty_str"));
2341        }
2342
2343        #[test]
2344        fn empty_iterator_returns_empty_map() {
2345            let entries: Vec<(String, Value)> = vec![];
2346            let result = remove_empty_value(entries);
2347            assert!(result.is_empty(), "expected an empty map");
2348        }
2349
2350        #[test]
2351        fn keys_are_sorted() {
2352            let entries = vec![
2353                ("c".to_string(), Value::String("foo".to_string())),
2354                ("a".to_string(), Value::String("bar".to_string())),
2355                ("b".to_string(), Value::String("baz".to_string())),
2356            ];
2357            let result = remove_empty_value(entries);
2358            let sorted_keys: Vec<&String> = result.keys().collect();
2359            assert_eq!(
2360                sorted_keys,
2361                [&"a".to_string(), &"b".to_string(), &"c".to_string()]
2362            );
2363        }
2364    }
2365
2366    mod sort_object_params {
2367        use crate::common::utils::sort_object_params;
2368        use serde_json::Value;
2369        use std::collections::BTreeMap;
2370
2371        #[test]
2372        fn sorts_keys() {
2373            let mut params = BTreeMap::new();
2374            params.insert("z".to_string(), Value::String("last".to_string()));
2375            params.insert("a".to_string(), Value::String("first".to_string()));
2376            params.insert("m".to_string(), Value::String("middle".to_string()));
2377
2378            let sorted = sort_object_params(&params);
2379            let keys: Vec<&String> = sorted.keys().collect();
2380            assert_eq!(
2381                keys,
2382                [&"a".to_string(), &"m".to_string(), &"z".to_string()],
2383                "Keys should be sorted alphabetically"
2384            );
2385        }
2386
2387        #[test]
2388        fn preserves_values() {
2389            let mut params = BTreeMap::new();
2390            params.insert("one".to_string(), Value::Number(1.into()));
2391            params.insert("two".to_string(), Value::Bool(true));
2392
2393            let sorted = sort_object_params(&params);
2394            assert_eq!(sorted.get("one"), Some(&Value::Number(1.into())));
2395            assert_eq!(sorted.get("two"), Some(&Value::Bool(true)));
2396        }
2397
2398        #[test]
2399        fn empty_map_returns_empty() {
2400            let params: BTreeMap<String, Value> = BTreeMap::new();
2401            let sorted = sort_object_params(&params);
2402            assert!(sorted.is_empty(), "Expected empty map");
2403        }
2404
2405        #[test]
2406        fn independent_clone() {
2407            let mut params = BTreeMap::new();
2408            params.insert("key".to_string(), Value::String("val".to_string()));
2409
2410            let mut sorted = sort_object_params(&params);
2411            sorted.insert("new".to_string(), Value::String("x".to_string()));
2412
2413            assert!(
2414                !params.contains_key("new"),
2415                "Original should not be modified when changing sorted"
2416            );
2417            assert!(
2418                sorted.contains_key("new"),
2419                "Sorted map should reflect its own insertions"
2420            );
2421        }
2422    }
2423
2424    mod normalize_ws_streams_key {
2425        use crate::common::utils::normalize_ws_streams_key;
2426
2427        #[test]
2428        fn returns_empty_for_empty() {
2429            assert_eq!(normalize_ws_streams_key(""), "");
2430        }
2431
2432        #[test]
2433        fn already_normalized_stays_same() {
2434            assert_eq!(normalize_ws_streams_key("streamname"), "streamname");
2435        }
2436
2437        #[test]
2438        fn uppercases_are_lowercased() {
2439            assert_eq!(normalize_ws_streams_key("MyStream"), "mystream");
2440        }
2441
2442        #[test]
2443        fn underscores_are_removed() {
2444            assert_eq!(normalize_ws_streams_key("my_stream_name"), "mystreamname");
2445        }
2446
2447        #[test]
2448        fn hyphens_are_removed() {
2449            assert_eq!(normalize_ws_streams_key("my-stream-name"), "mystreamname");
2450        }
2451
2452        #[test]
2453        fn mixed_underscores_and_hyphens_and_case() {
2454            let input = "Mixed_Case-Stream_Name";
2455            let expected = "mixedcasestreamname";
2456            assert_eq!(normalize_ws_streams_key(input), expected);
2457        }
2458
2459        #[test]
2460        fn retains_other_punctuation() {
2461            assert_eq!(normalize_ws_streams_key("stream.name!"), "stream.name!");
2462        }
2463    }
2464
2465    mod replace_websocket_streams_placeholders {
2466        use crate::common::utils::replace_websocket_streams_placeholders;
2467        use std::collections::HashMap;
2468
2469        #[test]
2470        fn empty_string_unchanged() {
2471            let vars: HashMap<&str, &str> = HashMap::new();
2472            assert_eq!(replace_websocket_streams_placeholders("", &vars), "");
2473        }
2474
2475        #[test]
2476        fn unknown_placeholder_becomes_empty() {
2477            let vars: HashMap<&str, &str> = HashMap::new();
2478            assert_eq!(replace_websocket_streams_placeholders("<foo>", &vars), "");
2479        }
2480
2481        #[test]
2482        fn leading_slash_symbol_lowercases_head() {
2483            let mut vars = HashMap::new();
2484            vars.insert("symbol", "BTC");
2485            assert_eq!(
2486                replace_websocket_streams_placeholders("/<symbol>", &vars),
2487                "btc"
2488            );
2489        }
2490
2491        #[test]
2492        fn no_lowercase_without_slash() {
2493            let mut vars = HashMap::new();
2494            vars.insert("symbol", "BTC");
2495            assert_eq!(
2496                replace_websocket_streams_placeholders("<symbol>", &vars),
2497                "BTC"
2498            );
2499        }
2500
2501        #[test]
2502        fn multiple_placeholders_mid_preserve_ats() {
2503            let mut vars = HashMap::new();
2504            vars.insert("symbol", "BNBUSDT");
2505            vars.insert("levels", "10");
2506            vars.insert("updateSpeed", "1000ms");
2507            let out = replace_websocket_streams_placeholders(
2508                "/<symbol>@depth<levels>@<updateSpeed>",
2509                &vars,
2510            );
2511            assert_eq!(out, "bnbusdt@depth10@1000ms");
2512        }
2513
2514        #[test]
2515        fn trailing_at_removed_when_missing_var() {
2516            let mut vars = HashMap::new();
2517            vars.insert("symbol", "BNBUSDT");
2518            vars.insert("levels", "10");
2519            let out = replace_websocket_streams_placeholders(
2520                "/<symbol>@depth<levels>@<updateSpeed>",
2521                &vars,
2522            );
2523            assert_eq!(out, "bnbusdt@depth10");
2524        }
2525
2526        #[test]
2527        fn custom_key_normalization_and_value() {
2528            let mut vars = HashMap::new();
2529            vars.insert("my-stream_key", "Value");
2530            assert_eq!(
2531                replace_websocket_streams_placeholders("<My_Stream-Key>", &vars),
2532                "Value"
2533            );
2534        }
2535
2536        #[test]
2537        fn text_surrounding_placeholders_intact() {
2538            let mut vars = HashMap::new();
2539            vars.insert("symbol", "ABC");
2540            let input = "pre-<symbol>-post";
2541            assert_eq!(
2542                replace_websocket_streams_placeholders(input, &vars),
2543                "pre-ABC-post"
2544            );
2545        }
2546    }
2547}