pub mod aggregator;
pub mod cql_parser;
pub mod discovery;
#[cfg(feature = "experimental")]
pub mod json_exporter;
pub mod parser;
pub mod registry;
pub use aggregator::{
AggregatorConfig, LoadErrorType, LoadResult, SchemaAggregator, SchemaLoadError,
SchemaLoadWarning,
};
pub use cql_parser::{
cql_type_to_type_id, extract_table_name, parse_cql_schema, parse_cql_schema_with_visitor,
parse_create_table, table_name_matches,
};
pub use discovery::{
ColumnDefinition, DiscoveryMethod, IndexDefinition, SchemaDiscoveryConfig,
SchemaDiscoveryEngine, SchemaInfo, SchemaMetadata, TableOptions, TypeInfo, UDTDefinition,
ValidationError, ValidationResults, ValidationStatus, ValidationWarning,
};
pub use registry::{
ParsingContext, RegistryStatistics, SchemaChange, SchemaChangeType, SchemaQuery,
SchemaRegistry, SchemaRegistryConfig, SchemaSource, SchemaValidationStatus, SchemaValidator,
SchemaVersion, ValidationReport,
};
pub use parser::SchemaParser;
#[cfg(feature = "experimental")]
pub use json_exporter::{
JsonClusteringKey, JsonColumn, JsonExportConfig, JsonExporter, JsonFormat, JsonIndex,
JsonMetadata, JsonPerformanceMetrics, JsonPrimaryKey, JsonSchema, JsonTable, JsonTableOptions,
JsonUDT, JsonValidationResults,
};
pub type ColumnSpec = Column;
use crate::error::{Error, Result};
use crate::parser::header::SSTableHeader;
use crate::parser::types::CqlTypeId;
use crate::storage::StorageEngine;
use crate::types::{ComparatorType, UdtTypeDef};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableSchema {
pub keyspace: String,
pub table: String,
pub partition_keys: Vec<KeyColumn>,
pub clustering_keys: Vec<ClusteringColumn>,
pub columns: Vec<Column>,
#[serde(default)]
pub comments: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyColumn {
pub name: String,
#[serde(rename = "type")]
pub data_type: String,
pub position: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusteringColumn {
pub name: String,
#[serde(rename = "type")]
pub data_type: String,
pub position: usize,
#[serde(default)]
pub order: ClusteringOrder,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ClusteringOrder {
#[default]
Asc,
Desc,
}
impl From<&str> for ClusteringOrder {
fn from(s: &str) -> Self {
match s.to_uppercase().as_str() {
"DESC" => ClusteringOrder::Desc,
_ => ClusteringOrder::Asc,
}
}
}
impl std::fmt::Display for ClusteringOrder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusteringOrder::Asc => write!(f, "ASC"),
ClusteringOrder::Desc => write!(f, "DESC"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Column {
pub name: String,
#[serde(rename = "type")]
pub data_type: String,
#[serde(default)]
pub nullable: bool,
#[serde(default)]
pub default: Option<serde_json::Value>,
#[serde(default)]
pub is_static: bool,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum CqlType {
Boolean,
TinyInt,
SmallInt,
Int,
BigInt,
Counter,
Float,
Double,
Decimal,
Text,
Ascii,
Varchar,
Blob,
Timestamp,
Date,
Time,
Uuid,
TimeUuid,
Inet,
Duration,
Varint,
List(Box<CqlType>),
Set(Box<CqlType>),
Map(Box<CqlType>, Box<CqlType>),
Tuple(Vec<CqlType>),
Udt(String, Vec<(String, CqlType)>), Frozen(Box<CqlType>),
Custom(String),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UdtRegistry {
udts: HashMap<String, HashMap<String, UdtTypeDef>>,
}
impl UdtRegistry {
pub fn new() -> Self {
Self {
udts: HashMap::new(),
}
}
pub fn with_cassandra5_defaults() -> Self {
let mut registry = Self::new();
registry.load_cassandra5_system_udts();
registry
}
pub fn register_udt(&mut self, udt_def: UdtTypeDef) {
let keyspace_udts = self.udts.entry(udt_def.keyspace.clone()).or_default();
keyspace_udts.insert(udt_def.name.clone(), udt_def);
}
pub fn get_udt(&self, keyspace: &str, name: &str) -> Option<&UdtTypeDef> {
self.udts.get(keyspace)?.get(name)
}
pub fn get_keyspace_udts(&self, keyspace: &str) -> Option<&HashMap<String, UdtTypeDef>> {
self.udts.get(keyspace)
}
pub fn list_udt_names(&self, keyspace: &str) -> Vec<&str> {
self.udts
.get(keyspace)
.map(|udts| udts.keys().map(|s| s.as_str()).collect())
.unwrap_or_default()
}
pub fn contains_udt(&self, keyspace: &str, name: &str) -> bool {
self.udts
.get(keyspace)
.map(|udts| udts.contains_key(name))
.unwrap_or(false)
}
pub fn remove_udt(&mut self, keyspace: &str, name: &str) -> Option<UdtTypeDef> {
self.udts.get_mut(keyspace)?.remove(name)
}
pub fn clear_keyspace(&mut self, keyspace: &str) {
self.udts.remove(keyspace);
}
pub fn total_udts(&self) -> usize {
self.udts.values().map(|udts| udts.len()).sum()
}
fn load_cassandra5_system_udts(&mut self) {
let address_udt = UdtTypeDef::new("system".to_string(), "address".to_string())
.with_field("street".to_string(), CqlType::Text, true)
.with_field("street2".to_string(), CqlType::Text, true)
.with_field("city".to_string(), CqlType::Text, true)
.with_field("state".to_string(), CqlType::Text, true)
.with_field("zip_code".to_string(), CqlType::Text, true)
.with_field("country".to_string(), CqlType::Text, true)
.with_field(
"coordinates".to_string(),
CqlType::Tuple(vec![CqlType::Double, CqlType::Double]),
true,
);
self.register_udt(address_udt);
let person_udt = UdtTypeDef::new("system".to_string(), "person".to_string())
.with_field("id".to_string(), CqlType::Uuid, false)
.with_field("first_name".to_string(), CqlType::Text, false)
.with_field("last_name".to_string(), CqlType::Text, false)
.with_field("middle_name".to_string(), CqlType::Text, true)
.with_field("age".to_string(), CqlType::Int, true)
.with_field("email".to_string(), CqlType::Text, true)
.with_field(
"phone_numbers".to_string(),
CqlType::Set(Box::new(CqlType::Text)),
true,
)
.with_field(
"addresses".to_string(),
CqlType::List(Box::new(CqlType::Udt("address".to_string(), vec![]))),
true,
)
.with_field(
"metadata".to_string(),
CqlType::Map(Box::new(CqlType::Text), Box::new(CqlType::Text)),
true,
);
self.register_udt(person_udt);
let contact_info_udt = UdtTypeDef::new("system".to_string(), "contact_info".to_string())
.with_field(
"person".to_string(),
CqlType::Udt("person".to_string(), vec![]),
false,
)
.with_field(
"primary_address".to_string(),
CqlType::Udt("address".to_string(), vec![]),
true,
)
.with_field(
"emergency_contacts".to_string(),
CqlType::List(Box::new(CqlType::Udt("person".to_string(), vec![]))),
true,
)
.with_field("last_updated".to_string(), CqlType::Timestamp, true);
self.register_udt(contact_info_udt);
}
pub fn resolve_udt_with_dependencies(
&self,
keyspace: &str,
name: &str,
) -> crate::Result<&UdtTypeDef> {
let udt = self.get_udt(keyspace, name).ok_or_else(|| {
crate::Error::schema(format!(
"UDT '{}' not found in keyspace '{}'",
name, keyspace
))
})?;
for field in &udt.fields {
self.validate_field_type_dependencies(&field.field_type, keyspace)?;
}
Ok(udt)
}
fn validate_field_type_dependencies(
&self,
field_type: &CqlType,
keyspace: &str,
) -> crate::Result<()> {
match field_type {
CqlType::Udt(udt_name, _) => {
if !self.contains_udt(keyspace, udt_name) {
return Err(crate::Error::schema(format!(
"UDT dependency '{}' not found in keyspace '{}'",
udt_name, keyspace
)));
}
}
CqlType::List(inner) | CqlType::Set(inner) | CqlType::Frozen(inner) => {
self.validate_field_type_dependencies(inner, keyspace)?;
}
CqlType::Map(key_type, value_type) => {
self.validate_field_type_dependencies(key_type, keyspace)?;
self.validate_field_type_dependencies(value_type, keyspace)?;
}
CqlType::Tuple(field_types) => {
for tuple_field_type in field_types {
self.validate_field_type_dependencies(tuple_field_type, keyspace)?;
}
}
_ => {} }
Ok(())
}
pub fn get_dependent_udts(&self, keyspace: &str, udt_name: &str) -> Vec<&UdtTypeDef> {
let mut dependents = Vec::new();
if let Some(keyspace_udts) = self.udts.get(keyspace) {
for udt in keyspace_udts.values() {
if udt.name == udt_name {
continue; }
if self.udt_depends_on(udt, udt_name) {
dependents.push(udt);
}
}
}
dependents
}
fn udt_depends_on(&self, udt: &UdtTypeDef, target_udt: &str) -> bool {
for field in &udt.fields {
if self.field_type_depends_on(&field.field_type, target_udt) {
return true;
}
}
false
}
#[allow(clippy::only_used_in_recursion)]
fn field_type_depends_on(&self, field_type: &CqlType, target_udt: &str) -> bool {
match field_type {
CqlType::Udt(udt_name, _) => udt_name == target_udt,
CqlType::List(inner) | CqlType::Set(inner) | CqlType::Frozen(inner) => {
self.field_type_depends_on(inner, target_udt)
}
CqlType::Map(key_type, value_type) => {
self.field_type_depends_on(key_type, target_udt)
|| self.field_type_depends_on(value_type, target_udt)
}
CqlType::Tuple(field_types) => field_types
.iter()
.any(|ft| self.field_type_depends_on(ft, target_udt)),
_ => false,
}
}
pub fn register_udt_with_validation(&mut self, udt_def: UdtTypeDef) -> crate::Result<()> {
for field in &udt_def.fields {
self.validate_field_type_dependencies(&field.field_type, &udt_def.keyspace)?;
}
if self.would_create_circular_dependency(&udt_def) {
return Err(crate::Error::schema(format!(
"Registering UDT '{}' would create circular dependency",
udt_def.name
)));
}
self.register_udt(udt_def);
Ok(())
}
fn would_create_circular_dependency(&self, udt_def: &UdtTypeDef) -> bool {
for field in &udt_def.fields {
if self.field_type_depends_on(&field.field_type, &udt_def.name) {
return true;
}
}
false
}
pub fn export_definitions(&self, keyspace: &str) -> Vec<String> {
let mut definitions = Vec::new();
if let Some(keyspace_udts) = self.udts.get(keyspace) {
for udt in keyspace_udts.values() {
let mut def = format!("CREATE TYPE {}.{} (\n", keyspace, udt.name);
for (i, field) in udt.fields.iter().enumerate() {
if i > 0 {
def.push_str(",\n");
}
def.push_str(&format!(
" {} {}",
field.name,
self.format_cql_type(&field.field_type)
));
}
def.push_str("\n);");
definitions.push(def);
}
}
definitions
}
#[allow(clippy::only_used_in_recursion)]
fn format_cql_type(&self, cql_type: &CqlType) -> String {
match cql_type {
CqlType::Boolean => "boolean".to_string(),
CqlType::TinyInt => "tinyint".to_string(),
CqlType::SmallInt => "smallint".to_string(),
CqlType::Int => "int".to_string(),
CqlType::BigInt => "bigint".to_string(),
CqlType::Counter => "counter".to_string(),
CqlType::Float => "float".to_string(),
CqlType::Double => "double".to_string(),
CqlType::Text | CqlType::Varchar => "text".to_string(),
CqlType::Ascii => "ascii".to_string(),
CqlType::Blob => "blob".to_string(),
CqlType::Timestamp => "timestamp".to_string(),
CqlType::Date => "date".to_string(),
CqlType::Time => "time".to_string(),
CqlType::Uuid => "uuid".to_string(),
CqlType::TimeUuid => "timeuuid".to_string(),
CqlType::Inet => "inet".to_string(),
CqlType::Duration => "duration".to_string(),
CqlType::Varint => "varint".to_string(),
CqlType::Decimal => "decimal".to_string(),
CqlType::List(inner) => format!("list<{}>", self.format_cql_type(inner)),
CqlType::Set(inner) => format!("set<{}>", self.format_cql_type(inner)),
CqlType::Map(key, value) => format!(
"map<{}, {}>",
self.format_cql_type(key),
self.format_cql_type(value)
),
CqlType::Udt(name, _) => name.clone(),
CqlType::Tuple(types) => {
let type_strs: Vec<String> =
types.iter().map(|t| self.format_cql_type(t)).collect();
format!("tuple<{}>", type_strs.join(", "))
}
CqlType::Frozen(inner) => format!("frozen<{}>", self.format_cql_type(inner)),
CqlType::Custom(name) => name.clone(),
}
}
}
impl TableSchema {
pub fn from_sstable_header(header: &SSTableHeader) -> Result<Self> {
let mut partition_keys = Vec::new();
let mut clustering_keys = Vec::new();
let mut regular_columns = Vec::new();
for col_info in &header.columns {
if col_info.is_primary_key {
if col_info.is_clustering {
clustering_keys.push(col_info);
} else {
partition_keys.push(col_info);
}
} else {
regular_columns.push(col_info);
}
}
for col_info in &partition_keys {
if col_info.key_position.is_none() {
return Err(Error::schema(format!(
"Partition key column '{}' missing key_position in SSTable header",
col_info.name
)));
}
}
for col_info in &clustering_keys {
if col_info.key_position.is_none() {
return Err(Error::schema(format!(
"Clustering key column '{}' missing key_position in SSTable header",
col_info.name
)));
}
}
partition_keys.sort_by_key(|c| c.key_position.unwrap());
clustering_keys.sort_by_key(|c| c.key_position.unwrap());
let partition_keys: Vec<KeyColumn> = partition_keys
.iter()
.enumerate()
.map(|(pos, col)| KeyColumn {
name: col.name.clone(),
data_type: col.column_type.clone(),
position: pos, })
.collect();
let clustering_keys: Vec<ClusteringColumn> = clustering_keys
.iter()
.enumerate()
.map(|(pos, col)| ClusteringColumn {
name: col.name.clone(),
data_type: col.column_type.clone(),
position: pos, order: ClusteringOrder::Asc, })
.collect();
let columns: Vec<Column> = header
.columns
.iter()
.map(|col| Column {
name: col.name.clone(),
data_type: col.column_type.clone(),
nullable: !col.is_primary_key, default: None,
is_static: false, })
.collect();
if partition_keys.is_empty() {
return Err(Error::schema(
"No partition keys found in SSTable header".to_string(),
));
}
let schema = TableSchema {
keyspace: header.keyspace.clone(),
table: header.table_name.clone(),
partition_keys,
clustering_keys,
columns,
comments: HashMap::new(),
};
schema.validate()?;
Ok(schema)
}
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = fs::read_to_string(path)
.map_err(|e| Error::schema(format!("Failed to read schema file: {}", e)))?;
Self::from_json(&content)
}
pub fn from_json(json: &str) -> Result<Self> {
let schema: TableSchema = serde_json::from_str(json)
.map_err(|e| Error::schema(format!("Invalid JSON schema: {}", e)))?;
schema.validate()?;
Ok(schema)
}
pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| Error::serialization(format!("Failed to serialize schema: {}", e)))?;
fs::write(path, json)
.map_err(|e| Error::schema(format!("Failed to write schema file: {}", e)))?;
Ok(())
}
pub fn validate(&self) -> Result<()> {
if self.keyspace.is_empty() {
return Err(Error::schema("Keyspace name cannot be empty".to_string()));
}
if self.table.is_empty() {
return Err(Error::schema("Table name cannot be empty".to_string()));
}
if self.partition_keys.is_empty() {
return Err(Error::schema(
"Table must have at least one partition key".to_string(),
));
}
let mut positions: Vec<_> = self.partition_keys.iter().map(|k| k.position).collect();
positions.sort();
for (i, &pos) in positions.iter().enumerate() {
if pos != i {
return Err(Error::schema(format!(
"Partition key positions must be contiguous starting from 0, found gap at position {}",
i
)));
}
}
if !self.clustering_keys.is_empty() {
let mut positions: Vec<_> = self.clustering_keys.iter().map(|k| k.position).collect();
positions.sort();
for (i, &pos) in positions.iter().enumerate() {
if pos != i {
return Err(Error::schema(format!(
"Clustering key positions must be contiguous starting from 0, found gap at position {}",
i
)));
}
}
}
for column in &self.columns {
CqlType::parse(&column.data_type).map_err(|e| {
Error::schema(format!(
"Invalid data type '{}' for column '{}': {}",
column.data_type, column.name, e
))
})?;
}
for key in &self.partition_keys {
if !self.columns.iter().any(|c| c.name == key.name) {
return Err(Error::schema(format!(
"Partition key '{}' not found in columns list",
key.name
)));
}
}
for key in &self.clustering_keys {
if !self.columns.iter().any(|c| c.name == key.name) {
return Err(Error::schema(format!(
"Clustering key '{}' not found in columns list",
key.name
)));
}
}
Ok(())
}
pub fn get_column(&self, name: &str) -> Option<&Column> {
self.columns.iter().find(|c| c.name == name)
}
pub fn is_partition_key(&self, name: &str) -> bool {
self.partition_keys.iter().any(|k| k.name == name)
}
pub fn is_clustering_key(&self, name: &str) -> bool {
self.clustering_keys.iter().any(|k| k.name == name)
}
pub fn ordered_partition_keys(&self) -> Vec<&KeyColumn> {
let mut keys = self.partition_keys.iter().collect::<Vec<_>>();
keys.sort_by_key(|k| k.position);
keys
}
pub fn ordered_clustering_keys(&self) -> Vec<&ClusteringColumn> {
let mut keys = self.clustering_keys.iter().collect::<Vec<_>>();
keys.sort_by_key(|k| k.position);
keys
}
pub fn get_column_comparator(&self, column_name: &str) -> Result<ComparatorType> {
let column = self
.get_column(column_name)
.ok_or_else(|| Error::Schema(format!("Column '{}' not found", column_name)))?;
let cql_type = CqlType::parse(&column.data_type)?;
ComparatorType::from_cql_type(&cql_type)
}
pub fn get_all_comparators(&self) -> Result<HashMap<String, ComparatorType>> {
let mut comparators = HashMap::new();
for column in &self.columns {
let cql_type = CqlType::parse(&column.data_type)?;
let comparator = ComparatorType::from_cql_type(&cql_type)?;
comparators.insert(column.name.clone(), comparator);
}
Ok(comparators)
}
pub fn get_partition_key_comparators(&self) -> Result<Vec<ComparatorType>> {
let mut comparators = Vec::new();
let ordered_keys = self.ordered_partition_keys();
for key_column in ordered_keys {
let cql_type = CqlType::parse(&key_column.data_type)?;
let comparator = ComparatorType::from_cql_type(&cql_type)?;
comparators.push(comparator);
}
Ok(comparators)
}
pub fn get_clustering_key_comparators(&self) -> Result<Vec<ComparatorType>> {
let mut comparators = Vec::new();
let ordered_keys = self.ordered_clustering_keys();
for key_column in ordered_keys {
let cql_type = CqlType::parse(&key_column.data_type)?;
let comparator = ComparatorType::from_cql_type(&cql_type)?;
comparators.push(comparator);
}
Ok(comparators)
}
pub fn is_column_type_compatible(
&self,
column_name: &str,
expected_type: &str,
) -> Result<bool> {
let column_comparator = self.get_column_comparator(column_name)?;
let expected_cql_type = CqlType::parse(expected_type)?;
let expected_comparator = ComparatorType::from_cql_type(&expected_cql_type)?;
Ok(self.comparators_are_compatible(&column_comparator, &expected_comparator))
}
#[allow(clippy::only_used_in_recursion)]
fn comparators_are_compatible(&self, left: &ComparatorType, right: &ComparatorType) -> bool {
match (left, right) {
(ComparatorType::Boolean, ComparatorType::Boolean) => true,
(ComparatorType::TinyInt, ComparatorType::TinyInt) => true,
(ComparatorType::SmallInt, ComparatorType::SmallInt) => true,
(ComparatorType::Int, ComparatorType::Int) => true,
(ComparatorType::BigInt, ComparatorType::BigInt) => true,
(ComparatorType::Float32, ComparatorType::Float32) => true,
(ComparatorType::Float, ComparatorType::Float) => true,
(ComparatorType::Text, ComparatorType::Text) => true,
(ComparatorType::Blob, ComparatorType::Blob) => true,
(ComparatorType::Timestamp, ComparatorType::Timestamp) => true,
(ComparatorType::Uuid, ComparatorType::Uuid) => true,
(ComparatorType::Json, ComparatorType::Json) => true,
(ComparatorType::List(l_elem), ComparatorType::List(r_elem)) => {
self.comparators_are_compatible(l_elem, r_elem)
}
(ComparatorType::Set(l_elem), ComparatorType::Set(r_elem)) => {
self.comparators_are_compatible(l_elem, r_elem)
}
(ComparatorType::Map(l_key, l_val), ComparatorType::Map(r_key, r_val)) => {
self.comparators_are_compatible(l_key, r_key)
&& self.comparators_are_compatible(l_val, r_val)
}
(ComparatorType::Tuple(l_fields), ComparatorType::Tuple(r_fields)) => {
l_fields.len() == r_fields.len()
&& l_fields
.iter()
.zip(r_fields.iter())
.all(|(l, r)| self.comparators_are_compatible(l, r))
}
(
ComparatorType::Udt {
type_name: l_name,
keyspace: l_ks,
..
},
ComparatorType::Udt {
type_name: r_name,
keyspace: r_ks,
..
},
) => l_name == r_name && l_ks == r_ks,
(ComparatorType::Frozen(l_inner), ComparatorType::Frozen(r_inner)) => {
self.comparators_are_compatible(l_inner, r_inner)
}
(ComparatorType::Custom(l_name), ComparatorType::Custom(r_name)) => l_name == r_name,
_ => false,
}
}
#[cfg(test)]
pub fn new_for_testing(keyspace: &str, table: &str) -> Self {
Self {
keyspace: keyspace.to_string(),
table: table.to_string(),
partition_keys: vec![KeyColumn {
name: "id".to_string(),
data_type: "int".to_string(),
position: 0,
}],
clustering_keys: vec![],
columns: vec![Column {
name: "id".to_string(),
data_type: "int".to_string(),
nullable: false,
default: None,
is_static: false,
}],
comments: HashMap::new(),
}
}
}
impl CqlType {
fn split_top_level_types(type_str: &str) -> Result<Vec<&str>> {
let mut parts = Vec::new();
let mut depth = 0usize;
let mut start = 0usize;
for (index, ch) in type_str.char_indices() {
match ch {
'<' => depth += 1,
'>' => {
if depth == 0 {
return Err(Error::schema(format!(
"Invalid nested type syntax: {}",
type_str
)));
}
depth -= 1;
}
',' if depth == 0 => {
parts.push(type_str[start..index].trim());
start = index + ch.len_utf8();
}
_ => {}
}
}
if depth != 0 {
return Err(Error::schema(format!(
"Unbalanced nested type syntax: {}",
type_str
)));
}
parts.push(type_str[start..].trim());
Ok(parts.into_iter().filter(|part| !part.is_empty()).collect())
}
pub fn parse(type_str: &str) -> Result<Self> {
let type_str = type_str.trim();
if let Some(inner) = type_str.strip_prefix("frozen<") {
if let Some(inner) = inner.strip_suffix('>') {
return Ok(CqlType::Frozen(Box::new(Self::parse(inner)?)));
}
}
if let Some(inner) = type_str.strip_prefix("list<") {
if let Some(inner) = inner.strip_suffix('>') {
return Ok(CqlType::List(Box::new(Self::parse(inner)?)));
}
}
if let Some(inner) = type_str.strip_prefix("set<") {
if let Some(inner) = inner.strip_suffix('>') {
return Ok(CqlType::Set(Box::new(Self::parse(inner)?)));
}
}
if let Some(inner) = type_str.strip_prefix("map<") {
if let Some(inner) = inner.strip_suffix('>') {
let parts = Self::split_top_level_types(inner)?;
if parts.len() != 2 {
return Err(Error::schema(format!("Invalid map type: {}", type_str)));
}
return Ok(CqlType::Map(
Box::new(Self::parse(parts[0].trim())?),
Box::new(Self::parse(parts[1].trim())?),
));
}
}
if let Some(inner) = type_str.strip_prefix("tuple<") {
if let Some(inner) = inner.strip_suffix('>') {
let parts = Self::split_top_level_types(inner)?;
let mut types = Vec::new();
for part in parts {
types.push(Self::parse(part.trim())?);
}
return Ok(CqlType::Tuple(types));
}
}
let lowercase_type = type_str.to_lowercase();
let is_primitive = matches!(
lowercase_type.as_str(),
"boolean"
| "bool"
| "tinyint"
| "smallint"
| "int"
| "integer"
| "bigint"
| "long"
| "counter"
| "float"
| "double"
| "decimal"
| "text"
| "varchar"
| "ascii"
| "blob"
| "timestamp"
| "date"
| "time"
| "uuid"
| "timeuuid"
| "inet"
| "duration"
);
if !is_primitive
&& type_str
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '.')
&& !type_str.chars().all(|c| c.is_ascii_lowercase())
{
return Ok(CqlType::Custom(format!("udt:{}", type_str)));
}
match type_str.to_lowercase().as_str() {
"boolean" | "bool" => Ok(CqlType::Boolean),
"tinyint" => Ok(CqlType::TinyInt),
"smallint" => Ok(CqlType::SmallInt),
"int" | "integer" => Ok(CqlType::Int),
"bigint" | "long" => Ok(CqlType::BigInt),
"counter" => Ok(CqlType::Counter),
"float" => Ok(CqlType::Float),
"double" => Ok(CqlType::Double),
"decimal" => Ok(CqlType::Decimal),
"text" | "varchar" => Ok(CqlType::Text),
"ascii" => Ok(CqlType::Ascii),
"blob" => Ok(CqlType::Blob),
"timestamp" => Ok(CqlType::Timestamp),
"date" => Ok(CqlType::Date),
"time" => Ok(CqlType::Time),
"uuid" => Ok(CqlType::Uuid),
"timeuuid" => Ok(CqlType::TimeUuid),
"inet" => Ok(CqlType::Inet),
"duration" => Ok(CqlType::Duration),
"varint" => Ok(CqlType::Varint),
_ => Ok(CqlType::Custom(type_str.to_string())),
}
}
pub fn fixed_size(&self) -> Option<usize> {
match self {
CqlType::Boolean => Some(1),
CqlType::TinyInt => Some(1),
CqlType::SmallInt => Some(2),
CqlType::Int => Some(4),
CqlType::BigInt => Some(8),
CqlType::Counter => Some(8),
CqlType::Float => Some(4),
CqlType::Double => Some(8),
CqlType::Timestamp => Some(8),
CqlType::Date => Some(4),
CqlType::Time => Some(8),
CqlType::Uuid | CqlType::TimeUuid => Some(16),
CqlType::Inet => Some(16), CqlType::Text
| CqlType::Ascii
| CqlType::Varchar
| CqlType::Blob
| CqlType::Decimal
| CqlType::Duration
| CqlType::Varint => None,
CqlType::List(_)
| CqlType::Set(_)
| CqlType::Map(_, _)
| CqlType::Tuple(_)
| CqlType::Udt(_, _) => None,
CqlType::Frozen(inner) => inner.fixed_size(),
CqlType::Custom(_) => None,
}
}
pub fn is_collection(&self) -> bool {
matches!(
self,
CqlType::List(_) | CqlType::Set(_) | CqlType::Map(_, _)
)
}
}
#[derive(Debug)]
pub struct SchemaManager {
#[allow(dead_code)]
storage: Arc<StorageEngine>,
schemas: Arc<RwLock<HashMap<String, TableSchema>>>,
pub(crate) udt_registry: Arc<RwLock<UdtRegistry>>,
}
impl SchemaManager {
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let config = Config::default();
let platform = Arc::new(crate::platform::Platform::new(&config).await?);
let storage = Arc::new(
StorageEngine::open(
path.as_ref(),
&config,
platform,
#[cfg(feature = "state_machine")]
None,
)
.await?,
);
Ok(Self {
storage,
schemas: Arc::new(RwLock::new(HashMap::new())),
udt_registry: Arc::new(RwLock::new(UdtRegistry::new())),
})
}
pub async fn new_with_storage(storage: Arc<StorageEngine>, _config: &Config) -> Result<Self> {
let manager = Self {
storage,
schemas: Arc::new(RwLock::new(HashMap::new())),
udt_registry: Arc::new(RwLock::new(UdtRegistry::new())),
};
manager.load_default_udts().await;
Ok(manager)
}
pub async fn new_with_registry(
storage: Arc<StorageEngine>,
registry: Arc<tokio::sync::RwLock<registry::SchemaRegistry>>,
_config: &Config,
) -> Result<Self> {
let (loaded_schemas, udt_registry) = {
let registry_guard = registry.read().await;
let schemas = registry_guard.list_schemas(None).await?;
let udt_reg = registry_guard.get_udt_registry();
(schemas, udt_reg)
};
let mut schemas_map = HashMap::new();
for schema in loaded_schemas {
let table_id = format!("{}.{}", schema.keyspace, schema.table);
schemas_map.insert(table_id, schema);
}
let manager = Self {
storage,
schemas: Arc::new(RwLock::new(schemas_map)),
udt_registry,
};
Ok(manager)
}
async fn load_default_udts(&self) {
let address_udt = UdtTypeDef::new("test_keyspace".to_string(), "address".to_string())
.with_field("street".to_string(), CqlType::Text, true)
.with_field("city".to_string(), CqlType::Text, true)
.with_field("state".to_string(), CqlType::Text, true)
.with_field("zip_code".to_string(), CqlType::Text, true)
.with_field("country".to_string(), CqlType::Text, true);
self.udt_registry.write().await.register_udt(address_udt);
let person_udt = UdtTypeDef::new("test_keyspace".to_string(), "person".to_string())
.with_field("name".to_string(), CqlType::Text, true)
.with_field("age".to_string(), CqlType::Int, true)
.with_field("email".to_string(), CqlType::Text, true)
.with_field(
"addresses".to_string(),
CqlType::List(Box::new(CqlType::Udt(
"address".to_string(),
vec![
("street".to_string(), CqlType::Text),
("city".to_string(), CqlType::Text),
("state".to_string(), CqlType::Text),
("zip_code".to_string(), CqlType::Text),
("country".to_string(), CqlType::Text),
],
))),
true,
)
.with_field(
"contact_info".to_string(),
CqlType::Map(Box::new(CqlType::Text), Box::new(CqlType::Text)),
true,
);
self.udt_registry.write().await.register_udt(person_udt);
let company_udt = UdtTypeDef::new("test_keyspace".to_string(), "company".to_string())
.with_field("name".to_string(), CqlType::Text, false)
.with_field(
"headquarters".to_string(),
CqlType::Udt(
"address".to_string(),
vec![
("street".to_string(), CqlType::Text),
("city".to_string(), CqlType::Text),
("state".to_string(), CqlType::Text),
("zip_code".to_string(), CqlType::Text),
("country".to_string(), CqlType::Text),
],
),
true,
)
.with_field(
"employees".to_string(),
CqlType::Set(Box::new(CqlType::Udt("person".to_string(), vec![]))),
true,
)
.with_field("founded_year".to_string(), CqlType::Int, true);
self.udt_registry.write().await.register_udt(company_udt);
}
pub async fn register_udt(&self, udt_def: UdtTypeDef) {
self.udt_registry.write().await.register_udt(udt_def);
}
pub async fn get_udt(&self, keyspace: &str, name: &str) -> Option<UdtTypeDef> {
self.udt_registry
.read()
.await
.get_udt(keyspace, name)
.cloned()
}
pub async fn load_schema(&self, table_name: &str) -> Result<TableSchema> {
let schemas = self.schemas.read().await;
if let Some(schema) = schemas.get(table_name) {
return Ok(schema.clone());
}
drop(schemas);
let schema = self.create_default_schema(table_name);
self.schemas
.write()
.await
.insert(table_name.to_string(), schema.clone());
Ok(schema)
}
fn create_default_schema(&self, table_name: &str) -> TableSchema {
TableSchema {
keyspace: "default".to_string(),
table: table_name.to_string(),
partition_keys: vec![KeyColumn {
name: "id".to_string(),
data_type: "uuid".to_string(),
position: 0,
}],
clustering_keys: vec![],
columns: vec![Column {
name: "id".to_string(),
data_type: "uuid".to_string(),
nullable: false,
default: None,
is_static: false,
}],
comments: HashMap::new(),
}
}
pub async fn parse_and_register_cql_schema(&self, cql: &str) -> Result<TableSchema> {
let schema = cql_parser::parse_cql_schema(cql)?;
let table_key = format!("{}.{}", schema.keyspace, schema.table);
self.schemas
.write()
.await
.insert(table_key.clone(), schema.clone());
Ok(schema)
}
pub async fn find_schema_by_table(
&self,
keyspace: &Option<String>,
table: &str,
) -> Option<TableSchema> {
let schemas = self.schemas.read().await;
if let Some(ks) = keyspace {
let key = format!("{}.{}", ks, table);
if let Some(schema) = schemas.get(&key) {
return Some(schema.clone());
}
}
schemas
.values()
.find(|schema| {
cql_parser::table_name_matches(
&Some(schema.keyspace.clone()),
&schema.table,
keyspace,
table,
)
})
.cloned()
}
pub fn extract_table_info(&self, cql: &str) -> Result<(Option<String>, String)> {
cql_parser::extract_table_name(cql)
}
pub fn cql_type_to_internal(&self, cql_type: &str) -> Result<CqlTypeId> {
cql_parser::cql_type_to_type_id(cql_type)
}
pub async fn get_table_schema(&self, table_name: &str) -> Result<TableSchema> {
if let Some(schema) = self.find_schema_by_table(&None, table_name).await {
Ok(schema)
} else {
Err(Error::Schema(format!(
"Table schema not found: {}",
table_name
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_validation() {
let schema_json = r#"
{
"keyspace": "test",
"table": "users",
"partition_keys": [
{"name": "id", "type": "bigint", "position": 0}
],
"clustering_keys": [],
"columns": [
{"name": "id", "type": "bigint", "nullable": false},
{"name": "name", "type": "text", "nullable": true}
]
}
"#;
let schema = TableSchema::from_json(schema_json).unwrap();
assert_eq!(schema.keyspace, "test");
assert_eq!(schema.table, "users");
assert_eq!(schema.partition_keys.len(), 1);
assert_eq!(schema.columns.len(), 2);
}
#[test]
fn test_cql_type_parsing() {
assert_eq!(CqlType::parse("text").unwrap(), CqlType::Text);
assert_eq!(CqlType::parse("bigint").unwrap(), CqlType::BigInt);
match CqlType::parse("list<int>").unwrap() {
CqlType::List(inner) => assert_eq!(*inner, CqlType::Int),
_ => panic!("Expected List type"),
}
match CqlType::parse("map<text, bigint>").unwrap() {
CqlType::Map(key, value) => {
assert_eq!(*key, CqlType::Text);
assert_eq!(*value, CqlType::BigInt);
}
_ => panic!("Expected Map type"),
}
match CqlType::parse("tuple<text, list<int>, map<text, text>>").unwrap() {
CqlType::Tuple(fields) => {
assert_eq!(fields.len(), 3);
assert_eq!(fields[0], CqlType::Text);
assert_eq!(fields[1], CqlType::List(Box::new(CqlType::Int)));
assert_eq!(
fields[2],
CqlType::Map(Box::new(CqlType::Text), Box::new(CqlType::Text))
);
}
_ => panic!("Expected Tuple type"),
}
}
#[test]
fn test_schema_validation_failures() {
let invalid_schema = r#"
{
"keyspace": "test",
"table": "users",
"partition_keys": [],
"clustering_keys": [],
"columns": []
}
"#;
assert!(TableSchema::from_json(invalid_schema).is_err());
let invalid_type = r#"
{
"keyspace": "test",
"table": "users",
"partition_keys": [
{"name": "id", "type": "invalid_type", "position": 0}
],
"clustering_keys": [],
"columns": [
{"name": "id", "type": "invalid_type", "nullable": false}
]
}
"#;
assert!(TableSchema::from_json(invalid_type).is_ok());
}
#[tokio::test]
async fn test_concurrent_schema_access() {
let config = Config::default();
let platform = Arc::new(crate::platform::Platform::new(&config).await.unwrap());
let temp_dir = tempfile::tempdir().unwrap();
let storage = Arc::new(
StorageEngine::open(
temp_dir.path(),
&config,
platform,
#[cfg(feature = "state_machine")]
None,
)
.await
.unwrap(),
);
let manager = Arc::new(
SchemaManager::new_with_storage(storage, &config)
.await
.unwrap(),
);
let mut handles = vec![];
for i in 0..10 {
let m = Arc::clone(&manager);
let handle = tokio::spawn(async move {
let table = format!("table_{}", i % 3); m.load_schema(&table).await.unwrap()
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let schemas = manager.schemas.read().await;
assert!(schemas.len() <= 3); assert!(schemas.contains_key("table_0"));
assert!(schemas.contains_key("table_1"));
assert!(schemas.contains_key("table_2"));
}
#[test]
fn test_schema_from_sstable_header() {
use crate::parser::header::{
CassandraVersion, ColumnInfo, CompressionInfo, SSTableHeader, SSTableStats,
};
use std::collections::HashMap;
let columns = vec![
ColumnInfo {
name: "id".to_string(),
column_type: "int".to_string(),
is_primary_key: true,
key_position: Some(0),
is_static: false,
is_clustering: false,
},
ColumnInfo {
name: "name".to_string(),
column_type: "text".to_string(),
is_primary_key: false,
key_position: None,
is_static: false,
is_clustering: false,
},
];
let header = SSTableHeader {
cassandra_version: CassandraVersion::V5_0Bti,
version: 1,
table_id: [0; 16],
keyspace: "test_ks".to_string(),
table_name: "test_table".to_string(),
generation: 1,
compression: CompressionInfo {
algorithm: "NONE".to_string(),
chunk_size: 0,
parameters: HashMap::new(),
},
stats: SSTableStats::default(),
columns,
properties: HashMap::new(),
};
let schema = TableSchema::from_sstable_header(&header).unwrap();
assert_eq!(schema.keyspace, "test_ks");
assert_eq!(schema.table, "test_table");
assert_eq!(schema.partition_keys.len(), 1);
assert_eq!(schema.partition_keys[0].name, "id");
assert_eq!(schema.columns.len(), 2);
}
}