use async_trait::async_trait;
use core_affinity::CoreId;
use datafusion::arrow::util::display::array_value_to_string;
use futures::Sink;
use futures::stream as futures_stream;
use mimalloc::MiMalloc;
use pgwire::api::auth::{self, StartupHandler};
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
QueryResponse, Response, Tag,
};
use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type as PgType};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use pgwire::tokio::process_socket;
use socket2::{Domain, Protocol, Socket, Type as SockType};
use spire_proto::spiredb::{
cluster::cluster_service_client::ClusterServiceClient,
cluster::schema_service_client::SchemaServiceClient,
};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tonic::transport::Channel;
mod sql {
pub mod cache;
pub mod config;
pub mod context;
pub mod ddl;
pub mod distributed;
pub mod distributed_exec;
pub mod dml;
pub mod exec;
pub mod filter;
pub mod pool;
pub mod provider;
pub mod pruning;
pub mod routing;
pub mod statistics;
pub mod topology;
}
mod stream;
use sql::config::{Config, load_config, print_banner};
use sql::context::SpireContext;
use sql::ddl;
use sql::dml;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = load_config();
unsafe {
std::env::set_var("RUST_LOG", &config.log_level);
}
spire_common::init_logging();
print_banner();
let num_workers = if config.num_workers == 0 {
num_cpus::get()
} else {
config.num_workers
};
log::info!(
"Starting SpireSQL with {} worker threads (thread-per-core mode)",
num_workers
);
let config = Arc::new(config);
let mut handles = Vec::with_capacity(num_workers);
for worker_id in 0..num_workers {
let config = config.clone();
let handle = std::thread::Builder::new()
.name(format!("spiresql-worker-{}", worker_id))
.spawn(move || {
let pinned = core_affinity::set_for_current(CoreId { id: worker_id });
if !pinned {
log::warn!("Worker {} failed to pin to core {}", worker_id, worker_id);
} else {
log::debug!("Worker {} pinned to core {}", worker_id, worker_id);
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");
rt.block_on(run_worker(worker_id, config));
})
.expect("Failed to spawn worker thread");
handles.push(handle);
}
for handle in handles {
if let Err(e) = handle.join() {
log::error!("Worker thread panicked: {:?}", e);
}
}
Ok(())
}
async fn run_worker(worker_id: usize, config: Arc<Config>) {
let addr: SocketAddr = config.listen_addr.parse().unwrap_or_else(|_| {
log::error!("Invalid listen address: {}", config.listen_addr);
"0.0.0.0:5432".parse().unwrap()
});
let listener = match create_reuseport_listener(&addr) {
Ok(l) => l,
Err(e) => {
log::error!("Worker {} failed to bind to {}: {}", worker_id, addr, e);
return;
}
};
if worker_id == 0 {
log::info!("SpireSQL listening on {} (SO_REUSEPORT)", addr);
log::info!(
"Query cache: {} (capacity: {})",
if config.enable_cache {
"enabled"
} else {
"disabled"
},
config.query_cache_capacity
);
}
let connect_timeout = std::time::Duration::from_secs(5);
let request_timeout = std::time::Duration::from_secs(30);
let keepalive_interval = std::time::Duration::from_secs(10);
let keepalive_timeout = std::time::Duration::from_secs(20);
let stream_window_size: u32 = 16 * 1024 * 1024;
let connection_window_size: u32 = 32 * 1024 * 1024;
let cluster_channel = match Channel::from_shared(config.cluster_addr.clone()) {
Ok(c) => c
.connect_timeout(connect_timeout)
.timeout(request_timeout)
.http2_keep_alive_interval(keepalive_interval)
.keep_alive_timeout(keepalive_timeout)
.keep_alive_while_idle(true)
.initial_stream_window_size(stream_window_size)
.initial_connection_window_size(connection_window_size)
.connect_lazy(),
Err(e) => {
log::error!("Worker {} invalid cluster addr: {}", worker_id, e);
return;
}
};
let schema_client = SchemaServiceClient::new(cluster_channel.clone());
let cluster_client = ClusterServiceClient::new(cluster_channel);
if worker_id == 0 {
log::info!(
"gRPC channels configured (cluster: {})",
config.cluster_addr
);
}
let ctx = Arc::new(SpireContext::new(schema_client, cluster_client, &config));
if let Err(e) = ctx.register_tables().await {
log::error!(
"Worker {} failed to register tables at startup: {}",
worker_id,
e
);
}
ctx.clone().start_table_refresh_task();
let processor = Arc::new(SpireSqlProcessor {
ctx,
query_parser: Arc::new(NoopQueryParser::new()),
});
let factory = Arc::new(SpireSqlProcessorFactory { handler: processor });
loop {
match listener.accept().await {
Ok((stream, peer)) => {
log::debug!("Worker {} accepted connection from {}", worker_id, peer);
let factory = factory.clone();
tokio::spawn(async move {
if let Err(e) = process_socket(stream, None, factory).await {
log::error!("Client error: {}", e);
}
});
}
Err(e) => {
log::error!("Worker {} accept error: {}", worker_id, e);
}
}
}
}
fn create_reuseport_listener(addr: &SocketAddr) -> std::io::Result<TcpListener> {
let socket = Socket::new(
Domain::for_address(*addr),
SockType::STREAM,
Some(Protocol::TCP),
)?;
socket.set_reuse_address(true)?;
socket.set_reuse_port(true)?;
socket.set_nonblocking(true)?;
socket.bind(&(*addr).into())?;
socket.listen(1024)?;
let std_listener: std::net::TcpListener = socket.into();
TcpListener::from_std(std_listener)
}
pub struct SpireSqlProcessor {
ctx: Arc<SpireContext>,
query_parser: Arc<NoopQueryParser>,
}
#[async_trait]
impl SimpleQueryHandler for SpireSqlProcessor {
async fn do_query<C>(&self, _client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + Unpin + Send + Sync,
{
let ctx = &self.ctx;
use sqlparser::ast::Statement;
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
let dialect = PostgreSqlDialect {};
let statements = Parser::parse_sql(&dialect, query).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
format!("SQL parse error: {}", e),
)))
})?;
for stmt in &statements {
if let Statement::SetVariable { .. } = stmt {
return Ok(vec![Response::Execution(Tag::new("SET"))]);
}
if let Statement::ShowVariable { .. } = stmt {
return Ok(vec![Response::Execution(Tag::new("SHOW"))]);
}
let mut ddl_handler =
ddl::DdlHandler::new(ctx.schema_service.clone(), Some(ctx.topology.clone()));
if let Some(response) = ddl_handler.try_execute(stmt).await? {
ctx.invalidate_query_cache();
if let Err(e) = ctx.register_tables().await {
log::warn!("Failed to refresh tables after DDL: {}", e);
}
return Ok(response);
}
let mut dml_handler = dml::DmlHandler::new(
ctx.region_router.clone(),
ctx.connection_pool.clone(),
ctx.topology.clone(),
ctx.schema_service.clone(),
);
if let Some(response) = dml_handler.try_execute(stmt).await? {
ctx.invalidate_query_cache();
return Ok(response);
}
}
if let Some(cached_batches) = ctx.get_cached_query(query) {
log::debug!("Query cache hit for: {}", query);
return batches_to_pgwire_response(&cached_batches);
}
let session_ctx = &ctx.session_context;
match session_ctx.sql(query).await {
Ok(df) => {
let batches = df.collect().await.map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
format!("Execution failed: {}", e),
)))
})?;
ctx.cache_query_result(query, batches.clone());
batches_to_pgwire_response(&batches)
}
Err(e) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"42000".to_string(),
format!("SQL Error: {}", e),
)))),
}
}
}
#[async_trait]
impl ExtendedQueryHandler for SpireSqlProcessor {
type Statement = String;
type QueryParser = NoopQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.query_parser.clone()
}
async fn do_query<C>(
&self,
_client: &mut C,
portal: &Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
let query = &portal.statement.statement;
let ctx = &self.ctx;
use sqlparser::ast::Statement;
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
let dialect = PostgreSqlDialect {};
let statements = Parser::parse_sql(&dialect, query).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
format!("SQL parse error: {}", e),
)))
})?;
for stmt in &statements {
if let Statement::SetVariable { .. } = stmt {
return Ok(Response::Execution(Tag::new("SET")));
}
if let Statement::ShowVariable { .. } = stmt {
return Ok(Response::Execution(Tag::new("SHOW")));
}
let mut ddl_handler =
ddl::DdlHandler::new(ctx.schema_service.clone(), Some(ctx.topology.clone()));
if let Some(response) = ddl_handler.try_execute(stmt).await? {
ctx.invalidate_query_cache();
if let Err(e) = ctx.register_tables().await {
log::warn!("Failed to refresh tables after DDL: {}", e);
}
return Ok(response
.into_iter()
.next()
.unwrap_or(Response::Execution(Tag::new("OK"))));
}
let mut dml_handler = dml::DmlHandler::new(
ctx.region_router.clone(),
ctx.connection_pool.clone(),
ctx.topology.clone(),
ctx.schema_service.clone(),
);
if let Some(response) = dml_handler.try_execute(stmt).await? {
ctx.invalidate_query_cache();
return Ok(response
.into_iter()
.next()
.unwrap_or(Response::Execution(Tag::new("OK"))));
}
}
let session_ctx = &ctx.session_context;
match session_ctx.sql(query).await {
Ok(df) => {
let batches = df.collect().await.map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
format!("Execution failed: {}", e),
)))
})?;
let responses = batches_to_pgwire_response(&batches)?;
Ok(responses
.into_iter()
.next()
.unwrap_or(Response::Execution(Tag::new("SELECT 0"))))
}
Err(e) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"42000".to_string(),
format!("SQL Error: {}", e),
)))),
}
}
async fn do_describe_statement<C>(
&self,
_client: &mut C,
stmt: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let param_types = stmt
.parameter_types
.iter()
.map(|t| t.clone().unwrap_or(PgType::UNKNOWN))
.collect();
Ok(DescribeStatementResponse::new(param_types, vec![]))
}
async fn do_describe_portal<C>(
&self,
_client: &mut C,
_portal: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
Ok(DescribePortalResponse::new(vec![]))
}
}
fn batches_to_pgwire_response(
batches: &[datafusion::arrow::record_batch::RecordBatch],
) -> PgWireResult<Vec<Response>> {
let mut rows_data = Vec::new();
let mut schema_ref = None;
for batch in batches {
if schema_ref.is_none() {
schema_ref = Some(batch.schema());
}
let schema = batch.schema();
let fields = schema
.fields()
.iter()
.map(|f| {
FieldInfo::new(
f.name().clone(),
None,
None,
map_arrow_type_to_pg_type(f.data_type()),
FieldFormat::Text,
)
})
.collect::<Vec<_>>();
let schema_arc = Arc::new(fields);
let num_rows = batch.num_rows();
for i in 0..num_rows {
let mut encoder = DataRowEncoder::new(schema_arc.clone());
for col in 0..batch.num_columns() {
let array = batch.column(col);
if array.is_null(i) {
encoder.encode_field(&None::<String>).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
e.to_string(),
)))
})?;
} else {
let val_str = array_value_to_string(array, i).unwrap_or_default();
encoder.encode_field(&val_str).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
e.to_string(),
)))
})?;
}
}
rows_data.push(encoder.take_row());
}
}
if let Some(schema) = schema_ref {
let fields = schema
.fields()
.iter()
.map(|f| {
FieldInfo::new(
f.name().clone(),
None,
None,
map_arrow_type_to_pg_type(f.data_type()),
FieldFormat::Text,
)
})
.collect::<Vec<_>>();
let headers = Arc::new(fields);
let row_stream = futures_stream::iter(rows_data.into_iter().map(Ok));
Ok(vec![Response::Query(QueryResponse::new(
headers, row_stream,
))])
} else {
Ok(vec![Response::Execution(Tag::new("OK"))])
}
}
struct SpireSqlProcessorFactory {
handler: Arc<SpireSqlProcessor>,
}
impl PgWireServerHandlers for SpireSqlProcessorFactory {
fn simple_query_handler(&self) -> Arc<impl SimpleQueryHandler> {
self.handler.clone()
}
fn extended_query_handler(&self) -> Arc<impl ExtendedQueryHandler> {
self.handler.clone()
}
fn startup_handler(&self) -> Arc<impl StartupHandler> {
Arc::new(SpireStartupHandler)
}
}
pub struct SpireStartupHandler;
#[async_trait]
impl StartupHandler for SpireStartupHandler {
async fn on_startup<C>(
&self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: std::fmt::Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if let PgWireFrontendMessage::Startup(ref startup) = message {
auth::save_startup_parameters_to_metadata(client, startup);
let params = auth::DefaultServerParameterProvider::default();
auth::finish_authentication(client, ¶ms).await?;
}
Ok(())
}
}
fn map_arrow_type_to_pg_type(dt: &datafusion::arrow::datatypes::DataType) -> pgwire::api::Type {
use datafusion::arrow::datatypes::DataType;
use pgwire::api::Type;
match dt {
DataType::Boolean => Type::BOOL,
DataType::Int8 => Type::CHAR,
DataType::Int16 => Type::INT2,
DataType::Int32 => Type::INT4,
DataType::Int64 => Type::INT8,
DataType::UInt8 => Type::CHAR,
DataType::UInt16 => Type::INT2,
DataType::UInt32 => Type::INT4,
DataType::UInt64 => Type::INT8,
DataType::Float16 => Type::FLOAT4,
DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Utf8 | DataType::LargeUtf8 => Type::VARCHAR,
DataType::Binary | DataType::LargeBinary => Type::BYTEA,
DataType::Date32 => Type::DATE,
DataType::Date64 => Type::DATE,
DataType::Timestamp(_, _) => Type::TIMESTAMP,
_ => Type::UNKNOWN,
}
}