alpaca_base/
utils.rs

1//! Utility functions and helpers for the Alpaca API.
2
3#![allow(missing_docs)]
4
5use crate::error::{AlpacaError, Result};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use uuid::Uuid;
10
11/// Generate a random client order ID
12pub fn generate_client_order_id() -> String {
13    Uuid::new_v4().to_string()
14}
15
16/// Parse a string to a decimal value with validation
17pub fn parse_decimal(value: &str) -> Result<f64> {
18    value
19        .parse::<f64>()
20        .map_err(|_| AlpacaError::InvalidData(format!("Invalid decimal value: {}", value)))
21}
22
23/// Format decimal value to string with specified precision
24pub fn format_decimal(value: f64, precision: usize) -> String {
25    format!("{:.prec$}", value, prec = precision)
26}
27
28/// Validate symbol format
29pub fn validate_symbol(symbol: &str) -> Result<()> {
30    if symbol.is_empty() {
31        return Err(AlpacaError::InvalidData(
32            "Symbol cannot be empty".to_string(),
33        ));
34    }
35
36    if symbol.len() > 12 {
37        return Err(AlpacaError::InvalidData("Symbol too long".to_string()));
38    }
39
40    if !symbol
41        .chars()
42        .all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-')
43    {
44        return Err(AlpacaError::InvalidData(
45            "Invalid symbol format".to_string(),
46        ));
47    }
48
49    Ok(())
50}
51
52/// Validate quantity value
53pub fn validate_quantity(qty: &str) -> Result<()> {
54    let value = parse_decimal(qty)?;
55
56    if value <= 0.0 {
57        return Err(AlpacaError::InvalidData(
58            "Quantity must be positive".to_string(),
59        ));
60    }
61
62    Ok(())
63}
64
65/// Validate price value
66pub fn validate_price(price: &str) -> Result<()> {
67    let value = parse_decimal(price)?;
68
69    if value <= 0.0 {
70        return Err(AlpacaError::InvalidData(
71            "Price must be positive".to_string(),
72        ));
73    }
74
75    Ok(())
76}
77
78/// Convert timestamp to RFC3339 format
79pub fn timestamp_to_rfc3339(timestamp: DateTime<Utc>) -> String {
80    timestamp.to_rfc3339()
81}
82
83/// Parse RFC3339 timestamp
84pub fn parse_rfc3339(timestamp: &str) -> Result<DateTime<Utc>> {
85    DateTime::parse_from_rfc3339(timestamp)
86        .map(|dt| dt.with_timezone(&Utc))
87        .map_err(|e| AlpacaError::InvalidData(format!("Invalid timestamp format: {}", e)))
88}
89
90/// Rate limiter for API requests
91#[derive(Debug)]
92pub struct RateLimiter {
93    requests_per_minute: u32,
94    last_reset: DateTime<Utc>,
95    current_count: u32,
96}
97
98impl RateLimiter {
99    /// Create a new rate limiter
100    pub fn new(requests_per_minute: u32) -> Self {
101        Self {
102            requests_per_minute,
103            last_reset: Utc::now(),
104            current_count: 0,
105        }
106    }
107
108    /// Check if a request can be made
109    pub fn can_make_request(&mut self) -> bool {
110        let now = Utc::now();
111
112        // Reset counter if a minute has passed
113        if now.signed_duration_since(self.last_reset).num_seconds() >= 60 {
114            self.last_reset = now;
115            self.current_count = 0;
116        }
117
118        if self.current_count < self.requests_per_minute {
119            self.current_count += 1;
120            true
121        } else {
122            false
123        }
124    }
125
126    /// Get remaining requests in current window
127    pub fn remaining_requests(&self) -> u32 {
128        self.requests_per_minute.saturating_sub(self.current_count)
129    }
130}
131
132/// Pagination parameters for API requests
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct Pagination {
135    pub page_token: Option<String>,
136    pub limit: Option<u32>,
137}
138
139impl Default for Pagination {
140    fn default() -> Self {
141        Self {
142            page_token: None,
143            limit: Some(100),
144        }
145    }
146}
147
148/// Response wrapper with pagination information
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct PaginatedResponse<T> {
151    pub data: Vec<T>,
152    pub next_page_token: Option<String>,
153}
154
155/// Retry configuration for API requests
156#[derive(Debug, Clone)]
157pub struct RetryConfig {
158    pub max_retries: u32,
159    pub initial_delay_ms: u64,
160    pub max_delay_ms: u64,
161    pub backoff_multiplier: f64,
162}
163
164impl Default for RetryConfig {
165    fn default() -> Self {
166        Self {
167            max_retries: 3,
168            initial_delay_ms: 1000,
169            max_delay_ms: 30000,
170            backoff_multiplier: 2.0,
171        }
172    }
173}
174
175/// Logger configuration
176pub fn init_logger() -> Result<()> {
177    tracing_subscriber::fmt()
178        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
179        .try_init()
180        .map_err(|e| AlpacaError::Config(format!("Failed to initialize logger: {}", e)))
181}
182
183/// URL builder helper
184#[derive(Debug)]
185pub struct UrlBuilder {
186    base_url: String,
187    path: String,
188    query_params: Vec<(String, String)>,
189}
190
191impl UrlBuilder {
192    /// Create a new URL builder
193    pub fn new(base_url: &str) -> Self {
194        Self {
195            base_url: base_url.to_string(),
196            path: String::new(),
197            query_params: Vec::new(),
198        }
199    }
200
201    /// Add path segment
202    pub fn path(mut self, segment: &str) -> Self {
203        if !self.path.is_empty() && !self.path.ends_with('/') {
204            self.path.push('/');
205        }
206        self.path.push_str(segment);
207        self
208    }
209
210    /// Add query parameter
211    pub fn query<T: fmt::Display>(mut self, key: &str, value: T) -> Self {
212        self.query_params.push((key.to_string(), value.to_string()));
213        self
214    }
215
216    /// Add optional query parameter
217    pub fn query_opt<T: fmt::Display>(self, key: &str, value: Option<T>) -> Self {
218        match value {
219            Some(v) => self.query(key, v),
220            None => self,
221        }
222    }
223
224    /// Build the final URL
225    pub fn build(self) -> Result<String> {
226        let mut url = format!(
227            "{}/{}",
228            self.base_url.trim_end_matches('/'),
229            self.path.trim_start_matches('/')
230        );
231
232        if !self.query_params.is_empty() {
233            url.push('?');
234            for (i, (key, value)) in self.query_params.iter().enumerate() {
235                if i > 0 {
236                    url.push('&');
237                }
238                url.push_str(&urlencoding::encode(key));
239                url.push('=');
240                url.push_str(&urlencoding::encode(value));
241            }
242        }
243
244        Ok(url)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_generate_client_order_id() {
254        let id1 = generate_client_order_id();
255        let id2 = generate_client_order_id();
256        assert_ne!(id1, id2);
257        assert!(Uuid::parse_str(&id1).is_ok());
258    }
259
260    #[test]
261    fn test_validate_symbol() {
262        assert!(validate_symbol("AAPL").is_ok());
263        assert!(validate_symbol("BRK.A").is_ok());
264        assert!(validate_symbol("").is_err());
265        assert!(validate_symbol("VERYLONGSYMBOL").is_err());
266    }
267
268    #[test]
269    fn test_validate_quantity() {
270        assert!(validate_quantity("100").is_ok());
271        assert!(validate_quantity("0.5").is_ok());
272        assert!(validate_quantity("0").is_err());
273        assert!(validate_quantity("-10").is_err());
274        assert!(validate_quantity("invalid").is_err());
275    }
276
277    #[test]
278    fn test_url_builder() {
279        let url = UrlBuilder::new("https://api.example.com")
280            .path("v2/orders")
281            .query("symbol", "AAPL")
282            .query("limit", 100)
283            .build()
284            .unwrap();
285
286        assert_eq!(
287            url,
288            "https://api.example.com/v2/orders?symbol=AAPL&limit=100"
289        );
290    }
291
292    #[test]
293    fn test_rate_limiter() {
294        let mut limiter = RateLimiter::new(2);
295        assert!(limiter.can_make_request());
296        assert!(limiter.can_make_request());
297        assert!(!limiter.can_make_request());
298        assert_eq!(limiter.remaining_requests(), 0);
299    }
300}