#[cfg(feature = "memory_compression")]
use std::collections::HashSet;
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, Weak};
use std::time::{Duration, Instant};
use crate::error::{CoreError, CoreResult, ErrorContext};
const DEFAULT_CORRELATION_THRESHOLD: f64 = 0.6;
const DEFAULT_CORRELATION_WINDOW: Duration = Duration::from_secs(60);
const DEFAULT_MIN_OCCURRENCES: usize = 3;
const DEFAULT_CORRELATION_EXPIRY: Duration = Duration::from_secs(3600);
#[derive(Debug, Clone, Eq)]
pub struct DatasetId {
pub path: Option<PathBuf>,
pub memory_address: Option<usize>,
pub name: String,
}
impl PartialEq for DatasetId {
fn eq(&self, other: &Self) -> bool {
if let (Some(ref self_path), Some(ref other_path)) = (&self.path, &other.path) {
return self_path == other_path;
}
if let (Some(self_addr), Some(other_addr)) = (self.memory_address, other.memory_address) {
return self_addr == other_addr;
}
self.name == other.name
}
}
impl Hash for DatasetId {
fn hash<H: Hasher>(&self, state: &mut H) {
if let Some(ref path) = self.path {
path.hash(state);
} else if let Some(addr) = self.memory_address {
addr.hash(state);
} else {
self.name.hash(state);
}
}
}
impl DatasetId {
pub fn from_path(path: impl AsRef<Path>) -> Self {
let path_buf = path.as_ref().to_path_buf();
let name = path_buf
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_else(|| "unnamed_dataset".to_string());
Self {
path: Some(path_buf),
memory_address: None,
name,
}
}
pub fn from_address(address: usize, name: impl Into<String>) -> Self {
Self {
path: None,
memory_address: Some(address),
name: name.into(),
}
}
pub fn from_name(name: impl Into<String>) -> Self {
Self {
path: None,
memory_address: None,
name: name.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DataAccess {
pub dataset: DatasetId,
pub index: usize,
pub access_type: AccessType,
pub size: Option<usize>,
pub dimensions: Option<Vec<usize>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccessType {
Read,
Write,
ReadWrite,
Metadata,
}
#[derive(Debug, Clone)]
struct AccessRecord {
access: DataAccess,
timestamp: Instant,
duration: Option<Duration>,
}
#[derive(Debug, Clone)]
pub struct CrossFilePrefetchConfig {
pub correlation_threshold: f64,
pub correlation_window: Duration,
pub min_occurrences: usize,
pub max_prefetch_datasets: usize,
pub max_prefetch_elements: usize,
pub prefetch_entire_file: bool,
pub correlation_expiry: Duration,
pub enable_learning: bool,
}
impl Default for CrossFilePrefetchConfig {
fn default() -> Self {
Self {
correlation_threshold: DEFAULT_CORRELATION_THRESHOLD,
correlation_window: DEFAULT_CORRELATION_WINDOW,
min_occurrences: DEFAULT_MIN_OCCURRENCES,
max_prefetch_datasets: 3,
max_prefetch_elements: 100,
prefetch_entire_file: false,
correlation_expiry: DEFAULT_CORRELATION_EXPIRY,
enable_learning: true,
}
}
}
#[derive(Debug, Clone)]
pub struct CrossFilePrefetchConfigBuilder {
config: CrossFilePrefetchConfig,
}
impl CrossFilePrefetchConfigBuilder {
pub fn new() -> Self {
Self {
config: CrossFilePrefetchConfig::default(),
}
}
pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
self.config.correlation_threshold = threshold.clamp(0.0, 1.0);
self
}
pub const fn with_correlation_window(mut self, window: Duration) -> Self {
self.config.correlation_window = window;
self
}
pub const fn with_min_occurrences(mut self, occurrences: usize) -> Self {
self.config.min_occurrences = occurrences;
self
}
pub const fn with_max_prefetch_datasets(mut self, maxdatasets: usize) -> Self {
self.config.max_prefetch_datasets = maxdatasets;
self
}
pub const fn with_max_prefetch_elements(mut self, maxelements: usize) -> Self {
self.config.max_prefetch_elements = maxelements;
self
}
pub const fn with_prefetch_entire_file(mut self, enable: bool) -> Self {
self.config.prefetch_entire_file = enable;
self
}
pub const fn with_correlation_expiry(mut self, expiry: Duration) -> Self {
self.config.correlation_expiry = expiry;
self
}
pub const fn with_enable_learning(mut self, enable: bool) -> Self {
self.config.enable_learning = enable;
self
}
pub fn build(self) -> CrossFilePrefetchConfig {
self.config
}
}
impl Default for CrossFilePrefetchConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct DatasetCorrelation {
#[allow(dead_code)]
primary: DatasetId,
related: DatasetId,
strength: f64,
occurrences: usize,
last_observed: Instant,
index_correlations: HashMap<usize, Vec<usize>>,
}
impl DatasetCorrelation {
fn new(primary: DatasetId, related: DatasetId) -> Self {
Self {
primary,
related,
strength: 0.0,
occurrences: 0,
last_observed: Instant::now(),
index_correlations: HashMap::new(),
}
}
fn update_model(&mut self, primary_index: usize, relatedindices: &[usize]) {
self.occurrences += 1;
self.last_observed = Instant::now();
self.strength = (self.strength * 0.9) + 0.1;
let entry = self.index_correlations.entry(primary_index).or_default();
for &related_index in relatedindices {
if !entry.contains(&related_index) {
entry.push(related_index);
}
}
}
fn is_valid(&self, expiry: Duration) -> bool {
self.last_observed.elapsed() <= expiry
}
fn get_related_indices(&self, primary_index: usize, maxcount: usize) -> Vec<usize> {
if let Some(indices) = self.index_correlations.get(&primary_index) {
indices.iter().take(maxcount).copied().collect()
} else {
if let Some(predicted) = self.predict_from_pattern(primary_index, maxcount) {
predicted
} else {
let mut nearby = Vec::with_capacity(maxcount);
nearby.push(primary_index);
for i in 1..=maxcount / 2 {
if primary_index >= i {
nearby.push(primary_index - i);
}
nearby.push(primary_index + i);
}
nearby.into_iter().take(maxcount).collect()
}
}
}
#[allow(unused_variables)]
fn predict_from_pattern(&self, primary_index: usize, maxcount: usize) -> Option<Vec<usize>> {
if self.index_correlations.len() < 2 {
return None;
}
let mut correlations: Vec<(usize, usize)> = Vec::new();
for (&primary, related_indices) in &self.index_correlations {
if let Some(&first_related) = related_indices.first() {
correlations.push((primary, first_related));
}
}
correlations.sort_by_key(|(primary, _)| *primary);
if correlations.len() >= 2 {
let (p1, r1) = correlations[0];
let (p2, r2) = correlations[1];
if p2 != p1 {
let scale = (r2 as f64 - r1 as f64) / (p2 as f64 - p1 as f64);
let offset = r1 as f64 - scale * p1 as f64;
let predicted_related = (scale * primary_index as f64 + offset).round() as usize;
if predicted_related < 1_000_000 {
return Some(vec![predicted_related]);
}
}
}
None
}
}
pub struct CrossFilePrefetchManager {
config: CrossFilePrefetchConfig,
access_history: VecDeque<AccessRecord>,
correlations: HashMap<DatasetId, HashMap<DatasetId, DatasetCorrelation>>,
datasets: HashMap<DatasetId, Weak<dyn DatasetPrefetcher>>,
last_dataset_access: HashMap<DatasetId, (usize, Instant)>,
}
impl CrossFilePrefetchManager {
pub fn new(config: CrossFilePrefetchConfig) -> Self {
Self {
config,
access_history: VecDeque::with_capacity(1000),
correlations: HashMap::new(),
datasets: HashMap::new(),
last_dataset_access: HashMap::new(),
}
}
pub fn record_access(&mut self, access: DataAccess) -> CoreResult<()> {
let access_record = AccessRecord {
access: access.clone(),
timestamp: Instant::now(),
duration: None,
};
self.access_history.push_back(access_record);
while self.access_history.len() > 1000 {
self.access_history.pop_front();
}
self.last_dataset_access
.insert(access.dataset.clone(), (access.index, Instant::now()));
if self.config.enable_learning {
self.update_correlations(&access);
}
self.prefetch_related_data(&access)
}
pub fn complete_access(&mut self, dataset: &DatasetId, index: usize, duration: Duration) {
if let Some(record) = self.access_history.iter_mut().rev().find(|r| {
r.access.dataset == *dataset && r.access.index == index && r.duration.is_none()
}) {
record.duration = Some(duration);
}
}
pub fn register_dataset(&mut self, dataset: DatasetId, prefetcher: Arc<dyn DatasetPrefetcher>) {
self.datasets.insert(dataset, Arc::downgrade(&prefetcher));
}
pub fn unregister_dataset(&mut self, dataset: &DatasetId) {
self.datasets.remove(dataset);
self.correlations.remove(dataset);
for (_, related_map) in self.correlations.iter_mut() {
related_map.remove(dataset);
}
}
fn update_correlations(&mut self, access: &DataAccess) {
let current_time = Instant::now();
let correlation_window = Duration::from_millis(100);
let recent_threshold = current_time - correlation_window;
let mut recent_by_dataset: HashMap<&DatasetId, &AccessRecord> = HashMap::new();
for record in self.access_history.iter().rev() {
if record.timestamp < recent_threshold {
break;
}
if record.access.dataset == access.dataset {
continue;
}
if !recent_by_dataset.contains_key(&record.access.dataset) {
recent_by_dataset.insert(&record.access.dataset, record);
}
}
for (related_dataset, recent_record) in recent_by_dataset {
let correlation = self
.correlations
.entry(related_dataset.clone())
.or_default()
.entry(access.dataset.clone())
.or_insert_with(|| {
DatasetCorrelation::new(related_dataset.clone(), access.dataset.clone())
});
correlation.update_model(recent_record.access.index, &[access.index]);
}
self.clean_expired_correlations();
}
fn clean_expired_correlations(&mut self) {
let expiry = self.config.correlation_expiry;
let primaries: Vec<_> = self.correlations.keys().cloned().collect();
for primary in primaries {
if let Some(related_map) = self.correlations.get_mut(&primary) {
related_map.retain(|_, corr| corr.is_valid(expiry));
if related_map.is_empty() {
self.correlations.remove(&primary);
}
}
}
}
fn prefetch_related_data(&self, access: &DataAccess) -> CoreResult<()> {
if !self.correlations.contains_key(&access.dataset) {
return Ok(());
}
let related_datasets = self
.correlations
.get(&access.dataset)
.expect("Operation failed");
let mut correlations: Vec<_> = related_datasets.values().collect();
correlations.sort_by(|a, b| {
b.strength
.partial_cmp(&a.strength)
.expect("Operation failed")
});
let mut prefetch_count = 0;
for correlation in correlations {
if correlation.strength < self.config.correlation_threshold {
continue;
}
if let Some(weak_prefetcher) = self.datasets.get(&correlation.related) {
if let Some(prefetcher) = weak_prefetcher.upgrade() {
let indices = correlation
.get_related_indices(access.index, self.config.max_prefetch_elements);
if self.config.prefetch_entire_file && correlation.strength > 0.9 {
prefetcher.prefetch_all()?;
} else {
prefetcher.prefetch_indices(&indices)?;
}
prefetch_count += 1;
if prefetch_count >= self.config.max_prefetch_datasets {
break;
}
}
}
}
Ok(())
}
pub fn get_correlations(&self, dataset: &DatasetId) -> Vec<(DatasetId, f64)> {
let mut result = Vec::new();
if let Some(related_map) = self.correlations.get(dataset) {
for (related, correlation) in related_map {
result.push((related.clone(), correlation.strength));
}
}
result.sort_by(|(_, a), (_, b)| b.partial_cmp(a).expect("Operation failed"));
result
}
pub fn get_strongest_related(&self, dataset: &DatasetId) -> Option<(DatasetId, f64)> {
if let Some(related_map) = self.correlations.get(dataset) {
related_map
.iter()
.max_by(|(_, a), (_, b)| {
a.strength
.partial_cmp(&b.strength)
.expect("Operation failed")
})
.map(|(k, v)| (k.clone(), v.strength))
} else {
None
}
}
pub fn get_active_datasets(&self) -> Vec<DatasetId> {
self.last_dataset_access.keys().cloned().collect()
}
}
pub trait DatasetPrefetcher: Send + Sync {
fn prefetch_indices(&self, indices: &[usize]) -> CoreResult<()>;
fn prefetch_all(&self) -> CoreResult<()>;
fn get_dataset_id(&self) -> DatasetId;
}
pub struct CrossFilePrefetchRegistry {
manager: Arc<Mutex<CrossFilePrefetchManager>>,
}
impl CrossFilePrefetchRegistry {
pub fn new(config: CrossFilePrefetchConfig) -> Self {
Self {
manager: Arc::new(Mutex::new(CrossFilePrefetchManager::new(config))),
}
}
pub fn global() -> &'static Self {
use std::sync::Once;
static INIT: Once = Once::new();
static mut INSTANCE: Option<CrossFilePrefetchRegistry> = None;
INIT.call_once(|| {
let registry = CrossFilePrefetchRegistry::new(CrossFilePrefetchConfig::default());
unsafe {
INSTANCE = Some(registry);
}
});
#[allow(static_mut_refs)]
unsafe {
INSTANCE.as_ref().expect("Operation failed")
}
}
pub fn record_access(&self, access: DataAccess) -> CoreResult<()> {
match self.manager.lock() {
Ok(mut manager) => manager.record_access(access),
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
pub fn complete_access(
&self,
dataset: &DatasetId,
index: usize,
duration: Duration,
) -> CoreResult<()> {
match self.manager.lock() {
Ok(mut manager) => {
manager.complete_access(dataset, index, duration);
Ok(())
}
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
pub fn register_dataset(
&self,
dataset: DatasetId,
prefetcher: Arc<dyn DatasetPrefetcher>,
) -> CoreResult<()> {
match self.manager.lock() {
Ok(mut manager) => {
manager.register_dataset(dataset, prefetcher);
Ok(())
}
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
pub fn unregister_dataset(&self, dataset: &DatasetId) -> CoreResult<()> {
match self.manager.lock() {
Ok(mut manager) => {
manager.unregister_dataset(dataset);
Ok(())
}
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
pub fn get_correlations(&self, dataset: &DatasetId) -> CoreResult<Vec<(DatasetId, f64)>> {
match self.manager.lock() {
Ok(manager) => Ok(manager.get_correlations(dataset)),
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
pub fn get_active_datasets(&self) -> CoreResult<Vec<DatasetId>> {
match self.manager.lock() {
Ok(manager) => Ok(manager.get_active_datasets()),
Err(_) => Err(CoreError::MutexError(ErrorContext::new(
"Failed to acquire lock on cross-file prefetch manager".to_string(),
))),
}
}
}
#[cfg(feature = "memory_compression")]
pub struct CompressedArrayPrefetcher<A: Clone + Copy + Send + Sync + 'static> {
dataset_id: DatasetId,
array: Arc<super::compressed_memmap::CompressedMemMappedArray<A>>,
}
#[cfg(feature = "memory_compression")]
impl<A: Clone + Copy + Send + Sync + 'static> CompressedArrayPrefetcher<A> {
pub fn new(
dataset_id: DatasetId,
array: Arc<super::compressed_memmap::CompressedMemMappedArray<A>>,
) -> Self {
Self { dataset_id, array }
}
}
#[cfg(feature = "memory_compression")]
impl<A: Clone + Copy + Send + Sync + 'static> DatasetPrefetcher for CompressedArrayPrefetcher<A> {
fn prefetch_indices(&self, indices: &[usize]) -> CoreResult<()> {
if indices.is_empty() {
return Ok(());
}
let block_size = self.array.block_size();
let mut block_indices = HashSet::new();
for &idx in indices {
let block_idx = idx / block_size;
block_indices.insert(block_idx);
}
for block_idx in block_indices {
self.array.preload_block(block_idx)?;
}
Ok(())
}
fn prefetch_all(&self) -> CoreResult<()> {
let total_blocks = self.array.num_blocks();
for block_idx in 0..total_blocks {
self.array.preload_block(block_idx)?;
}
Ok(())
}
fn get_dataset_id(&self) -> DatasetId {
self.dataset_id.clone()
}
}
pub struct MemoryMappedArrayPrefetcher<A: Clone + Copy + Send + Sync + 'static> {
dataset_id: DatasetId,
array: Arc<super::memmap::MemoryMappedArray<A>>,
#[allow(dead_code)]
chunk_size: usize,
}
impl<A: Clone + Copy + Send + Sync + 'static> MemoryMappedArrayPrefetcher<A> {
pub fn new(
dataset_id: DatasetId,
array: Arc<super::memmap::MemoryMappedArray<A>>,
chunk_size: usize,
) -> Self {
Self {
dataset_id,
array,
chunk_size,
}
}
}
impl<A: Clone + Copy + Send + Sync + 'static> DatasetPrefetcher for MemoryMappedArrayPrefetcher<A> {
fn prefetch_indices(&self, indices: &[usize]) -> CoreResult<()> {
if indices.is_empty() {
return Ok(());
}
self.array.as_array::<crate::ndarray::IxDyn>()?;
Ok(())
}
fn prefetch_all(&self) -> CoreResult<()> {
self.array.as_array::<crate::ndarray::IxDyn>()?;
Ok(())
}
fn get_dataset_id(&self) -> DatasetId {
self.dataset_id.clone()
}
}
#[cfg(feature = "memory_compression")]
pub trait CompressedArrayPrefetchExt<A: Clone + Copy + Send + Sync + 'static> {
#[allow(dead_code)]
fn register_with_cross_file_prefetch(
&self,
dataset_id: DatasetId,
) -> CoreResult<Arc<CompressedArrayPrefetcher<A>>>;
}
#[cfg(feature = "memory_compression")]
impl<A: Clone + Copy + Send + Sync + 'static> CompressedArrayPrefetchExt<A>
for super::compressed_memmap::CompressedMemMappedArray<A>
{
fn register_with_cross_file_prefetch(
&self,
dataset_id: DatasetId,
) -> CoreResult<Arc<CompressedArrayPrefetcher<A>>> {
let array = Arc::new((*self).clone());
let prefetcher = Arc::new(CompressedArrayPrefetcher::new(dataset_id.clone(), array));
CrossFilePrefetchRegistry::global().register_dataset(dataset_id, prefetcher.clone())?;
Ok(prefetcher)
}
}
pub trait MemoryMappedArrayPrefetchExt<A: Clone + Copy + Send + Sync + 'static> {
#[allow(dead_code)]
fn register_with_cross_file_prefetch(
&self,
dataset_id: DatasetId,
chunk_size: usize,
) -> CoreResult<Arc<MemoryMappedArrayPrefetcher<A>>>;
}
impl<A: Clone + Copy + Send + Sync + 'static> MemoryMappedArrayPrefetchExt<A>
for super::memmap::MemoryMappedArray<A>
{
fn register_with_cross_file_prefetch(
&self,
dataset_id: DatasetId,
chunk_size: usize,
) -> CoreResult<Arc<MemoryMappedArrayPrefetcher<A>>> {
let array = Arc::new((*self).clone());
let prefetcher = Arc::new(MemoryMappedArrayPrefetcher::new(
dataset_id.clone(),
array,
chunk_size,
));
CrossFilePrefetchRegistry::global().register_dataset(dataset_id, prefetcher.clone())?;
Ok(prefetcher)
}
}
#[allow(dead_code)]
pub struct TrackedArray<A: Clone + Copy + 'static + Send + Sync, T> {
array: T,
dataset_id: DatasetId,
phantom: std::marker::PhantomData<A>,
}
#[allow(dead_code)]
impl<A: Clone + Copy + 'static + Send + Sync, T> TrackedArray<A, T> {
pub fn new(array: T, datasetid: DatasetId) -> Self {
Self {
array,
dataset_id: datasetid,
phantom: std::marker::PhantomData,
}
}
pub const fn inner(&self) -> &T {
&self.array
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.array
}
pub const fn dataset_id(&self) -> &DatasetId {
&self.dataset_id
}
fn record_access(
&self,
index: usize,
access_type: AccessType,
dimensions: Option<Vec<usize>>,
) -> CoreResult<()> {
let access = DataAccess {
dataset: self.dataset_id.clone(),
index,
access_type,
size: None,
dimensions,
};
CrossFilePrefetchRegistry::global().record_access(access)
}
}
#[cfg(feature = "memory_compression")]
impl<A: Clone + Copy + 'static + Send + Sync>
TrackedArray<A, super::compressed_memmap::CompressedMemMappedArray<A>>
{
#[allow(dead_code)]
pub fn get(&self, indices: &[usize]) -> CoreResult<A> {
let start = Instant::now();
let mut flat_index = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
flat_index += indices[i] * stride;
if i > 0 {
stride *= self.array.shape()[i];
}
}
self.record_access(flat_index, AccessType::Read, Some(indices.to_vec()))?;
let result = self.array.get(indices);
CrossFilePrefetchRegistry::global().complete_access(
&self.dataset_id,
flat_index,
start.elapsed(),
)?;
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dataset_id() {
let path_id = DatasetId::from_path("/path/to/data.bin");
let mem_id = DatasetId::from_address(0x12345678, "memory_dataset");
let name_id = DatasetId::from_name("named_dataset");
assert_eq!(path_id, DatasetId::from_path("/path/to/data.bin"));
assert_eq!(mem_id, DatasetId::from_address(0x12345678, "other_name"));
assert_eq!(name_id, DatasetId::from_name("named_dataset"));
assert_ne!(path_id, mem_id);
assert_ne!(path_id, name_id);
assert_ne!(mem_id, name_id);
}
#[test]
fn test_correlation_update() {
let primary = DatasetId::from_name("dataset1");
let related = DatasetId::from_name("dataset2");
let mut correlation = DatasetCorrelation::new(primary.clone(), related.clone());
assert_eq!(correlation.strength, 0.0);
assert_eq!(correlation.occurrences, 0);
correlation.update_model(10, &[20, 30, 40]);
assert!(correlation.strength > 0.0);
assert_eq!(correlation.occurrences, 1);
assert!(correlation.index_correlations.contains_key(&10));
assert_eq!(correlation.index_correlations[&10], vec![20, 30, 40]);
correlation.update_model(10, &[20, 30, 50]);
assert!(correlation.strength > 0.1);
assert_eq!(correlation.occurrences, 2);
assert_eq!(correlation.index_correlations[&10], vec![20, 30, 40, 50]);
}
#[test]
fn test_prefetch_manager() {
let config = CrossFilePrefetchConfig {
correlation_threshold: 0.5,
min_occurrences: 2,
..Default::default()
};
let mut manager = CrossFilePrefetchManager::new(config);
let dataset1 = DatasetId::from_name("dataset1");
let dataset2 = DatasetId::from_name("dataset2");
for i in 0..5 {
let access1 = DataAccess {
dataset: dataset1.clone(),
index: i,
access_type: AccessType::Read,
size: None,
dimensions: None,
};
let access2 = DataAccess {
dataset: dataset2.clone(),
index: i * 2,
access_type: AccessType::Read,
size: None,
dimensions: None,
};
manager.record_access(access1).expect("Operation failed");
manager.record_access(access2).expect("Operation failed");
}
let correlations = manager.get_correlations(&dataset1);
assert!(!correlations.is_empty());
assert_eq!(correlations[0].0, dataset2);
assert!(correlations[0].1 > 0.0);
}
struct MockPrefetcher {
dataset_id: DatasetId,
prefetched_indices: Arc<Mutex<Vec<usize>>>,
prefetched_all: Arc<Mutex<bool>>,
}
impl MockPrefetcher {
fn new(datasetid: DatasetId) -> Self {
Self {
dataset_id: datasetid,
prefetched_indices: Arc::new(Mutex::new(Vec::new())),
prefetched_all: Arc::new(Mutex::new(false)),
}
}
}
impl DatasetPrefetcher for MockPrefetcher {
fn prefetch_indices(&self, indices: &[usize]) -> CoreResult<()> {
let mut prefetched = self.prefetched_indices.lock().expect("Operation failed");
prefetched.extend_from_slice(indices);
Ok(())
}
fn prefetch_all(&self) -> CoreResult<()> {
let mut prefetched_all = self.prefetched_all.lock().expect("Operation failed");
*prefetched_all = true;
Ok(())
}
fn get_dataset_id(&self) -> DatasetId {
self.dataset_id.clone()
}
}
#[test]
fn test_cross_file_prefetching() {
let config = CrossFilePrefetchConfig {
correlation_threshold: 0.01, min_occurrences: 1,
max_prefetch_datasets: 5,
..Default::default()
};
let mut manager = CrossFilePrefetchManager::new(config);
let dataset1 = DatasetId::from_name("dataset1");
let dataset2 = DatasetId::from_name("dataset2");
let prefetcher1 = Arc::new(MockPrefetcher::new(dataset1.clone()));
let prefetcher2 = Arc::new(MockPrefetcher::new(dataset2.clone()));
manager.register_dataset(dataset1.clone(), prefetcher1.clone());
manager.register_dataset(dataset2.clone(), prefetcher2.clone());
for i in 0..3 {
let access1 = DataAccess {
dataset: dataset1.clone(),
index: i,
access_type: AccessType::Read,
size: None,
dimensions: None,
};
let access2 = DataAccess {
dataset: dataset2.clone(),
index: i * 2,
access_type: AccessType::Read,
size: None,
dimensions: None,
};
manager.record_access(access1).expect("Operation failed");
manager.record_access(access2).expect("Operation failed");
}
let access = DataAccess {
dataset: dataset1.clone(),
index: 10,
access_type: AccessType::Read,
size: None,
dimensions: None,
};
manager.record_access(access).expect("Operation failed");
let prefetched = prefetcher2
.prefetched_indices
.lock()
.expect("Operation failed");
assert!(!prefetched.is_empty());
assert!(prefetched.contains(&20));
}
}