use crate::core::{ValidationResult, ValidationSuite};
use crate::error::{Result, TermError};
use crate::sources::DataSource;
use crate::telemetry::TermTelemetry;
use arrow::record_batch::RecordBatch;
use datafusion::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, instrument, span, Level};
pub struct MultiSourceValidator {
ctx: SessionContext,
sources: HashMap<String, Arc<dyn DataSource>>,
query_cache: HashMap<String, CachedResult>,
telemetry: Option<Arc<TermTelemetry>>,
enable_caching: bool,
max_cache_size: usize,
current_cache_size: usize,
}
#[derive(Debug, Clone)]
struct CachedResult {
data: Vec<RecordBatch>,
cached_at: Instant,
size_bytes: usize,
}
impl MultiSourceValidator {
pub fn new() -> Self {
Self::with_context(SessionContext::new())
}
pub fn with_context(ctx: SessionContext) -> Self {
Self {
ctx,
sources: HashMap::new(),
query_cache: HashMap::new(),
telemetry: None,
enable_caching: true,
max_cache_size: 100 * 1024 * 1024, current_cache_size: 0,
}
}
pub fn with_telemetry(mut self, telemetry: Arc<TermTelemetry>) -> Self {
self.telemetry = Some(telemetry);
self
}
pub fn with_caching(mut self, enable: bool) -> Self {
self.enable_caching = enable;
self
}
pub fn with_max_cache_size(mut self, size_bytes: usize) -> Self {
self.max_cache_size = size_bytes;
self
}
#[instrument(skip(self, source, name))]
pub async fn add_source<S: DataSource + 'static>(
&mut self,
name: impl Into<String>,
source: S,
) -> Result<()> {
let name = name.into();
info!("Adding data source: {}", name);
let source = Arc::new(source);
source
.register_with_telemetry(&self.ctx, &name, self.telemetry.as_ref())
.await
.map_err(|e| {
TermError::data_source(
"multi_source",
format!("Failed to register source '{name}': {e}"),
)
})?;
self.sources.insert(name.clone(), source);
info!("Successfully added data source: {}", name);
Ok(())
}
pub fn context(&self) -> &SessionContext {
&self.ctx
}
pub fn get_source(&self, name: &str) -> Option<&Arc<dyn DataSource>> {
self.sources.get(name)
}
pub fn list_sources(&self) -> Vec<String> {
self.sources.keys().cloned().collect()
}
#[instrument(skip(self, suite), fields(suite_name = %suite.name()))]
pub async fn run_suite(&self, suite: &ValidationSuite) -> Result<ValidationResult> {
let span = span!(Level::INFO, "multi_source_validation", suite = %suite.name());
let _enter = span.enter();
info!(
"Running validation suite '{}' with {} registered sources",
suite.name(),
self.sources.len()
);
if self.enable_caching {
self.cleanup_cache();
}
let result = suite.run(&self.ctx).await?;
match &result {
ValidationResult::Success { report, .. } => {
info!(
"Validation suite '{}' succeeded: {} checks passed",
suite.name(),
report.metrics.total_checks
);
}
ValidationResult::Failure { report } => {
info!(
"Validation suite '{}' failed: {} issues found",
suite.name(),
report.issues.len()
);
}
}
Ok(result)
}
#[instrument(skip(self))]
pub async fn execute_query(&mut self, sql: &str) -> Result<Vec<RecordBatch>> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
sql.hash(&mut hasher);
let cache_key = format!("{:x}", hasher.finish());
if self.enable_caching {
if let Some(cached) = self.query_cache.get(&cache_key) {
debug!("Cache hit for query");
return Ok(cached.data.clone());
}
}
debug!("Executing query: {}", sql);
let df = self.ctx.sql(sql).await.map_err(|e| {
TermError::data_source("multi_source", format!("Query execution failed: {e}"))
})?;
let batches = df.collect().await.map_err(|e| {
TermError::data_source("multi_source", format!("Failed to collect results: {e}"))
})?;
if self.enable_caching {
self.cache_result(cache_key, batches.clone());
}
Ok(batches)
}
fn cache_result(&mut self, key: String, data: Vec<RecordBatch>) {
let size_bytes = data.iter().map(|batch| batch.get_array_memory_size()).sum();
if self.current_cache_size + size_bytes > self.max_cache_size {
self.evict_cache_entries(size_bytes);
}
let cached = CachedResult {
data,
cached_at: Instant::now(),
size_bytes,
};
self.current_cache_size += size_bytes;
self.query_cache.insert(key, cached);
}
fn evict_cache_entries(&mut self, needed_bytes: usize) {
let mut entries_to_remove = Vec::new();
{
let mut entries: Vec<_> = self.query_cache.iter().collect();
entries.sort_by_key(|(_, cached)| cached.cached_at);
for (key, cached) in entries {
if self.current_cache_size + needed_bytes <= self.max_cache_size {
break;
}
entries_to_remove.push((key.clone(), cached.size_bytes));
}
}
for (key, size) in entries_to_remove {
self.query_cache.remove(&key);
self.current_cache_size -= size;
debug!("Evicted cache entry to free {} bytes", size);
}
}
fn cleanup_cache(&self) {
debug!(
"Cache cleanup: {} entries, {} bytes",
self.query_cache.len(),
self.current_cache_size
);
}
pub fn cache_stats(&self) -> CacheStats {
CacheStats {
entries: self.query_cache.len(),
size_bytes: self.current_cache_size,
max_size_bytes: self.max_cache_size,
hit_rate: 0.0, }
}
}
impl Default for MultiSourceValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entries: usize,
pub size_bytes: usize,
pub max_size_bytes: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sources::CsvSource;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_csv(data: &str) -> Result<NamedTempFile> {
let mut temp_file = NamedTempFile::with_suffix(".csv")?;
write!(temp_file, "{data}")?;
temp_file.flush()?;
Ok(temp_file)
}
#[tokio::test]
async fn test_multi_source_validator_creation() {
let validator = MultiSourceValidator::new();
assert_eq!(validator.sources.len(), 0);
assert!(validator.enable_caching);
}
#[tokio::test]
async fn test_add_source() -> Result<()> {
let mut validator = MultiSourceValidator::new();
let csv_data = "id,name\n1,Alice\n2,Bob";
let temp_file = create_test_csv(csv_data)?;
let source = CsvSource::new(temp_file.path().to_string_lossy().to_string())?;
validator.add_source("test_data", source).await?;
assert_eq!(validator.sources.len(), 1);
assert!(validator.get_source("test_data").is_some());
Ok(())
}
#[tokio::test]
async fn test_list_sources() -> Result<()> {
let mut validator = MultiSourceValidator::new();
let csv_data = "id,value\n1,100";
let temp_file1 = create_test_csv(csv_data)?;
let temp_file2 = create_test_csv(csv_data)?;
validator
.add_source(
"source1",
CsvSource::new(temp_file1.path().to_string_lossy().to_string())?,
)
.await?;
validator
.add_source(
"source2",
CsvSource::new(temp_file2.path().to_string_lossy().to_string())?,
)
.await?;
let sources = validator.list_sources();
assert_eq!(sources.len(), 2);
assert!(sources.contains(&"source1".to_string()));
assert!(sources.contains(&"source2".to_string()));
Ok(())
}
#[tokio::test]
async fn test_cache_configuration() {
let validator = MultiSourceValidator::new()
.with_caching(false)
.with_max_cache_size(1024 * 1024);
assert!(!validator.enable_caching);
assert_eq!(validator.max_cache_size, 1024 * 1024);
}
#[tokio::test]
async fn test_cache_stats() {
let validator = MultiSourceValidator::new();
let stats = validator.cache_stats();
assert_eq!(stats.entries, 0);
assert_eq!(stats.size_bytes, 0);
assert_eq!(stats.max_size_bytes, 100 * 1024 * 1024);
}
}