use super::{
DataFrame, DataQuery, OwnedDataRow, SharedAsyncProvider, Timeframe,
async_provider::AsyncDataError,
};
use crate::snapshot::{
CacheKeySnapshot, CachedDataSnapshot, DEFAULT_CHUNK_LEN, DataCacheSnapshot, LiveBufferSnapshot,
SnapshotStore, load_chunked_vec, store_chunked_vec,
};
use anyhow::Result as AnyResult;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use tokio::task::JoinHandle;
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheKey {
pub id: String,
pub timeframe: Timeframe,
}
impl CacheKey {
pub fn new(id: String, timeframe: Timeframe) -> Self {
Self { id, timeframe }
}
}
#[derive(Debug, Clone)]
pub struct CachedData {
pub historical: DataFrame,
pub current_index: usize,
}
impl CachedData {
pub fn new(historical: DataFrame) -> Self {
Self {
historical,
current_index: 0,
}
}
pub fn row_count(&self) -> usize {
self.historical.row_count()
}
}
#[derive(Clone)]
pub struct DataCache {
provider: SharedAsyncProvider,
historical: Arc<RwLock<HashMap<CacheKey, CachedData>>>,
live_buffer: Arc<RwLock<HashMap<CacheKey, Vec<OwnedDataRow>>>>,
subscriptions: Arc<Mutex<HashMap<CacheKey, JoinHandle<()>>>>,
runtime: tokio::runtime::Handle,
}
impl DataCache {
pub fn new(provider: SharedAsyncProvider, runtime: tokio::runtime::Handle) -> Self {
Self {
provider,
historical: Arc::new(RwLock::new(HashMap::new())),
live_buffer: Arc::new(RwLock::new(HashMap::new())),
subscriptions: Arc::new(Mutex::new(HashMap::new())),
runtime,
}
}
#[cfg(test)]
pub(crate) fn from_test_data(data: HashMap<CacheKey, DataFrame>) -> Self {
let historical: HashMap<CacheKey, CachedData> = data
.into_iter()
.map(|(k, df)| (k, CachedData::new(df)))
.collect();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("test tokio runtime");
Self {
provider: Arc::new(super::async_provider::NullAsyncProvider),
historical: Arc::new(RwLock::new(historical)),
live_buffer: Arc::new(RwLock::new(HashMap::new())),
subscriptions: Arc::new(Mutex::new(HashMap::new())),
runtime: rt.handle().clone(),
}
}
pub async fn prefetch(&self, queries: Vec<DataQuery>) -> Result<(), AsyncDataError> {
use futures::future::join_all;
let futures: Vec<_> = queries
.iter()
.map(|q| {
let provider = self.provider.clone();
let query = q.clone();
async move {
let df = provider.load(&query).await?;
Ok::<_, AsyncDataError>((query, df))
}
})
.collect();
let results = join_all(futures).await;
let mut historical = self.historical.write().unwrap();
for result in results {
let (query, df) = result?;
let key = CacheKey::new(query.id.clone(), query.timeframe);
historical.insert(key, CachedData::new(df));
}
Ok(())
}
pub fn get_row(&self, id: &str, timeframe: &Timeframe, index: usize) -> Option<OwnedDataRow> {
let key = CacheKey::new(id.to_string(), *timeframe);
let historical = self.historical.read().unwrap();
historical.get(&key).and_then(|cached| {
let hist_len = cached.row_count();
if index < hist_len {
if let Some(row) = cached.historical.get_row(index) {
return OwnedDataRow::from_data_row(&row);
}
}
if let Ok(live) = self.live_buffer.read() {
let live_index = index.saturating_sub(hist_len);
if let Some(live_rows) = live.get(&key) {
return live_rows.get(live_index).cloned();
}
}
None
})
}
pub fn get_row_range(
&self,
id: &str,
timeframe: &Timeframe,
start: usize,
end: usize,
) -> Vec<OwnedDataRow> {
let key = CacheKey::new(id.to_string(), *timeframe);
let mut rows = Vec::new();
let historical = self.historical.read().unwrap();
if let Some(cached) = historical.get(&key) {
let hist_len = cached.row_count();
for i in start..end.min(hist_len) {
if let Some(row) = cached.historical.get_row(i) {
if let Some(owned) = OwnedDataRow::from_data_row(&row) {
rows.push(owned);
}
}
}
if end > hist_len {
if let Ok(live) = self.live_buffer.read() {
if let Some(live_rows) = live.get(&key) {
let live_start = start.saturating_sub(hist_len);
let live_end = end - hist_len;
for row in live_rows
.iter()
.skip(live_start)
.take(live_end.saturating_sub(live_start))
{
rows.push(row.clone());
}
}
}
}
}
rows
}
pub fn subscribe_live(&self, id: &str, timeframe: &Timeframe) -> Result<(), AsyncDataError> {
let key = CacheKey::new(id.to_string(), *timeframe);
{
let subscriptions = self.subscriptions.lock().unwrap();
if subscriptions.contains_key(&key) {
return Ok(());
}
}
let mut rx = self.provider.subscribe(id, timeframe)?;
let live_buffer = self.live_buffer.clone();
let key_clone = key.clone();
let handle = self.runtime.spawn(async move {
while let Some(df) = rx.recv().await {
if let Ok(mut buffer) = live_buffer.write() {
let rows = buffer.entry(key_clone.clone()).or_insert_with(Vec::new);
for i in 0..df.row_count() {
if let Some(row) = df.get_row(i) {
if let Some(owned) = OwnedDataRow::from_data_row(&row) {
rows.push(owned);
}
}
}
}
}
});
let mut subscriptions = self.subscriptions.lock().unwrap();
subscriptions.insert(key, handle);
Ok(())
}
pub fn unsubscribe_live(&self, symbol: &str, timeframe: &Timeframe) {
let key = CacheKey::new(symbol.to_string(), *timeframe);
let mut subscriptions = self.subscriptions.lock().unwrap();
if let Some(handle) = subscriptions.remove(&key) {
handle.abort();
}
let _ = self.provider.unsubscribe(symbol, timeframe);
if let Ok(mut buffer) = self.live_buffer.write() {
buffer.remove(&key);
}
}
pub fn row_count(&self, id: &str, timeframe: &Timeframe) -> usize {
let key = CacheKey::new(id.to_string(), *timeframe);
let historical = self.historical.read().unwrap();
let hist_count = historical.get(&key).map(|c| c.row_count()).unwrap_or(0);
let live_count = self
.live_buffer
.read()
.ok()
.and_then(|b| b.get(&key).map(|v| v.len()))
.unwrap_or(0);
hist_count + live_count
}
pub fn has_cached(&self, symbol: &str, timeframe: &Timeframe) -> bool {
let key = CacheKey::new(symbol.to_string(), *timeframe);
let historical = self.historical.read().unwrap();
historical.contains_key(&key)
}
pub fn cached_keys(&self) -> Vec<(String, Timeframe)> {
let historical = self.historical.read().unwrap();
historical
.keys()
.map(|k| (k.id.clone(), k.timeframe))
.collect()
}
pub fn clear(&self) {
let mut subscriptions = self.subscriptions.lock().unwrap();
for (_, handle) in subscriptions.drain() {
handle.abort();
}
drop(subscriptions);
let mut historical = self.historical.write().unwrap();
historical.clear();
drop(historical);
if let Ok(mut buffer) = self.live_buffer.write() {
buffer.clear();
}
}
pub fn provider(&self) -> SharedAsyncProvider {
self.provider.clone()
}
pub fn snapshot(&self, store: &SnapshotStore) -> AnyResult<DataCacheSnapshot> {
let _ = (
store,
&self.historical,
&self.live_buffer,
DEFAULT_CHUNK_LEN,
);
let _: Option<CacheKeySnapshot> = None;
let _: Option<CachedDataSnapshot> = None;
let _: Option<LiveBufferSnapshot> = None;
let _ = store_chunked_vec::<u8>;
anyhow::bail!(
"DataCache::snapshot: W17-snapshot-resume surface — \
DataFrame / cached-row (de)serializers were deleted alongside \
the kind-threaded `slot_to_serializable` rebuild. The kinded \
replacement uses `store_chunked_vec` over the parallel \
(bits, NativeKind) per-row track. Tracked as \
W17-snapshot-resume per docs/cluster-audits/phase-2d-playbook.md §3. \
ADR-006 §2.7.4 (snapshot serialization deferral) + §2.7.5.1 \
(post-proof wire-format shape for new HeapKinds).",
);
}
pub fn restore_from_snapshot(
&self,
_snapshot: DataCacheSnapshot,
_store: &SnapshotStore,
) -> AnyResult<()> {
let _ = load_chunked_vec::<OwnedDataRow>;
anyhow::bail!(
"DataCache::restore_from_snapshot: W17-snapshot-resume \
surface — symmetric to `snapshot()`. The kinded \
`serializable_to_slot(sv, expected_kind, store)` inverse \
reconstructs row-storage parallel kind tracks from the \
persisted discriminator. Tracked as W17-snapshot-resume per \
docs/cluster-audits/phase-2d-playbook.md §3. ADR-006 \
§2.7.4 + §2.7.5.1.",
);
}
}
impl Drop for DataCache {
fn drop(&mut self) {
let mut subscriptions = self.subscriptions.lock().unwrap();
for (_, handle) in subscriptions.drain() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::{DataQuery, NullAsyncProvider};
use crate::snapshot::SnapshotStore;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn test_cache_key() {
let key1 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
let key2 = CacheKey::new("AAPL".to_string(), Timeframe::d1());
let key3 = CacheKey::new("MSFT".to_string(), Timeframe::d1());
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_cached_data() {
let df = DataFrame::new("TEST", Timeframe::d1());
let cached = CachedData::new(df);
assert_eq!(cached.row_count(), 0);
assert_eq!(cached.current_index, 0);
}
#[derive(Clone)]
struct TestAsyncProvider {
frames: Arc<HashMap<CacheKey, DataFrame>>,
load_calls: Arc<AtomicUsize>,
}
impl crate::data::AsyncDataProvider for TestAsyncProvider {
fn load<'a>(
&'a self,
query: &'a DataQuery,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<DataFrame, crate::data::AsyncDataError>>
+ Send
+ 'a,
>,
> {
let key = CacheKey::new(query.id.clone(), query.timeframe);
let frames = self.frames.clone();
let calls = self.load_calls.clone();
Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
frames
.get(&key)
.cloned()
.ok_or_else(|| crate::data::AsyncDataError::SymbolNotFound(query.id.clone()))
})
}
fn has_data(&self, symbol: &str, timeframe: &Timeframe) -> bool {
let key = CacheKey::new(symbol.to_string(), *timeframe);
self.frames.contains_key(&key)
}
fn symbols(&self) -> Vec<String> {
self.frames.keys().map(|k| k.id.clone()).collect()
}
}
#[allow(dead_code)]
fn _unused_test_imports(
_provider: TestAsyncProvider,
_df: DataFrame,
_query: DataQuery,
_kind: NullAsyncProvider,
_store: SnapshotStore,
_arc: Arc<()>,
_atomic: AtomicUsize,
_ordering: Ordering,
) {
let _ = (SystemTime::UNIX_EPOCH, UNIX_EPOCH);
}
#[test]
fn test_w17_data_cache_snapshot_returns_structured_error() {
let tmp = tempfile::tempdir().expect("tempdir");
let store = SnapshotStore::new(tmp.path()).expect("snapshot store");
let cache = DataCache::from_test_data(HashMap::new());
let result = cache.snapshot(&store);
let err = result.expect_err("expected Err, got Ok");
let msg = format!("{err}");
assert!(
msg.contains("W17-snapshot-resume surface"),
"missing W17 marker; got: {msg}"
);
assert!(
msg.contains("§2.7.4"),
"missing ADR-006 §2.7.4 cite; got: {msg}"
);
}
}