1#![allow(refining_impl_trait)]
17
18use std::collections::HashMap;
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22
23use crate::proto::sql::v1::{
24 cell::Kind as ProtoCellKind, Cell as ProtoCell, Column as ProtoColumn, Index as ProtoIndex,
25 IndexLayout as ProtoIndexLayout, ListValue as ProtoListValue, Null as ProtoNull,
26 QueryRequestView, QueryResponse, Row as ProtoRow, Service, ServiceServer, SubscribeRequestView,
27 SubscribeResponse, Table as ProtoTable, TablesRequestView, TablesResponse,
28};
29use bytes::Bytes;
30use connectrpc::{ConnectError, ConnectRpcService, RequestContext as Context};
31use datafusion::arrow::array::{
32 Array, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
33 FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeListArray,
34 LargeStringArray, ListArray, StringArray, StringViewArray, TimestampMicrosecondArray,
35 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt32Array,
36 UInt64Array,
37};
38use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
39use datafusion::arrow::record_batch::RecordBatch;
40use datafusion::common::{DataFusionError, Result as DataFusionResult};
41use datafusion::datasource::MemTable;
42use datafusion::prelude::SessionContext;
43use exoware_sdk::keys::Key;
44use exoware_sdk::kv_codec::{decode_stored_row, Utf8};
45use exoware_sdk::match_key::MatchKey;
46use exoware_sdk::stream_filter::StreamFilter;
47use exoware_sdk::{StoreClient, StreamSubscription};
48use futures::future::BoxFuture;
49use futures::stream::Stream;
50use futures::FutureExt;
51
52use crate::builder::{projected_column_indices, ProjectedBatchBuilder};
53use crate::codec::decode_primary_key_selected;
54use crate::filter::ScanAccessPlan;
55use crate::predicate::QueryPredicate;
56use crate::schema::KvSchema;
57use crate::types::{
58 IndexLayout, ResolvedIndexSpec, TableModel, KEY_KIND_BITS, PRIMARY_RESERVED_BITS,
59};
60
61const MAX_CONNECTRPC_BODY_BYTES: usize = 256 * 1024 * 1024;
62
63type SubscribeStream = Pin<Box<dyn Stream<Item = Result<SubscribeResponse, ConnectError>> + Send>>;
64
65#[derive(Clone)]
67struct TableStream {
68 model: Arc<TableModel>,
69 schema: SchemaRef,
70 access_plan: Arc<ScanAccessPlan>,
71 match_key: MatchKey,
72 indexes: Arc<Vec<ResolvedIndexSpec>>,
73}
74
75impl TableStream {
76 fn new(model: Arc<TableModel>, indexes: Vec<ResolvedIndexSpec>) -> Self {
77 let projection: Option<Vec<usize>> = Some((0..model.columns.len()).collect());
78 let access_plan = Arc::new(ScanAccessPlan::new(
79 &model,
80 &projection,
81 &QueryPredicate::default(),
82 ));
83 let prefix = u16::from(model.table_prefix) << KEY_KIND_BITS;
84 let match_key = MatchKey {
85 reserved_bits: PRIMARY_RESERVED_BITS,
86 prefix,
87 payload_regex: Utf8::from("(?s-u).*"),
88 };
89 Self {
90 schema: model.schema.clone(),
91 access_plan,
92 model,
93 match_key,
94 indexes: Arc::new(indexes),
95 }
96 }
97
98 fn decode_batch(&self, entries: &[(Key, Bytes)]) -> DataFusionResult<RecordBatch> {
99 let mut builder = ProjectedBatchBuilder::from_access_plan(&self.model, &self.access_plan);
100 for (key, value) in entries {
101 if !self.model.primary_key_codec.matches(key) {
102 continue;
103 }
104 let Some(pk_values) = decode_primary_key_selected(
105 self.model.table_prefix,
106 key,
107 &self.model,
108 &self.access_plan.required_pk_mask,
109 ) else {
110 continue;
111 };
112 let Ok(archived) = decode_stored_row(value) else {
113 continue;
114 };
115 if archived.values.len() != self.model.columns.len() {
116 continue;
117 }
118 let _ = builder.append_archived_row(&pk_values, &archived)?;
119 }
120 builder.finish(&self.schema)
121 }
122}
123
124pub struct SqlServer {
129 ctx: Arc<SessionContext>,
130 streams: HashMap<String, TableStream>,
131 table_names: Vec<String>,
134 store: StoreClient,
135}
136
137impl SqlServer {
138 pub fn new(schema: KvSchema) -> DataFusionResult<Self> {
142 let store = schema.client().clone();
143 let mut streams = HashMap::with_capacity(schema.tables().len());
144 let mut table_names = Vec::with_capacity(schema.tables().len());
145 for (name, config) in schema.tables() {
146 let model =
147 Arc::new(TableModel::from_config(config).map_err(|e| {
148 DataFusionError::Execution(format!("invalid table config: {e}"))
149 })?);
150 let indexes = model
151 .resolve_index_specs(&config.index_specs)
152 .map_err(|e| DataFusionError::Execution(format!("invalid index specs: {e}")))?;
153 streams.insert(name.clone(), TableStream::new(model, indexes));
154 table_names.push(name.clone());
155 }
156 let ctx = SessionContext::new();
157 schema.register_all(&ctx)?;
158 Ok(Self {
159 ctx: Arc::new(ctx),
160 streams,
161 table_names,
162 store,
163 })
164 }
165
166 pub fn session(&self) -> &SessionContext {
169 &self.ctx
170 }
171
172 #[allow(clippy::result_large_err)]
173 fn stream(&self, table: &str) -> Result<&TableStream, ConnectError> {
174 self.streams
175 .get(table)
176 .ok_or_else(|| ConnectError::not_found(format!("unknown table '{table}'")))
177 }
178
179 fn describe_tables(&self) -> Vec<ProtoTable> {
180 self.table_names
181 .iter()
182 .filter_map(|name| {
183 let stream = self.streams.get(name)?;
184 let columns = stream
185 .schema
186 .fields()
187 .iter()
188 .map(|field| ProtoColumn {
189 name: field.name().clone(),
190 data_type: format!("{}", field.data_type()),
191 nullable: field.is_nullable(),
192 ..Default::default()
193 })
194 .collect();
195 let primary_key_columns = stream
196 .model
197 .primary_key_indices
198 .iter()
199 .map(|&idx| idx as u32)
200 .collect();
201 let indexes = stream
202 .indexes
203 .iter()
204 .map(|spec| {
205 let key_set: std::collections::HashSet<usize> =
206 spec.key_columns.iter().copied().collect();
207 ProtoIndex {
208 name: spec.name.clone(),
209 layout: proto_index_layout(spec.layout).into(),
210 key_columns: spec.key_columns.iter().map(|&idx| idx as u32).collect(),
211 cover_columns: spec
215 .value_column_mask
216 .iter()
217 .enumerate()
218 .filter_map(|(idx, covered)| {
219 (*covered && !key_set.contains(&idx)).then_some(idx as u32)
220 })
221 .collect(),
222 ..Default::default()
223 }
224 })
225 .collect();
226 Some(ProtoTable {
227 name: name.clone(),
228 columns,
229 primary_key_columns,
230 indexes,
231 ..Default::default()
232 })
233 })
234 .collect()
235 }
236}
237
238fn proto_index_layout(layout: IndexLayout) -> ProtoIndexLayout {
239 match layout {
240 IndexLayout::Lexicographic => ProtoIndexLayout::INDEX_LAYOUT_LEXICOGRAPHIC,
241 IndexLayout::ZOrder => ProtoIndexLayout::INDEX_LAYOUT_Z_ORDER,
242 }
243}
244
245pub fn sql_connect_stack(server: Arc<SqlServer>) -> ConnectRpcService<ServiceServer<SqlConnect>> {
248 ConnectRpcService::new(ServiceServer::new(SqlConnect::new(server)))
249 .with_limits(
250 connectrpc::Limits::default()
251 .max_request_body_size(MAX_CONNECTRPC_BODY_BYTES)
252 .max_message_size(MAX_CONNECTRPC_BODY_BYTES),
253 )
254 .with_compression(exoware_sdk::connect_compression_registry())
255}
256
257#[derive(Clone)]
259pub struct SqlConnect {
260 server: Arc<SqlServer>,
261}
262
263impl SqlConnect {
264 pub fn new(server: Arc<SqlServer>) -> Self {
265 Self { server }
266 }
267}
268
269impl Service for SqlConnect {
270 fn subscribe(
271 &self,
272 _ctx: Context,
273 request: buffa::view::OwnedView<SubscribeRequestView<'static>>,
274 ) -> impl Future<Output = connectrpc::ServiceResult<SubscribeStream>> + Send {
275 let server = self.server.clone();
276 async move {
277 let table_name = request.table.to_string();
278 let where_sql = request.where_sql.trim().to_string();
279 let since = request.since_sequence_number.filter(|seq| *seq != 0);
280 let stream = server.stream(&table_name)?.clone();
281
282 let filter = StreamFilter {
283 match_keys: vec![stream.match_key.clone()],
284 value_filters: vec![],
285 };
286 let sub = server
287 .store
288 .stream()
289 .subscribe(filter, since)
290 .await
291 .map_err(client_error_to_connect)?;
292
293 let output = Box::pin(BatchPredicateStream::new(
294 sub, stream, table_name, where_sql,
295 ));
296 Ok(connectrpc::Response::stream(output as SubscribeStream))
297 }
298 }
299
300 fn tables(
301 &self,
302 _ctx: Context,
303 _request: buffa::view::OwnedView<TablesRequestView<'static>>,
304 ) -> impl Future<Output = connectrpc::ServiceResult<TablesResponse>> + Send {
305 let server = self.server.clone();
306 async move {
307 connectrpc::Response::ok(TablesResponse {
308 tables: server.describe_tables(),
309 ..Default::default()
310 })
311 }
312 }
313
314 fn query(
315 &self,
316 _ctx: Context,
317 request: buffa::view::OwnedView<QueryRequestView<'static>>,
318 ) -> impl Future<Output = connectrpc::ServiceResult<QueryResponse>> + Send {
319 let server = self.server.clone();
320 async move {
321 let sql = request.sql.to_string();
322 let df = server
323 .ctx
324 .sql(&sql)
325 .await
326 .map_err(datafusion_error_to_connect)?;
327 let schema = df.schema().clone();
328 let batches = df.collect().await.map_err(datafusion_error_to_connect)?;
329 let columns: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
330 let rows =
331 record_batches_to_proto_rows(&batches).map_err(datafusion_error_to_connect)?;
332 connectrpc::Response::ok(QueryResponse {
333 column: columns,
334 rows,
335 ..Default::default()
336 })
337 }
338 }
339}
340
341struct BatchPredicateStream {
342 sub: StreamSubscription,
343 state: TableStream,
344 table_name: String,
345 where_sql: String,
346 building: Option<BoxFuture<'static, Result<Option<SubscribeResponse>, ConnectError>>>,
347}
348
349impl BatchPredicateStream {
350 fn new(
351 sub: StreamSubscription,
352 state: TableStream,
353 table_name: String,
354 where_sql: String,
355 ) -> Self {
356 Self {
357 sub,
358 state,
359 table_name,
360 where_sql,
361 building: None,
362 }
363 }
364}
365
366impl Stream for BatchPredicateStream {
367 type Item = Result<SubscribeResponse, ConnectError>;
368
369 fn poll_next(
370 self: Pin<&mut Self>,
371 cx: &mut std::task::Context<'_>,
372 ) -> std::task::Poll<Option<Self::Item>> {
373 let this = self.get_mut();
374 loop {
375 if let Some(fut) = this.building.as_mut() {
376 match fut.as_mut().poll(cx) {
377 std::task::Poll::Pending => return std::task::Poll::Pending,
378 std::task::Poll::Ready(Ok(Some(resp))) => {
379 this.building = None;
380 return std::task::Poll::Ready(Some(Ok(resp)));
381 }
382 std::task::Poll::Ready(Ok(None)) => {
383 this.building = None;
384 }
385 std::task::Poll::Ready(Err(err)) => {
386 this.building = None;
387 return std::task::Poll::Ready(Some(Err(err)));
388 }
389 }
390 }
391
392 let frame = {
393 let next_fut = this.sub.next();
394 tokio::pin!(next_fut);
395 match next_fut.as_mut().poll(cx) {
396 std::task::Poll::Ready(Ok(Some(frame))) => frame,
397 std::task::Poll::Ready(Ok(None)) => return std::task::Poll::Ready(None),
398 std::task::Poll::Ready(Err(err)) => {
399 return std::task::Poll::Ready(Some(Err(client_error_to_connect(err))));
400 }
401 std::task::Poll::Pending => return std::task::Poll::Pending,
402 }
403 };
404
405 let sequence_number = frame.sequence_number;
406 let entries: Vec<(Key, Bytes)> = frame
407 .entries
408 .into_iter()
409 .map(|entry| (entry.key, entry.value))
410 .collect();
411 let state = this.state.clone();
412 let table_name = this.table_name.clone();
413 let where_sql = this.where_sql.clone();
414 this.building = Some(
415 async move {
416 evaluate_batch(state, table_name, where_sql, sequence_number, entries).await
417 }
418 .boxed(),
419 );
420 }
421 }
422}
423
424async fn evaluate_batch(
425 state: TableStream,
426 table_name: String,
427 where_sql: String,
428 sequence_number: u64,
429 entries: Vec<(Key, Bytes)>,
430) -> Result<Option<SubscribeResponse>, ConnectError> {
431 let batch = state
432 .decode_batch(&entries)
433 .map_err(datafusion_error_to_connect)?;
434 if batch.num_rows() == 0 {
435 return Ok(None);
436 }
437
438 let filtered = if where_sql.is_empty() {
439 batch
440 } else {
441 apply_where(state.schema.clone(), batch, &table_name, &where_sql)
442 .await
443 .map_err(datafusion_error_to_connect)?
444 };
445 if filtered.num_rows() == 0 {
446 return Ok(None);
447 }
448
449 let columns: Vec<String> = filtered
450 .schema()
451 .fields()
452 .iter()
453 .map(|f| f.name().clone())
454 .collect();
455 let rows = record_batches_to_proto_rows(std::slice::from_ref(&filtered))
456 .map_err(datafusion_error_to_connect)?;
457 Ok(Some(SubscribeResponse {
458 sequence_number,
459 column: columns,
460 rows,
461 ..Default::default()
462 }))
463}
464
465async fn apply_where(
466 schema: SchemaRef,
467 batch: RecordBatch,
468 table_name: &str,
469 where_sql: &str,
470) -> DataFusionResult<RecordBatch> {
471 let ctx = SessionContext::new();
472 let mem = MemTable::try_new(schema.clone(), vec![vec![batch]])?;
473 ctx.register_table(table_name, Arc::new(mem))?;
474 let sql = format!("SELECT * FROM {table_name} WHERE {where_sql}");
475 let df = ctx.sql(&sql).await?;
476 let batches = df.collect().await?;
477 if batches.is_empty() {
478 return Ok(RecordBatch::new_empty(schema));
479 }
480 datafusion::arrow::compute::concat_batches(&schema, batches.iter())
481 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
482}
483
484fn record_batches_to_proto_rows(batches: &[RecordBatch]) -> DataFusionResult<Vec<ProtoRow>> {
485 let mut out = Vec::with_capacity(batches.iter().map(|b| b.num_rows()).sum());
486 for batch in batches {
487 for row_idx in 0..batch.num_rows() {
488 let mut cells = Vec::with_capacity(batch.num_columns());
489 for col_idx in 0..batch.num_columns() {
490 cells.push(arrow_value_to_cell(batch.column(col_idx), row_idx)?);
491 }
492 out.push(ProtoRow {
493 cells,
494 ..Default::default()
495 });
496 }
497 }
498 Ok(out)
499}
500
501fn arrow_value_to_cell(array: &ArrayRef, row: usize) -> DataFusionResult<ProtoCell> {
502 let kind = if array.is_null(row) {
503 ProtoCellKind::NullValue(Box::<ProtoNull>::default())
504 } else {
505 arrow_value_to_kind(array, row)?
506 };
507 Ok(ProtoCell {
508 kind: Some(kind),
509 ..Default::default()
510 })
511}
512
513fn arrow_value_to_kind(array: &ArrayRef, row: usize) -> DataFusionResult<ProtoCellKind> {
514 match array.data_type() {
515 DataType::Int64 => Ok(ProtoCellKind::Int64Value(
516 array
517 .as_any()
518 .downcast_ref::<Int64Array>()
519 .unwrap()
520 .value(row),
521 )),
522 DataType::Int32 => Ok(ProtoCellKind::Int64Value(
523 array
524 .as_any()
525 .downcast_ref::<Int32Array>()
526 .unwrap()
527 .value(row) as i64,
528 )),
529 DataType::UInt64 => Ok(ProtoCellKind::Uint64Value(
530 array
531 .as_any()
532 .downcast_ref::<UInt64Array>()
533 .unwrap()
534 .value(row),
535 )),
536 DataType::UInt32 => Ok(ProtoCellKind::Uint64Value(
537 array
538 .as_any()
539 .downcast_ref::<UInt32Array>()
540 .unwrap()
541 .value(row) as u64,
542 )),
543 DataType::Float64 => Ok(ProtoCellKind::Float64Value(
544 array
545 .as_any()
546 .downcast_ref::<Float64Array>()
547 .unwrap()
548 .value(row),
549 )),
550 DataType::Float32 => Ok(ProtoCellKind::Float64Value(
551 array
552 .as_any()
553 .downcast_ref::<Float32Array>()
554 .unwrap()
555 .value(row) as f64,
556 )),
557 DataType::Boolean => Ok(ProtoCellKind::BooleanValue(
558 array
559 .as_any()
560 .downcast_ref::<BooleanArray>()
561 .unwrap()
562 .value(row),
563 )),
564 DataType::Utf8 => Ok(ProtoCellKind::Utf8Value(
565 array
566 .as_any()
567 .downcast_ref::<StringArray>()
568 .unwrap()
569 .value(row)
570 .to_string(),
571 )),
572 DataType::LargeUtf8 => Ok(ProtoCellKind::Utf8Value(
573 array
574 .as_any()
575 .downcast_ref::<LargeStringArray>()
576 .unwrap()
577 .value(row)
578 .to_string(),
579 )),
580 DataType::Utf8View => Ok(ProtoCellKind::Utf8Value(
581 array
582 .as_any()
583 .downcast_ref::<StringViewArray>()
584 .unwrap()
585 .value(row)
586 .to_string(),
587 )),
588 DataType::FixedSizeBinary(_) => {
589 Ok(ProtoCellKind::FixedSizeBinaryValue(Bytes::copy_from_slice(
590 array
591 .as_any()
592 .downcast_ref::<FixedSizeBinaryArray>()
593 .unwrap()
594 .value(row),
595 )))
596 }
597 DataType::Date32 => Ok(ProtoCellKind::Date32Value(
598 array
599 .as_any()
600 .downcast_ref::<Date32Array>()
601 .unwrap()
602 .value(row),
603 )),
604 DataType::Date64 => Ok(ProtoCellKind::Date64Value(
605 array
606 .as_any()
607 .downcast_ref::<Date64Array>()
608 .unwrap()
609 .value(row),
610 )),
611 DataType::Timestamp(unit, _) => {
612 let v = match unit {
613 TimeUnit::Second => array
614 .as_any()
615 .downcast_ref::<TimestampSecondArray>()
616 .unwrap()
617 .value(row),
618 TimeUnit::Millisecond => array
619 .as_any()
620 .downcast_ref::<TimestampMillisecondArray>()
621 .unwrap()
622 .value(row),
623 TimeUnit::Microsecond => array
624 .as_any()
625 .downcast_ref::<TimestampMicrosecondArray>()
626 .unwrap()
627 .value(row),
628 TimeUnit::Nanosecond => array
629 .as_any()
630 .downcast_ref::<TimestampNanosecondArray>()
631 .unwrap()
632 .value(row),
633 };
634 Ok(ProtoCellKind::TimestampValue(v))
635 }
636 DataType::Decimal128(_, _) => {
637 let v = array
638 .as_any()
639 .downcast_ref::<Decimal128Array>()
640 .unwrap()
641 .value(row);
642 Ok(ProtoCellKind::Decimal128Value(Bytes::copy_from_slice(
643 &v.to_be_bytes(),
644 )))
645 }
646 DataType::Decimal256(_, _) => {
647 let v = array
648 .as_any()
649 .downcast_ref::<Decimal256Array>()
650 .unwrap()
651 .value(row);
652 Ok(ProtoCellKind::Decimal256Value(Bytes::copy_from_slice(
653 &v.to_be_bytes(),
654 )))
655 }
656 DataType::List(_) => {
657 let list = array.as_any().downcast_ref::<ListArray>().unwrap();
658 Ok(ProtoCellKind::ListValue(Box::new(list_array_to_proto(
659 &list.value(row),
660 )?)))
661 }
662 DataType::LargeList(_) => {
663 let list = array.as_any().downcast_ref::<LargeListArray>().unwrap();
664 Ok(ProtoCellKind::ListValue(Box::new(list_array_to_proto(
665 &list.value(row),
666 )?)))
667 }
668 other => Err(DataFusionError::NotImplemented(format!(
669 "cell conversion for arrow type {other:?}"
670 ))),
671 }
672}
673
674fn list_array_to_proto(elements: &ArrayRef) -> DataFusionResult<ProtoListValue> {
675 let mut cells = Vec::with_capacity(elements.len());
676 for idx in 0..elements.len() {
677 cells.push(arrow_value_to_cell(elements, idx)?);
678 }
679 Ok(ProtoListValue {
680 elements: cells,
681 ..Default::default()
682 })
683}
684
685fn datafusion_error_to_connect(err: DataFusionError) -> ConnectError {
686 match err {
687 DataFusionError::Plan(msg)
688 | DataFusionError::SQL(_, Some(msg))
689 | DataFusionError::Configuration(msg)
690 | DataFusionError::NotImplemented(msg) => ConnectError::invalid_argument(msg),
691 DataFusionError::SchemaError(err, _) => ConnectError::invalid_argument(err.to_string()),
692 other => ConnectError::internal(other.to_string()),
693 }
694}
695
696fn client_error_to_connect(err: exoware_sdk::ClientError) -> ConnectError {
697 if let Some(rpc) = err.rpc_error() {
698 ConnectError::new(rpc.code, rpc.message.clone().unwrap_or_default())
699 } else {
700 ConnectError::internal(err.to_string())
701 }
702}
703
704#[allow(dead_code)]
706fn _assert_projected_column_indices_visible() {
707 let _ = projected_column_indices;
708}