Skip to main content

laminar_db/
handle.rs

1//! Handle types for query results, source access, and subscriptions.
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5use std::time::Duration;
6
7use arrow::array::RecordBatch;
8use arrow::datatypes::SchemaRef;
9
10use laminar_core::streaming::{Record, Subscription};
11
12use crate::catalog::{ArrowRecord, SourceEntry};
13use crate::DbError;
14
15/// Result of executing a SQL statement.
16#[derive(Debug)]
17pub enum ExecuteResult {
18    /// DDL statement completed (CREATE, DROP, ALTER).
19    Ddl(DdlInfo),
20    /// Query is running, subscribe to results.
21    Query(QueryHandle),
22    /// Rows were affected (INSERT INTO).
23    RowsAffected(u64),
24    /// Metadata result (SHOW, DESCRIBE).
25    Metadata(RecordBatch),
26}
27
28impl ExecuteResult {
29    /// Convert to `QueryHandle`, returning an error if this is not a query result.
30    ///
31    /// # Errors
32    ///
33    /// Returns `DbError::InvalidOperation` if this is not a query result.
34    pub fn into_query(self) -> Result<QueryHandle, DbError> {
35        match self {
36            Self::Query(q) => Ok(q),
37            _ => Err(DbError::InvalidOperation(
38                "Expected a query result".to_string(),
39            )),
40        }
41    }
42}
43
44/// Information about a completed DDL statement.
45#[derive(Debug, Clone)]
46pub struct DdlInfo {
47    /// The statement type (e.g., "CREATE SOURCE").
48    pub statement_type: String,
49    /// The object name affected.
50    pub object_name: String,
51}
52
53/// Handle to a running streaming query.
54#[derive(Debug)]
55pub struct QueryHandle {
56    /// Query identifier.
57    pub(crate) id: u64,
58    /// Output schema.
59    pub(crate) schema: SchemaRef,
60    /// The SQL text.
61    pub(crate) sql: String,
62    /// The subscription for receiving results.
63    pub(crate) subscription: Option<Subscription<ArrowRecord>>,
64    /// Whether the query is active.
65    pub(crate) active: bool,
66}
67
68impl QueryHandle {
69    /// Get the output schema.
70    #[must_use]
71    pub fn schema(&self) -> &SchemaRef {
72        &self.schema
73    }
74
75    /// Get the query SQL text.
76    #[must_use]
77    pub fn sql(&self) -> &str {
78        &self.sql
79    }
80
81    /// Get the query ID.
82    #[must_use]
83    pub fn id(&self) -> u64 {
84        self.id
85    }
86
87    /// Check if the query is still active.
88    #[must_use]
89    pub fn is_active(&self) -> bool {
90        self.active
91    }
92
93    /// Subscribe to raw `RecordBatch` results.
94    pub(crate) fn subscribe_raw(&mut self) -> Result<Subscription<ArrowRecord>, DbError> {
95        self.subscription
96            .take()
97            .ok_or_else(|| DbError::InvalidOperation("Subscription already consumed".to_string()))
98    }
99
100    /// Subscribe to typed results.
101    ///
102    /// The type `T` must implement `from_batch()` and `from_batch_all()` methods,
103    /// which are generated by `#[derive(FromRecordBatch)]`.
104    ///
105    /// # Errors
106    ///
107    /// Returns `DbError::InvalidOperation` if the subscription was already consumed.
108    pub fn subscribe<T: FromBatch>(&mut self) -> Result<TypedSubscription<T>, DbError> {
109        let sub = self.subscribe_raw()?;
110        Ok(TypedSubscription {
111            inner: sub,
112            _phantom: PhantomData,
113        })
114    }
115
116    /// Cancel this query.
117    pub fn cancel(&mut self) {
118        self.active = false;
119        self.subscription = None;
120    }
121}
122
123/// Trait for types that can be deserialized from a `RecordBatch`.
124///
125/// Auto-generated by `#[derive(FromRecordBatch)]`.
126pub trait FromBatch: Sized {
127    /// Deserialize a single row from a `RecordBatch`.
128    fn from_batch(batch: &RecordBatch, row: usize) -> Self;
129    /// Deserialize all rows from a `RecordBatch`.
130    fn from_batch_all(batch: &RecordBatch) -> Vec<Self>;
131}
132
133/// Typed subscription that deserializes `RecordBatch` rows.
134pub struct TypedSubscription<T: FromBatch> {
135    inner: Subscription<ArrowRecord>,
136    _phantom: PhantomData<T>,
137}
138
139impl<T: FromBatch> TypedSubscription<T> {
140    /// Create from a raw subscription.
141    pub(crate) fn from_raw(sub: Subscription<ArrowRecord>) -> Self {
142        Self {
143            inner: sub,
144            _phantom: PhantomData,
145        }
146    }
147
148    /// Poll for the next batch of typed results (non-blocking).
149    #[must_use]
150    pub fn poll(&self) -> Option<Vec<T>> {
151        self.inner.poll().map(|batch| T::from_batch_all(&batch))
152    }
153
154    /// Blocking receive.
155    ///
156    /// # Errors
157    ///
158    /// Returns `RecvError` if the channel is disconnected.
159    pub fn recv(&self) -> Result<Vec<T>, laminar_core::streaming::RecvError> {
160        self.inner.recv().map(|batch| T::from_batch_all(&batch))
161    }
162
163    /// Receive with timeout.
164    ///
165    /// # Errors
166    ///
167    /// Returns `RecvError` on timeout or if the channel is disconnected.
168    pub fn recv_timeout(
169        &self,
170        timeout: Duration,
171    ) -> Result<Vec<T>, laminar_core::streaming::RecvError> {
172        self.inner
173            .recv_timeout(timeout)
174            .map(|batch| T::from_batch_all(&batch))
175    }
176
177    /// Zero-allocation callback-based consumption.
178    ///
179    /// Calls `f` for each deserialized record. Return `false` to stop.
180    pub fn poll_each<F: FnMut(T) -> bool>(&self, max_batches: usize, mut f: F) -> usize {
181        let mut count = 0;
182        for _ in 0..max_batches {
183            match self.inner.poll() {
184                Some(batch) => {
185                    let items = T::from_batch_all(&batch);
186                    for item in items {
187                        count += 1;
188                        if !f(item) {
189                            return count;
190                        }
191                    }
192                }
193                None => break,
194            }
195        }
196        count
197    }
198
199    /// Get the underlying raw subscription.
200    #[allow(dead_code)]
201    pub(crate) fn into_raw(self) -> Subscription<ArrowRecord> {
202        self.inner
203    }
204}
205
206impl<T: FromBatch> std::fmt::Debug for TypedSubscription<T> {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        f.debug_struct("TypedSubscription").finish()
209    }
210}
211
212/// Typed handle for pushing data into a registered source.
213pub struct SourceHandle<T: Record> {
214    entry: Arc<SourceEntry>,
215    _phantom: PhantomData<T>,
216}
217
218impl<T: Record> SourceHandle<T> {
219    /// Create a new source handle from a catalog entry.
220    ///
221    /// Validates that the Rust type's schema matches the source schema.
222    pub(crate) fn new(entry: Arc<SourceEntry>) -> Result<Self, DbError> {
223        let rust_schema = T::schema();
224        let sql_schema = &entry.schema;
225
226        // Validate field count matches
227        if rust_schema.fields().len() != sql_schema.fields().len() {
228            return Err(DbError::SchemaMismatch(format!(
229                "Rust type has {} fields but source '{}' has {} columns",
230                rust_schema.fields().len(),
231                entry.name,
232                sql_schema.fields().len()
233            )));
234        }
235
236        Ok(Self {
237            entry,
238            _phantom: PhantomData,
239        })
240    }
241
242    /// Push a single record.
243    ///
244    /// # Errors
245    ///
246    /// Returns `StreamingError` if the channel is full or closed.
247    #[allow(clippy::needless_pass_by_value)]
248    pub fn push(&self, record: T) -> Result<(), laminar_core::streaming::StreamingError> {
249        let batch = record.to_record_batch();
250        self.entry.push_and_buffer(batch)
251    }
252
253    /// Push a batch of records.
254    pub fn push_batch(&self, records: impl IntoIterator<Item = T>) -> usize {
255        const BATCH_SIZE: usize = 1024;
256        let mut count = 0;
257        let mut buffer = Vec::with_capacity(BATCH_SIZE);
258
259        for record in records {
260            buffer.push(record);
261            if buffer.len() >= BATCH_SIZE {
262                let batch = T::to_record_batch_from_iter(buffer.drain(..));
263                if self.push_arrow(batch).is_err() {
264                    return count;
265                }
266                count += BATCH_SIZE;
267            }
268        }
269
270        if !buffer.is_empty() {
271            let len = buffer.len();
272            let batch = T::to_record_batch_from_iter(buffer);
273            if self.push_arrow(batch).is_ok() {
274                count += len;
275            }
276        }
277        count
278    }
279
280    /// Push a raw `RecordBatch`.
281    ///
282    /// The batch is sent to the SPSC channel for pipeline processing and
283    /// also buffered for ad-hoc `SELECT` snapshot queries.
284    ///
285    /// # Errors
286    ///
287    /// Returns `StreamingError` if the channel is full or closed.
288    pub fn push_arrow(
289        &self,
290        batch: RecordBatch,
291    ) -> Result<(), laminar_core::streaming::StreamingError> {
292        self.entry.push_and_buffer(batch)
293    }
294
295    /// Emit a watermark.
296    pub fn watermark(&self, timestamp: i64) {
297        self.entry.source.watermark(timestamp);
298    }
299
300    /// Get current watermark.
301    #[must_use]
302    pub fn current_watermark(&self) -> i64 {
303        self.entry.source.current_watermark()
304    }
305
306    /// Number of buffered records.
307    #[must_use]
308    pub fn pending(&self) -> usize {
309        self.entry.source.pending()
310    }
311
312    /// Buffer capacity.
313    #[must_use]
314    pub fn capacity(&self) -> usize {
315        self.entry.source.capacity()
316    }
317
318    /// Whether the source buffer is experiencing backpressure (>80% full).
319    #[must_use]
320    pub fn is_backpressured(&self) -> bool {
321        crate::metrics::is_backpressured(self.pending(), self.capacity())
322    }
323
324    /// Get the source name.
325    #[must_use]
326    pub fn name(&self) -> &str {
327        &self.entry.name
328    }
329
330    /// Get the schema.
331    #[must_use]
332    pub fn schema(&self) -> &SchemaRef {
333        &self.entry.schema
334    }
335
336    /// Get the maximum out-of-orderness duration, if configured.
337    #[must_use]
338    pub fn max_out_of_orderness(&self) -> Option<Duration> {
339        self.entry.max_out_of_orderness
340    }
341
342    /// Declare which column in the source data represents event time.
343    ///
344    /// When set, `source.watermark()` enables late-row filtering
345    /// without a SQL `WATERMARK FOR` clause.
346    pub fn set_event_time_column(&self, column: &str) {
347        self.entry.source.set_event_time_column(column);
348    }
349}
350
351impl<T: Record> std::fmt::Debug for SourceHandle<T> {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        f.debug_struct("SourceHandle")
354            .field("name", &self.entry.name)
355            .field("pending", &self.pending())
356            .finish()
357    }
358}
359
360/// Untyped handle for pushing raw `RecordBatch` data.
361pub struct UntypedSourceHandle {
362    entry: Arc<SourceEntry>,
363}
364
365impl UntypedSourceHandle {
366    /// Create from a catalog entry.
367    pub(crate) fn new(entry: Arc<SourceEntry>) -> Self {
368        Self { entry }
369    }
370
371    /// Push a `RecordBatch`.
372    ///
373    /// The batch is sent to the SPSC channel for pipeline processing and
374    /// also buffered for ad-hoc `SELECT` snapshot queries.
375    ///
376    /// # Errors
377    ///
378    /// Returns `StreamingError` if the channel is full or closed.
379    pub fn push_arrow(
380        &self,
381        batch: RecordBatch,
382    ) -> Result<(), laminar_core::streaming::StreamingError> {
383        self.entry.push_and_buffer(batch)
384    }
385
386    /// Emit a watermark.
387    pub fn watermark(&self, timestamp: i64) {
388        self.entry.source.watermark(timestamp);
389    }
390
391    /// Get current watermark.
392    #[must_use]
393    pub fn current_watermark(&self) -> i64 {
394        self.entry.source.current_watermark()
395    }
396
397    /// Number of buffered records.
398    #[must_use]
399    pub fn pending(&self) -> usize {
400        self.entry.source.pending()
401    }
402
403    /// Buffer capacity.
404    #[must_use]
405    pub fn capacity(&self) -> usize {
406        self.entry.source.capacity()
407    }
408
409    /// Whether the source buffer is experiencing backpressure (>80% full).
410    #[must_use]
411    pub fn is_backpressured(&self) -> bool {
412        crate::metrics::is_backpressured(self.pending(), self.capacity())
413    }
414
415    /// Get the source name.
416    #[must_use]
417    pub fn name(&self) -> &str {
418        &self.entry.name
419    }
420
421    /// Get the schema.
422    #[must_use]
423    pub fn schema(&self) -> &SchemaRef {
424        &self.entry.schema
425    }
426
427    /// Get the maximum out-of-orderness duration, if configured.
428    #[must_use]
429    pub fn max_out_of_orderness(&self) -> Option<Duration> {
430        self.entry.max_out_of_orderness
431    }
432
433    /// Declare which column in the source data represents event time.
434    ///
435    /// When set, `source.watermark()` enables late-row filtering
436    /// without a SQL `WATERMARK FOR` clause.
437    pub fn set_event_time_column(&self, column: &str) {
438        self.entry.source.set_event_time_column(column);
439    }
440}
441
442impl std::fmt::Debug for UntypedSourceHandle {
443    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
444        f.debug_struct("UntypedSourceHandle")
445            .field("name", &self.entry.name)
446            .finish()
447    }
448}
449
450/// Type of a node in the pipeline topology.
451#[derive(Debug, Clone, Copy, PartialEq, Eq)]
452pub enum PipelineNodeType {
453    /// A data source (CREATE SOURCE).
454    Source,
455    /// A continuous stream (CREATE STREAM).
456    Stream,
457    /// A data sink (CREATE SINK).
458    Sink,
459}
460
461/// A node in the pipeline topology graph.
462#[derive(Debug, Clone)]
463pub struct PipelineNode {
464    /// Node name.
465    pub name: String,
466    /// Node type (source, stream, or sink).
467    pub node_type: PipelineNodeType,
468    /// Arrow schema, if available (sources have schemas).
469    pub schema: Option<SchemaRef>,
470    /// SQL definition, if applicable (streams have query SQL).
471    pub sql: Option<String>,
472}
473
474/// A directed edge in the pipeline topology graph.
475#[derive(Debug, Clone)]
476pub struct PipelineEdge {
477    /// Source node name.
478    pub from: String,
479    /// Target node name.
480    pub to: String,
481}
482
483/// The complete pipeline topology: nodes and edges.
484#[derive(Debug, Clone)]
485pub struct PipelineTopology {
486    /// All nodes in the pipeline.
487    pub nodes: Vec<PipelineNode>,
488    /// All edges (data flow connections).
489    pub edges: Vec<PipelineEdge>,
490}
491
492/// Metadata about a registered stream.
493#[derive(Debug, Clone)]
494pub struct StreamInfo {
495    /// Stream name.
496    pub name: String,
497    /// The SQL query that defines the stream.
498    pub sql: Option<String>,
499}
500
501/// Information about a registered source.
502#[derive(Debug, Clone)]
503pub struct SourceInfo {
504    /// Source name.
505    pub name: String,
506    /// Schema.
507    pub schema: SchemaRef,
508    /// Watermark column, if configured.
509    pub watermark_column: Option<String>,
510}
511
512/// Information about a registered sink.
513#[derive(Debug, Clone)]
514pub struct SinkInfo {
515    /// Sink name.
516    pub name: String,
517}
518
519/// Information about a running query.
520#[derive(Debug, Clone)]
521pub struct QueryInfo {
522    /// Query identifier.
523    pub id: u64,
524    /// SQL text.
525    pub sql: String,
526    /// Whether the query is active.
527    pub active: bool,
528}