use std::collections::{BTreeMap, HashMap};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::{
parser::header::{CassandraVersion, ColumnInfo},
platform::Platform,
schema::{Column, TableSchema},
storage::sstable::reader::SSTableReader,
types::{DataType, Value},
Config, Result,
};
#[derive(Debug, Clone)]
pub struct SchemaDiscoveryConfig {
pub max_sample_rows: usize,
pub aggressive_inference: bool,
pub cache_schemas: bool,
pub cache_ttl_seconds: u64,
pub enable_versioning: bool,
pub max_versions: usize,
}
impl Default for SchemaDiscoveryConfig {
fn default() -> Self {
Self {
max_sample_rows: 1000,
aggressive_inference: true,
cache_schemas: true,
cache_ttl_seconds: 3600, enable_versioning: true,
max_versions: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredSchema {
pub schema: TableSchema,
pub metadata: SchemaMetadata,
pub column_stats: HashMap<String, ColumnStatistics>,
pub inference_confidence: f64,
pub validation_status: ValidationStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaMetadata {
pub discovered_at: SystemTime,
pub source_files: Vec<PathBuf>,
pub rows_sampled: usize,
pub cassandra_version: Option<CassandraVersion>,
pub discovery_method: DiscoveryMethod,
pub version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColumnStatistics {
pub name: String,
pub inferred_type: String,
pub type_confidence: f64,
pub null_percentage: f64,
pub unique_values: usize,
pub avg_size_bytes: f64,
pub min_value: Option<Value>,
pub max_value: Option<Value>,
pub patterns: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ValidationStatus {
Valid,
WarningsPresent,
Invalid,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DiscoveryMethod {
HeaderMetadata,
DataSampling,
Hybrid,
External,
}
#[allow(dead_code)]
pub struct SchemaDiscovery {
config: SchemaDiscoveryConfig,
platform: Arc<Platform>,
core_config: Config,
schema_cache: Arc<RwLock<HashMap<String, (DiscoveredSchema, Instant)>>>,
type_inference: Arc<TypeInferenceEngine>,
validator: Arc<SchemaValidator>,
}
impl SchemaDiscovery {
pub async fn new(
config: SchemaDiscoveryConfig,
platform: Arc<Platform>,
core_config: Config,
) -> Result<Self> {
let type_inference = Arc::new(TypeInferenceEngine::new());
let validator = Arc::new(SchemaValidator::new());
Ok(Self {
config,
platform,
core_config,
schema_cache: Arc::new(RwLock::new(HashMap::new())),
type_inference,
validator,
})
}
pub async fn discover_table_schema(
&self,
keyspace: &str,
table: &str,
sstable_files: &[PathBuf],
) -> Result<DiscoveredSchema> {
let cache_key = format!("{}.{}", keyspace, table);
if self.config.cache_schemas {
if let Some(cached) = self.get_cached_schema(&cache_key).await {
return Ok(cached);
}
}
let discovered = self
.perform_schema_discovery(keyspace, table, sstable_files)
.await?;
if self.config.cache_schemas {
self.cache_schema(cache_key, discovered.clone()).await;
}
Ok(discovered)
}
async fn perform_schema_discovery(
&self,
keyspace: &str,
table: &str,
sstable_files: &[PathBuf],
) -> Result<DiscoveredSchema> {
let start_time = SystemTime::now();
let mut source_files = Vec::new();
let mut all_column_data = HashMap::new();
let mut total_rows_sampled = 0;
let mut cassandra_version = None;
for file_path in sstable_files {
if let Ok(reader) = self.create_reader(file_path).await {
source_files.push(file_path.clone());
if let Ok(header_schema) = self.extract_schema_from_header(&reader).await {
if cassandra_version.is_none() {
let header = reader.header();
cassandra_version = Some(header.cassandra_version);
}
self.merge_header_schema(&mut all_column_data, header_schema);
}
let sampled_data = self.sample_table_data(&reader).await?;
total_rows_sampled += sampled_data.len();
self.analyze_sampled_data(&mut all_column_data, sampled_data);
if total_rows_sampled >= self.config.max_sample_rows {
break;
}
}
}
let schema = self
.infer_table_schema(keyspace, table, &all_column_data)
.await?;
let column_stats = self.calculate_column_statistics(&all_column_data).await;
let inference_confidence = self.calculate_inference_confidence(&column_stats);
let validation_status = self.validator.validate_schema(&schema, &column_stats).await;
let discovery_method = if source_files.is_empty() {
DiscoveryMethod::External
} else if all_column_data.values().any(|cd| cd.header_info.is_some()) {
if total_rows_sampled > 0 {
DiscoveryMethod::Hybrid
} else {
DiscoveryMethod::HeaderMetadata
}
} else {
DiscoveryMethod::DataSampling
};
let metadata = SchemaMetadata {
discovered_at: start_time,
source_files,
rows_sampled: total_rows_sampled,
cassandra_version,
discovery_method,
version: 1,
};
Ok(DiscoveredSchema {
schema,
metadata,
column_stats,
inference_confidence,
validation_status,
})
}
async fn create_reader(&self, file_path: &Path) -> Result<SSTableReader> {
SSTableReader::open(file_path, &self.core_config, self.platform.clone()).await
}
async fn extract_schema_from_header(
&self,
reader: &SSTableReader,
) -> Result<HashMap<String, ColumnInfo>> {
let header = reader.header();
let mut columns = HashMap::new();
for column_def in &header.columns {
columns.insert(column_def.name.clone(), column_def.clone());
}
Ok(columns)
}
async fn sample_table_data(
&self,
reader: &SSTableReader,
) -> Result<Vec<HashMap<String, Value>>> {
let header = reader.header();
let column_names: Vec<String> = header.columns.iter().map(|col| col.name.clone()).collect();
let all_entries = reader.get_all_entries().await?;
let samples: Vec<HashMap<String, Value>> = all_entries
.into_iter()
.take(self.config.max_sample_rows)
.filter_map(|(_table_id, _row_key, value)| {
let mut row_data = HashMap::new();
if !column_names.is_empty() {
row_data.insert(column_names[0].clone(), value);
Some(row_data)
} else {
None
}
})
.collect();
Ok(samples)
}
async fn infer_table_schema(
&self,
keyspace: &str,
table: &str,
column_data: &HashMap<String, ColumnData>,
) -> Result<TableSchema> {
let mut columns = Vec::new();
for (name, data) in column_data {
let data_type = self.type_inference.infer_column_type(data).await;
let column = Column {
name: name.clone(),
data_type: data_type.to_string(),
nullable: true,
default: None,
is_static: false,
};
columns.push(column);
}
columns.sort_by(|a, b| a.name.cmp(&b.name));
Ok(TableSchema {
keyspace: keyspace.to_string(),
table: table.to_string(),
partition_keys: vec![], clustering_keys: vec![],
columns,
comments: HashMap::new(),
})
}
async fn calculate_column_statistics(
&self,
column_data: &HashMap<String, ColumnData>,
) -> HashMap<String, ColumnStatistics> {
let mut stats = HashMap::new();
for (name, data) in column_data {
let stat = ColumnStatistics {
name: name.clone(),
inferred_type: self
.type_inference
.infer_column_type(data)
.await
.to_string(),
type_confidence: data.calculate_type_confidence(),
null_percentage: data.calculate_null_percentage(),
unique_values: data.unique_values.len(),
avg_size_bytes: data.calculate_average_size(),
min_value: data.min_value.clone(),
max_value: data.max_value.clone(),
patterns: data.detected_patterns.clone(),
};
stats.insert(name.clone(), stat);
}
stats
}
fn calculate_inference_confidence(
&self,
column_stats: &HashMap<String, ColumnStatistics>,
) -> f64 {
if column_stats.is_empty() {
return 0.0;
}
let total_confidence: f64 = column_stats.values().map(|stat| stat.type_confidence).sum();
total_confidence / column_stats.len() as f64
}
async fn get_cached_schema(&self, cache_key: &str) -> Option<DiscoveredSchema> {
let cache = self.schema_cache.read().await;
if let Some((schema, cached_at)) = cache.get(cache_key) {
let ttl = Duration::from_secs(self.config.cache_ttl_seconds);
if cached_at.elapsed() < ttl {
return Some(schema.clone());
}
}
None
}
async fn cache_schema(&self, cache_key: String, schema: DiscoveredSchema) {
let mut cache = self.schema_cache.write().await;
cache.insert(cache_key, (schema, Instant::now()));
if cache.len() > 100 {
let oldest_key = cache
.iter()
.min_by_key(|(_, (_, time))| time)
.map(|(key, _)| key.clone());
if let Some(key) = oldest_key {
cache.remove(&key);
}
}
}
fn merge_header_schema(
&self,
column_data: &mut HashMap<String, ColumnData>,
header_columns: HashMap<String, ColumnInfo>,
) {
for (name, column_info) in header_columns {
let entry = column_data.entry(name).or_insert_with(ColumnData::new);
entry.header_info = Some(column_info);
}
}
fn analyze_sampled_data(
&self,
column_data: &mut HashMap<String, ColumnData>,
samples: Vec<HashMap<String, Value>>,
) {
for sample in samples {
for (column_name, value) in sample {
let entry = column_data
.entry(column_name)
.or_insert_with(ColumnData::new);
entry.add_sample_value(value);
}
}
}
}
#[derive(Debug)]
struct ColumnData {
header_info: Option<ColumnInfo>,
sample_values: Vec<Value>,
unique_values: BTreeMap<String, usize>,
null_count: usize,
min_value: Option<Value>,
max_value: Option<Value>,
detected_patterns: Vec<String>,
type_frequency: HashMap<String, usize>,
}
impl ColumnData {
fn new() -> Self {
Self {
header_info: None,
sample_values: Vec::new(),
unique_values: BTreeMap::new(),
null_count: 0,
min_value: None,
max_value: None,
detected_patterns: Vec::new(),
type_frequency: HashMap::new(),
}
}
fn add_sample_value(&mut self, value: Value) {
if value == Value::Null {
self.null_count += 1;
} else {
let type_name = value.type_name();
*self.type_frequency.entry(type_name).or_insert(0) += 1;
if self.unique_values.len() < 1000 {
let value_str = format!("{:?}", value);
*self.unique_values.entry(value_str).or_insert(0) += 1;
}
if self.min_value.is_none() || Some(&value) < self.min_value.as_ref() {
self.min_value = Some(value.clone());
}
if self.max_value.is_none() || Some(&value) > self.max_value.as_ref() {
self.max_value = Some(value.clone());
}
self.sample_values.push(value);
}
}
fn calculate_type_confidence(&self) -> f64 {
if self.type_frequency.is_empty() {
return 0.0;
}
let total_samples = self.type_frequency.values().sum::<usize>();
let max_frequency = *self.type_frequency.values().max().unwrap_or(&0);
max_frequency as f64 / total_samples as f64
}
fn calculate_null_percentage(&self) -> f64 {
let total = self.sample_values.len() + self.null_count;
if total == 0 {
0.0
} else {
self.null_count as f64 / total as f64
}
}
fn calculate_average_size(&self) -> f64 {
if self.sample_values.is_empty() {
0.0
} else {
let total_size: usize = self.sample_values.iter().map(|v| v.estimate_size()).sum();
total_size as f64 / self.sample_values.len() as f64
}
}
}
struct TypeInferenceEngine;
impl TypeInferenceEngine {
fn new() -> Self {
Self
}
async fn infer_column_type(&self, column_data: &ColumnData) -> DataType {
if let Some(ref header_info) = column_data.header_info {
return self.convert_cql_type_to_data_type(&header_info.column_type);
}
if let Some(most_common_type) = column_data
.type_frequency
.iter()
.max_by_key(|(_, count)| *count)
.map(|(type_name, _)| type_name)
{
return self.string_to_data_type(most_common_type);
}
DataType::Text }
fn convert_cql_type_to_data_type(&self, type_name: &str) -> DataType {
match type_name.to_lowercase().as_str() {
"text" | "varchar" | "ascii" => DataType::Text,
"int" => DataType::Integer,
"bigint" => DataType::BigInt,
"boolean" => DataType::Boolean,
"double" => DataType::Float,
"float" => DataType::Float,
"uuid" => DataType::Uuid,
"timestamp" => DataType::Timestamp,
"blob" => DataType::Blob,
_ => DataType::Text,
}
}
fn string_to_data_type(&self, type_name: &str) -> DataType {
match type_name {
"Text" => DataType::Text,
"Integer" => DataType::Integer,
"Float" => DataType::Float,
"Boolean" => DataType::Boolean,
_ => DataType::Text,
}
}
}
struct SchemaValidator;
impl SchemaValidator {
fn new() -> Self {
Self
}
async fn validate_schema(
&self,
_schema: &TableSchema,
column_stats: &HashMap<String, ColumnStatistics>,
) -> ValidationStatus {
let mut warnings = 0;
let mut errors = 0;
for stat in column_stats.values() {
if stat.type_confidence < 0.5 {
warnings += 1;
}
if stat.type_confidence < 0.3 {
errors += 1;
}
}
if errors > 0 {
ValidationStatus::Invalid
} else if warnings > 0 {
ValidationStatus::WarningsPresent
} else {
ValidationStatus::Valid
}
}
}
trait ValueExt {
fn type_name(&self) -> String;
fn estimate_size(&self) -> usize;
}
impl ValueExt for Value {
fn type_name(&self) -> String {
match self {
Value::Null => "Null".to_string(),
Value::Text(_) => "Text".to_string(),
Value::Integer(_) => "Integer".to_string(),
Value::BigInt(_) => "BigInteger".to_string(),
Value::Counter(_) => "Counter".to_string(),
Value::Float(_) => "Float".to_string(),
Value::Boolean(_) => "Boolean".to_string(),
Value::Uuid(_) => "UUID".to_string(),
Value::Timestamp(_) => "Timestamp".to_string(),
Value::Date(_) => "Date".to_string(),
Value::Time(_) => "Time".to_string(),
Value::Inet(_) => "Inet".to_string(),
Value::Blob(_) => "Blob".to_string(),
Value::List(_) => "List".to_string(),
Value::Set(_) => "Set".to_string(),
Value::Map(_) => "Map".to_string(),
Value::Json(_) => "JSON".to_string(),
Value::TinyInt(_) => "TinyInt".to_string(),
Value::SmallInt(_) => "SmallInt".to_string(),
Value::Float32(_) => "Float32".to_string(),
Value::Tuple(_) => "Tuple".to_string(),
Value::Udt(_) => "UDT".to_string(),
Value::Frozen(_) => "Frozen".to_string(),
Value::Varint(_) => "Varint".to_string(),
Value::Decimal { .. } => "Decimal".to_string(),
Value::Duration { .. } => "Duration".to_string(),
Value::Tombstone(_) => "Tombstone".to_string(),
}
}
fn estimate_size(&self) -> usize {
match self {
Value::Null => 0,
Value::Text(s) => s.len(),
Value::Integer(_) => 4,
Value::BigInt(_) => 8,
Value::Counter(_) => 8,
Value::Float(_) => 8,
Value::Boolean(_) => 1,
Value::Uuid(_) => 16,
Value::Timestamp(_) => 8,
Value::Date(_) => 4,
Value::Time(_) => 8,
Value::Inet(bytes) => bytes.len(),
Value::Blob(b) => b.len(),
Value::List(items) => items.iter().map(|v| v.estimate_size()).sum::<usize>() + 8,
Value::Set(items) => items.iter().map(|v| v.estimate_size()).sum::<usize>() + 8,
Value::Map(map) => {
map.iter()
.map(|(k, v)| k.estimate_size() + v.estimate_size())
.sum::<usize>()
+ 16
}
Value::Json(_) => 64, Value::TinyInt(_) => 1,
Value::SmallInt(_) => 2,
Value::Float32(_) => 4,
Value::Tuple(t) => t.iter().map(|v| v.estimate_size()).sum::<usize>() + 8,
Value::Udt(_) => 32, Value::Frozen(f) => f.estimate_size(), Value::Varint(data) => data.len(),
Value::Decimal { unscaled, .. } => 4 + unscaled.len(), Value::Duration { .. } => 12, Value::Tombstone(_) => 8, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_schema_discovery_creation() {
let _temp_dir = TempDir::new().unwrap();
let config = SchemaDiscoveryConfig::default();
let core_config = Config::default();
let platform = Arc::new(Platform::new(&core_config).await.unwrap());
let discovery = SchemaDiscovery::new(config, platform, core_config)
.await
.unwrap();
assert!(!discovery.config.cache_schemas || discovery.schema_cache.read().await.is_empty());
}
#[test]
fn test_column_data_analysis() {
let mut column_data = ColumnData::new();
column_data.add_sample_value(Value::Text("test1".to_string()));
column_data.add_sample_value(Value::Text("test2".to_string()));
column_data.add_sample_value(Value::Null);
column_data.add_sample_value(Value::Text("test3".to_string()));
assert_eq!(column_data.calculate_null_percentage(), 0.25); assert_eq!(column_data.unique_values.len(), 3); assert!(column_data.calculate_type_confidence() > 0.7); }
}