use std::time::Instant;
use hyperdb_api_core::client::AsyncCopyInWriter;
use hyperdb_api_core::protocol::copy;
use hyperdb_api_core::types::bytes::BytesMut;
use hyperdb_api_core::types::{Date, Interval, Numeric, OffsetTimestamp, Time, Timestamp};
use tracing::info;
use crate::async_connection::AsyncConnection;
use crate::error::{Error, Result};
use crate::inserter::InsertChunk;
use crate::table_definition::TableDefinition;
#[derive(Debug)]
pub struct AsyncInserter<'conn> {
connection: &'conn AsyncConnection,
table_def: TableDefinition,
chunk: InsertChunk,
row_count: u64,
chunk_count: usize,
writer: Option<AsyncCopyInWriter<'conn>>,
start_time: Instant,
}
#[allow(
clippy::missing_errors_doc,
reason = "per-column add_* methods all return the same error shape \
documented on Inserter::add_bool — repeating the same `# Errors` \
block on 15 thin delegators adds noise without adding info"
)]
impl<'conn> AsyncInserter<'conn> {
pub fn new(connection: &'conn AsyncConnection, table_def: &TableDefinition) -> Result<Self> {
if table_def.column_count() == 0 {
return Err(Error::InvalidTableDefinition(
"Table definition must have at least one column".into(),
));
}
if connection.async_tcp_client().is_none() {
return Err(Error::new(
"AsyncInserter requires a TCP connection. \
gRPC connections do not support COPY operations.",
));
}
Ok(Self {
connection,
table_def: table_def.clone(),
chunk: InsertChunk::from_table_definition(table_def),
row_count: 0,
chunk_count: 0,
writer: None,
start_time: Instant::now(),
})
}
pub fn add_null(&mut self) -> Result<()> {
self.chunk.add_null()
}
pub fn add_bool(&mut self, value: bool) -> Result<()> {
self.chunk.add_bool(value)
}
pub fn add_i16(&mut self, value: i16) -> Result<()> {
self.chunk.add_i16(value)
}
pub fn add_i32(&mut self, value: i32) -> Result<()> {
self.chunk.add_i32(value)
}
pub fn add_i64(&mut self, value: i64) -> Result<()> {
self.chunk.add_i64(value)
}
pub fn add_f32(&mut self, value: f32) -> Result<()> {
self.chunk.add_f32(value)
}
pub fn add_f64(&mut self, value: f64) -> Result<()> {
self.chunk.add_f64(value)
}
pub fn add_str(&mut self, value: &str) -> Result<()> {
self.chunk.add_str(value)
}
pub fn add_bytes(&mut self, value: &[u8]) -> Result<()> {
self.chunk.add_bytes(value)
}
pub fn add_date(&mut self, value: Date) -> Result<()> {
self.chunk.add_date(value)
}
pub fn add_time(&mut self, value: Time) -> Result<()> {
self.chunk.add_time(value)
}
pub fn add_timestamp(&mut self, value: Timestamp) -> Result<()> {
self.chunk.add_timestamp(value)
}
pub fn add_offset_timestamp(&mut self, value: OffsetTimestamp) -> Result<()> {
self.chunk.add_offset_timestamp(value)
}
pub fn add_interval(&mut self, value: Interval) -> Result<()> {
self.chunk.add_interval(value)
}
pub fn add_numeric(&mut self, value: Numeric) -> Result<()> {
let column_index = self.chunk.column_index();
let precision = self
.table_def
.columns
.get(column_index)
.and_then(super::table_definition::ColumnDefinition::sql_type)
.and_then(|t| t.precision())
.ok_or_else(|| {
let col_name = self
.table_def
.columns
.get(column_index)
.map_or("<unknown>", |c| c.name.as_str());
Error::new(format!(
"Cannot determine numeric precision for column '{col_name}' at index {column_index}. \
Ensure the column is defined with explicit SqlType including precision."
))
})?;
if precision <= Numeric::SMALL_NUMERIC_MAX_PRECISION {
let unscaled = value.unscaled_value();
let narrowed = i64::try_from(unscaled).map_err(|_| {
Error::new(format!(
"Numeric value {unscaled} is out of range for i64 storage (precision {precision})"
))
})?;
self.chunk.add_i64(narrowed)
} else {
self.chunk.add_data128(&value.to_packed())
}
}
pub async fn end_row(&mut self) -> Result<()> {
self.chunk.end_row()?;
self.row_count += 1;
if self.chunk.should_flush() {
self.flush().await?;
}
Ok(())
}
async fn flush(&mut self) -> Result<()> {
if self.writer.is_none() {
let client = self.connection.async_tcp_client().ok_or_else(|| {
Error::new(
"AsyncInserter requires a TCP connection. \
gRPC connections do not support COPY operations.",
)
})?;
let columns: Vec<&str> = self
.table_def
.columns
.iter()
.map(|c| c.name.as_str())
.collect();
let table_name = self.table_def.qualified_name();
self.writer = Some(client.copy_in(&table_name, &columns).await?);
}
if let Some(buffer) = self.chunk.take() {
if let Some(writer) = self.writer.as_mut() {
writer.send(&buffer).await?;
self.chunk_count += 1;
}
}
Ok(())
}
pub async fn execute(&mut self) -> Result<u64> {
if self.chunk.column_index() != 0 {
return Err(Error::new("Incomplete row at execute time"));
}
if self.row_count == 0 {
return Ok(0);
}
self.flush().await?;
let mut trailer_buf = BytesMut::with_capacity(2);
copy::write_trailer(&mut trailer_buf);
if let Some(writer) = self.writer.as_mut() {
writer.send(&trailer_buf).await?;
}
let rows = if let Some(writer) = self.writer.take() {
writer.finish().await?
} else {
0
};
let duration_ms = u64::try_from(self.start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
info!(
target: "hyperdb_api",
rows,
chunks = self.chunk_count,
duration_ms,
table = %self.table_def.qualified_name(),
"async-inserter-end"
);
self.row_count = 0;
Ok(rows)
}
pub fn cancel(&mut self) {
self.writer = None;
self.row_count = 0;
}
}