use ahash::AHashSet;
use datafusion::arrow::datatypes::TimeUnit;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::prelude::SessionContext;
use spire_proto::spiredb::{
cluster::cluster_service_client::ClusterServiceClient,
cluster::schema_service_client::SchemaServiceClient,
};
use std::sync::Arc;
use tonic::transport::Channel;
use crate::cache::{SharedLruCache, new_shared_cache};
use crate::distributed::{DistributedConfig, DistributedExecutor};
use crate::pool::{ConnectionPool, PoolConfig};
use crate::provider::SpireProvider;
use crate::routing::RegionRouter;
use crate::statistics::StatisticsProvider;
use crate::topology::ClusterTopology;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use spire_proto::spiredb::cluster::{ColumnType, Empty};
use std::fmt;
use crate::config::Config;
pub const DEFAULT_QUERY_CACHE_CAPACITY: usize = 1024;
pub const DEFAULT_CATALOG: &str = "spire";
#[allow(dead_code)]
pub struct SpireContext {
pub schema_service: SchemaServiceClient<Channel>,
pub session_context: SessionContext,
pub region_router: Arc<RegionRouter>,
pub connection_pool: Arc<ConnectionPool>,
pub distributed_executor: Arc<DistributedExecutor>,
pub stats_provider: Arc<StatisticsProvider>,
pub topology: Arc<ClusterTopology>,
pub query_cache: SharedLruCache<Arc<Vec<RecordBatch>>>,
pub cache_enabled: bool,
}
impl SpireContext {
pub fn new(
schema_service: SchemaServiceClient<Channel>,
cluster_service: ClusterServiceClient<Channel>,
config: &Config,
) -> Self {
let topology = Arc::new(ClusterTopology::new(cluster_service.clone()));
topology.clone().start_refresh_task();
let region_router = Arc::new(RegionRouter::new(cluster_service, topology.clone()));
let connection_pool = Arc::new(ConnectionPool::new(PoolConfig::default()));
let distributed_executor = Arc::new(DistributedExecutor::new(
region_router.clone(),
connection_pool.clone(),
DistributedConfig::default(),
));
let stats_provider = Arc::new(StatisticsProvider::new(schema_service.clone()));
let cache_capacity = if config.query_cache_capacity > 0 {
config.query_cache_capacity
} else {
DEFAULT_QUERY_CACHE_CAPACITY
};
let query_cache = new_shared_cache(cache_capacity);
let session_config = datafusion::prelude::SessionConfig::new()
.with_default_catalog_and_schema(DEFAULT_CATALOG, "public")
.with_information_schema(true);
let session_context = SessionContext::new_with_config(session_config);
Self {
schema_service,
session_context,
region_router,
connection_pool,
distributed_executor,
stats_provider,
topology,
query_cache,
cache_enabled: config.enable_cache,
}
}
pub async fn register_tables(&self) -> Result<(), Box<dyn std::error::Error>> {
let mut client = if let Some(leader) = self.topology.get_leader_address() {
let leader_uri = leader.address.parse::<tonic::transport::Uri>().ok();
let pd_addr = if let Some(uri) = leader_uri {
let host = uri.host().unwrap_or("spiredb");
format!("http://{}:50051", host)
} else {
leader.address.replace(":50052", ":50051")
};
log::debug!("Connecting to PD leader for registration: {}", pd_addr);
match Channel::from_shared(pd_addr) {
Ok(endpoint) => match endpoint.connect().await {
Ok(channel) => SchemaServiceClient::new(channel),
Err(_) => self.schema_service.clone(),
},
Err(_) => self.schema_service.clone(),
}
} else {
self.schema_service.clone()
};
let response = client.list_tables(Empty {}).await?;
let table_list = response.into_inner();
let remote_tables: AHashSet<String> =
table_list.tables.iter().map(|t| t.name.clone()).collect();
for table in table_list.tables {
let table_name = table.name.clone();
let mut fields = Vec::new();
for col in table.columns {
let dt = map_column_type(
ColumnType::try_from(col.r#type).unwrap_or(ColumnType::TypeBytes),
);
fields.push(Field::new(col.name, dt, col.nullable));
}
let schema = Arc::new(Schema::new(fields));
let pk_column = table
.primary_key
.first()
.cloned()
.unwrap_or_else(|| "id".to_string());
let provider = SpireProvider::with_distributed(
table_name.clone(),
schema,
self.distributed_executor.clone(),
pk_column,
self.stats_provider.clone(),
self.connection_pool.clone(),
self.region_router.clone(),
self.topology.clone(),
);
self.session_context
.register_table(table_name.as_str(), Arc::new(provider))?;
if let Err(e) = self.region_router.get_table_regions(&table_name).await {
log::warn!("Failed to pre-warm region cache for {}: {}", table_name, e);
}
if let Err(e) = self.stats_provider.get_table_stats(&table_name).await {
log::warn!("Failed to pre-warm stats for {}: {}", table_name, e);
}
}
if let Some(catalog) = self.session_context.catalog(DEFAULT_CATALOG)
&& let Some(schema) = catalog.schema("public")
{
let local_tables = schema.table_names();
for local_table in local_tables {
if !remote_tables.contains(&local_table) {
log::info!("Deregistering stale table: {}", local_table);
self.session_context.deregister_table(&local_table)?;
}
}
}
Ok(())
}
pub fn start_table_refresh_task(self: Arc<Self>) {
tokio::spawn(async move {
let refresh_interval = std::time::Duration::from_secs(2);
loop {
tokio::time::sleep(refresh_interval).await;
if let Err(e) = self.register_tables().await {
log::debug!("Table refresh failed: {}", e);
}
}
});
}
fn hash_query(query: &str) -> u64 {
use ahash::AHasher;
use std::hash::{Hash, Hasher};
let mut hasher = AHasher::default();
query.hash(&mut hasher);
hasher.finish()
}
pub fn get_cached_query(&self, query: &str) -> Option<Arc<Vec<RecordBatch>>> {
if !self.cache_enabled {
return None;
}
let hash = Self::hash_query(query);
self.query_cache.get_and_touch(hash)
}
pub fn cache_query_result(&self, query: &str, batches: Vec<RecordBatch>) {
if !self.cache_enabled {
return;
}
let hash = Self::hash_query(query);
self.query_cache.insert(hash, Arc::new(batches));
}
pub fn invalidate_query_cache(&self) {
self.query_cache.clear();
log::debug!("Query cache invalidated");
}
#[allow(dead_code)]
pub fn executor(&self) -> &DistributedExecutor {
&self.distributed_executor
}
#[allow(dead_code)]
pub fn router(&self) -> &RegionRouter {
&self.region_router
}
#[allow(dead_code)]
pub fn stats(&self) -> &StatisticsProvider {
&self.stats_provider
}
}
fn map_column_type(ct: ColumnType) -> DataType {
match ct {
ColumnType::TypeInt8 => DataType::Int8,
ColumnType::TypeInt16 => DataType::Int16,
ColumnType::TypeInt32 => DataType::Int32,
ColumnType::TypeInt64 => DataType::Int64,
ColumnType::TypeUint8 => DataType::UInt8,
ColumnType::TypeUint16 => DataType::UInt16,
ColumnType::TypeUint32 => DataType::UInt32,
ColumnType::TypeUint64 => DataType::UInt64,
ColumnType::TypeFloat32 => DataType::Float32,
ColumnType::TypeFloat64 => DataType::Float64,
ColumnType::TypeBool => DataType::Boolean,
ColumnType::TypeString => DataType::Utf8,
ColumnType::TypeBytes => DataType::Binary,
ColumnType::TypeDate => DataType::Date32,
ColumnType::TypeTimestamp => DataType::Timestamp(TimeUnit::Microsecond, None),
ColumnType::TypeDecimal => DataType::Decimal128(38, 10),
ColumnType::TypeList => DataType::Utf8,
ColumnType::TypeVector => DataType::Binary,
}
}
impl fmt::Debug for SpireContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SpireContext")
.field("schema_service", &self.schema_service)
.field("session_context", &"SessionContext")
.field("region_router", &"RegionRouter")
.field("connection_pool", &"ConnectionPool")
.field("distributed_executor", &"DistributedExecutor")
.field("stats_provider", &"StatisticsProvider")
.field("query_cache", &"LruCache")
.field("cache_enabled", &self.cache_enabled)
.finish()
}
}