use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
};
use datafusion::execution::memory_pool::GreedyMemoryPool;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
use datafusion::prelude::{SessionConfig, SessionContext};
use futures::stream::{self, BoxStream, StreamExt};
use kyma_core::catalog::Catalog;
use kyma_core::segment_format::SegmentFormat;
use kyma_exec::KymaTable;
use std::sync::Arc;
use tonic::{Request, Response, Status, Streaming};
use tracing::debug;
#[derive(Clone)]
pub struct FlightState {
pub catalog: Arc<dyn Catalog>,
pub format: Arc<dyn SegmentFormat>,
pub node_id: Option<kyma_core::types::NodeId>,
}
pub struct FlightQueryService {
state: FlightState,
}
impl FlightQueryService {
pub fn new(state: FlightState) -> Self {
Self { state }
}
async fn serve_extent(
&self,
ticket: &FlightTicket,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let table = self
.state
.catalog
.lookup_table(&ticket.database, &ticket.table)
.await
.map_err(|e| Status::not_found(format!("lookup_table: {e}")))?;
let reader = self
.state
.format
.open_extent(kyma_core::segment_format::OpenExtentInput {
extent_id: kyma_core::types::ExtentId::new(),
table_id: table.id,
schema: table.schema.clone(),
object_path: ticket.object_path.clone(),
byte_size: ticket.byte_size,
})
.await
.map_err(|e| Status::internal(format!("open_extent: {e}")))?;
let block_ids = reader
.pruned_blocks(&kyma_core::segment_format::BlockPredicate::All)
.await
.map_err(|e| Status::internal(format!("pruned_blocks: {e}")))?;
let mut batches = Vec::with_capacity(block_ids.len());
for bid in block_ids {
let b = reader
.read_block(bid, &[])
.await
.map_err(|e| Status::internal(format!("read_block: {e}")))?;
batches.push(b);
}
::metrics::counter!("kyma_flight_serve_extent_total").increment(1);
let s = stream::iter(
batches
.into_iter()
.map(|b| Ok::<_, arrow_flight::error::FlightError>(b)),
);
let encoder = FlightDataEncoderBuilder::new()
.build(s)
.map(|r| r.map_err(|e| Status::internal(format!("encode: {e}"))))
.boxed();
Ok(Response::new(encoder))
}
}
#[derive(Debug, serde::Deserialize)]
struct FlightTicket {
#[serde(default = "default_kind")]
kind: String,
#[serde(default = "default_database")]
database: String,
#[serde(default)]
query: String,
#[serde(default = "default_language")]
language: String,
#[serde(default)]
table: String,
#[serde(default)]
object_path: String,
#[serde(default)]
byte_size: u64,
}
fn default_kind() -> String {
"query".to_string()
}
fn default_database() -> String {
"default".to_string()
}
fn default_language() -> String {
"sql".to_string()
}
#[tonic::async_trait]
impl FlightService for FlightQueryService {
type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
async fn handshake(
&self,
_req: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
let s = stream::empty::<Result<HandshakeResponse, Status>>().boxed();
Ok(Response::new(s))
}
async fn list_flights(
&self,
_req: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented(
"list_flights not supported; issue do_get with a JSON ticket",
))
}
async fn get_flight_info(
&self,
_req: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info not supported; issue do_get directly",
))
}
async fn get_schema(
&self,
_req: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented("get_schema not supported"))
}
async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let ticket = request.into_inner();
let ticket: FlightTicket = serde_json::from_slice(&ticket.ticket)
.map_err(|e| Status::invalid_argument(format!("bad ticket JSON: {e}")))?;
if ticket.kind == "extent" {
return self.serve_extent(&ticket).await;
}
debug!(db = %ticket.database, lang = %ticket.language, sql_len = ticket.query.len(), "flight do_get");
let sql = match ticket.language.as_str() {
"kql" => kyma_kql::kql_to_sql(&ticket.query)
.map_err(|e| Status::invalid_argument(format!("KQL parse: {e}")))?,
"sql" => ticket.query,
other => {
return Err(Status::invalid_argument(format!(
"unknown language `{other}`; use `sql` or `kql`"
)))
}
};
let tables = self
.state
.catalog
.list_tables_in_database(&ticket.database)
.await
.map_err(|e| Status::not_found(format!("list_tables: {e}")))?;
if tables.is_empty() {
return Err(Status::not_found(format!(
"no tables in database {}",
ticket.database
)));
}
let runtime = Arc::new(
RuntimeEnvBuilder::new()
.with_memory_pool(Arc::new(GreedyMemoryPool::new(4 * 1024 * 1024 * 1024)))
.build()
.map_err(|e| Status::internal(format!("runtime: {e}")))?,
);
let ctx = SessionContext::new_with_config_rt(SessionConfig::new(), runtime);
kyma_exec::register_vector_udfs(&ctx);
for t in tables {
let name = t.name.clone();
let tbl: Arc<KymaTable> = match self.state.node_id {
Some(nid) => Arc::new(KymaTable::with_node_id(
t,
self.state.catalog.clone(),
self.state.format.clone(),
nid,
ticket.database.clone(),
)),
None => Arc::new(KymaTable::new(
t,
self.state.catalog.clone(),
self.state.format.clone(),
)),
};
ctx.register_table(&name, tbl)
.map_err(|e| Status::internal(format!("register_table {name}: {e}")))?;
}
let df = ctx
.sql(&sql)
.await
.map_err(|e| Status::invalid_argument(format!("sql plan: {e}")))?;
let stream = df
.execute_stream()
.await
.map_err(|e| Status::internal(format!("execute: {e}")))?;
let mapped = stream
.map(|r| r.map_err(|e| arrow_flight::error::FlightError::ExternalError(Box::new(e))));
let encoder = FlightDataEncoderBuilder::new()
.build(mapped)
.map(|r| r.map_err(|e| Status::internal(format!("encode: {e}"))))
.boxed();
::metrics::counter!("kyma_flight_do_get_total", "lang" => ticket.language).increment(1);
Ok(Response::new(encoder))
}
async fn do_put(
&self,
_req: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
Err(Status::unimplemented(
"do_put not supported; use POST /v1/ingest for now",
))
}
async fn do_action(
&self,
_req: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
Err(Status::unimplemented("do_action not supported"))
}
async fn list_actions(
&self,
_req: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented("list_actions not supported"))
}
async fn do_exchange(
&self,
_req: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented("do_exchange not supported"))
}
async fn poll_flight_info(
&self,
_req: Request<FlightDescriptor>,
) -> Result<Response<PollInfo>, Status> {
Err(Status::unimplemented("poll_flight_info not supported"))
}
}
pub fn flight_server(state: FlightState) -> FlightServiceServer<FlightQueryService> {
FlightServiceServer::new(FlightQueryService::new(state))
}
#[cfg(feature = "web-ui")]
pub fn flight_grpc_web_service(state: FlightState) -> FlightGrpcWebService {
use tower::ServiceBuilder;
let svc = FlightServiceServer::new(FlightQueryService::new(state));
let inner = ServiceBuilder::new()
.layer(tonic_web::GrpcWebLayer::new())
.service(svc);
FlightGrpcWebService { inner }
}
#[cfg(feature = "web-ui")]
#[derive(Clone)]
pub struct FlightGrpcWebService {
inner: tonic_web::GrpcWebService<FlightServiceServer<FlightQueryService>>,
}
#[cfg(feature = "web-ui")]
impl tower::Service<axum::http::Request<axum::body::Body>> for FlightGrpcWebService {
type Response = axum::http::Response<axum::body::Body>;
type Error = std::convert::Infallible;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>,
>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
tower::Service::poll_ready(&mut self.inner, cx).map_err(|e| match e {})
}
fn call(&mut self, req: axum::http::Request<axum::body::Body>) -> Self::Future {
use http_body_util::BodyExt as _;
let (parts, body) = req.into_parts();
let tonic_body: tonic::body::BoxBody = body
.map_err(|e| tonic::Status::internal(e.to_string()))
.boxed_unsync();
let tonic_req = axum::http::Request::from_parts(parts, tonic_body);
let fut = tower::Service::call(&mut self.inner, tonic_req);
Box::pin(async move {
#[allow(clippy::expect_used)]
let resp = fut.await.expect("infallible");
let (parts, body) = resp.into_parts();
let axum_body = axum::body::Body::new(body.map_err(axum::Error::new));
Ok(axum::http::Response::from_parts(parts, axum_body))
})
}
}