use std::time::Instant;
use hyperdb_api_core::client::client::CopyInWriter;
use hyperdb_api_core::protocol::copy;
use hyperdb_api_core::types::bytes::BytesMut;
use hyperdb_api_core::types::{Date, Interval, Numeric, OffsetTimestamp, Time, Timestamp};
use tracing::{debug, info};
use crate::catalog::Catalog;
use crate::connection::Connection;
use crate::error::{Error, Result};
use crate::table_definition::TableDefinition;
const INITIAL_BUFFER_SIZE: usize = 4 * 1024 * 1024;
const CHUNK_SIZE_LIMIT: usize = 16 * 1024 * 1024;
const CHUNK_ROW_LIMIT: usize = 64_000;
#[derive(Debug)]
pub struct Inserter<'conn> {
connection: &'conn Connection,
table_def: TableDefinition,
chunk: InsertChunk,
row_count: u64,
chunk_count: usize,
writer: Option<CopyInWriter<'conn>>,
start_time: Instant,
}
impl<'conn> Inserter<'conn> {
pub fn new(connection: &'conn Connection, table_def: &TableDefinition) -> Result<Self> {
if table_def.column_count() == 0 {
return Err(Error::InvalidTableDefinition(
"Table definition must have at least one column".into(),
));
}
if connection.tcp_client().is_none() {
return Err(Error::new(
"Inserter requires a TCP connection. \
gRPC connections do not support COPY operations.",
));
}
Ok(Inserter {
connection,
table_def: table_def.clone(),
chunk: InsertChunk::from_table_definition(table_def),
row_count: 0,
chunk_count: 0,
writer: None,
start_time: Instant::now(),
})
}
pub fn from_table<T>(connection: &'conn Connection, table_name: T) -> Result<Self>
where
T: TryInto<crate::TableName>,
crate::Error: From<T::Error>,
{
let catalog = Catalog::new(connection);
let table_def = catalog.get_table_definition(table_name)?;
Self::new(connection, &table_def)
}
pub fn with_column_mappings<T>(
connection: &'conn Connection,
inserter_def: &TableDefinition,
target_table: T,
mappings: &[ColumnMapping],
) -> Result<MappedInserter<'conn>>
where
T: TryInto<crate::TableName>,
crate::Error: From<T::Error>,
{
MappedInserter::new(connection, inserter_def, target_table, mappings)
}
pub fn table_definition(&self) -> &TableDefinition {
&self.table_def
}
#[must_use]
pub fn column_count(&self) -> usize {
self.table_def.column_count()
}
#[must_use]
pub fn row_count(&self) -> u64 {
self.row_count
}
#[inline]
pub fn add_null(&mut self) -> Result<()> {
self.chunk.add_null()
}
#[inline]
pub fn add_bool(&mut self, value: bool) -> Result<()> {
self.chunk.add_bool(value)
}
#[inline]
pub fn add_i16(&mut self, value: i16) -> Result<()> {
self.chunk.add_i16(value)
}
#[inline]
pub fn add_i32(&mut self, value: i32) -> Result<()> {
self.chunk.add_i32(value)
}
#[inline]
pub fn add_i64(&mut self, value: i64) -> Result<()> {
self.chunk.add_i64(value)
}
#[inline]
pub fn add_f32(&mut self, value: f32) -> Result<()> {
self.chunk.add_f32(value)
}
#[inline]
pub fn add_f64(&mut self, value: f64) -> Result<()> {
self.chunk.add_f64(value)
}
#[inline]
pub fn add_str(&mut self, value: &str) -> Result<()> {
self.chunk.add_str(value)
}
#[inline]
pub fn add_bytes(&mut self, value: &[u8]) -> Result<()> {
self.chunk.add_bytes(value)
}
#[inline]
pub fn add_data128(&mut self, value: &[u8; 16]) -> Result<()> {
self.chunk.add_data128(value)
}
pub fn add_optional<T, F>(&mut self, value: Option<T>, add_fn: F) -> Result<()>
where
F: FnOnce(&mut Self, T) -> Result<()>,
{
match value {
Some(v) => add_fn(self, v),
None => self.add_null(),
}
}
pub fn end_row(&mut self) -> Result<()> {
self.chunk.end_row()?;
self.row_count += 1;
if self.chunk.should_flush() {
self.flush()?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
if self.chunk.is_empty() {
return Ok(());
}
let chunk_rows = self.chunk.row_count();
let Some(buffer) = self.chunk.take() else {
return Ok(());
};
if self.writer.is_none() {
let client = self.connection.tcp_client().ok_or_else(|| {
crate::Error::new("Inserter requires a TCP connection. gRPC connections do not support COPY operations.")
})?;
let columns: Vec<&str> = self
.table_def
.columns
.iter()
.map(|c| c.name.as_str())
.collect();
let table_name = self.table_def.qualified_name();
self.writer = Some(client.copy_in(&table_name, &columns)?);
}
if let Some(ref mut writer) = self.writer {
writer.send_direct(&buffer)?;
writer.flush_stream()?;
}
debug!(
target: "hyperdb_api",
chunk = self.chunk_count,
rows = chunk_rows,
bytes = buffer.len(),
"inserter-chunk"
);
self.chunk_count += 1;
Ok(())
}
pub fn add_row(&mut self, values: &[&dyn IntoValue]) -> Result<()> {
let column_count = self.table_def.column_count();
if values.len() != column_count {
return Err(Error::new(format!(
"Column count mismatch: expected {} columns but got {}",
column_count,
values.len()
)));
}
for value in values {
value.add_to_inserter(self)?;
}
self.end_row()?;
Ok(())
}
#[inline]
pub fn add_date(&mut self, value: Date) -> Result<()> {
self.chunk.add_date(value)
}
#[inline]
pub fn add_time(&mut self, value: Time) -> Result<()> {
self.chunk.add_time(value)
}
#[inline]
pub fn add_timestamp(&mut self, value: Timestamp) -> Result<()> {
self.chunk.add_timestamp(value)
}
#[inline]
pub fn add_offset_timestamp(&mut self, value: OffsetTimestamp) -> Result<()> {
self.chunk.add_offset_timestamp(value)
}
#[inline]
pub fn add_interval(&mut self, value: Interval) -> Result<()> {
self.chunk.add_interval(value)
}
pub fn add_numeric(&mut self, value: Numeric) -> Result<()> {
let column_index = self.chunk.column_index();
let precision = self
.table_def
.columns
.get(column_index)
.and_then(super::table_definition::ColumnDefinition::sql_type)
.and_then(|t| t.precision())
.ok_or_else(|| {
let col_name = self
.table_def
.columns
.get(column_index)
.map_or("<unknown>", |c| c.name.as_str());
Error::new(format!(
"Cannot determine numeric precision for column '{col_name}' at index {column_index}. \
Ensure the column is defined with explicit SqlType including precision.\n\n\
Example fix:\n \
table_def.add_column_with_type(\"{col_name}\", SqlType::Numeric {{ precision: 10, scale: 2 }}, true);"
))
})?;
if precision <= Numeric::SMALL_NUMERIC_MAX_PRECISION {
let unscaled = value.unscaled_value();
let narrowed = i64::try_from(unscaled).map_err(|_| {
Error::new(format!(
"Numeric value {unscaled} is out of range for i64 storage (precision {precision})"
))
})?;
self.chunk.add_i64(narrowed)
} else {
self.chunk.add_data128(&value.to_packed())
}
}
pub fn execute(&mut self) -> Result<u64> {
if self.chunk.column_index() != 0 {
return Err(Error::new("Incomplete row at execute time"));
}
if self.row_count == 0 {
return Ok(0);
}
if self.writer.is_none() {
let client = self.connection.tcp_client().ok_or_else(|| {
Error::new("Inserter requires a TCP connection. gRPC connections do not support COPY operations.")
})?;
let columns: Vec<&str> = self
.table_def
.columns
.iter()
.map(|c| c.name.as_str())
.collect();
let table_name = self.table_def.qualified_name();
self.writer = Some(client.copy_in(&table_name, &columns)?);
}
let writer = self
.writer
.as_mut()
.ok_or_else(|| Error::new("Failed to initialize COPY connection for inserter"))?;
if !self.chunk.is_empty() {
writer.send(self.chunk.buffer())?;
}
let mut trailer_buf = BytesMut::with_capacity(2);
copy::write_trailer(&mut trailer_buf);
writer.send(&trailer_buf)?;
let rows = self
.writer
.take()
.map(hyperdb_api_core::client::CopyInWriter::finish)
.transpose()?
.unwrap_or(0);
let duration_ms = u64::try_from(self.start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
info!(
target: "hyperdb_api",
rows,
chunks = self.chunk_count,
duration_ms,
table = %self.table_def.qualified_name(),
"inserter-end"
);
self.row_count = 0;
Ok(rows)
}
pub fn cancel(&mut self) {
self.writer = None;
self.row_count = 0;
}
}
#[derive(Debug, Clone)]
#[must_use = "ColumnMapping represents a column configuration that should not be discarded. Use it when defining inserter column mappings"]
pub struct ColumnMapping {
pub column_name: String,
pub expression: Option<String>,
}
impl ColumnMapping {
pub fn new(column_name: impl Into<String>) -> Self {
ColumnMapping {
column_name: column_name.into(),
expression: None,
}
}
pub fn with_expression(column_name: impl Into<String>, expression: impl Into<String>) -> Self {
ColumnMapping {
column_name: column_name.into(),
expression: Some(expression.into()),
}
}
#[must_use]
pub fn column_name(&self) -> &str {
&self.column_name
}
#[must_use]
pub fn expression(&self) -> Option<&str> {
self.expression.as_deref()
}
#[must_use]
pub fn is_direct(&self) -> bool {
self.expression.is_none()
}
fn to_select_item(&self) -> String {
match &self.expression {
Some(expr) => format!("{} AS \"{}\"", expr, self.column_name.replace('"', "\"\"")),
None => format!("\"{}\"", self.column_name.replace('"', "\"\"")),
}
}
}
pub trait IntoValue {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()>;
}
impl IntoValue for bool {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_bool(*self)
}
}
impl IntoValue for i16 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i16(*self)
}
}
impl IntoValue for i32 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i32(*self)
}
}
impl IntoValue for i64 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i64(*self)
}
}
impl IntoValue for f32 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_f32(*self)
}
}
impl IntoValue for f64 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_f64(*self)
}
}
impl IntoValue for str {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_str(self)
}
}
impl IntoValue for String {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_str(self)
}
}
impl IntoValue for [u8] {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_bytes(self)
}
}
impl IntoValue for Vec<u8> {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_bytes(self)
}
}
impl IntoValue for Date {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_date(*self)
}
}
impl IntoValue for Time {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_time(*self)
}
}
impl IntoValue for Timestamp {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_timestamp(*self)
}
}
impl IntoValue for OffsetTimestamp {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_offset_timestamp(*self)
}
}
impl IntoValue for Interval {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_interval(*self)
}
}
impl IntoValue for Numeric {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_numeric(*self)
}
}
impl<T: IntoValue> IntoValue for Option<T> {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
match self {
Some(value) => value.add_to_inserter(inserter),
None => inserter.add_null(),
}
}
}
impl IntoValue for &bool {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_bool(**self)
}
}
impl IntoValue for &i16 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i16(**self)
}
}
impl IntoValue for &i32 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i32(**self)
}
}
impl IntoValue for &i64 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_i64(**self)
}
}
impl IntoValue for &f32 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_f32(**self)
}
}
impl IntoValue for &f64 {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_f64(**self)
}
}
impl IntoValue for &String {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_str(self)
}
}
impl IntoValue for &str {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_str(self)
}
}
impl IntoValue for &&str {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_str(self)
}
}
impl IntoValue for &[u8] {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_bytes(self)
}
}
impl IntoValue for &Date {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_date(**self)
}
}
impl IntoValue for &Time {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_time(**self)
}
}
impl IntoValue for &Timestamp {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_timestamp(**self)
}
}
impl IntoValue for &OffsetTimestamp {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_offset_timestamp(**self)
}
}
impl IntoValue for &Interval {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_interval(**self)
}
}
impl IntoValue for &Numeric {
fn add_to_inserter(&self, inserter: &mut Inserter<'_>) -> Result<()> {
inserter.add_numeric(**self)
}
}
#[derive(Debug)]
pub struct MappedInserter<'conn> {
inner: Inserter<'conn>,
target_table: crate::TableName,
mappings: Vec<ColumnMapping>,
staging_table: String,
}
impl<'conn> MappedInserter<'conn> {
fn new<T>(
connection: &'conn Connection,
inserter_def: &TableDefinition,
target_table: T,
mappings: &[ColumnMapping],
) -> Result<Self>
where
T: TryInto<crate::TableName>,
crate::Error: From<T::Error>,
{
let target_table = target_table.try_into()?;
let staging_table = format!("_hyper_staging_{}", std::process::id());
let mut staging_def = inserter_def.clone();
staging_def.name.clone_from(&staging_table);
let create_sql = staging_def.to_create_sql(true)?;
let create_temp = create_sql.replace("CREATE TABLE", "CREATE TEMPORARY TABLE");
connection.execute_command(&create_temp)?;
let inner = Inserter::new(connection, &staging_def)?;
Ok(MappedInserter {
inner,
target_table,
mappings: mappings.to_vec(),
staging_table,
})
}
pub fn add_row(&mut self, values: &[&dyn IntoValue]) -> Result<()> {
self.inner.add_row(values)
}
pub fn add_null(&mut self) -> Result<()> {
self.inner.add_null()
}
pub fn add_bool(&mut self, value: bool) -> Result<()> {
self.inner.add_bool(value)
}
pub fn add_i16(&mut self, value: i16) -> Result<()> {
self.inner.add_i16(value)
}
pub fn add_i32(&mut self, value: i32) -> Result<()> {
self.inner.add_i32(value)
}
pub fn add_i64(&mut self, value: i64) -> Result<()> {
self.inner.add_i64(value)
}
pub fn add_f32(&mut self, value: f32) -> Result<()> {
self.inner.add_f32(value)
}
pub fn add_f64(&mut self, value: f64) -> Result<()> {
self.inner.add_f64(value)
}
pub fn add_str(&mut self, value: &str) -> Result<()> {
self.inner.add_str(value)
}
pub fn add_bytes(&mut self, value: &[u8]) -> Result<()> {
self.inner.add_bytes(value)
}
pub fn end_row(&mut self) -> Result<()> {
self.inner.end_row()
}
pub fn execute(&mut self) -> Result<u64> {
let connection = self.inner.connection;
let staging_table = self.staging_table.clone();
let _staging_rows = self.inner.execute()?;
use hyperdb_api_core::protocol::escape::SqlIdentifier;
let target_columns: Vec<String> = self
.mappings
.iter()
.map(|m| format!("{}", SqlIdentifier(&m.column_name)))
.collect();
let select_items: Vec<String> = self
.mappings
.iter()
.map(ColumnMapping::to_select_item)
.collect();
let sql = format!(
"INSERT INTO {} ({}) SELECT {} FROM {}",
self.target_table,
target_columns.join(", "),
select_items.join(", "),
SqlIdentifier(&staging_table),
);
let row_count = connection.execute_command(&sql)?;
connection.execute_command(&format!(
"DROP TABLE IF EXISTS {}",
SqlIdentifier(&staging_table)
))?;
Ok(row_count)
}
pub fn cancel(&mut self) {
let connection = self.inner.connection;
let staging_table = &self.staging_table;
if let Err(e) = connection.execute_command(&format!(
"DROP TABLE IF EXISTS \"{}\"",
staging_table.replace('"', "\"\"")
)) {
eprintln!("Warning: Failed to drop staging table '{staging_table}' during cancel: {e}");
}
}
}
#[derive(Debug)]
pub struct InsertChunk {
buffer: BytesMut,
header_written: bool,
column_index: usize,
column_count: usize,
row_count: usize,
column_nullable: Vec<bool>,
}
unsafe impl Send for InsertChunk {}
unsafe impl Sync for InsertChunk {}
impl InsertChunk {
#[must_use]
pub fn new(column_count: usize, column_nullable: Vec<bool>) -> Self {
debug_assert_eq!(column_count, column_nullable.len());
InsertChunk {
buffer: BytesMut::with_capacity(INITIAL_BUFFER_SIZE),
header_written: false,
column_index: 0,
column_count,
row_count: 0,
column_nullable,
}
}
#[must_use]
pub fn from_table_definition(table_def: &TableDefinition) -> Self {
let column_nullable: Vec<bool> = table_def.columns.iter().map(|c| c.nullable).collect();
Self::new(table_def.column_count(), column_nullable)
}
#[must_use]
pub fn row_count(&self) -> usize {
self.row_count
}
#[must_use]
pub fn buffer_size(&self) -> usize {
self.buffer.len()
}
#[must_use]
pub fn should_flush(&self) -> bool {
self.row_count >= CHUNK_ROW_LIMIT || self.buffer.len() >= CHUNK_SIZE_LIMIT
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.row_count == 0
}
pub fn take(&mut self) -> Option<BytesMut> {
if self.row_count == 0 {
return None;
}
self.row_count = 0;
Some(std::mem::take(&mut self.buffer))
}
pub fn clear(&mut self) {
self.buffer.clear();
self.header_written = false;
self.column_index = 0;
self.row_count = 0;
}
#[allow(
clippy::inline_always,
reason = "hot-path numeric kernel; forced inlining measured to matter on this specific function"
)]
fn ensure_header(&mut self) {
if !self.header_written {
copy::write_header(&mut self.buffer);
self.header_written = true;
}
}
#[expect(
clippy::inline_always,
reason = "hot inner loop of the inserter; measured to matter for per-row throughput"
)]
#[inline(always)]
fn current_column_nullable(&self) -> bool {
*self.column_nullable.get(self.column_index).unwrap_or(&true)
}
pub fn add_null(&mut self) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
if !self.current_column_nullable() {
return Err(Error::new("Cannot add NULL to non-nullable column"));
}
self.ensure_header();
copy::write_null(&mut self.buffer);
self.column_index += 1;
Ok(())
}
pub fn add_bool(&mut self, value: bool) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let int_value = i8::from(value);
if self.current_column_nullable() {
copy::write_i8(&mut self.buffer, int_value);
} else {
copy::write_i8_not_null(&mut self.buffer, int_value);
}
self.column_index += 1;
Ok(())
}
pub fn add_i16(&mut self, value: i16) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_i16(&mut self.buffer, value);
} else {
copy::write_i16_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_i32(&mut self, value: i32) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_i32(&mut self.buffer, value);
} else {
copy::write_i32_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_i64(&mut self, value: i64) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_i64(&mut self.buffer, value);
} else {
copy::write_i64_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_f32(&mut self, value: f32) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_f32(&mut self.buffer, value);
} else {
copy::write_f32_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_f64(&mut self, value: f64) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_f64(&mut self.buffer, value);
} else {
copy::write_f64_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_str(&mut self, value: &str) -> Result<()> {
self.add_bytes(value.as_bytes())
}
pub fn add_bytes(&mut self, value: &[u8]) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
if value.len() > u32::MAX as usize {
return Err(Error::new(format!(
"Value length {} exceeds HyperBinary 4-byte length limit ({})",
value.len(),
u32::MAX
)));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_varbinary(&mut self.buffer, value);
} else {
copy::write_varbinary_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_data128(&mut self, value: &[u8; 16]) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
if self.current_column_nullable() {
copy::write_data128(&mut self.buffer, value);
} else {
copy::write_data128_not_null(&mut self.buffer, value);
}
self.column_index += 1;
Ok(())
}
pub fn add_date(&mut self, value: Date) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let julian_day = value.to_julian_day();
if self.current_column_nullable() {
copy::write_i32(&mut self.buffer, julian_day);
} else {
copy::write_i32_not_null(&mut self.buffer, julian_day);
}
self.column_index += 1;
Ok(())
}
pub fn add_time(&mut self, value: Time) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let micros = value.to_microseconds();
if self.current_column_nullable() {
copy::write_i64(&mut self.buffer, micros);
} else {
copy::write_i64_not_null(&mut self.buffer, micros);
}
self.column_index += 1;
Ok(())
}
pub fn add_timestamp(&mut self, value: Timestamp) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let micros = value.to_microseconds();
if self.current_column_nullable() {
copy::write_i64(&mut self.buffer, micros);
} else {
copy::write_i64_not_null(&mut self.buffer, micros);
}
self.column_index += 1;
Ok(())
}
pub fn add_offset_timestamp(&mut self, value: OffsetTimestamp) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let micros = value.to_microseconds_utc();
if self.current_column_nullable() {
copy::write_i64(&mut self.buffer, micros);
} else {
copy::write_i64_not_null(&mut self.buffer, micros);
}
self.column_index += 1;
Ok(())
}
pub fn add_interval(&mut self, value: Interval) -> Result<()> {
if self.column_index >= self.column_count {
return Err(Error::new("Too many columns in row"));
}
self.ensure_header();
let packed = value.to_packed();
if self.current_column_nullable() {
copy::write_data128(&mut self.buffer, &packed);
} else {
copy::write_data128_not_null(&mut self.buffer, &packed);
}
self.column_index += 1;
Ok(())
}
pub fn end_row(&mut self) -> Result<()> {
if self.column_index != self.column_count {
return Err(Error::new(format!(
"Expected {} columns, got {}",
self.column_count, self.column_index
)));
}
self.column_index = 0;
self.row_count += 1;
Ok(())
}
#[must_use]
pub fn column_index(&self) -> usize {
self.column_index
}
pub(crate) fn buffer(&self) -> &BytesMut {
&self.buffer
}
}
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Mutex;
#[derive(Debug)]
pub struct ChunkSender<'conn> {
connection: &'conn Connection,
table_name: String,
columns: Vec<String>,
writer: Mutex<Option<CopyInWriter<'conn>>>,
header_sent: std::sync::atomic::AtomicBool,
total_rows: AtomicU64,
chunks_sent: AtomicUsize,
}
impl<'conn> ChunkSender<'conn> {
pub fn new(connection: &'conn Connection, table_def: &TableDefinition) -> Result<Self> {
if table_def.column_count() == 0 {
return Err(Error::InvalidTableDefinition(
"Table definition must have at least one column".into(),
));
}
let columns: Vec<String> = table_def.columns.iter().map(|c| c.name.clone()).collect();
let table_name = table_def.qualified_name();
Ok(ChunkSender {
connection,
table_name,
columns,
writer: Mutex::new(None),
header_sent: std::sync::atomic::AtomicBool::new(false),
total_rows: AtomicU64::new(0),
chunks_sent: AtomicUsize::new(0),
})
}
pub fn send_chunk(&self, mut chunk: InsertChunk) -> Result<()> {
let row_count = chunk.row_count();
let Some(buffer) = chunk.take() else {
return Ok(());
};
let mut writer_guard = self
.writer
.lock()
.map_err(|_| Error::new("ChunkSender mutex poisoned"))?;
if writer_guard.is_none() {
let client = self.connection.tcp_client().ok_or_else(|| {
Error::new("ChunkSender requires a TCP connection. gRPC connections do not support COPY operations.")
})?;
let columns: Vec<&str> = self
.columns
.iter()
.map(std::string::String::as_str)
.collect();
*writer_guard = Some(client.copy_in(&self.table_name, &columns)?);
}
let is_first = !self.header_sent.swap(true, Ordering::SeqCst);
let data_to_send = if is_first {
&buffer[..]
} else {
if buffer.len() > hyperdb_api_core::protocol::copy::HYPER_BINARY_HEADER_SIZE
&& buffer.starts_with(hyperdb_api_core::protocol::copy::HYPER_BINARY_HEADER)
{
&buffer[hyperdb_api_core::protocol::copy::HYPER_BINARY_HEADER_SIZE..]
} else {
&buffer[..]
}
};
if let Some(ref mut writer) = *writer_guard {
writer.send_direct(data_to_send)?;
writer.flush_stream()?;
}
drop(writer_guard);
self.total_rows
.fetch_add(row_count as u64, Ordering::Relaxed);
self.chunks_sent.fetch_add(1, Ordering::Relaxed);
debug!(
target: "hyperdb_api",
chunk = self.chunks_sent.load(Ordering::Relaxed),
rows = row_count,
bytes = data_to_send.len(),
"chunk-sender"
);
Ok(())
}
pub fn total_rows(&self) -> u64 {
self.total_rows.load(Ordering::Relaxed)
}
pub fn chunks_sent(&self) -> usize {
self.chunks_sent.load(Ordering::Relaxed)
}
pub fn finish(self) -> Result<u64> {
let mut writer_guard = self
.writer
.lock()
.map_err(|_| Error::new("ChunkSender mutex poisoned"))?;
let Some(writer) = writer_guard.take() else {
return Ok(0);
};
let mut trailer_buf = BytesMut::with_capacity(2);
copy::write_trailer(&mut trailer_buf);
let mut writer = writer;
writer.send(&trailer_buf)?;
let rows = writer.finish()?;
info!(
target: "hyperdb_api",
rows,
chunks = self.chunks_sent.load(Ordering::Relaxed),
table = %self.table_name,
"chunk-sender-finish"
);
Ok(rows)
}
}
#[cfg(test)]
mod tests {
use crate::table_definition::TableDefinition;
use hyperdb_api_core::types::SqlType;
use super::InsertChunk;
fn create_test_table_def() -> TableDefinition {
TableDefinition::new("test")
.add_required_column("id", SqlType::int())
.add_nullable_column("name", SqlType::text())
}
#[test]
fn test_inserter_column_validation() {
let table_def = create_test_table_def();
assert_eq!(table_def.column_count(), 2);
}
#[test]
fn test_insert_chunk_encoding() {
let table_def = create_test_table_def();
let mut chunk = InsertChunk::from_table_definition(&table_def);
chunk.add_i32(42).unwrap();
chunk.add_str("hello").unwrap();
chunk.end_row().unwrap();
assert_eq!(chunk.row_count(), 1);
assert!(!chunk.is_empty());
assert!(!chunk.should_flush());
for i in 0..100 {
chunk.add_i32(i).unwrap();
chunk.add_str(&format!("item {i}")).unwrap();
chunk.end_row().unwrap();
}
assert_eq!(chunk.row_count(), 101);
let buffer = chunk.take().unwrap();
assert!(!buffer.is_empty());
assert!(chunk.take().is_none());
}
#[test]
fn test_insert_chunk_null_handling() {
let table_def = create_test_table_def();
let mut chunk = InsertChunk::from_table_definition(&table_def);
assert!(chunk.add_null().is_err());
chunk.add_i32(1).unwrap();
chunk.add_null().unwrap();
chunk.end_row().unwrap();
assert_eq!(chunk.row_count(), 1);
}
#[test]
fn test_insert_chunk_column_count_validation() {
let table_def = create_test_table_def();
let mut chunk = InsertChunk::from_table_definition(&table_def);
chunk.add_i32(1).unwrap();
assert!(chunk.end_row().is_err());
chunk.add_str("test").unwrap();
chunk.end_row().unwrap();
}
#[test]
fn test_insert_chunk_too_many_columns() {
let table_def = create_test_table_def();
let mut chunk = InsertChunk::from_table_definition(&table_def);
chunk.add_i32(1).unwrap();
chunk.add_str("test").unwrap();
assert!(chunk.add_i32(2).is_err());
}
#[test]
fn test_insert_chunk_clear() {
let table_def = create_test_table_def();
let mut chunk = InsertChunk::from_table_definition(&table_def);
chunk.add_i32(1).unwrap();
chunk.add_str("test").unwrap();
chunk.end_row().unwrap();
assert_eq!(chunk.row_count(), 1);
chunk.clear();
assert_eq!(chunk.row_count(), 0);
assert!(chunk.is_empty());
}
}