use datafusion::prelude::SessionContext;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::prelude::*;
use crate::test_utils::{create_tpc_h_context as create_tpc_h_context_uncached, ScaleFactor};
type ContextCacheMap = HashMap<String, Arc<SessionContext>>;
type ContextCache = Arc<RwLock<ContextCacheMap>>;
static CONTEXT_CACHE: Lazy<ContextCache> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub async fn get_or_create_tpc_h_context(scale: ScaleFactor) -> Result<Arc<SessionContext>> {
let cache_key = format!("tpc_h_{scale:?}");
{
let cache = CONTEXT_CACHE.read().await;
if let Some(ctx) = cache.get(&cache_key) {
tracing::debug!("Using cached TPC-H context for scale {:?}", scale);
return Ok(Arc::clone(ctx));
}
}
let mut cache = CONTEXT_CACHE.write().await;
if let Some(ctx) = cache.get(&cache_key) {
tracing::debug!(
"Using cached TPC-H context for scale {:?} (created by another thread)",
scale
);
return Ok(Arc::clone(ctx));
}
tracing::info!("Creating new TPC-H context for scale {:?}", scale);
let ctx = Arc::new(create_tpc_h_context_uncached(scale).await?);
cache.insert(cache_key, Arc::clone(&ctx));
Ok(ctx)
}
#[allow(dead_code)]
pub async fn clear_context_cache() {
let mut cache = CONTEXT_CACHE.write().await;
cache.clear();
tracing::info!("Cleared TPC-H context cache");
}
#[allow(dead_code)]
pub async fn cache_size() -> usize {
let cache = CONTEXT_CACHE.read().await;
cache.len()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_context_caching() -> Result<()> {
let ctx1 = get_or_create_tpc_h_context(ScaleFactor::SF01).await?;
let initial_size = cache_size().await;
assert!(initial_size > 0, "Cache should contain at least one entry");
let ctx2 = get_or_create_tpc_h_context(ScaleFactor::SF01).await?;
let size_after_second = cache_size().await;
assert_eq!(
size_after_second, initial_size,
"Cache size shouldn't change for same scale"
);
assert!(
Arc::ptr_eq(&ctx1, &ctx2),
"Should return the same cached instance"
);
let ctx3 = get_or_create_tpc_h_context(ScaleFactor::SF1).await?;
let size_after_different = cache_size().await;
assert!(
size_after_different > initial_size,
"Cache should grow with different scale"
);
assert!(
!Arc::ptr_eq(&ctx1, &ctx3),
"Different scales should have different instances"
);
Ok(())
}
#[tokio::test]
async fn test_concurrent_cache_access() -> Result<()> {
use tokio::task::JoinSet;
let scale = ScaleFactor::SF01;
let reference_ctx = get_or_create_tpc_h_context(scale).await?;
let mut tasks = JoinSet::new();
for i in 0..10 {
tasks.spawn(async move {
let ctx = get_or_create_tpc_h_context(ScaleFactor::SF01)
.await
.expect("Failed to get context");
(i, ctx)
});
}
let mut contexts = Vec::new();
while let Some(result) = tasks.join_next().await {
let (_id, ctx) = result.expect("Task failed");
contexts.push(ctx);
}
assert_eq!(contexts.len(), 10);
for ctx in &contexts {
assert!(
Arc::ptr_eq(&reference_ctx, ctx),
"All concurrent accesses should return the same cached instance"
);
}
Ok(())
}
}