#![allow(dead_code)]
use crate::concurrent_document_text_storage::{ConcurrentDocumentTextStorage, ConcurrentStorageConfig};
use crate::document_text_storage::DocumentTextStorage;
use crate::error::ShardexError;
use crate::identifiers::DocumentId;
use parking_lot::Mutex;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use tokio::sync::{RwLock, Semaphore};
use tokio::time::timeout;
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct AccessEntry {
document_id: DocumentId,
timestamp: SystemTime,
#[allow(dead_code)] sequence_position: usize,
}
impl AccessEntry {
#[allow(dead_code)]
fn new(document_id: DocumentId, sequence_position: usize) -> Self {
Self {
document_id,
timestamp: SystemTime::now(),
sequence_position,
}
}
#[allow(dead_code)]
fn age(&self) -> Duration {
self.timestamp.elapsed().unwrap_or(Duration::ZERO)
}
}
#[derive(Debug, Default)]
#[allow(dead_code)]
struct CooccurrenceMap {
patterns: HashMap<DocumentId, HashMap<DocumentId, f64>>,
max_patterns_per_document: usize,
}
impl CooccurrenceMap {
#[allow(dead_code)]
fn new(max_patterns_per_document: usize) -> Self {
Self {
patterns: HashMap::new(),
max_patterns_per_document,
}
}
#[allow(dead_code)]
fn record_cooccurrence(&mut self, doc1: DocumentId, doc2: DocumentId, weight: f64) {
if doc1 == doc2 {
return;
}
let entry = self.patterns.entry(doc1).or_default();
*entry.entry(doc2).or_default() += weight;
if entry.len() > self.max_patterns_per_document {
if let Some((&weakest_doc, _)) = entry.iter().min_by(|a, b| a.1.partial_cmp(b.1).unwrap()) {
entry.remove(&weakest_doc);
}
}
}
#[allow(dead_code)]
fn get_predicted_documents(&self, document_id: DocumentId, limit: usize) -> Vec<(DocumentId, f64)> {
self.patterns
.get(&document_id)
.map(|patterns| {
let mut sorted: Vec<_> = patterns.iter().map(|(&id, &score)| (id, score)).collect();
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
sorted.truncate(limit);
sorted
})
.unwrap_or_default()
}
#[allow(dead_code)]
fn cleanup(&mut self, min_strength: f64) {
self.patterns.retain(|_, patterns| {
patterns.retain(|_, &mut strength| strength >= min_strength);
!patterns.is_empty()
});
}
}
#[derive(Debug)]
#[allow(dead_code)]
struct AccessPatternTracker {
access_history: VecDeque<AccessEntry>,
max_history_size: usize,
temporal_window: Duration,
cooccurrence: CooccurrenceMap,
sequence_counter: usize,
}
impl AccessPatternTracker {
#[allow(dead_code)]
fn new(max_history_size: usize, temporal_window: Duration, max_cooccurrence_patterns: usize) -> Self {
Self {
access_history: VecDeque::with_capacity(max_history_size),
max_history_size,
temporal_window,
cooccurrence: CooccurrenceMap::new(max_cooccurrence_patterns),
sequence_counter: 0,
}
}
#[allow(dead_code)]
fn record_access(&mut self, document_id: DocumentId) {
let entry = AccessEntry::new(document_id, self.sequence_counter);
self.access_history.push_back(entry);
self.sequence_counter += 1;
while self.access_history.len() > self.max_history_size {
self.access_history.pop_front();
}
let recent_accesses: Vec<_> = self.access_history
.iter()
.rev()
.take(10) .filter(|entry| entry.age() <= self.temporal_window)
.map(|entry| entry.document_id)
.collect();
for other_doc in recent_accesses {
let distance = if let Some(pos) = self
.access_history
.iter()
.rposition(|e| e.document_id == other_doc)
{
self.access_history.len() - pos
} else {
continue;
};
let weight = 1.0 / (distance as f64 + 1.0);
self.cooccurrence
.record_cooccurrence(document_id, other_doc, weight);
}
}
#[allow(dead_code)]
fn predict_next_documents(&self, current_document: DocumentId, limit: usize) -> Vec<DocumentId> {
let mut predictions = Vec::new();
if let Some(_current_pos) = self
.access_history
.iter()
.rposition(|entry| entry.document_id == current_document)
{
let mut sequence_scores: HashMap<DocumentId, f64> = HashMap::new();
for (i, entry) in self.access_history.iter().enumerate() {
if entry.document_id == current_document && i + 1 < self.access_history.len() {
let next_doc = self.access_history[i + 1].document_id;
let distance_from_current = (self.access_history.len() - i) as f64;
let weight = 1.0 / distance_from_current; *sequence_scores.entry(next_doc).or_default() += weight;
}
}
let mut seq_predictions: Vec<_> = sequence_scores.into_iter().collect();
seq_predictions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
predictions.extend(
seq_predictions
.into_iter()
.take(limit / 2)
.map(|(doc, _)| doc),
);
}
let cooccurrence_predictions = self
.cooccurrence
.get_predicted_documents(current_document, limit);
predictions.extend(cooccurrence_predictions.into_iter().map(|(doc, _)| doc));
predictions.retain(|&doc| doc != current_document);
predictions.sort();
predictions.dedup();
predictions.truncate(limit);
predictions
}
#[allow(dead_code)]
fn cleanup(&mut self) {
let cutoff_time = SystemTime::now() - self.temporal_window;
while let Some(entry) = self.access_history.front() {
if entry.timestamp < cutoff_time {
self.access_history.pop_front();
} else {
break;
}
}
self.cooccurrence.cleanup(0.1); }
#[allow(dead_code)]
fn get_stats(&self) -> (usize, usize) {
(self.access_history.len(), self.cooccurrence.patterns.len())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct AsyncStorageConfig {
pub concurrent_config: ConcurrentStorageConfig,
pub read_ahead_buffer_size: usize,
pub read_ahead_ttl: Duration,
pub max_concurrent_async_ops: usize,
pub default_timeout: Duration,
pub read_ahead_window: usize,
pub cleanup_interval: Duration,
pub max_access_history: usize,
pub prediction_temporal_window: Duration,
pub max_cooccurrence_patterns: usize,
pub prediction_count: usize,
}
impl Default for AsyncStorageConfig {
fn default() -> Self {
Self {
concurrent_config: ConcurrentStorageConfig::default(),
read_ahead_buffer_size: 1000,
read_ahead_ttl: Duration::from_secs(300),
max_concurrent_async_ops: 200,
default_timeout: Duration::from_secs(30),
read_ahead_window: 10,
cleanup_interval: Duration::from_secs(60),
max_access_history: 1000,
prediction_temporal_window: Duration::from_secs(1800), max_cooccurrence_patterns: 50,
prediction_count: 5,
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ReadAheadEntry {
#[allow(dead_code)] document_id: DocumentId,
text: String,
created_at: SystemTime,
access_count: u64,
}
impl ReadAheadEntry {
#[allow(dead_code)]
fn new(document_id: DocumentId, text: String) -> Self {
Self {
document_id,
text,
created_at: SystemTime::now(),
access_count: 0,
}
}
#[allow(dead_code)]
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed().unwrap_or(Duration::ZERO) > ttl
}
#[allow(dead_code)]
fn touch(&mut self) {
self.access_count += 1;
}
}
#[derive(Debug)]
#[allow(dead_code)]
struct ReadAheadBuffer {
entries: HashMap<DocumentId, ReadAheadEntry>,
access_order: Vec<DocumentId>,
max_size: usize,
ttl: Duration,
}
impl ReadAheadBuffer {
fn new(max_size: usize, ttl: Duration) -> Self {
Self {
entries: HashMap::with_capacity(max_size),
access_order: Vec::with_capacity(max_size),
max_size,
ttl,
}
}
fn get(&mut self, document_id: &DocumentId) -> Option<String> {
if let Some(entry) = self.entries.get_mut(document_id) {
if !entry.is_expired(self.ttl) {
entry.touch();
if let Some(pos) = self.access_order.iter().position(|id| id == document_id) {
self.access_order.remove(pos);
}
self.access_order.push(*document_id);
return Some(entry.text.clone());
} else {
self.entries.remove(document_id);
self.access_order.retain(|id| id != document_id);
}
}
None
}
fn put(&mut self, document_id: DocumentId, text: String) {
if self.entries.contains_key(&document_id) {
self.access_order.retain(|id| id != &document_id);
}
while self.entries.len() >= self.max_size {
if let Some(oldest_id) = self.access_order.first().copied() {
self.entries.remove(&oldest_id);
self.access_order.remove(0);
} else {
break;
}
}
let entry = ReadAheadEntry::new(document_id, text);
self.entries.insert(document_id, entry);
self.access_order.push(document_id);
}
fn cleanup_expired(&mut self) -> usize {
let original_len = self.entries.len();
let expired_ids: Vec<_> = self
.entries
.iter()
.filter(|(_, entry)| entry.is_expired(self.ttl))
.map(|(id, _)| *id)
.collect();
for id in expired_ids {
self.entries.remove(&id);
self.access_order.retain(|entry_id| entry_id != &id);
}
original_len - self.entries.len()
}
fn len(&self) -> usize {
self.entries.len()
}
fn capacity(&self) -> usize {
self.max_size
}
fn clear(&mut self) {
self.entries.clear();
self.access_order.clear();
}
}
#[derive(Debug, Clone, Default)]
pub struct AsyncStorageMetrics {
pub async_reads: u64,
pub successful_async_reads: u64,
pub failed_async_reads: u64,
pub async_writes: u64,
pub successful_async_writes: u64,
pub failed_async_writes: u64,
pub read_ahead_hits: u64,
pub read_ahead_misses: u64,
pub read_ahead_predictions: u64,
pub avg_async_latency_ms: f64,
pub timeout_errors: u64,
pub background_tasks_executed: u64,
}
impl AsyncStorageMetrics {
pub fn async_read_success_ratio(&self) -> f64 {
if self.async_reads == 0 {
0.0
} else {
self.successful_async_reads as f64 / self.async_reads as f64
}
}
pub fn async_write_success_ratio(&self) -> f64 {
if self.async_writes == 0 {
0.0
} else {
self.successful_async_writes as f64 / self.async_writes as f64
}
}
pub fn read_ahead_hit_ratio(&self) -> f64 {
let total = self.read_ahead_hits + self.read_ahead_misses;
if total == 0 {
0.0
} else {
self.read_ahead_hits as f64 / total as f64
}
}
pub fn total_async_operations(&self) -> u64 {
self.async_reads + self.async_writes
}
}
pub struct AsyncDocumentTextStorage {
storage: Arc<ConcurrentDocumentTextStorage>,
read_ahead_buffer: Arc<RwLock<ReadAheadBuffer>>,
async_semaphore: Arc<Semaphore>,
config: AsyncStorageConfig,
metrics: Arc<Mutex<AsyncStorageMetrics>>,
access_tracker: Arc<RwLock<AccessPatternTracker>>,
background_tasks: Arc<Mutex<Vec<tokio::task::JoinHandle<()>>>>,
}
impl AsyncDocumentTextStorage {
pub async fn new(storage: DocumentTextStorage, config: AsyncStorageConfig) -> Result<Self, ShardexError> {
let concurrent_storage = ConcurrentDocumentTextStorage::new(storage, config.concurrent_config.clone());
concurrent_storage.start_background_processor().await?;
let read_ahead_buffer = Arc::new(RwLock::new(ReadAheadBuffer::new(
config.read_ahead_buffer_size,
config.read_ahead_ttl,
)));
let access_tracker = Arc::new(RwLock::new(AccessPatternTracker::new(
config.max_access_history,
config.prediction_temporal_window,
config.max_cooccurrence_patterns,
)));
let async_storage = Self {
storage: Arc::new(concurrent_storage),
read_ahead_buffer,
async_semaphore: Arc::new(Semaphore::new(config.max_concurrent_async_ops)),
config,
metrics: Arc::new(Mutex::new(AsyncStorageMetrics::default())),
access_tracker,
background_tasks: Arc::new(Mutex::new(Vec::new())),
};
async_storage.start_background_cleanup().await?;
Ok(async_storage)
}
async fn start_background_cleanup(&self) -> Result<(), ShardexError> {
let read_ahead_buffer = Arc::clone(&self.read_ahead_buffer);
let access_tracker = Arc::clone(&self.access_tracker);
let metrics = Arc::clone(&self.metrics);
let cleanup_interval = self.config.cleanup_interval;
let cleanup_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(cleanup_interval);
loop {
interval.tick().await;
let expired_count = {
let mut buffer = read_ahead_buffer.write().await;
buffer.cleanup_expired()
};
if expired_count > 0 {
log::debug!("Cleaned up {} expired read-ahead entries", expired_count);
}
{
let mut tracker = access_tracker.write().await;
tracker.cleanup();
let (history_size, pattern_count) = tracker.get_stats();
log::trace!(
"Access tracker stats: {} history entries, {} pattern groups",
history_size,
pattern_count
);
}
{
let mut metrics_guard = metrics.lock();
metrics_guard.background_tasks_executed += 1;
}
}
});
let mut tasks = self.background_tasks.lock();
tasks.push(cleanup_task);
Ok(())
}
pub async fn get_text_async(&self, document_id: DocumentId) -> Result<String, ShardexError> {
let _permit = self
.async_semaphore
.acquire()
.await
.map_err(|_| ShardexError::InvalidInput {
field: "async_semaphore".to_string(),
reason: "Failed to acquire async semaphore permit".to_string(),
suggestion: "Retry the operation".to_string(),
})?;
let start_time = Instant::now();
{
let mut buffer = self.read_ahead_buffer.write().await;
if let Some(text) = buffer.get(&document_id) {
drop(buffer); {
let mut tracker = self.access_tracker.write().await;
tracker.record_access(document_id);
}
self.trigger_read_ahead(document_id).await;
self.record_read_ahead_hit();
self.record_async_read_success(start_time.elapsed().as_millis() as f64);
return Ok(text);
}
}
self.record_read_ahead_miss();
let result = timeout(
self.config.default_timeout,
self.storage.get_text_concurrent(document_id),
)
.await;
match result {
Ok(Ok(text)) => {
{
let mut tracker = self.access_tracker.write().await;
tracker.record_access(document_id);
}
{
let mut buffer = self.read_ahead_buffer.write().await;
buffer.put(document_id, text.clone());
}
self.trigger_read_ahead(document_id).await;
self.record_async_read_success(start_time.elapsed().as_millis() as f64);
Ok(text)
}
Ok(Err(e)) => {
self.record_async_read_failure(start_time.elapsed().as_millis() as f64);
Err(e)
}
Err(_) => {
self.record_timeout_error();
self.record_async_read_failure(start_time.elapsed().as_millis() as f64);
Err(ShardexError::InvalidInput {
field: "async_operation".to_string(),
reason: "Async operation timed out".to_string(),
suggestion: "Increase timeout or check storage performance".to_string(),
})
}
}
}
pub async fn store_text_async(&self, document_id: DocumentId, text: String) -> Result<(), ShardexError> {
let _permit = self
.async_semaphore
.acquire()
.await
.map_err(|_| ShardexError::InvalidInput {
field: "async_semaphore".to_string(),
reason: "Failed to acquire async semaphore permit".to_string(),
suggestion: "Retry the operation".to_string(),
})?;
let start_time = Instant::now();
let result = timeout(
self.config.default_timeout,
self.storage.store_text_batched(document_id, text.clone()),
)
.await;
match result {
Ok(Ok(())) => {
{
let mut buffer = self.read_ahead_buffer.write().await;
buffer.put(document_id, text);
}
self.record_async_write_success(start_time.elapsed().as_millis() as f64);
Ok(())
}
Ok(Err(e)) => {
self.record_async_write_failure(start_time.elapsed().as_millis() as f64);
Err(e)
}
Err(_) => {
self.record_timeout_error();
self.record_async_write_failure(start_time.elapsed().as_millis() as f64);
Err(ShardexError::InvalidInput {
field: "async_operation".to_string(),
reason: "Async operation timed out".to_string(),
suggestion: "Increase timeout or check storage performance".to_string(),
})
}
}
}
pub async fn store_texts_batch_async(
&self,
documents: Vec<(DocumentId, String)>,
) -> Result<Vec<Result<(), ShardexError>>, ShardexError> {
let batch_size = documents.len();
let start_time = Instant::now();
let mut tasks = Vec::new();
for (doc_id, text) in documents {
let storage = Arc::clone(&self.storage);
let task = tokio::spawn(async move { storage.store_text_batched(doc_id, text).await });
tasks.push(task);
}
let mut results = Vec::new();
let mut successful_count = 0;
for task in tasks {
match task.await {
Ok(Ok(())) => {
results.push(Ok(()));
successful_count += 1;
}
Ok(Err(e)) => {
results.push(Err(e));
}
Err(_) => {
results.push(Err(ShardexError::InvalidInput {
field: "batch_task".to_string(),
reason: "Batch task was cancelled".to_string(),
suggestion: "Retry the batch operation".to_string(),
}));
}
}
}
let avg_latency = start_time.elapsed().as_millis() as f64 / batch_size as f64;
for _ in 0..successful_count {
self.record_async_write_success(avg_latency);
}
for _ in 0..(batch_size - successful_count) {
self.record_async_write_failure(avg_latency);
}
Ok(results)
}
pub async fn extract_text_substring_async(
&self,
document_id: DocumentId,
start: u32,
length: u32,
) -> Result<String, ShardexError> {
let _permit = self
.async_semaphore
.acquire()
.await
.map_err(|_| ShardexError::InvalidInput {
field: "async_semaphore".to_string(),
reason: "Failed to acquire async semaphore permit".to_string(),
suggestion: "Retry the operation".to_string(),
})?;
let start_time = Instant::now();
let result = timeout(
self.config.default_timeout,
self.storage
.extract_text_substring_concurrent(document_id, start, length),
)
.await;
match result {
Ok(Ok(text)) => {
self.record_async_read_success(start_time.elapsed().as_millis() as f64);
Ok(text)
}
Ok(Err(e)) => {
self.record_async_read_failure(start_time.elapsed().as_millis() as f64);
Err(e)
}
Err(_) => {
self.record_timeout_error();
self.record_async_read_failure(start_time.elapsed().as_millis() as f64);
Err(ShardexError::InvalidInput {
field: "async_operation".to_string(),
reason: "Async operation timed out".to_string(),
suggestion: "Increase timeout or check storage performance".to_string(),
})
}
}
}
async fn trigger_read_ahead(&self, document_id: DocumentId) {
self.record_read_ahead_prediction();
let predicted_documents = {
let tracker = self.access_tracker.read().await;
tracker.predict_next_documents(document_id, self.config.prediction_count)
};
if predicted_documents.is_empty() {
log::trace!("No predictions available for document {}", document_id);
return;
}
let prediction_count = predicted_documents.len();
log::trace!(
"Predicting {} documents for read-ahead: {:?}",
prediction_count,
predicted_documents
);
for predicted_id in predicted_documents {
{
let buffer = self.read_ahead_buffer.read().await;
if buffer.entries.contains_key(&predicted_id) {
log::trace!("Document {} already in read-ahead buffer, skipping", predicted_id);
continue;
}
}
let storage = Arc::clone(&self.storage);
let read_ahead_buffer = Arc::clone(&self.read_ahead_buffer);
let predicted_doc_id = predicted_id;
tokio::spawn(async move {
match storage.get_text_concurrent(predicted_doc_id).await {
Ok(text) => {
let mut buffer = read_ahead_buffer.write().await;
buffer.put(predicted_doc_id, text);
log::trace!("Pre-loaded document {} into read-ahead buffer", predicted_doc_id);
}
Err(e) => {
log::debug!(
"Failed to pre-load document {} for read-ahead: {:?}",
predicted_doc_id,
e
);
}
}
});
}
log::trace!("Triggered read-ahead prediction for {} documents", prediction_count);
}
pub async fn warm_read_ahead_buffer(&self, document_ids: Vec<DocumentId>) -> Result<(), ShardexError> {
for document_id in document_ids {
match self.storage.get_text_concurrent(document_id).await {
Ok(text) => {
let mut buffer = self.read_ahead_buffer.write().await;
buffer.put(document_id, text);
}
Err(e) => {
log::warn!("Failed to warm read-ahead buffer for document {}: {:?}", document_id, e);
}
}
}
Ok(())
}
pub async fn shutdown(&self) -> Result<(), ShardexError> {
{
let mut tasks = self.background_tasks.lock();
for task in tasks.drain(..) {
task.abort();
}
}
self.storage.flush_write_queue().await?;
self.storage.stop_background_processor().await?;
{
let mut buffer = self.read_ahead_buffer.write().await;
buffer.clear();
}
Ok(())
}
pub fn get_metrics(&self) -> AsyncStorageMetrics {
let metrics = self.metrics.lock();
metrics.clone()
}
pub async fn read_ahead_info(&self) -> (usize, usize) {
let buffer = self.read_ahead_buffer.read().await;
(buffer.len(), buffer.capacity())
}
pub async fn clear_read_ahead_buffer(&self) {
let mut buffer = self.read_ahead_buffer.write().await;
buffer.clear();
}
fn record_async_read_success(&self, latency_ms: f64) {
let mut metrics = self.metrics.lock();
metrics.async_reads += 1;
metrics.successful_async_reads += 1;
self.update_avg_latency(&mut metrics, latency_ms);
}
fn record_async_read_failure(&self, latency_ms: f64) {
let mut metrics = self.metrics.lock();
metrics.async_reads += 1;
metrics.failed_async_reads += 1;
self.update_avg_latency(&mut metrics, latency_ms);
}
fn record_async_write_success(&self, latency_ms: f64) {
let mut metrics = self.metrics.lock();
metrics.async_writes += 1;
metrics.successful_async_writes += 1;
self.update_avg_latency(&mut metrics, latency_ms);
}
fn record_async_write_failure(&self, latency_ms: f64) {
let mut metrics = self.metrics.lock();
metrics.async_writes += 1;
metrics.failed_async_writes += 1;
self.update_avg_latency(&mut metrics, latency_ms);
}
fn record_read_ahead_hit(&self) {
let mut metrics = self.metrics.lock();
metrics.read_ahead_hits += 1;
}
fn record_read_ahead_miss(&self) {
let mut metrics = self.metrics.lock();
metrics.read_ahead_misses += 1;
}
fn record_read_ahead_prediction(&self) {
let mut metrics = self.metrics.lock();
metrics.read_ahead_predictions += 1;
}
fn record_timeout_error(&self) {
let mut metrics = self.metrics.lock();
metrics.timeout_errors += 1;
}
fn update_avg_latency(&self, metrics: &mut AsyncStorageMetrics, latency_ms: f64) {
let total_ops = metrics.total_async_operations();
if total_ops == 1 {
metrics.avg_async_latency_ms = latency_ms;
} else {
metrics.avg_async_latency_ms =
((metrics.avg_async_latency_ms * (total_ops - 1) as f64) + latency_ms) / total_ops as f64;
}
}
}
impl Drop for AsyncDocumentTextStorage {
fn drop(&mut self) {
if let Ok(rt) = tokio::runtime::Handle::try_current() {
let background_tasks = Arc::clone(&self.background_tasks);
rt.spawn(async move {
let mut tasks = background_tasks.lock();
for task in tasks.drain(..) {
task.abort();
}
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::document_text_storage::DocumentTextStorage;
use tempfile::TempDir;
#[tokio::test]
async fn test_async_storage_creation() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig::default();
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let metrics = async_storage.get_metrics();
assert_eq!(metrics.async_reads, 0);
assert_eq!(metrics.async_writes, 0);
async_storage.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_async_read_write() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig::default();
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let doc_id = DocumentId::new();
let text = "Async test document content";
async_storage
.store_text_async(doc_id, text.to_string())
.await
.unwrap();
let retrieved = async_storage.get_text_async(doc_id).await.unwrap();
assert_eq!(retrieved, text);
let metrics = async_storage.get_metrics();
assert_eq!(metrics.async_reads, 1);
assert_eq!(metrics.async_writes, 1);
assert_eq!(metrics.successful_async_reads, 1);
assert_eq!(metrics.successful_async_writes, 1);
async_storage.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_read_ahead_buffer() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig::default();
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let doc_id = DocumentId::new();
let text = "Read-ahead test content";
async_storage
.store_text_async(doc_id, text.to_string())
.await
.unwrap();
let _ = async_storage.get_text_async(doc_id).await.unwrap();
let retrieved = async_storage.get_text_async(doc_id).await.unwrap();
assert_eq!(retrieved, text);
let metrics = async_storage.get_metrics();
assert!(metrics.read_ahead_hits > 0);
async_storage.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_batch_async_operations() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig::default();
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let documents = vec![
(DocumentId::new(), "Document 1".to_string()),
(DocumentId::new(), "Document 2".to_string()),
(DocumentId::new(), "Document 3".to_string()),
];
let _doc_ids: Vec<_> = documents.iter().map(|(id, _)| *id).collect();
let results = async_storage
.store_texts_batch_async(documents.clone())
.await
.unwrap();
assert_eq!(results.len(), 3);
for result in results {
assert!(result.is_ok());
}
for (doc_id, expected_text) in documents {
let retrieved = async_storage.get_text_async(doc_id).await.unwrap();
assert_eq!(retrieved, expected_text);
}
async_storage.shutdown().await.unwrap();
}
#[test]
fn test_read_ahead_buffer_functionality() {
let mut buffer = ReadAheadBuffer::new(3, Duration::from_secs(60));
let doc1 = DocumentId::new();
let doc2 = DocumentId::new();
let doc3 = DocumentId::new();
let doc4 = DocumentId::new();
buffer.put(doc1, "Text 1".to_string());
buffer.put(doc2, "Text 2".to_string());
buffer.put(doc3, "Text 3".to_string());
assert_eq!(buffer.len(), 3);
let text = buffer.get(&doc2);
assert_eq!(text, Some("Text 2".to_string()));
buffer.put(doc4, "Text 4".to_string());
assert_eq!(buffer.len(), 3);
assert!(buffer.get(&doc1).is_none());
assert!(buffer.get(&doc4).is_some());
}
#[test]
fn test_async_metrics_calculations() {
let metrics = AsyncStorageMetrics {
successful_async_reads: 80,
async_reads: 100,
successful_async_writes: 90,
async_writes: 100,
read_ahead_hits: 70,
read_ahead_misses: 30,
..Default::default()
};
assert_eq!(metrics.async_read_success_ratio(), 0.8);
assert_eq!(metrics.async_write_success_ratio(), 0.9);
assert_eq!(metrics.read_ahead_hit_ratio(), 0.7);
assert_eq!(metrics.total_async_operations(), 200);
}
#[test]
fn test_access_pattern_tracker() {
let mut tracker = AccessPatternTracker::new(100, Duration::from_secs(60), 50);
let doc1 = DocumentId::new();
let doc2 = DocumentId::new();
let doc3 = DocumentId::new();
tracker.record_access(doc1);
tracker.record_access(doc2);
tracker.record_access(doc3);
let predictions = tracker.predict_next_documents(doc1, 5);
assert!(!predictions.is_empty());
let (history_size, pattern_count) = tracker.get_stats();
assert_eq!(history_size, 3);
assert!(pattern_count > 0);
}
#[test]
fn test_cooccurrence_map() {
let mut cooccur = CooccurrenceMap::new(10);
let doc1 = DocumentId::new();
let doc2 = DocumentId::new();
let doc3 = DocumentId::new();
cooccur.record_cooccurrence(doc1, doc2, 1.0);
cooccur.record_cooccurrence(doc1, doc3, 0.5);
cooccur.record_cooccurrence(doc2, doc3, 0.8);
let predictions = cooccur.get_predicted_documents(doc1, 2);
assert_eq!(predictions.len(), 2);
assert_eq!(predictions[0].0, doc2);
assert_eq!(predictions[0].1, 1.0);
assert_eq!(predictions[1].0, doc3);
assert_eq!(predictions[1].1, 0.5);
cooccur.cleanup(0.6);
let predictions_after_cleanup = cooccur.get_predicted_documents(doc1, 2);
assert_eq!(predictions_after_cleanup.len(), 1); }
#[test]
fn test_access_pattern_sequence_prediction() {
let mut tracker = AccessPatternTracker::new(100, Duration::from_secs(300), 50);
let docs = [
DocumentId::new(),
DocumentId::new(),
DocumentId::new(),
DocumentId::new(),
];
for _ in 0..3 {
for doc in &docs[0..3] {
tracker.record_access(*doc);
}
}
let predictions_from_doc1 = tracker.predict_next_documents(docs[0], 3);
assert!(!predictions_from_doc1.is_empty());
let predictions_from_doc2 = tracker.predict_next_documents(docs[1], 3);
assert!(!predictions_from_doc2.is_empty());
}
#[tokio::test]
async fn test_read_ahead_prediction_integration() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig::default();
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let doc_id = DocumentId::new();
let text = "Test document for prediction".to_string();
async_storage
.store_text_async(doc_id, text.clone())
.await
.unwrap();
let _ = async_storage.get_text_async(doc_id).await.unwrap();
let metrics = async_storage.get_metrics();
assert!(
metrics.read_ahead_predictions > 0,
"Expected at least one read-ahead prediction to be triggered"
);
async_storage.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_access_pattern_cleanup() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig {
prediction_temporal_window: Duration::from_millis(50), cleanup_interval: Duration::from_millis(100),
..AsyncStorageConfig::default()
};
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let doc_id = DocumentId::new();
let text = "Test document for cleanup".to_string();
async_storage.store_text_async(doc_id, text).await.unwrap();
let _ = async_storage.get_text_async(doc_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
let (history_size, _) = {
let tracker = async_storage.access_tracker.read().await;
tracker.get_stats()
};
assert!(history_size <= 1);
async_storage.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_prediction_performance_with_many_documents() {
let temp_dir = TempDir::new().unwrap();
let storage = DocumentTextStorage::create(&temp_dir, 1024 * 1024).unwrap();
let config = AsyncStorageConfig {
max_access_history: 500,
prediction_count: 10,
..AsyncStorageConfig::default()
};
let async_storage = AsyncDocumentTextStorage::new(storage, config)
.await
.unwrap();
let num_docs = 100;
let mut doc_ids = Vec::new();
for i in 0..num_docs {
let doc_id = DocumentId::new();
let text = format!("Document number {}", i);
doc_ids.push(doc_id);
async_storage.store_text_async(doc_id, text).await.unwrap();
}
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let start_time = std::time::Instant::now();
for i in 0..200 {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
let index = (hasher.finish() as usize) % num_docs;
let _ = async_storage.get_text_async(doc_ids[index]).await.unwrap();
}
let access_time = start_time.elapsed();
assert!(access_time.as_millis() < 5000);
let metrics = async_storage.get_metrics();
assert!(
metrics.read_ahead_predictions > 0,
"Expected predictions with {} reads",
metrics.async_reads
);
assert_eq!(metrics.async_reads, 200);
async_storage.shutdown().await.unwrap();
}
}