use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use super::error::ApiError;
use crate::catalog::ArrowRecord;
#[derive(Debug, Clone)]
pub struct QueryResult {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}
impl QueryResult {
#[must_use]
pub fn from_batch(batch: RecordBatch) -> Self {
let schema = batch.schema();
Self {
schema,
batches: vec![batch],
}
}
#[must_use]
pub fn from_batches(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
Self { schema, batches }
}
#[must_use]
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
#[must_use]
pub fn batches(&self) -> &[RecordBatch] {
&self.batches
}
#[must_use]
pub fn into_batches(self) -> Vec<RecordBatch> {
self.batches
}
#[must_use]
pub fn num_rows(&self) -> usize {
self.batches.iter().map(RecordBatch::num_rows).sum()
}
#[must_use]
pub fn num_batches(&self) -> usize {
self.batches.len()
}
#[must_use]
pub fn batch(&self, index: usize) -> Option<&RecordBatch> {
self.batches.get(index)
}
#[must_use]
pub fn num_columns(&self) -> usize {
self.schema.fields().len()
}
}
#[derive(Debug)]
pub struct QueryStream {
schema: SchemaRef,
handle: Option<crate::QueryHandle>,
subscription: Option<laminar_core::streaming::Subscription<ArrowRecord>>,
}
impl QueryStream {
pub(crate) fn from_handle(mut handle: crate::QueryHandle) -> Self {
let schema = handle.schema().clone();
let subscription = handle.subscribe_raw().ok();
Self {
schema,
handle: Some(handle),
subscription,
}
}
#[must_use]
pub fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
#[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Result<Option<RecordBatch>, ApiError> {
match &self.subscription {
Some(sub) => match sub.recv() {
Ok(batch) => Ok(Some(batch)),
Err(laminar_core::streaming::RecvError::Disconnected) => Ok(None),
Err(e) => Err(ApiError::subscription(e.to_string())),
},
None => Ok(None),
}
}
pub fn try_next(&mut self) -> Result<Option<RecordBatch>, ApiError> {
match &self.subscription {
Some(sub) => Ok(sub.poll()),
None => Ok(None),
}
}
pub fn collect(mut self) -> Result<QueryResult, ApiError> {
let mut batches = Vec::new();
while let Some(batch) = self.try_next()? {
batches.push(batch);
}
Ok(QueryResult::from_batches(self.schema, batches))
}
#[must_use]
pub fn is_active(&self) -> bool {
self.handle
.as_ref()
.is_some_and(crate::QueryHandle::is_active)
}
pub fn cancel(&mut self) {
if let Some(ref mut handle) = self.handle {
handle.cancel();
}
self.subscription = None;
}
}
unsafe impl Send for QueryStream {}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn test_batch() -> RecordBatch {
RecordBatch::try_new(
test_schema(),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
],
)
.unwrap()
}
#[test]
fn test_query_result_from_batch() {
let batch = test_batch();
let result = QueryResult::from_batch(batch);
assert_eq!(result.num_rows(), 3);
assert_eq!(result.num_columns(), 2);
assert_eq!(result.batches().len(), 1);
}
#[test]
fn test_query_result_schema() {
let result = QueryResult::from_batch(test_batch());
let schema = result.schema();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "id");
}
#[test]
fn test_query_result_into_batches() {
let result = QueryResult::from_batch(test_batch());
let batches = result.into_batches();
assert_eq!(batches.len(), 1);
}
}