1use std::collections::HashMap;
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::Arc;
20
21use bytes::Bytes;
22use connectrpc::{ConnectError, ConnectRpcService, Context};
23use datafusion::arrow::array::{
24 Array, ArrayRef, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
25 FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array, LargeListArray,
26 LargeStringArray, ListArray, StringArray, StringViewArray, TimestampMicrosecondArray,
27 TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt32Array,
28 UInt64Array,
29};
30use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
31use datafusion::arrow::record_batch::RecordBatch;
32use datafusion::common::{DataFusionError, Result as DataFusionResult};
33use datafusion::datasource::MemTable;
34use datafusion::prelude::SessionContext;
35use exoware_sdk::keys::Key;
36use exoware_sdk::kv_codec::{decode_stored_row, Utf8};
37use exoware_sdk::match_key::MatchKey;
38use exoware_sdk::store::sql::v1::{
39 cell::Kind as ProtoCellKind, Cell as ProtoCell, Column as ProtoColumn, Index as ProtoIndex,
40 IndexLayout as ProtoIndexLayout, ListValue as ProtoListValue, Null as ProtoNull,
41 QueryRequestView, QueryResponse, Row as ProtoRow, Service, ServiceServer, SubscribeRequestView,
42 SubscribeResponse, Table as ProtoTable, TablesRequestView, TablesResponse,
43};
44use exoware_sdk::stream_filter::StreamFilter;
45use exoware_sdk::{StoreClient, StreamSubscription};
46use futures::future::BoxFuture;
47use futures::stream::Stream;
48use futures::FutureExt;
49
50use crate::builder::{projected_column_indices, ProjectedBatchBuilder};
51use crate::codec::decode_primary_key_selected;
52use crate::filter::ScanAccessPlan;
53use crate::predicate::QueryPredicate;
54use crate::schema::KvSchema;
55use crate::types::{
56 IndexLayout, ResolvedIndexSpec, TableModel, KEY_KIND_BITS, PRIMARY_RESERVED_BITS,
57};
58
59const MAX_CONNECTRPC_BODY_BYTES: usize = 256 * 1024 * 1024;
60
61type SubscribeStream = Pin<Box<dyn Stream<Item = Result<SubscribeResponse, ConnectError>> + Send>>;
62
63#[derive(Clone)]
65struct TableStream {
66 model: Arc<TableModel>,
67 schema: SchemaRef,
68 access_plan: Arc<ScanAccessPlan>,
69 match_key: MatchKey,
70 indexes: Arc<Vec<ResolvedIndexSpec>>,
71}
72
73impl TableStream {
74 fn new(model: Arc<TableModel>, indexes: Vec<ResolvedIndexSpec>) -> Self {
75 let projection: Option<Vec<usize>> = Some((0..model.columns.len()).collect());
76 let access_plan = Arc::new(ScanAccessPlan::new(
77 &model,
78 &projection,
79 &QueryPredicate::default(),
80 ));
81 let prefix = u16::from(model.table_prefix) << KEY_KIND_BITS;
82 let match_key = MatchKey {
83 reserved_bits: PRIMARY_RESERVED_BITS,
84 prefix,
85 payload_regex: Utf8::from("(?s-u).*"),
86 };
87 Self {
88 schema: model.schema.clone(),
89 access_plan,
90 model,
91 match_key,
92 indexes: Arc::new(indexes),
93 }
94 }
95
96 fn decode_batch(&self, entries: &[(Key, Bytes)]) -> DataFusionResult<RecordBatch> {
97 let mut builder = ProjectedBatchBuilder::from_access_plan(&self.model, &self.access_plan);
98 for (key, value) in entries {
99 if !self.model.primary_key_codec.matches(key) {
100 continue;
101 }
102 let Some(pk_values) = decode_primary_key_selected(
103 self.model.table_prefix,
104 key,
105 &self.model,
106 &self.access_plan.required_pk_mask,
107 ) else {
108 continue;
109 };
110 let Ok(archived) = decode_stored_row(value) else {
111 continue;
112 };
113 if archived.values.len() != self.model.columns.len() {
114 continue;
115 }
116 let _ = builder.append_archived_row(&pk_values, &archived)?;
117 }
118 builder.finish(&self.schema)
119 }
120}
121
122pub struct SqlServer {
127 ctx: Arc<SessionContext>,
128 streams: HashMap<String, TableStream>,
129 table_names: Vec<String>,
132 store: StoreClient,
133}
134
135impl SqlServer {
136 pub fn new(schema: KvSchema) -> DataFusionResult<Self> {
140 let store = schema.client().clone();
141 let mut streams = HashMap::with_capacity(schema.tables().len());
142 let mut table_names = Vec::with_capacity(schema.tables().len());
143 for (name, config) in schema.tables() {
144 let model =
145 Arc::new(TableModel::from_config(config).map_err(|e| {
146 DataFusionError::Execution(format!("invalid table config: {e}"))
147 })?);
148 let indexes = model
149 .resolve_index_specs(&config.index_specs)
150 .map_err(|e| DataFusionError::Execution(format!("invalid index specs: {e}")))?;
151 streams.insert(name.clone(), TableStream::new(model, indexes));
152 table_names.push(name.clone());
153 }
154 let ctx = SessionContext::new();
155 schema.register_all(&ctx)?;
156 Ok(Self {
157 ctx: Arc::new(ctx),
158 streams,
159 table_names,
160 store,
161 })
162 }
163
164 pub fn session(&self) -> &SessionContext {
167 &self.ctx
168 }
169
170 #[allow(clippy::result_large_err)]
171 fn stream(&self, table: &str) -> Result<&TableStream, ConnectError> {
172 self.streams
173 .get(table)
174 .ok_or_else(|| ConnectError::not_found(format!("unknown table '{table}'")))
175 }
176
177 fn describe_tables(&self) -> Vec<ProtoTable> {
178 self.table_names
179 .iter()
180 .filter_map(|name| {
181 let stream = self.streams.get(name)?;
182 let columns = stream
183 .schema
184 .fields()
185 .iter()
186 .map(|field| ProtoColumn {
187 name: field.name().clone(),
188 data_type: format!("{}", field.data_type()),
189 nullable: field.is_nullable(),
190 ..Default::default()
191 })
192 .collect();
193 let primary_key_columns = stream
194 .model
195 .primary_key_indices
196 .iter()
197 .map(|&idx| idx as u32)
198 .collect();
199 let indexes = stream
200 .indexes
201 .iter()
202 .map(|spec| {
203 let key_set: std::collections::HashSet<usize> =
204 spec.key_columns.iter().copied().collect();
205 ProtoIndex {
206 name: spec.name.clone(),
207 layout: proto_index_layout(spec.layout).into(),
208 key_columns: spec.key_columns.iter().map(|&idx| idx as u32).collect(),
209 cover_columns: spec
213 .value_column_mask
214 .iter()
215 .enumerate()
216 .filter_map(|(idx, covered)| {
217 (*covered && !key_set.contains(&idx)).then_some(idx as u32)
218 })
219 .collect(),
220 ..Default::default()
221 }
222 })
223 .collect();
224 Some(ProtoTable {
225 name: name.clone(),
226 columns,
227 primary_key_columns,
228 indexes,
229 ..Default::default()
230 })
231 })
232 .collect()
233 }
234}
235
236fn proto_index_layout(layout: IndexLayout) -> ProtoIndexLayout {
237 match layout {
238 IndexLayout::Lexicographic => ProtoIndexLayout::INDEX_LAYOUT_LEXICOGRAPHIC,
239 IndexLayout::ZOrder => ProtoIndexLayout::INDEX_LAYOUT_Z_ORDER,
240 }
241}
242
243pub fn sql_connect_stack(server: Arc<SqlServer>) -> ConnectRpcService<ServiceServer<SqlConnect>> {
246 ConnectRpcService::new(ServiceServer::new(SqlConnect::new(server)))
247 .with_limits(
248 connectrpc::Limits::default()
249 .max_request_body_size(MAX_CONNECTRPC_BODY_BYTES)
250 .max_message_size(MAX_CONNECTRPC_BODY_BYTES),
251 )
252 .with_compression(exoware_sdk::connect_compression_registry())
253}
254
255#[derive(Clone)]
257pub struct SqlConnect {
258 server: Arc<SqlServer>,
259}
260
261impl SqlConnect {
262 pub fn new(server: Arc<SqlServer>) -> Self {
263 Self { server }
264 }
265}
266
267impl Service for SqlConnect {
268 fn subscribe(
269 &self,
270 ctx: Context,
271 request: buffa::view::OwnedView<SubscribeRequestView<'static>>,
272 ) -> impl Future<Output = Result<(SubscribeStream, Context), ConnectError>> + Send {
273 let server = self.server.clone();
274 async move {
275 let table_name = request.table.to_string();
276 let where_sql = request.where_sql.trim().to_string();
277 let since = request.since_sequence_number.filter(|seq| *seq != 0);
278 let stream = server.stream(&table_name)?.clone();
279
280 let filter = StreamFilter {
281 match_keys: vec![stream.match_key.clone()],
282 value_filters: vec![],
283 };
284 let sub = server
285 .store
286 .stream()
287 .subscribe(filter, since)
288 .await
289 .map_err(client_error_to_connect)?;
290
291 let output = Box::pin(BatchPredicateStream::new(
292 sub, stream, table_name, where_sql,
293 ));
294 Ok((output as SubscribeStream, ctx))
295 }
296 }
297
298 fn tables(
299 &self,
300 ctx: Context,
301 _request: buffa::view::OwnedView<TablesRequestView<'static>>,
302 ) -> impl Future<Output = Result<(TablesResponse, Context), ConnectError>> + Send {
303 let server = self.server.clone();
304 async move {
305 Ok((
306 TablesResponse {
307 tables: server.describe_tables(),
308 ..Default::default()
309 },
310 ctx,
311 ))
312 }
313 }
314
315 fn query(
316 &self,
317 ctx: Context,
318 request: buffa::view::OwnedView<QueryRequestView<'static>>,
319 ) -> impl Future<Output = Result<(QueryResponse, Context), ConnectError>> + Send {
320 let server = self.server.clone();
321 async move {
322 let sql = request.sql.to_string();
323 let df = server
324 .ctx
325 .sql(&sql)
326 .await
327 .map_err(datafusion_error_to_connect)?;
328 let schema = df.schema().clone();
329 let batches = df.collect().await.map_err(datafusion_error_to_connect)?;
330 let columns: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
331 let rows =
332 record_batches_to_proto_rows(&batches).map_err(datafusion_error_to_connect)?;
333 Ok((
334 QueryResponse {
335 column: columns,
336 rows,
337 ..Default::default()
338 },
339 ctx,
340 ))
341 }
342 }
343}
344
345struct BatchPredicateStream {
346 sub: StreamSubscription,
347 state: TableStream,
348 table_name: String,
349 where_sql: String,
350 building: Option<BoxFuture<'static, Result<Option<SubscribeResponse>, ConnectError>>>,
351}
352
353impl BatchPredicateStream {
354 fn new(
355 sub: StreamSubscription,
356 state: TableStream,
357 table_name: String,
358 where_sql: String,
359 ) -> Self {
360 Self {
361 sub,
362 state,
363 table_name,
364 where_sql,
365 building: None,
366 }
367 }
368}
369
370impl Stream for BatchPredicateStream {
371 type Item = Result<SubscribeResponse, ConnectError>;
372
373 fn poll_next(
374 self: Pin<&mut Self>,
375 cx: &mut std::task::Context<'_>,
376 ) -> std::task::Poll<Option<Self::Item>> {
377 let this = self.get_mut();
378 loop {
379 if let Some(fut) = this.building.as_mut() {
380 match fut.as_mut().poll(cx) {
381 std::task::Poll::Pending => return std::task::Poll::Pending,
382 std::task::Poll::Ready(Ok(Some(resp))) => {
383 this.building = None;
384 return std::task::Poll::Ready(Some(Ok(resp)));
385 }
386 std::task::Poll::Ready(Ok(None)) => {
387 this.building = None;
388 }
389 std::task::Poll::Ready(Err(err)) => {
390 this.building = None;
391 return std::task::Poll::Ready(Some(Err(err)));
392 }
393 }
394 }
395
396 let frame = {
397 let next_fut = this.sub.next();
398 tokio::pin!(next_fut);
399 match next_fut.as_mut().poll(cx) {
400 std::task::Poll::Ready(Ok(Some(frame))) => frame,
401 std::task::Poll::Ready(Ok(None)) => return std::task::Poll::Ready(None),
402 std::task::Poll::Ready(Err(err)) => {
403 return std::task::Poll::Ready(Some(Err(client_error_to_connect(err))));
404 }
405 std::task::Poll::Pending => return std::task::Poll::Pending,
406 }
407 };
408
409 let sequence_number = frame.sequence_number;
410 let entries: Vec<(Key, Bytes)> = frame
411 .entries
412 .into_iter()
413 .map(|entry| (entry.key, entry.value))
414 .collect();
415 let state = this.state.clone();
416 let table_name = this.table_name.clone();
417 let where_sql = this.where_sql.clone();
418 this.building = Some(
419 async move {
420 evaluate_batch(state, table_name, where_sql, sequence_number, entries).await
421 }
422 .boxed(),
423 );
424 }
425 }
426}
427
428async fn evaluate_batch(
429 state: TableStream,
430 table_name: String,
431 where_sql: String,
432 sequence_number: u64,
433 entries: Vec<(Key, Bytes)>,
434) -> Result<Option<SubscribeResponse>, ConnectError> {
435 let batch = state
436 .decode_batch(&entries)
437 .map_err(datafusion_error_to_connect)?;
438 if batch.num_rows() == 0 {
439 return Ok(None);
440 }
441
442 let filtered = if where_sql.is_empty() {
443 batch
444 } else {
445 apply_where(state.schema.clone(), batch, &table_name, &where_sql)
446 .await
447 .map_err(datafusion_error_to_connect)?
448 };
449 if filtered.num_rows() == 0 {
450 return Ok(None);
451 }
452
453 let columns: Vec<String> = filtered
454 .schema()
455 .fields()
456 .iter()
457 .map(|f| f.name().clone())
458 .collect();
459 let rows = record_batches_to_proto_rows(std::slice::from_ref(&filtered))
460 .map_err(datafusion_error_to_connect)?;
461 Ok(Some(SubscribeResponse {
462 sequence_number,
463 column: columns,
464 rows,
465 ..Default::default()
466 }))
467}
468
469async fn apply_where(
470 schema: SchemaRef,
471 batch: RecordBatch,
472 table_name: &str,
473 where_sql: &str,
474) -> DataFusionResult<RecordBatch> {
475 let ctx = SessionContext::new();
476 let mem = MemTable::try_new(schema.clone(), vec![vec![batch]])?;
477 ctx.register_table(table_name, Arc::new(mem))?;
478 let sql = format!("SELECT * FROM {table_name} WHERE {where_sql}");
479 let df = ctx.sql(&sql).await?;
480 let batches = df.collect().await?;
481 if batches.is_empty() {
482 return Ok(RecordBatch::new_empty(schema));
483 }
484 datafusion::arrow::compute::concat_batches(&schema, batches.iter())
485 .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
486}
487
488fn record_batches_to_proto_rows(batches: &[RecordBatch]) -> DataFusionResult<Vec<ProtoRow>> {
489 let mut out = Vec::with_capacity(batches.iter().map(|b| b.num_rows()).sum());
490 for batch in batches {
491 for row_idx in 0..batch.num_rows() {
492 let mut cells = Vec::with_capacity(batch.num_columns());
493 for col_idx in 0..batch.num_columns() {
494 cells.push(arrow_value_to_cell(batch.column(col_idx), row_idx)?);
495 }
496 out.push(ProtoRow {
497 cells,
498 ..Default::default()
499 });
500 }
501 }
502 Ok(out)
503}
504
505fn arrow_value_to_cell(array: &ArrayRef, row: usize) -> DataFusionResult<ProtoCell> {
506 let kind = if array.is_null(row) {
507 ProtoCellKind::NullValue(Box::<ProtoNull>::default())
508 } else {
509 arrow_value_to_kind(array, row)?
510 };
511 Ok(ProtoCell {
512 kind: Some(kind),
513 ..Default::default()
514 })
515}
516
517fn arrow_value_to_kind(array: &ArrayRef, row: usize) -> DataFusionResult<ProtoCellKind> {
518 match array.data_type() {
519 DataType::Int64 => Ok(ProtoCellKind::Int64Value(
520 array
521 .as_any()
522 .downcast_ref::<Int64Array>()
523 .unwrap()
524 .value(row),
525 )),
526 DataType::Int32 => Ok(ProtoCellKind::Int64Value(
527 array
528 .as_any()
529 .downcast_ref::<Int32Array>()
530 .unwrap()
531 .value(row) as i64,
532 )),
533 DataType::UInt64 => Ok(ProtoCellKind::Uint64Value(
534 array
535 .as_any()
536 .downcast_ref::<UInt64Array>()
537 .unwrap()
538 .value(row),
539 )),
540 DataType::UInt32 => Ok(ProtoCellKind::Uint64Value(
541 array
542 .as_any()
543 .downcast_ref::<UInt32Array>()
544 .unwrap()
545 .value(row) as u64,
546 )),
547 DataType::Float64 => Ok(ProtoCellKind::Float64Value(
548 array
549 .as_any()
550 .downcast_ref::<Float64Array>()
551 .unwrap()
552 .value(row),
553 )),
554 DataType::Float32 => Ok(ProtoCellKind::Float64Value(
555 array
556 .as_any()
557 .downcast_ref::<Float32Array>()
558 .unwrap()
559 .value(row) as f64,
560 )),
561 DataType::Boolean => Ok(ProtoCellKind::BooleanValue(
562 array
563 .as_any()
564 .downcast_ref::<BooleanArray>()
565 .unwrap()
566 .value(row),
567 )),
568 DataType::Utf8 => Ok(ProtoCellKind::Utf8Value(
569 array
570 .as_any()
571 .downcast_ref::<StringArray>()
572 .unwrap()
573 .value(row)
574 .to_string(),
575 )),
576 DataType::LargeUtf8 => Ok(ProtoCellKind::Utf8Value(
577 array
578 .as_any()
579 .downcast_ref::<LargeStringArray>()
580 .unwrap()
581 .value(row)
582 .to_string(),
583 )),
584 DataType::Utf8View => Ok(ProtoCellKind::Utf8Value(
585 array
586 .as_any()
587 .downcast_ref::<StringViewArray>()
588 .unwrap()
589 .value(row)
590 .to_string(),
591 )),
592 DataType::FixedSizeBinary(_) => Ok(ProtoCellKind::FixedSizeBinaryValue(
593 array
594 .as_any()
595 .downcast_ref::<FixedSizeBinaryArray>()
596 .unwrap()
597 .value(row)
598 .to_vec(),
599 )),
600 DataType::Date32 => Ok(ProtoCellKind::Date32Value(
601 array
602 .as_any()
603 .downcast_ref::<Date32Array>()
604 .unwrap()
605 .value(row),
606 )),
607 DataType::Date64 => Ok(ProtoCellKind::Date64Value(
608 array
609 .as_any()
610 .downcast_ref::<Date64Array>()
611 .unwrap()
612 .value(row),
613 )),
614 DataType::Timestamp(unit, _) => {
615 let v = match unit {
616 TimeUnit::Second => array
617 .as_any()
618 .downcast_ref::<TimestampSecondArray>()
619 .unwrap()
620 .value(row),
621 TimeUnit::Millisecond => array
622 .as_any()
623 .downcast_ref::<TimestampMillisecondArray>()
624 .unwrap()
625 .value(row),
626 TimeUnit::Microsecond => array
627 .as_any()
628 .downcast_ref::<TimestampMicrosecondArray>()
629 .unwrap()
630 .value(row),
631 TimeUnit::Nanosecond => array
632 .as_any()
633 .downcast_ref::<TimestampNanosecondArray>()
634 .unwrap()
635 .value(row),
636 };
637 Ok(ProtoCellKind::TimestampValue(v))
638 }
639 DataType::Decimal128(_, _) => {
640 let v = array
641 .as_any()
642 .downcast_ref::<Decimal128Array>()
643 .unwrap()
644 .value(row);
645 Ok(ProtoCellKind::Decimal128Value(v.to_be_bytes().to_vec()))
646 }
647 DataType::Decimal256(_, _) => {
648 let v = array
649 .as_any()
650 .downcast_ref::<Decimal256Array>()
651 .unwrap()
652 .value(row);
653 Ok(ProtoCellKind::Decimal256Value(v.to_be_bytes().to_vec()))
654 }
655 DataType::List(_) => {
656 let list = array.as_any().downcast_ref::<ListArray>().unwrap();
657 Ok(ProtoCellKind::ListValue(Box::new(list_array_to_proto(
658 &list.value(row),
659 )?)))
660 }
661 DataType::LargeList(_) => {
662 let list = array.as_any().downcast_ref::<LargeListArray>().unwrap();
663 Ok(ProtoCellKind::ListValue(Box::new(list_array_to_proto(
664 &list.value(row),
665 )?)))
666 }
667 other => Err(DataFusionError::NotImplemented(format!(
668 "cell conversion for arrow type {other:?}"
669 ))),
670 }
671}
672
673fn list_array_to_proto(elements: &ArrayRef) -> DataFusionResult<ProtoListValue> {
674 let mut cells = Vec::with_capacity(elements.len());
675 for idx in 0..elements.len() {
676 cells.push(arrow_value_to_cell(elements, idx)?);
677 }
678 Ok(ProtoListValue {
679 elements: cells,
680 ..Default::default()
681 })
682}
683
684fn datafusion_error_to_connect(err: DataFusionError) -> ConnectError {
685 match err {
686 DataFusionError::Plan(msg)
687 | DataFusionError::SQL(_, Some(msg))
688 | DataFusionError::Configuration(msg)
689 | DataFusionError::NotImplemented(msg) => ConnectError::invalid_argument(msg),
690 DataFusionError::SchemaError(err, _) => ConnectError::invalid_argument(err.to_string()),
691 other => ConnectError::internal(other.to_string()),
692 }
693}
694
695fn client_error_to_connect(err: exoware_sdk::ClientError) -> ConnectError {
696 if let Some(rpc) = err.rpc_error() {
697 ConnectError::new(rpc.code, rpc.message.clone().unwrap_or_default())
698 } else {
699 ConnectError::internal(err.to_string())
700 }
701}
702
703#[allow(dead_code)]
705fn _assert_projected_column_indices_visible() {
706 let _ = projected_column_indices;
707}