use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use super::error::ApiError;
use super::ingestion::Writer;
use super::query::{QueryResult, QueryStream};
use crate::{LaminarConfig, LaminarDB};
pub struct Connection {
inner: Arc<LaminarDB>,
}
impl Connection {
pub fn open() -> Result<Self, ApiError> {
let db = LaminarDB::open().map_err(ApiError::from)?;
Ok(Self {
inner: Arc::new(db),
})
}
pub fn open_with_config(config: LaminarConfig) -> Result<Self, ApiError> {
let db = LaminarDB::open_with_config(config).map_err(ApiError::from)?;
Ok(Self {
inner: Arc::new(db),
})
}
pub fn execute(&self, sql: &str) -> Result<ExecuteResult, ApiError> {
if self.inner.is_closed() {
return Err(ApiError::shutdown());
}
let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
std::thread::scope(|s| {
s.spawn(|| {
let inner = Arc::clone(&self.inner);
let sql = sql.to_string();
handle.block_on(async move { inner.execute(&sql).await })
})
.join()
.unwrap()
})
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| ApiError::internal(format!("Runtime error: {e}")))?;
rt.block_on(self.inner.execute(sql))
};
result.map(ExecuteResult::from).map_err(ApiError::from)
}
pub fn query(&self, sql: &str) -> Result<QueryResult, ApiError> {
let result = self.execute(sql)?;
match result {
ExecuteResult::Query(stream) => stream.collect(),
ExecuteResult::Metadata(batch) => Ok(QueryResult::from_batch(batch)),
ExecuteResult::RowsAffected(n) => Err(ApiError::Query {
code: super::error::codes::QUERY_FAILED,
message: format!("Expected query result, got {n} rows affected"),
}),
ExecuteResult::Ddl(info) => Err(ApiError::Query {
code: super::error::codes::QUERY_FAILED,
message: format!("Expected query result, got DDL: {}", info.statement_type),
}),
}
}
pub fn query_stream(&self, sql: &str) -> Result<QueryStream, ApiError> {
let result = self.execute(sql)?;
match result {
ExecuteResult::Query(stream) => Ok(stream),
_ => Err(ApiError::Query {
code: super::error::codes::QUERY_FAILED,
message: "Expected streaming query result".into(),
}),
}
}
pub fn writer(&self, source_name: &str) -> Result<Writer, ApiError> {
let handle = self
.inner
.source_untyped(source_name)
.map_err(ApiError::from)?;
Ok(Writer::new(handle))
}
pub fn insert(&self, source_name: &str, batch: RecordBatch) -> Result<u64, ApiError> {
let handle = self
.inner
.source_untyped(source_name)
.map_err(ApiError::from)?;
let num_rows = batch.num_rows() as u64;
handle
.push_arrow(batch)
.map_err(|e| ApiError::ingestion(e.to_string()))?;
Ok(num_rows)
}
pub fn get_schema(&self, name: &str) -> Result<SchemaRef, ApiError> {
for source in self.inner.sources() {
if source.name == name {
return Ok(source.schema);
}
}
Err(ApiError::table_not_found(name))
}
#[must_use]
pub fn list_sources(&self) -> Vec<String> {
self.inner.sources().into_iter().map(|s| s.name).collect()
}
#[must_use]
pub fn list_streams(&self) -> Vec<String> {
self.inner.streams().into_iter().map(|s| s.name).collect()
}
#[must_use]
pub fn list_sinks(&self) -> Vec<String> {
self.inner.sinks().into_iter().map(|s| s.name).collect()
}
pub fn start(&self) -> Result<(), ApiError> {
let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
std::thread::scope(|s| {
s.spawn(|| {
let inner = Arc::clone(&self.inner);
handle.block_on(async move { inner.start().await })
})
.join()
.unwrap()
})
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| ApiError::internal(format!("Runtime error: {e}")))?;
rt.block_on(self.inner.start())
};
result.map_err(ApiError::from)
}
#[allow(clippy::unnecessary_wraps)]
pub fn close(self) -> Result<(), ApiError> {
self.inner.close();
match Arc::try_unwrap(self.inner) {
Ok(_db) => {
Ok(())
}
Err(_arc) => {
Ok(())
}
}
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub fn checkpoint(&self) -> Result<u64, ApiError> {
let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
std::thread::scope(|s| {
s.spawn(|| {
let inner = Arc::clone(&self.inner);
handle.block_on(async move { inner.checkpoint().await })
})
.join()
.unwrap()
})
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| ApiError::internal(format!("Runtime error: {e}")))?;
rt.block_on(self.inner.checkpoint())
};
result.map(|r| r.checkpoint_id).map_err(ApiError::from)
}
#[must_use]
pub fn is_checkpoint_enabled(&self) -> bool {
self.inner.is_checkpoint_enabled()
}
#[must_use]
pub fn source_info(&self) -> Vec<crate::SourceInfo> {
self.inner.sources()
}
#[must_use]
pub fn sink_info(&self) -> Vec<crate::SinkInfo> {
self.inner.sinks()
}
#[must_use]
pub fn stream_info(&self) -> Vec<crate::StreamInfo> {
self.inner.streams()
}
#[must_use]
pub fn query_info(&self) -> Vec<crate::QueryInfo> {
self.inner.queries()
}
#[must_use]
pub fn pipeline_topology(&self) -> crate::PipelineTopology {
self.inner.pipeline_topology()
}
#[must_use]
pub fn pipeline_state(&self) -> String {
self.inner.pipeline_state().to_string()
}
#[must_use]
pub fn pipeline_watermark(&self) -> i64 {
self.inner.pipeline_watermark()
}
#[must_use]
pub fn total_events_processed(&self) -> u64 {
self.inner.total_events_processed()
}
#[must_use]
pub fn source_count(&self) -> usize {
self.inner.source_count()
}
#[must_use]
pub fn sink_count(&self) -> usize {
self.inner.sink_count()
}
#[must_use]
pub fn active_query_count(&self) -> usize {
self.inner.active_query_count()
}
#[must_use]
pub fn metrics(&self) -> crate::PipelineMetrics {
self.inner.metrics()
}
#[must_use]
pub fn source_metrics(&self, name: &str) -> Option<crate::SourceMetrics> {
self.inner.source_metrics(name)
}
#[must_use]
pub fn all_source_metrics(&self) -> Vec<crate::SourceMetrics> {
self.inner.all_source_metrics()
}
#[must_use]
pub fn stream_metrics(&self, name: &str) -> Option<crate::StreamMetrics> {
self.inner.stream_metrics(name)
}
#[must_use]
pub fn all_stream_metrics(&self) -> Vec<crate::StreamMetrics> {
self.inner.all_stream_metrics()
}
pub fn cancel_query(&self, query_id: u64) -> Result<(), ApiError> {
self.inner.cancel_query(query_id).map_err(ApiError::from)
}
pub fn shutdown(&self) -> Result<(), ApiError> {
let result = if let Ok(handle) = tokio::runtime::Handle::try_current() {
std::thread::scope(|s| {
s.spawn(|| {
let inner = Arc::clone(&self.inner);
handle.block_on(async move { inner.shutdown().await })
})
.join()
.unwrap()
})
} else {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| ApiError::internal(format!("Runtime error: {e}")))?;
rt.block_on(self.inner.shutdown())
};
result.map_err(ApiError::from)
}
pub fn subscribe(
&self,
stream_name: &str,
) -> Result<super::subscription::ArrowSubscription, ApiError> {
let sub = self
.inner
.subscribe_raw(stream_name)
.map_err(ApiError::from)?;
Ok(super::subscription::ArrowSubscription::new(
sub,
std::sync::Arc::new(arrow::datatypes::Schema::empty()),
))
}
}
unsafe impl Send for Connection {}
unsafe impl Sync for Connection {}
#[derive(Debug)]
pub enum ExecuteResult {
Ddl(crate::DdlInfo),
Query(QueryStream),
RowsAffected(u64),
Metadata(RecordBatch),
}
impl From<crate::ExecuteResult> for ExecuteResult {
fn from(result: crate::ExecuteResult) -> Self {
match result {
crate::ExecuteResult::Ddl(info) => Self::Ddl(info),
crate::ExecuteResult::Query(handle) => Self::Query(QueryStream::from_handle(handle)),
crate::ExecuteResult::RowsAffected(n) => Self::RowsAffected(n),
crate::ExecuteResult::Metadata(batch) => Self::Metadata(batch),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Connection>();
}
#[test]
fn test_connection_open_close() {
let conn = Connection::open().unwrap();
assert!(!conn.is_closed());
conn.close().unwrap();
}
#[test]
fn test_connection_thread_safe() {
let conn = Arc::new(Connection::open().unwrap());
let handles: Vec<_> = (0..4)
.map(|_| {
let conn = Arc::clone(&conn);
std::thread::spawn(move || {
let _ = conn.list_sources();
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_execute_create_source() {
let conn = Connection::open().unwrap();
let result = conn.execute("CREATE SOURCE test_api (id BIGINT, name VARCHAR)");
assert!(result.is_ok());
let sources = conn.list_sources();
assert!(sources.contains(&"test_api".to_string()));
}
#[test]
fn test_get_schema() {
let conn = Connection::open().unwrap();
conn.execute("CREATE SOURCE schema_test (id BIGINT, value DOUBLE)")
.unwrap();
let schema = conn.get_schema("schema_test").unwrap();
assert_eq!(schema.fields().len(), 2);
assert_eq!(schema.field(0).name(), "id");
}
#[test]
fn test_get_schema_not_found() {
let conn = Connection::open().unwrap();
let result = conn.get_schema("nonexistent");
assert!(result.is_err());
assert_eq!(
result.unwrap_err().code(),
super::super::error::codes::TABLE_NOT_FOUND
);
}
#[test]
fn test_source_info() {
let conn = Connection::open().unwrap();
conn.execute("CREATE SOURCE test_info (id BIGINT, name VARCHAR)")
.unwrap();
let info = conn.source_info();
assert_eq!(info.len(), 1);
assert_eq!(info[0].name, "test_info");
assert_eq!(info[0].schema.fields().len(), 2);
}
#[test]
fn test_pipeline_state() {
let conn = Connection::open().unwrap();
let state = conn.pipeline_state();
assert!(!state.is_empty());
}
#[test]
fn test_metrics() {
let conn = Connection::open().unwrap();
let m = conn.metrics();
assert_eq!(m.total_events_ingested, 0);
}
#[test]
fn test_source_count() {
let conn = Connection::open().unwrap();
assert_eq!(conn.source_count(), 0);
conn.execute("CREATE SOURCE cnt_test (x BIGINT)").unwrap();
assert_eq!(conn.source_count(), 1);
}
#[test]
fn test_cancel_query_invalid() {
let conn = Connection::open().unwrap();
let result = conn.cancel_query(999);
assert!(result.is_err());
}
#[test]
fn test_shutdown() {
let conn = Connection::open().unwrap();
assert!(conn.shutdown().is_ok());
}
}