Skip to main content

alpaca_data/cache/
client.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::SystemTime;
4
5use chrono::{DateTime, Utc};
6use tokio::sync::RwLock;
7
8use crate::cache::state::{
9    BarsMap, CacheState, StockBarsRequest, collect_cached_hits, normalize_option_symbols,
10    normalize_stock_symbols,
11};
12use crate::cache::stats::CacheStats;
13use crate::options::{self, OptionsFeed, SnapshotsRequest as OptionSnapshotsRequest};
14use crate::stocks::{self, DataFeed, SnapshotsRequest as StockSnapshotsRequest};
15use crate::{Client, Error};
16
17#[derive(Clone)]
18pub struct CachedClientConfig {
19    pub stocks_feed: Arc<dyn Fn() -> DataFeed + Send + Sync>,
20    pub options_feed: OptionsFeed,
21}
22
23impl Default for CachedClientConfig {
24    fn default() -> Self {
25        Self {
26            stocks_feed: Arc::new(|| stocks::preferred_feed(false)),
27            options_feed: options::preferred_feed(),
28        }
29    }
30}
31
32pub struct CachedClient {
33    raw: Client,
34    config: CachedClientConfig,
35    state: RwLock<CacheState>,
36}
37
38impl CachedClient {
39    #[must_use]
40    pub fn new(raw: Client) -> Self {
41        Self::with_config(raw, CachedClientConfig::default())
42    }
43
44    #[must_use]
45    pub fn with_config(raw: Client, config: CachedClientConfig) -> Self {
46        Self {
47            raw,
48            config,
49            state: RwLock::new(CacheState::default()),
50        }
51    }
52
53    #[must_use]
54    pub fn raw(&self) -> &Client {
55        &self.raw
56    }
57
58    pub async fn stocks<S: AsRef<str>>(
59        &self,
60        symbols: &[S],
61    ) -> Result<HashMap<String, stocks::Snapshot>, Error> {
62        let requested = normalize_stock_symbols(symbols);
63        if requested.is_empty() {
64            return Ok(HashMap::new());
65        }
66
67        let resolved = unique_resolved_symbols(&requested);
68        let (mut hits, missing) = {
69            let state = self.state.read().await;
70            collect_cached_hits(&resolved, &state.stocks.values, &state.stocks.empty)
71        };
72
73        if !missing.is_empty() {
74            let fetched = self.fetch_stocks(&missing).await?;
75            let mut state = self.state.write().await;
76            for symbol in &missing {
77                state.stocks.subscribed.insert(symbol.clone());
78                if fetched.contains_key(symbol) {
79                    state.stocks.empty.remove(symbol);
80                } else {
81                    state.stocks.empty.insert(symbol.clone());
82                }
83            }
84            for (symbol, snapshot) in &fetched {
85                state.stocks.values.insert(symbol.clone(), snapshot.clone());
86            }
87            state.stocks.updated_at = Some(SystemTime::now());
88            hits.extend(fetched);
89        }
90
91        Ok(requested
92            .into_iter()
93            .filter_map(|(original, resolved)| {
94                hits.get(&resolved)
95                    .cloned()
96                    .map(|snapshot| (original, snapshot))
97            })
98            .collect())
99    }
100
101    pub async fn stock(&self, symbol: &str) -> Option<stocks::Snapshot> {
102        self.stocks(&[symbol])
103            .await
104            .ok()?
105            .into_iter()
106            .next()
107            .map(|(_, snapshot)| snapshot)
108    }
109
110    pub async fn options<S: AsRef<str>>(
111        &self,
112        contracts: &[S],
113    ) -> Result<HashMap<String, options::Snapshot>, Error> {
114        let requested = normalize_option_symbols(contracts);
115        if requested.is_empty() {
116            return Ok(HashMap::new());
117        }
118
119        let (mut hits, missing) = {
120            let state = self.state.read().await;
121            collect_cached_hits(&requested, &state.options.values, &state.options.empty)
122        };
123
124        if !missing.is_empty() {
125            let fetched = self.fetch_options(&missing).await?;
126            let mut state = self.state.write().await;
127            for contract in &missing {
128                state.options.subscribed.insert(contract.clone());
129                if fetched.contains_key(contract) {
130                    state.options.empty.remove(contract);
131                } else {
132                    state.options.empty.insert(contract.clone());
133                }
134            }
135            for (contract, snapshot) in &fetched {
136                state
137                    .options
138                    .values
139                    .insert(contract.clone(), snapshot.clone());
140            }
141            state.options.updated_at = Some(SystemTime::now());
142            hits.extend(fetched);
143        }
144
145        Ok(requested
146            .into_iter()
147            .filter_map(|contract| hits.remove_entry(&contract))
148            .collect())
149    }
150
151    pub async fn option(&self, contract: &str) -> Option<options::Snapshot> {
152        self.options(&[contract]).await.ok()?.remove(contract)
153    }
154
155    pub async fn watch_stocks(&self, symbols: &[String]) {
156        let normalized = normalize_stock_symbols(symbols);
157        let mut state = self.state.write().await;
158        for (_, symbol) in normalized {
159            state.stocks.subscribed.insert(symbol);
160        }
161    }
162
163    pub async fn watch_options(&self, contracts: &[String]) {
164        let normalized = normalize_option_symbols(contracts);
165        let mut state = self.state.write().await;
166        for contract in normalized {
167            state.options.subscribed.insert(contract);
168        }
169    }
170
171    pub async fn refresh_stocks(&self) -> Result<usize, Error> {
172        let symbols = {
173            let state = self.state.read().await;
174            state.stocks.subscribed.iter().cloned().collect::<Vec<_>>()
175        };
176        if symbols.is_empty() {
177            return Ok(0);
178        }
179
180        let fetched = self.fetch_stocks(&symbols).await?;
181        let count = fetched.len();
182
183        let mut state = self.state.write().await;
184        for (symbol, snapshot) in fetched {
185            state.stocks.values.insert(symbol, snapshot);
186        }
187        state.stocks.updated_at = Some(SystemTime::now());
188        Ok(count)
189    }
190
191    pub async fn refresh_options(&self) -> Result<usize, Error> {
192        let contracts = {
193            let state = self.state.read().await;
194            state.options.subscribed.iter().cloned().collect::<Vec<_>>()
195        };
196        if contracts.is_empty() {
197            return Ok(0);
198        }
199
200        let fetched = self.fetch_options(&contracts).await?;
201        let count = fetched.len();
202
203        let mut state = self.state.write().await;
204        for (contract, snapshot) in fetched {
205            state.options.values.insert(contract, snapshot);
206        }
207        state.options.updated_at = Some(SystemTime::now());
208        Ok(count)
209    }
210
211    pub async fn watch_bars(&self, request: StockBarsRequest) {
212        let request = request.normalized();
213        let mut state = self.state.write().await;
214        state
215            .bars
216            .requests
217            .entry(request.key.clone())
218            .and_modify(|current| current.merge_from(&request))
219            .or_insert(request);
220    }
221
222    pub async fn bars(&self, key: &str) -> Result<HashMap<String, Vec<stocks::BarPoint>>, Error> {
223        let request = self.bars_request(key).await?;
224        let missing = {
225            let state = self.state.read().await;
226            let cached = state.bars.values.get(key);
227            let empty = state.bars.empty.get(key);
228
229            request
230                .symbols
231                .iter()
232                .filter(|symbol| {
233                    !cached.is_some_and(|bars| bars.contains_key(*symbol))
234                        && !empty.is_some_and(|values| values.contains(*symbol))
235                })
236                .cloned()
237                .collect::<Vec<_>>()
238        };
239
240        if missing.is_empty() {
241            let state = self.state.read().await;
242            return Ok(state.bars.values.get(key).cloned().unwrap_or_default());
243        }
244
245        self.fetch_missing_bars(key, &request, &missing).await
246    }
247
248    pub async fn bar(&self, key: &str, symbol: &str) -> Option<Vec<stocks::BarPoint>> {
249        let resolved = stocks::display_stock_symbol(symbol);
250        {
251            let state = self.state.read().await;
252            if let Some(values) = state.bars.values.get(key)
253                && let Some(bars) = values.get(&resolved)
254            {
255                return Some(bars.clone());
256            }
257            if state
258                .bars
259                .empty
260                .get(key)
261                .is_some_and(|symbols| symbols.contains(&resolved))
262            {
263                return None;
264            }
265        }
266
267        self.bars(key).await.ok()?.get(&resolved).cloned()
268    }
269
270    pub async fn refresh_bars(&self, key: &str) -> Result<usize, Error> {
271        let request = self.bars_request(key).await?;
272        let fetched = self.fetch_bars_request(&request, &request.symbols).await?;
273        let count = fetched.len();
274
275        let missing: HashSet<String> = request
276            .symbols
277            .iter()
278            .filter(|symbol| !fetched.contains_key(*symbol))
279            .cloned()
280            .collect();
281
282        let mut state = self.state.write().await;
283        state.bars.values.insert(key.to_string(), fetched);
284        state.bars.empty.insert(key.to_string(), missing);
285        state
286            .bars
287            .updated_at
288            .insert(key.to_string(), SystemTime::now());
289        Ok(count)
290    }
291
292    pub async fn clear_options(&self) {
293        let mut state = self.state.write().await;
294        state.options.subscribed.clear();
295        state.options.values.clear();
296        state.options.empty.clear();
297        state.options.updated_at = None;
298    }
299
300    pub async fn stats(&self) -> CacheStats {
301        let state = self.state.read().await;
302        CacheStats {
303            subscribed_symbols: state.stocks.subscribed.len(),
304            subscribed_contracts: state.options.subscribed.len(),
305            subscribed_bar_requests: state.bars.requests.len(),
306            cached_stocks: state.stocks.values.len(),
307            cached_options: state.options.values.len(),
308            cached_bar_symbols: state.bars.values.values().map(HashMap::len).sum(),
309            stocks_updated_at: format_timestamp(state.stocks.updated_at),
310            options_updated_at: format_timestamp(state.options.updated_at),
311            bars_updated_at: state
312                .bars
313                .updated_at
314                .iter()
315                .map(|(key, value)| {
316                    (
317                        key.clone(),
318                        format_timestamp(Some(*value)).unwrap_or_default(),
319                    )
320                })
321                .collect(),
322        }
323    }
324
325    async fn fetch_stocks(
326        &self,
327        symbols: &[String],
328    ) -> Result<HashMap<String, stocks::Snapshot>, Error> {
329        self.raw
330            .stocks()
331            .snapshots(StockSnapshotsRequest {
332                symbols: symbols.to_vec(),
333                feed: Some((self.config.stocks_feed)()),
334                currency: None,
335            })
336            .await
337    }
338
339    async fn fetch_options(
340        &self,
341        contracts: &[String],
342    ) -> Result<HashMap<String, options::Snapshot>, Error> {
343        self.raw
344            .options()
345            .snapshots_all(OptionSnapshotsRequest {
346                symbols: contracts.to_vec(),
347                feed: Some(self.config.options_feed),
348                limit: Some(1000),
349                page_token: None,
350            })
351            .await
352            .map(|response| response.snapshots)
353    }
354
355    async fn bars_request(&self, key: &str) -> Result<StockBarsRequest, Error> {
356        let key = key.trim();
357        if key.is_empty() {
358            return Err(Error::InvalidRequest(
359                "bars key is invalid: must not be empty".to_owned(),
360            ));
361        }
362
363        let state = self.state.read().await;
364        state
365            .bars
366            .requests
367            .get(key)
368            .cloned()
369            .ok_or_else(|| Error::InvalidRequest(format!("bars key is unknown: {key}")))
370    }
371
372    async fn fetch_missing_bars(
373        &self,
374        key: &str,
375        request: &StockBarsRequest,
376        missing: &[String],
377    ) -> Result<HashMap<String, Vec<stocks::BarPoint>>, Error> {
378        let fetched = self.fetch_bars_request(request, missing).await?;
379        let missing_empty: HashSet<String> = missing
380            .iter()
381            .filter(|symbol| !fetched.contains_key(*symbol))
382            .cloned()
383            .collect();
384
385        let mut state = self.state.write().await;
386        let key = key.to_string();
387        {
388            let cached = state.bars.values.entry(key.clone()).or_default();
389            for (symbol, bars) in &fetched {
390                cached.insert(symbol.clone(), bars.clone());
391            }
392        }
393        {
394            let empty = state.bars.empty.entry(key.clone()).or_default();
395            for symbol in missing {
396                if missing_empty.contains(symbol) {
397                    empty.insert(symbol.clone());
398                } else {
399                    empty.remove(symbol);
400                }
401            }
402        }
403        state.bars.updated_at.insert(key.clone(), SystemTime::now());
404
405        Ok(state.bars.values.get(&key).cloned().unwrap_or_default())
406    }
407
408    async fn fetch_bars_request(
409        &self,
410        request: &StockBarsRequest,
411        symbols: &[String],
412    ) -> Result<BarsMap, Error> {
413        if symbols.is_empty() {
414            return Ok(HashMap::new());
415        }
416
417        let mut merged = HashMap::new();
418        let chunk_size = request.chunk_size.max(1);
419        let daily = request.timeframe == stocks::TimeFrame::day_1();
420
421        for chunk in symbols.chunks(chunk_size) {
422            let response = self
423                .raw
424                .stocks()
425                .bars_all(stocks::BarsRequest {
426                    symbols: chunk.to_vec(),
427                    timeframe: request.timeframe.clone(),
428                    start: request.start.clone(),
429                    end: request.end.clone(),
430                    limit: Some(request.limit),
431                    adjustment: request.adjustment.clone(),
432                    feed: request.feed,
433                    sort: None,
434                    asof: None,
435                    currency: request.currency.clone(),
436                    page_token: None,
437                })
438                .await?;
439
440            for (symbol, bars) in response.bars {
441                merged.insert(
442                    symbol,
443                    bars.into_iter().map(|bar| bar.point(daily)).collect(),
444                );
445            }
446        }
447
448        Ok(merged)
449    }
450}
451
452fn unique_resolved_symbols(requested: &[(String, String)]) -> Vec<String> {
453    let mut resolved = Vec::new();
454    let mut seen = HashSet::new();
455    for (_, symbol) in requested {
456        if seen.insert(symbol.clone()) {
457            resolved.push(symbol.clone());
458        }
459    }
460    resolved
461}
462
463fn format_timestamp(value: Option<SystemTime>) -> Option<String> {
464    value.map(|value| {
465        DateTime::<Utc>::from(value)
466            .format("%Y-%m-%d %H:%M:%S")
467            .to_string()
468    })
469}