use std::time::Instant;
use arrow::ipc::writer::StreamWriter;
use hyperdb_api_core::client::client::CopyInWriter;
use tracing::{debug, info};
const DEFAULT_FLUSH_THRESHOLD: usize = 16 * 1024 * 1024;
use crate::catalog::Catalog;
use crate::connection::Connection;
use crate::data_format::DataFormat;
use crate::error::{Error, Result};
use crate::table_definition::TableDefinition;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InsertMode {
RawIpc,
BatchIpc,
}
pub struct ArrowInserter<'conn> {
connection: &'conn Connection,
table_name: String,
columns: Vec<String>,
writer: Option<CopyInWriter<'conn>>,
schema_sent: bool,
total_bytes: usize,
chunk_count: usize,
start_time: Instant,
flush_threshold: usize,
buffered_bytes: usize,
batch_ipc_writer: Option<StreamWriter<Vec<u8>>>,
insert_mode: Option<InsertMode>,
}
impl std::fmt::Debug for ArrowInserter<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrowInserter")
.field("table_name", &self.table_name)
.field("columns", &self.columns)
.field("schema_sent", &self.schema_sent)
.field("chunk_count", &self.chunk_count)
.field("total_bytes", &self.total_bytes)
.finish_non_exhaustive()
}
}
impl<'conn> ArrowInserter<'conn> {
pub fn new(connection: &'conn Connection, table_def: &TableDefinition) -> Result<Self> {
let column_count = table_def.column_count();
if column_count == 0 {
return Err(Error::new("Table definition must have at least one column"));
}
if connection.tcp_client().is_none() {
return Err(Error::new(
"ArrowInserter requires a TCP connection. \
gRPC connections do not support COPY operations.",
));
}
let columns: Vec<String> = table_def.columns.iter().map(|c| c.name.clone()).collect();
let table_name = table_def.qualified_name();
Ok(ArrowInserter {
connection,
table_name,
columns,
writer: None,
schema_sent: false,
total_bytes: 0,
chunk_count: 0,
start_time: Instant::now(),
flush_threshold: DEFAULT_FLUSH_THRESHOLD,
buffered_bytes: 0,
batch_ipc_writer: None,
insert_mode: None,
})
}
pub fn from_table<T>(connection: &'conn Connection, table_name: T) -> Result<Self>
where
T: TryInto<crate::TableName>,
crate::Error: From<T::Error>,
{
let catalog = Catalog::new(connection);
let table_def = catalog.get_table_definition(table_name)?;
Self::new(connection, &table_def)
}
#[must_use]
pub fn with_flush_threshold(mut self, threshold: usize) -> Self {
self.flush_threshold = threshold;
self
}
pub fn insert_data(&mut self, arrow_ipc_data: &[u8]) -> Result<()> {
if arrow_ipc_data.is_empty() {
return Ok(());
}
if self.insert_mode == Some(InsertMode::BatchIpc) {
return Err(Error::new(
"Cannot mix insert_data() with insert_batch(). \
Use either raw IPC methods (insert_data/insert_record_batches) \
or RecordBatch methods (insert_batch), not both.",
));
}
if self.schema_sent {
return Err(Error::new(
"Arrow schema was already sent. Use insert_record_batches() for subsequent chunks without schema, \
or use insert_data() only once with the complete Arrow IPC stream.",
));
}
self.ensure_writer()?;
if let Some(ref mut writer) = self.writer {
writer.send_direct(arrow_ipc_data)?;
}
self.buffered_bytes += arrow_ipc_data.len();
self.maybe_flush()?;
self.insert_mode = Some(InsertMode::RawIpc);
self.schema_sent = true;
self.total_bytes += arrow_ipc_data.len();
self.chunk_count += 1;
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
bytes = arrow_ipc_data.len(),
total_bytes = self.total_bytes,
buffered_bytes = self.buffered_bytes,
"arrow-inserter-chunk"
);
Ok(())
}
pub fn insert_record_batches(&mut self, arrow_batch_data: &[u8]) -> Result<()> {
if arrow_batch_data.is_empty() {
return Ok(());
}
if self.insert_mode == Some(InsertMode::BatchIpc) {
return Err(Error::new(
"Cannot mix insert_record_batches() with insert_batch(). \
Use either raw IPC methods (insert_data/insert_record_batches) \
or RecordBatch methods (insert_batch), not both.",
));
}
if !self.schema_sent {
return Err(Error::new(
"No Arrow schema has been sent yet. Call insert_data() first with a complete \
Arrow IPC stream that includes the schema.",
));
}
if let Some(ref mut writer) = self.writer {
writer.send_direct(arrow_batch_data)?;
}
self.buffered_bytes += arrow_batch_data.len();
self.maybe_flush()?;
self.total_bytes += arrow_batch_data.len();
self.chunk_count += 1;
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
bytes = arrow_batch_data.len(),
total_bytes = self.total_bytes,
buffered_bytes = self.buffered_bytes,
"arrow-inserter-batch-chunk"
);
Ok(())
}
pub fn insert_raw(&mut self, data: &[u8]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
if self.insert_mode == Some(InsertMode::BatchIpc) {
return Err(Error::new(
"Cannot mix insert_raw() with insert_batch(). \
Use either raw IPC methods (insert_data/insert_record_batches/insert_raw) \
or RecordBatch methods (insert_batch), not both.",
));
}
self.ensure_writer()?;
if let Some(ref mut writer) = self.writer {
writer.send_direct(data)?;
}
self.buffered_bytes += data.len();
self.maybe_flush()?;
self.total_bytes += data.len();
self.chunk_count += 1;
Ok(())
}
pub fn execute(mut self) -> Result<u64> {
if let Some(ipc) = self.batch_ipc_writer.take() {
let buf = ipc
.into_inner()
.map_err(|e| Error::new(format!("Failed to finalize Arrow IPC stream: {e}")))?;
if !buf.is_empty() {
if let Some(ref mut writer) = self.writer {
writer.send_direct(&buf)?;
}
}
}
if self.writer.is_none() {
return Ok(0);
}
let rows = self
.writer
.take()
.map(hyperdb_api_core::client::CopyInWriter::finish)
.transpose()?
.unwrap_or(0);
let duration_ms = u64::try_from(self.start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
info!(
target: "hyperdb_api",
rows,
chunks = self.chunk_count,
total_bytes = self.total_bytes,
duration_ms,
table = %self.table_name,
"arrow-inserter-end"
);
Ok(rows)
}
pub fn cancel(mut self) {
if let Some(writer) = self.writer.take() {
let _ = writer.cancel("Arrow insert cancelled");
}
}
#[must_use]
pub fn has_data(&self) -> bool {
self.chunk_count > 0
}
#[must_use]
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.chunk_count
}
pub fn insert_batch(&mut self, batch: &arrow::record_batch::RecordBatch) -> Result<()> {
if self.insert_mode == Some(InsertMode::RawIpc) {
return Err(Error::new(
"Cannot mix insert_batch() with raw IPC methods. \
Use either RecordBatch methods (insert_batch) \
or raw IPC methods (insert_data/insert_record_batches/insert_raw), not both.",
));
}
self.ensure_writer()?;
self.insert_mode = Some(InsertMode::BatchIpc);
if self.batch_ipc_writer.is_none() {
let ipc_writer = StreamWriter::try_new(Vec::new(), &batch.schema())
.map_err(|e| Error::new(format!("Failed to create Arrow IPC writer: {e}")))?;
self.batch_ipc_writer = Some(ipc_writer);
self.drain_ipc_buffer()?;
self.schema_sent = true;
}
self.batch_ipc_writer
.as_mut()
.expect("IPC writer must exist")
.write(batch)
.map_err(|e| Error::new(format!("Failed to write Arrow batch: {e}")))?;
self.drain_ipc_buffer()?;
self.chunk_count += 1;
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
total_bytes = self.total_bytes,
buffered_bytes = self.buffered_bytes,
"arrow-inserter-batch"
);
Ok(())
}
fn drain_ipc_buffer(&mut self) -> Result<()> {
let ipc = self
.batch_ipc_writer
.as_mut()
.expect("IPC writer must exist");
let buf = ipc.get_mut();
if buf.is_empty() {
return Ok(());
}
let len = buf.len();
if let Some(ref mut writer) = self.writer {
writer.send_direct(buf)?;
}
buf.clear();
self.buffered_bytes += len;
self.total_bytes += len;
self.maybe_flush()?;
Ok(())
}
pub fn insert_batches<'b>(
&mut self,
batches: impl IntoIterator<Item = &'b arrow::record_batch::RecordBatch>,
) -> Result<()> {
for batch in batches {
self.insert_batch(batch)?;
}
Ok(())
}
fn ensure_writer(&mut self) -> Result<()> {
if self.writer.is_none() {
let client = self.connection.tcp_client().ok_or_else(|| {
crate::Error::new("ArrowInserter requires a TCP connection. gRPC connections do not support COPY operations.")
})?;
let columns: Vec<&str> = self
.columns
.iter()
.map(std::string::String::as_str)
.collect();
let mut writer = client.copy_in_with_format(
&self.table_name,
&columns,
DataFormat::ArrowStream.as_sql_str(),
)?;
writer.reserve_buffer(self.flush_threshold + 1024 * 1024);
self.writer = Some(writer);
}
Ok(())
}
fn maybe_flush(&mut self) -> Result<()> {
if self.buffered_bytes >= self.flush_threshold {
if let Some(ref mut writer) = self.writer {
writer.flush_stream()?;
}
debug!(
target: "hyperdb_api",
flushed_bytes = self.buffered_bytes,
threshold = self.flush_threshold,
"arrow-inserter-flush"
);
self.buffered_bytes = 0;
}
Ok(())
}
}
impl Drop for ArrowInserter<'_> {
fn drop(&mut self) {
if let Some(writer) = self.writer.take() {
let _ = writer.cancel("Arrow inserter dropped without execute");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_format_sql_str() {
assert_eq!(DataFormat::ArrowStream.as_sql_str(), "ARROWSTREAM");
}
}