use std::pin::Pin;
use arrow::ipc::writer::IpcWriteOptions;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::sql::server::FlightSqlService;
use arrow_flight::sql::server::PeekableFlightDataStream;
use arrow_flight::sql::{
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, CommandPreparedStatementQuery, CommandStatementQuery,
CommandStatementUpdate, ProstMessageExt, TicketStatementQuery,
};
use arrow_flight::{
Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
Ticket,
};
use futures::stream;
use futures::TryStreamExt;
use prost::Message;
use rhei_core::OlapEngine;
use rhei_olap::OlapBackend;
use tonic::{Request, Response, Status, Streaming};
use tracing::{debug, warn};
pub struct RheiFlightSqlService {
olap: OlapBackend,
compression: CompressionType,
auth_token: Option<String>,
}
#[derive(Debug, Clone, Copy)]
pub enum CompressionType {
None,
Zstd,
Lz4,
}
impl RheiFlightSqlService {
pub fn new(olap: OlapBackend) -> Self {
Self {
olap,
compression: CompressionType::Zstd,
auth_token: None,
}
}
pub fn with_compression(olap: OlapBackend, compression: CompressionType) -> Self {
Self {
olap,
compression,
auth_token: None,
}
}
pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
let token = token.into();
if token.is_empty() {
tracing::warn!(
"FlightSQL auth token is empty; ignoring configuration and running without auth"
);
self.auth_token = None;
} else {
self.auth_token = Some(token);
}
self
}
fn check_auth<T>(&self, request: &Request<T>) -> Result<(), Status> {
let Some(expected) = &self.auth_token else {
return Ok(());
};
if expected.is_empty() {
return Err(Status::unauthenticated(
"server auth misconfigured (empty token)",
));
}
let got = request
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let token = got.strip_prefix("Bearer ").unwrap_or("");
if token.is_empty() || token != expected {
return Err(Status::unauthenticated("invalid or missing bearer token"));
}
Ok(())
}
fn ipc_options(&self) -> IpcWriteOptions {
let options = IpcWriteOptions::default();
match self.compression {
CompressionType::None => options,
CompressionType::Zstd => options
.try_with_compression(Some(arrow::ipc::CompressionType::ZSTD))
.unwrap_or_default(),
CompressionType::Lz4 => options
.try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
.unwrap_or_default(),
}
}
async fn execute_streaming(
&self,
sql: &str,
) -> Result<
Pin<Box<dyn futures::Stream<Item = Result<arrow_flight::FlightData, Status>> + Send>>,
Status,
> {
let batch_stream = self.olap.query_stream(sql).await.map_err(|e| {
warn!(error = %e, sql, "OLAP query failed");
Status::internal(format!("query error: {e}"))
})?;
let mapped = batch_stream.map_err(arrow_flight::error::FlightError::ExternalError);
let flight_stream = FlightDataEncoderBuilder::new()
.with_options(self.ipc_options())
.build(mapped)
.map_err(|e| Status::internal(e.to_string()));
Ok(Box::pin(flight_stream))
}
}
#[tonic::async_trait]
impl FlightSqlService for RheiFlightSqlService {
type FlightService = RheiFlightSqlService;
async fn do_handshake(
&self,
request: Request<Streaming<HandshakeRequest>>,
) -> Result<
Response<Pin<Box<dyn futures::Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
Status,
> {
self.check_auth(&request)?;
let mut response = Response::new(Box::pin(stream::once(async {
Ok(HandshakeResponse {
protocol_version: 0,
payload: bytes::Bytes::new(),
})
}))
as Pin<Box<dyn futures::Stream<Item = Result<HandshakeResponse, Status>> + Send>>);
if let Some(token) = &self.auth_token {
if let Ok(val) = format!("Bearer {token}").parse() {
response.metadata_mut().insert("authorization", val);
}
}
Ok(response)
}
async fn get_flight_info_statement(
&self,
query: CommandStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
self.check_auth(&request)?;
let sql = &query.query;
debug!(sql, "get_flight_info_statement");
let ticket = TicketStatementQuery {
statement_handle: sql.as_bytes().to_vec().into(),
};
let any = ticket.as_any();
let ticket_bytes = any.encode_to_vec();
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let info = FlightInfo::new().with_endpoint(endpoint);
Ok(Response::new(info))
}
async fn do_get_statement(
&self,
ticket: TicketStatementQuery,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
self.check_auth(&request)?;
let sql = String::from_utf8(ticket.statement_handle.to_vec())
.map_err(|_| Status::internal("invalid statement handle"))?;
debug!(sql, "do_get_statement (streaming)");
let stream = self.execute_streaming(&sql).await?;
Ok(Response::new(stream))
}
async fn get_flight_info_prepared_statement(
&self,
cmd: CommandPreparedStatementQuery,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
self.check_auth(&request)?;
debug!("get_flight_info_prepared_statement");
let any = cmd.as_any();
let ticket_bytes = any.encode_to_vec();
let endpoint = FlightEndpoint::new().with_ticket(Ticket::new(ticket_bytes));
let info = FlightInfo::new().with_endpoint(endpoint);
Ok(Response::new(info))
}
async fn do_action_create_prepared_statement(
&self,
query: ActionCreatePreparedStatementRequest,
request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
self.check_auth(&request)?;
debug!(sql = query.query, "create_prepared_statement");
Ok(ActionCreatePreparedStatementResult {
prepared_statement_handle: query.query.into_bytes().into(),
dataset_schema: bytes::Bytes::new(),
parameter_schema: bytes::Bytes::new(),
})
}
async fn do_action_close_prepared_statement(
&self,
_query: ActionClosePreparedStatementRequest,
request: Request<Action>,
) -> Result<(), Status> {
self.check_auth(&request)?;
Ok(())
}
async fn do_get_prepared_statement(
&self,
query: CommandPreparedStatementQuery,
request: Request<Ticket>,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
self.check_auth(&request)?;
let sql = String::from_utf8(query.prepared_statement_handle.to_vec())
.map_err(|_| Status::internal("invalid prepared statement handle"))?;
debug!(sql, "do_get_prepared_statement (streaming)");
let stream = self.execute_streaming(&sql).await?;
Ok(Response::new(stream))
}
async fn do_put_statement_update(
&self,
_ticket: CommandStatementUpdate,
request: Request<PeekableFlightDataStream>,
) -> Result<i64, Status> {
self.check_auth(&request)?;
Err(Status::unimplemented(
"write operations not supported — OLAP is read-only",
))
}
async fn register_sql_info(&self, _id: i32, _result: &arrow_flight::sql::SqlInfo) {}
}