1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use std::time::SystemTime;
4
5use chrono::{DateTime, Utc};
6use tokio::sync::RwLock;
7
8use crate::cache::state::{
9 BarsMap, CacheState, StockBarsRequest, collect_cached_hits, normalize_option_symbols,
10 normalize_stock_symbols,
11};
12use crate::cache::stats::CacheStats;
13use crate::options::{self, OptionsFeed, SnapshotsRequest as OptionSnapshotsRequest};
14use crate::stocks::{self, DataFeed, SnapshotsRequest as StockSnapshotsRequest};
15use crate::{Client, Error};
16
17#[derive(Clone)]
18pub struct CachedClientConfig {
19 pub stocks_feed: Arc<dyn Fn() -> DataFeed + Send + Sync>,
20 pub options_feed: OptionsFeed,
21}
22
23impl Default for CachedClientConfig {
24 fn default() -> Self {
25 Self {
26 stocks_feed: Arc::new(|| stocks::preferred_feed(false)),
27 options_feed: options::preferred_feed(),
28 }
29 }
30}
31
32pub struct CachedClient {
33 raw: Client,
34 config: CachedClientConfig,
35 state: RwLock<CacheState>,
36}
37
38impl CachedClient {
39 #[must_use]
40 pub fn new(raw: Client) -> Self {
41 Self::with_config(raw, CachedClientConfig::default())
42 }
43
44 #[must_use]
45 pub fn with_config(raw: Client, config: CachedClientConfig) -> Self {
46 Self {
47 raw,
48 config,
49 state: RwLock::new(CacheState::default()),
50 }
51 }
52
53 #[must_use]
54 pub fn raw(&self) -> &Client {
55 &self.raw
56 }
57
58 pub async fn stocks<S: AsRef<str>>(
59 &self,
60 symbols: &[S],
61 ) -> Result<HashMap<String, stocks::Snapshot>, Error> {
62 let requested = normalize_stock_symbols(symbols);
63 if requested.is_empty() {
64 return Ok(HashMap::new());
65 }
66
67 let resolved = unique_resolved_symbols(&requested);
68 let (mut hits, missing) = {
69 let state = self.state.read().await;
70 collect_cached_hits(&resolved, &state.stocks.values, &state.stocks.empty)
71 };
72
73 if !missing.is_empty() {
74 let fetched = self.fetch_stocks(&missing).await?;
75 let mut state = self.state.write().await;
76 for symbol in &missing {
77 state.stocks.subscribed.insert(symbol.clone());
78 if fetched.contains_key(symbol) {
79 state.stocks.empty.remove(symbol);
80 } else {
81 state.stocks.empty.insert(symbol.clone());
82 }
83 }
84 for (symbol, snapshot) in &fetched {
85 state.stocks.values.insert(symbol.clone(), snapshot.clone());
86 }
87 state.stocks.updated_at = Some(SystemTime::now());
88 hits.extend(fetched);
89 }
90
91 Ok(requested
92 .into_iter()
93 .filter_map(|(original, resolved)| {
94 hits.get(&resolved)
95 .cloned()
96 .map(|snapshot| (original, snapshot))
97 })
98 .collect())
99 }
100
101 pub async fn stock(&self, symbol: &str) -> Option<stocks::Snapshot> {
102 self.stocks(&[symbol])
103 .await
104 .ok()?
105 .into_iter()
106 .next()
107 .map(|(_, snapshot)| snapshot)
108 }
109
110 pub async fn options<S: AsRef<str>>(
111 &self,
112 contracts: &[S],
113 ) -> Result<HashMap<String, options::Snapshot>, Error> {
114 let requested = normalize_option_symbols(contracts);
115 if requested.is_empty() {
116 return Ok(HashMap::new());
117 }
118
119 let (mut hits, missing) = {
120 let state = self.state.read().await;
121 collect_cached_hits(&requested, &state.options.values, &state.options.empty)
122 };
123
124 if !missing.is_empty() {
125 let fetched = self.fetch_options(&missing).await?;
126 let mut state = self.state.write().await;
127 for contract in &missing {
128 state.options.subscribed.insert(contract.clone());
129 if fetched.contains_key(contract) {
130 state.options.empty.remove(contract);
131 } else {
132 state.options.empty.insert(contract.clone());
133 }
134 }
135 for (contract, snapshot) in &fetched {
136 state
137 .options
138 .values
139 .insert(contract.clone(), snapshot.clone());
140 }
141 state.options.updated_at = Some(SystemTime::now());
142 hits.extend(fetched);
143 }
144
145 Ok(requested
146 .into_iter()
147 .filter_map(|contract| hits.remove_entry(&contract))
148 .collect())
149 }
150
151 pub async fn option(&self, contract: &str) -> Option<options::Snapshot> {
152 self.options(&[contract]).await.ok()?.remove(contract)
153 }
154
155 pub async fn watch_stocks(&self, symbols: &[String]) {
156 let normalized = normalize_stock_symbols(symbols);
157 let mut state = self.state.write().await;
158 for (_, symbol) in normalized {
159 state.stocks.subscribed.insert(symbol);
160 }
161 }
162
163 pub async fn watch_options(&self, contracts: &[String]) {
164 let normalized = normalize_option_symbols(contracts);
165 let mut state = self.state.write().await;
166 for contract in normalized {
167 state.options.subscribed.insert(contract);
168 }
169 }
170
171 pub async fn refresh_stocks(&self) -> Result<usize, Error> {
172 let symbols = {
173 let state = self.state.read().await;
174 state.stocks.subscribed.iter().cloned().collect::<Vec<_>>()
175 };
176 if symbols.is_empty() {
177 return Ok(0);
178 }
179
180 let fetched = self.fetch_stocks(&symbols).await?;
181 let count = fetched.len();
182
183 let mut state = self.state.write().await;
184 for (symbol, snapshot) in fetched {
185 state.stocks.values.insert(symbol, snapshot);
186 }
187 state.stocks.updated_at = Some(SystemTime::now());
188 Ok(count)
189 }
190
191 pub async fn refresh_options(&self) -> Result<usize, Error> {
192 let contracts = {
193 let state = self.state.read().await;
194 state.options.subscribed.iter().cloned().collect::<Vec<_>>()
195 };
196 if contracts.is_empty() {
197 return Ok(0);
198 }
199
200 let fetched = self.fetch_options(&contracts).await?;
201 let count = fetched.len();
202
203 let mut state = self.state.write().await;
204 for (contract, snapshot) in fetched {
205 state.options.values.insert(contract, snapshot);
206 }
207 state.options.updated_at = Some(SystemTime::now());
208 Ok(count)
209 }
210
211 pub async fn watch_bars(&self, request: StockBarsRequest) {
212 let request = request.normalized();
213 let mut state = self.state.write().await;
214 state
215 .bars
216 .requests
217 .entry(request.key.clone())
218 .and_modify(|current| current.merge_from(&request))
219 .or_insert(request);
220 }
221
222 pub async fn bars(&self, key: &str) -> Result<HashMap<String, Vec<stocks::BarPoint>>, Error> {
223 let request = self.bars_request(key).await?;
224 let missing = {
225 let state = self.state.read().await;
226 let cached = state.bars.values.get(key);
227 let empty = state.bars.empty.get(key);
228
229 request
230 .symbols
231 .iter()
232 .filter(|symbol| {
233 !cached.is_some_and(|bars| bars.contains_key(*symbol))
234 && !empty.is_some_and(|values| values.contains(*symbol))
235 })
236 .cloned()
237 .collect::<Vec<_>>()
238 };
239
240 if missing.is_empty() {
241 let state = self.state.read().await;
242 return Ok(state.bars.values.get(key).cloned().unwrap_or_default());
243 }
244
245 self.fetch_missing_bars(key, &request, &missing).await
246 }
247
248 pub async fn bar(&self, key: &str, symbol: &str) -> Option<Vec<stocks::BarPoint>> {
249 let resolved = stocks::display_stock_symbol(symbol);
250 {
251 let state = self.state.read().await;
252 if let Some(values) = state.bars.values.get(key)
253 && let Some(bars) = values.get(&resolved)
254 {
255 return Some(bars.clone());
256 }
257 if state
258 .bars
259 .empty
260 .get(key)
261 .is_some_and(|symbols| symbols.contains(&resolved))
262 {
263 return None;
264 }
265 }
266
267 self.bars(key).await.ok()?.get(&resolved).cloned()
268 }
269
270 pub async fn refresh_bars(&self, key: &str) -> Result<usize, Error> {
271 let request = self.bars_request(key).await?;
272 let fetched = self.fetch_bars_request(&request, &request.symbols).await?;
273 let count = fetched.len();
274
275 let missing: HashSet<String> = request
276 .symbols
277 .iter()
278 .filter(|symbol| !fetched.contains_key(*symbol))
279 .cloned()
280 .collect();
281
282 let mut state = self.state.write().await;
283 state.bars.values.insert(key.to_string(), fetched);
284 state.bars.empty.insert(key.to_string(), missing);
285 state
286 .bars
287 .updated_at
288 .insert(key.to_string(), SystemTime::now());
289 Ok(count)
290 }
291
292 pub async fn clear_options(&self) {
293 let mut state = self.state.write().await;
294 state.options.subscribed.clear();
295 state.options.values.clear();
296 state.options.empty.clear();
297 state.options.updated_at = None;
298 }
299
300 pub async fn stats(&self) -> CacheStats {
301 let state = self.state.read().await;
302 CacheStats {
303 subscribed_symbols: state.stocks.subscribed.len(),
304 subscribed_contracts: state.options.subscribed.len(),
305 subscribed_bar_requests: state.bars.requests.len(),
306 cached_stocks: state.stocks.values.len(),
307 cached_options: state.options.values.len(),
308 cached_bar_symbols: state.bars.values.values().map(HashMap::len).sum(),
309 stocks_updated_at: format_timestamp(state.stocks.updated_at),
310 options_updated_at: format_timestamp(state.options.updated_at),
311 bars_updated_at: state
312 .bars
313 .updated_at
314 .iter()
315 .map(|(key, value)| {
316 (
317 key.clone(),
318 format_timestamp(Some(*value)).unwrap_or_default(),
319 )
320 })
321 .collect(),
322 }
323 }
324
325 async fn fetch_stocks(
326 &self,
327 symbols: &[String],
328 ) -> Result<HashMap<String, stocks::Snapshot>, Error> {
329 self.raw
330 .stocks()
331 .snapshots(StockSnapshotsRequest {
332 symbols: symbols.to_vec(),
333 feed: Some((self.config.stocks_feed)()),
334 currency: None,
335 })
336 .await
337 }
338
339 async fn fetch_options(
340 &self,
341 contracts: &[String],
342 ) -> Result<HashMap<String, options::Snapshot>, Error> {
343 self.raw
344 .options()
345 .snapshots_all(OptionSnapshotsRequest {
346 symbols: contracts.to_vec(),
347 feed: Some(self.config.options_feed),
348 limit: Some(1000),
349 page_token: None,
350 })
351 .await
352 .map(|response| response.snapshots)
353 }
354
355 async fn bars_request(&self, key: &str) -> Result<StockBarsRequest, Error> {
356 let key = key.trim();
357 if key.is_empty() {
358 return Err(Error::InvalidRequest(
359 "bars key is invalid: must not be empty".to_owned(),
360 ));
361 }
362
363 let state = self.state.read().await;
364 state
365 .bars
366 .requests
367 .get(key)
368 .cloned()
369 .ok_or_else(|| Error::InvalidRequest(format!("bars key is unknown: {key}")))
370 }
371
372 async fn fetch_missing_bars(
373 &self,
374 key: &str,
375 request: &StockBarsRequest,
376 missing: &[String],
377 ) -> Result<HashMap<String, Vec<stocks::BarPoint>>, Error> {
378 let fetched = self.fetch_bars_request(request, missing).await?;
379 let missing_empty: HashSet<String> = missing
380 .iter()
381 .filter(|symbol| !fetched.contains_key(*symbol))
382 .cloned()
383 .collect();
384
385 let mut state = self.state.write().await;
386 let key = key.to_string();
387 {
388 let cached = state.bars.values.entry(key.clone()).or_default();
389 for (symbol, bars) in &fetched {
390 cached.insert(symbol.clone(), bars.clone());
391 }
392 }
393 {
394 let empty = state.bars.empty.entry(key.clone()).or_default();
395 for symbol in missing {
396 if missing_empty.contains(symbol) {
397 empty.insert(symbol.clone());
398 } else {
399 empty.remove(symbol);
400 }
401 }
402 }
403 state.bars.updated_at.insert(key.clone(), SystemTime::now());
404
405 Ok(state.bars.values.get(&key).cloned().unwrap_or_default())
406 }
407
408 async fn fetch_bars_request(
409 &self,
410 request: &StockBarsRequest,
411 symbols: &[String],
412 ) -> Result<BarsMap, Error> {
413 if symbols.is_empty() {
414 return Ok(HashMap::new());
415 }
416
417 let mut merged = HashMap::new();
418 let chunk_size = request.chunk_size.max(1);
419 let daily = request.timeframe == stocks::TimeFrame::day_1();
420
421 for chunk in symbols.chunks(chunk_size) {
422 let response = self
423 .raw
424 .stocks()
425 .bars_all(stocks::BarsRequest {
426 symbols: chunk.to_vec(),
427 timeframe: request.timeframe.clone(),
428 start: request.start.clone(),
429 end: request.end.clone(),
430 limit: Some(request.limit),
431 adjustment: request.adjustment.clone(),
432 feed: request.feed,
433 sort: None,
434 asof: None,
435 currency: request.currency.clone(),
436 page_token: None,
437 })
438 .await?;
439
440 for (symbol, bars) in response.bars {
441 merged.insert(
442 symbol,
443 bars.into_iter().map(|bar| bar.point(daily)).collect(),
444 );
445 }
446 }
447
448 Ok(merged)
449 }
450}
451
452fn unique_resolved_symbols(requested: &[(String, String)]) -> Vec<String> {
453 let mut resolved = Vec::new();
454 let mut seen = HashSet::new();
455 for (_, symbol) in requested {
456 if seen.insert(symbol.clone()) {
457 resolved.push(symbol.clone());
458 }
459 }
460 resolved
461}
462
463fn format_timestamp(value: Option<SystemTime>) -> Option<String> {
464 value.map(|value| {
465 DateTime::<Utc>::from(value)
466 .format("%Y-%m-%d %H:%M:%S")
467 .to_string()
468 })
469}