use std::sync::Arc;
use std::time::Instant;
use hyperdb_api_core::client::{AsyncClient, AsyncCopyInWriter, AsyncCopyInWriterOwned};
use tracing::{debug, info};
use crate::async_connection::AsyncConnection;
use crate::data_format::DataFormat;
use crate::error::{Error, Result};
use crate::table_definition::TableDefinition;
const DEFAULT_FLUSH_THRESHOLD: usize = 16 * 1024 * 1024;
#[derive(Debug)]
pub struct AsyncArrowInserter<'conn> {
connection: &'conn AsyncConnection,
table_name: String,
columns: Vec<String>,
writer: Option<AsyncCopyInWriter<'conn>>,
schema_sent: bool,
total_bytes: usize,
chunk_count: usize,
start_time: Instant,
flush_threshold: usize,
buffered_bytes: usize,
}
impl<'conn> AsyncArrowInserter<'conn> {
pub fn new(connection: &'conn AsyncConnection, table_def: &TableDefinition) -> Result<Self> {
let column_count = table_def.column_count();
if column_count == 0 {
return Err(Error::new("Table definition must have at least one column"));
}
if connection.async_tcp_client().is_none() {
return Err(Error::new(
"AsyncArrowInserter requires a TCP connection. \
gRPC connections do not support COPY operations.",
));
}
let columns: Vec<String> = table_def.columns.iter().map(|c| c.name.clone()).collect();
let table_name = table_def.qualified_name();
Ok(AsyncArrowInserter {
connection,
table_name,
columns,
writer: None,
schema_sent: false,
total_bytes: 0,
chunk_count: 0,
start_time: Instant::now(),
flush_threshold: DEFAULT_FLUSH_THRESHOLD,
buffered_bytes: 0,
})
}
#[must_use]
pub fn with_flush_threshold(mut self, threshold: usize) -> Self {
self.flush_threshold = threshold;
self
}
pub async fn insert_data(&mut self, arrow_ipc_data: &[u8]) -> Result<()> {
if arrow_ipc_data.is_empty() {
return Ok(());
}
if self.schema_sent {
return Err(Error::new(
"Arrow schema was already sent. Use insert_record_batches() for subsequent chunks without schema, \
or use insert_data() only once with the complete Arrow IPC stream.",
));
}
self.ensure_writer().await?;
if let Some(ref mut writer) = self.writer {
writer.send_direct(arrow_ipc_data).await?;
}
self.buffered_bytes += arrow_ipc_data.len();
self.maybe_flush().await?;
self.schema_sent = true;
self.total_bytes += arrow_ipc_data.len();
self.chunk_count += 1;
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
bytes = arrow_ipc_data.len(),
total_bytes = self.total_bytes,
buffered_bytes = self.buffered_bytes,
"async-arrow-inserter-chunk"
);
Ok(())
}
pub async fn insert_record_batches(&mut self, arrow_batch_data: &[u8]) -> Result<()> {
if arrow_batch_data.is_empty() {
return Ok(());
}
if !self.schema_sent {
return Err(Error::new(
"No Arrow schema has been sent yet. Call insert_data() first with a complete \
Arrow IPC stream that includes the schema.",
));
}
if let Some(ref mut writer) = self.writer {
writer.send_direct(arrow_batch_data).await?;
}
self.buffered_bytes += arrow_batch_data.len();
self.maybe_flush().await?;
self.total_bytes += arrow_batch_data.len();
self.chunk_count += 1;
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
bytes = arrow_batch_data.len(),
total_bytes = self.total_bytes,
buffered_bytes = self.buffered_bytes,
"async-arrow-inserter-batch-chunk"
);
Ok(())
}
pub async fn insert_raw(&mut self, data: &[u8]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
self.ensure_writer().await?;
if let Some(ref mut writer) = self.writer {
writer.send_direct(data).await?;
}
self.buffered_bytes += data.len();
self.maybe_flush().await?;
self.total_bytes += data.len();
self.chunk_count += 1;
Ok(())
}
pub async fn execute(mut self) -> Result<u64> {
if self.writer.is_none() {
return Ok(0);
}
let rows = match self.writer.take() {
Some(w) => w.finish().await?,
None => 0,
};
let duration_ms = u64::try_from(self.start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
info!(
target: "hyperdb_api",
rows,
chunks = self.chunk_count,
total_bytes = self.total_bytes,
duration_ms,
table = %self.table_name,
"async-arrow-inserter-end"
);
Ok(rows)
}
pub async fn cancel(mut self) {
if let Some(writer) = self.writer.take() {
let _ = writer.cancel("Arrow insert cancelled").await;
}
}
#[must_use]
pub fn has_data(&self) -> bool {
self.chunk_count > 0
}
#[must_use]
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.chunk_count
}
async fn ensure_writer(&mut self) -> Result<()> {
if self.writer.is_none() {
let client = self.connection.async_tcp_client().ok_or_else(|| {
crate::Error::new("AsyncArrowInserter requires a TCP connection. gRPC connections do not support COPY operations.")
})?;
let columns: Vec<&str> = self
.columns
.iter()
.map(std::string::String::as_str)
.collect();
self.writer = Some(
client
.copy_in_with_format(
&self.table_name,
&columns,
DataFormat::ArrowStream.as_sql_str(),
)
.await?,
);
}
Ok(())
}
async fn maybe_flush(&mut self) -> Result<()> {
if self.buffered_bytes >= self.flush_threshold {
if let Some(ref mut writer) = self.writer {
writer.flush_stream().await?;
}
debug!(
target: "hyperdb_api",
flushed_bytes = self.buffered_bytes,
threshold = self.flush_threshold,
"async-arrow-inserter-flush"
);
self.buffered_bytes = 0;
}
Ok(())
}
}
impl Drop for AsyncArrowInserter<'_> {
fn drop(&mut self) {
if self.writer.is_some() {
tracing::warn!(
target: "hyperdb_api",
chunks = self.chunk_count,
total_bytes = self.total_bytes,
table = %self.table_name,
"AsyncArrowInserter dropped without calling execute() or cancel(). \
Data may be lost. The underlying AsyncCopyInWriter will \
attempt a best-effort cancel to restore the connection."
);
drop(self.writer.take());
}
}
}
#[derive(Debug)]
pub struct AsyncArrowInserterOwned {
#[allow(
dead_code,
reason = "kept alive to anchor the client's Mutex Arc for the writer's lifetime"
)]
connection: Arc<AsyncConnection>,
table_name: String,
columns: Vec<String>,
writer: Option<AsyncCopyInWriterOwned>,
schema_sent: bool,
total_bytes: usize,
chunk_count: usize,
start_time: Instant,
flush_threshold: usize,
buffered_bytes: usize,
}
impl AsyncArrowInserterOwned {
pub fn new(connection: Arc<AsyncConnection>, table_def: &TableDefinition) -> Result<Self> {
let column_count = table_def.column_count();
if column_count == 0 {
return Err(Error::new("Table definition must have at least one column"));
}
if connection.async_tcp_client().is_none() {
return Err(Error::new(
"AsyncArrowInserterOwned requires a TCP connection. \
gRPC connections do not support COPY operations.",
));
}
let columns: Vec<String> = table_def.columns.iter().map(|c| c.name.clone()).collect();
let table_name = table_def.qualified_name();
Ok(AsyncArrowInserterOwned {
connection,
table_name,
columns,
writer: None,
schema_sent: false,
total_bytes: 0,
chunk_count: 0,
start_time: Instant::now(),
flush_threshold: DEFAULT_FLUSH_THRESHOLD,
buffered_bytes: 0,
})
}
#[must_use]
pub fn with_flush_threshold(mut self, threshold: usize) -> Self {
self.flush_threshold = threshold;
self
}
pub async fn insert_data(&mut self, arrow_ipc_data: &[u8]) -> Result<()> {
if arrow_ipc_data.is_empty() {
return Ok(());
}
if self.schema_sent {
return Err(Error::new(
"Arrow schema was already sent. Use insert_record_batches() for subsequent chunks.",
));
}
self.ensure_writer().await?;
if let Some(ref mut w) = self.writer {
w.send_direct(arrow_ipc_data).await?;
}
self.schema_sent = true;
self.buffered_bytes += arrow_ipc_data.len();
self.maybe_flush().await?;
self.total_bytes += arrow_ipc_data.len();
self.chunk_count += 1;
Ok(())
}
pub async fn insert_record_batches(&mut self, arrow_batch_data: &[u8]) -> Result<()> {
if arrow_batch_data.is_empty() {
return Ok(());
}
if !self.schema_sent {
return Err(Error::new(
"No Arrow schema has been sent yet. Call insert_data() first.",
));
}
if let Some(ref mut w) = self.writer {
w.send_direct(arrow_batch_data).await?;
}
self.buffered_bytes += arrow_batch_data.len();
self.maybe_flush().await?;
self.total_bytes += arrow_batch_data.len();
self.chunk_count += 1;
Ok(())
}
pub async fn insert_raw(&mut self, data: &[u8]) -> Result<()> {
if data.is_empty() {
return Ok(());
}
self.ensure_writer().await?;
if let Some(ref mut w) = self.writer {
w.send_direct(data).await?;
}
self.schema_sent = true;
self.buffered_bytes += data.len();
self.maybe_flush().await?;
self.total_bytes += data.len();
self.chunk_count += 1;
Ok(())
}
pub async fn execute(mut self) -> Result<u64> {
let elapsed = self.start_time.elapsed();
info!(
target: "hyperdb_api",
chunks = self.chunk_count,
total_bytes = self.total_bytes,
elapsed_ms = u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX),
"async-arrow-inserter-execute"
);
let writer = self
.writer
.take()
.ok_or_else(|| Error::new("No data was inserted before execute()"))?;
writer.finish().await.map_err(Into::into)
}
pub async fn cancel(mut self) {
if let Some(writer) = self.writer.take() {
let _ = writer.cancel("AsyncArrowInserterOwned::cancel").await;
}
}
#[must_use]
pub fn has_data(&self) -> bool {
self.schema_sent
}
#[must_use]
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
#[must_use]
pub fn chunk_count(&self) -> usize {
self.chunk_count
}
async fn ensure_writer(&mut self) -> Result<()> {
if self.writer.is_none() {
let client: &AsyncClient = self.connection.async_tcp_client().ok_or_else(|| {
Error::new(
"AsyncArrowInserterOwned requires a TCP connection. \
gRPC connections do not support COPY operations.",
)
})?;
let columns: Vec<&str> = self
.columns
.iter()
.map(std::string::String::as_str)
.collect();
self.writer = Some(
client
.copy_in_arc_with_format(
&self.table_name,
&columns,
DataFormat::ArrowStream.as_sql_str(),
)
.await?,
);
}
Ok(())
}
async fn maybe_flush(&mut self) -> Result<()> {
if self.buffered_bytes >= self.flush_threshold {
if let Some(ref mut w) = self.writer {
w.flush_stream().await?;
}
debug!(
target: "hyperdb_api",
flushed_bytes = self.buffered_bytes,
threshold = self.flush_threshold,
"async-arrow-inserter-owned-flush"
);
self.buffered_bytes = 0;
}
Ok(())
}
}
impl Drop for AsyncArrowInserterOwned {
fn drop(&mut self) {
if self.writer.is_some() {
tracing::warn!(
target: "hyperdb_api",
chunks = self.chunk_count,
total_bytes = self.total_bytes,
table = %self.table_name,
"AsyncArrowInserterOwned dropped without calling execute() or cancel(). \
Data may be lost."
);
drop(self.writer.take());
}
}
}