use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use arrow_flight::error::FlightError;
use arrow_flight::{FlightClient, FlightData, PutResult, SchemaAsIpc};
use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
use bytes::Bytes;
use futures::{Stream, StreamExt};
use tokio::sync::{mpsc, watch, Mutex};
use tokio::time::{sleep, Duration};
use tokio_retry::strategy::FixedInterval;
use tokio_retry::RetryIf;
use tonic::transport::Channel;
use tracing::{debug, error, info, instrument, warn};
pub use arrow_array::RecordBatch;
pub use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use crate::arrow_configuration::ArrowStreamConfigurationOptions;
use crate::arrow_metadata::{FlightAckMetadata, FlightBatchMetadata};
use crate::errors::ZerobusError;
use crate::headers_provider::HeadersProvider;
use crate::offset_generator::{OffsetId, OffsetIdGenerator};
use crate::tls_config::TlsConfig;
use crate::ZerobusResult;
type BatchSender = Arc<Mutex<Option<mpsc::Sender<Result<FlightData, FlightError>>>>>;
#[derive(Clone)]
enum ArrowPayload {
Ipc(Bytes),
Batch(RecordBatch),
}
impl ArrowPayload {
#[allow(clippy::result_large_err)]
fn materialize(&self) -> ZerobusResult<RecordBatch> {
match self {
ArrowPayload::Batch(b) => Ok(b.clone()),
ArrowPayload::Ipc(bytes) => materialize_ipc(bytes),
}
}
}
#[derive(Debug, Clone)]
pub struct ArrowTableProperties {
pub table_name: String,
pub schema: Arc<ArrowSchema>,
}
#[derive(Clone)]
struct PendingBatch {
payload: ArrowPayload,
offset_id: OffsetId,
start_record: u64,
end_record: u64,
}
#[allow(clippy::result_large_err)]
fn slice_batch_for_recovery(
pb: &PendingBatch,
acked_before_disconnect: u64,
) -> ZerobusResult<Option<ArrowPayload>> {
if pb.start_record >= acked_before_disconnect {
return Ok(Some(pb.payload.clone()));
}
let total_rows = pb.end_record - pb.start_record;
let records_already_acked = (acked_before_disconnect - pb.start_record).min(total_rows);
let remaining_rows = total_rows.saturating_sub(records_already_acked);
if remaining_rows == 0 {
Ok(None)
} else if records_already_acked == 0 {
Ok(Some(pb.payload.clone()))
} else {
debug!(
offset_id = pb.offset_id,
total_rows = total_rows,
records_already_acked = records_already_acked,
remaining_rows = remaining_rows,
"Slicing partially-acked batch for recovery"
);
match &pb.payload {
ArrowPayload::Batch(b) => Ok(Some(ArrowPayload::Batch(
b.slice(records_already_acked as usize, remaining_rows as usize),
))),
ArrowPayload::Ipc(bytes) => {
let b = materialize_ipc(bytes).map_err(|e| {
ZerobusError::InvalidArgument(format!(
"IPC batch could not be deserialised for partial recovery (offset_id={}): {e}",
pb.offset_id
))
})?;
Ok(Some(ArrowPayload::Batch(b.slice(
records_already_acked as usize,
remaining_rows as usize,
))))
}
}
}
}
#[allow(clippy::result_large_err)]
fn materialize_ipc(bytes: &Bytes) -> ZerobusResult<RecordBatch> {
use std::io::Cursor;
let mut reader = arrow_ipc::reader::StreamReader::try_new(Cursor::new(bytes.as_ref()), None)
.map_err(|e| {
ZerobusError::InvalidArgument(format!("IPC: invalid Arrow IPC stream: {e}"))
})?;
let batch = match reader.next() {
None => {
return Err(ZerobusError::InvalidArgument(
"IPC stream contains no RecordBatch".into(),
));
}
Some(Err(e)) => {
return Err(ZerobusError::InvalidArgument(format!(
"IPC: record batch read failed: {e}"
)));
}
Some(Ok(b)) => b,
};
match reader.next() {
None => Ok(batch),
Some(Ok(_)) => Err(ZerobusError::InvalidArgument(
"IPC stream must contain exactly one RecordBatch (found extra batch)".into(),
)),
Some(Err(e)) => Err(ZerobusError::InvalidArgument(format!(
"IPC: trailing message read failed: {e}"
))),
}
}
struct ParsedIpcBatch {
schema: ArrowSchema,
num_rows: u64,
flight_data: Vec<FlightData>,
}
#[allow(clippy::result_large_err)]
fn ipc_bytes_to_flight_data(ipc_bytes: &Bytes) -> ZerobusResult<ParsedIpcBatch> {
let bytes = &ipc_bytes[..];
fn align8(n: usize) -> usize {
(n + 7) & !7
}
#[allow(clippy::result_large_err)]
fn read_meta_range(bytes: &[u8], mut p: usize) -> ZerobusResult<(usize, usize)> {
if p + 4 <= bytes.len() && bytes[p..p + 4] == [0xFF, 0xFF, 0xFF, 0xFF] {
p += 4;
}
if p + 4 > bytes.len() {
return Err(ZerobusError::InvalidArgument(
"IPC: truncated at length field".into(),
));
}
let meta_len = i32::from_le_bytes([bytes[p], bytes[p + 1], bytes[p + 2], bytes[p + 3]]);
if meta_len <= 0 {
return Err(ZerobusError::InvalidArgument(
"IPC: invalid metadata length".into(),
));
}
let meta_start = p + 4;
let meta_end = meta_start + meta_len as usize;
if meta_end > bytes.len() {
return Err(ZerobusError::InvalidArgument(
"IPC: truncated metadata".into(),
));
}
Ok((meta_start, meta_end))
}
let (ms, me) = read_meta_range(bytes, 0)?;
let schema_msg = arrow_ipc::root_as_message(&bytes[ms..me])
.map_err(|e| ZerobusError::InvalidArgument(format!("IPC flatbuffer: {e}")))?;
let fb_schema = schema_msg.header_as_schema().ok_or_else(|| {
ZerobusError::InvalidArgument("IPC: first message is not a Schema".into())
})?;
let schema = arrow_ipc::convert::fb_to_schema(fb_schema);
let after_schema = align8(me + schema_msg.bodyLength().max(0) as usize);
if after_schema > bytes.len() {
return Err(ZerobusError::InvalidArgument(
"IPC: truncated schema body".into(),
));
}
let mut pos = after_schema;
let mut flight_data_messages: Vec<FlightData> = Vec::new();
let mut num_rows: Option<u64> = None;
while pos < bytes.len() {
if pos + 8 <= bytes.len()
&& bytes[pos..pos + 4] == [0xFF, 0xFF, 0xFF, 0xFF]
&& bytes[pos + 4..pos + 8] == [0x00, 0x00, 0x00, 0x00]
{
break; }
let (msg_ms, msg_me) = match read_meta_range(bytes, pos) {
Ok(r) => r,
Err(_) => {
debug!(pos, "IPC: ignoring trailing bytes");
break;
}
};
let msg = arrow_ipc::root_as_message(&bytes[msg_ms..msg_me])
.map_err(|e| ZerobusError::InvalidArgument(format!("IPC flatbuffer: {e}")))?;
let body_end = align8(msg_me + msg.bodyLength().max(0) as usize);
if body_end > bytes.len() {
return Err(ZerobusError::InvalidArgument(
"IPC: truncated message body".into(),
));
}
match msg.header_type() {
arrow_ipc::MessageHeader::DictionaryBatch => {
flight_data_messages.push(FlightData {
data_header: ipc_bytes.slice(msg_ms..msg_me),
data_body: ipc_bytes.slice(msg_me..body_end),
..Default::default()
});
}
arrow_ipc::MessageHeader::RecordBatch => {
if num_rows.is_some() {
return Err(ZerobusError::InvalidArgument(
"IPC stream must contain exactly one RecordBatch (found extra batch)"
.into(),
));
}
let rb = msg.header_as_record_batch().ok_or_else(|| {
ZerobusError::InvalidArgument(
"IPC: RecordBatch header could not be parsed".into(),
)
})?;
num_rows = Some(rb.length().max(0) as u64);
flight_data_messages.push(FlightData {
data_header: ipc_bytes.slice(msg_ms..msg_me),
data_body: ipc_bytes.slice(msg_me..body_end),
..Default::default()
});
}
_ => {
return Err(ZerobusError::InvalidArgument(format!(
"IPC: unexpected message type {:?}",
msg.header_type()
)));
}
}
pos = body_end;
}
let num_rows = num_rows.ok_or_else(|| {
ZerobusError::InvalidArgument("IPC stream contains no RecordBatch".into())
})?;
Ok(ParsedIpcBatch {
schema,
num_rows,
flight_data: flight_data_messages,
})
}
#[allow(clippy::result_large_err)]
fn make_ipc_write_options(
compression: Option<arrow_ipc::CompressionType>,
) -> ZerobusResult<IpcWriteOptions> {
match compression {
None => Ok(IpcWriteOptions::default()),
Some(c) => IpcWriteOptions::default()
.try_with_compression(Some(c))
.map_err(|e| {
ZerobusError::InvalidArgument(format!(
"Failed to enable Arrow IPC compression: {e}"
))
}),
}
}
fn schema_to_flight_data(schema: &ArrowSchema, opts: &IpcWriteOptions) -> FlightData {
SchemaAsIpc::new(schema, opts).into()
}
#[allow(clippy::result_large_err)]
fn record_batch_to_flight_data(
batch: &RecordBatch,
opts: &IpcWriteOptions,
) -> ZerobusResult<Vec<FlightData>> {
let data_gen = IpcDataGenerator::default();
let mut dict_tracker = DictionaryTracker::new(true);
let _ = data_gen.schema_to_bytes_with_dictionary_tracker(
batch.schema_ref(),
&mut dict_tracker,
opts,
);
let (dict_batches, encoded) = data_gen
.encoded_batch(batch, &mut dict_tracker, opts)
.map_err(|e| ZerobusError::InvalidArgument(format!("Failed to encode RecordBatch: {e}")))?;
let mut flight_data: Vec<FlightData> = dict_batches.into_iter().map(Into::into).collect();
flight_data.push(encoded.into());
Ok(flight_data)
}
pub struct ZerobusArrowStream {
pub(crate) table_properties: ArrowTableProperties,
pub(crate) options: ArrowStreamConfigurationOptions,
batch_tx: BatchSender,
offset_generator: OffsetIdGenerator,
last_ack_tx: tokio::sync::watch::Sender<Option<OffsetId>>,
_last_ack_rx: tokio::sync::watch::Receiver<Option<OffsetId>>,
is_closed: Arc<AtomicBool>,
receiver_task: Arc<Mutex<Option<tokio::task::JoinHandle<ZerobusResult<()>>>>>,
pending_batches: Arc<Mutex<Vec<PendingBatch>>>,
failed_batches: Arc<Mutex<Vec<ArrowPayload>>>,
recovery_attempts: Arc<AtomicU32>,
endpoint: String,
tls_config: Arc<dyn TlsConfig>,
headers_provider: Arc<dyn HeadersProvider>,
ingest_mutex: Arc<Mutex<()>>,
server_error_tx: watch::Sender<Option<ZerobusError>>,
server_error_rx: watch::Receiver<Option<ZerobusError>>,
cumulative_records_sent: Arc<AtomicU64>,
last_acked_records: Arc<AtomicU64>,
}
impl ZerobusArrowStream {
#[instrument(level = "debug", skip_all, fields(table_name = %table_properties.table_name))]
pub(crate) async fn new(
endpoint: &str,
tls_config: Arc<dyn TlsConfig>,
table_properties: ArrowTableProperties,
headers_provider: Arc<dyn HeadersProvider>,
options: ArrowStreamConfigurationOptions,
) -> ZerobusResult<Self> {
let (last_ack_tx, _last_ack_rx) = tokio::sync::watch::channel(None);
let is_closed = Arc::new(AtomicBool::new(false));
let pending_batches = Arc::new(Mutex::new(Vec::new()));
let failed_batches = Arc::new(Mutex::new(Vec::new()));
let recovery_attempts = Arc::new(AtomicU32::new(0));
let batch_tx = Arc::new(Mutex::new(None));
let receiver_task = Arc::new(Mutex::new(None));
let cumulative_records_sent = Arc::new(AtomicU64::new(0));
let last_acked_records = Arc::new(AtomicU64::new(0));
let (server_error_tx, server_error_rx) = watch::channel(None);
let stream = Self {
table_properties,
options,
batch_tx,
offset_generator: OffsetIdGenerator::default(),
last_ack_tx,
_last_ack_rx,
is_closed,
receiver_task,
pending_batches,
failed_batches,
recovery_attempts,
endpoint: endpoint.to_string(),
tls_config,
headers_provider,
ingest_mutex: Arc::new(Mutex::new(())),
server_error_tx,
server_error_rx,
cumulative_records_sent,
last_acked_records,
};
let endpoint = stream.endpoint.clone();
let tls_config = Arc::clone(&stream.tls_config);
let table_properties = stream.table_properties.clone();
let options = stream.options.clone();
let headers_provider = Arc::clone(&stream.headers_provider);
let strategy = FixedInterval::from_millis(options.recovery_backoff_ms)
.take(options.recovery_retries as usize);
let create_attempt = || {
let endpoint = endpoint.clone();
let tls_config = Arc::clone(&tls_config);
let table_properties = table_properties.clone();
let options = options.clone();
let headers_provider = Arc::clone(&headers_provider);
async move {
tokio::time::timeout(
Duration::from_millis(options.recovery_timeout_ms),
Self::try_connect(
&endpoint,
&tls_config,
&table_properties,
&options,
&headers_provider,
),
)
.await
.map_err(|_| {
ZerobusError::CreateStreamError(tonic::Status::deadline_exceeded(
"Stream creation timed out",
))
})?
}
};
let should_retry = |e: &ZerobusError| options.recovery && e.is_retryable();
let creation = RetryIf::spawn(strategy, create_attempt, should_retry).await;
let (response_stream, tx) = match creation {
Ok(result) => result,
Err(e) => {
error!("Arrow Flight stream creation failed after retries: {}", e);
return Err(e);
}
};
{
let mut batch_tx = stream.batch_tx.lock().await;
*batch_tx = Some(tx);
}
let task = Self::spawn_supervisor_task(
stream.endpoint.clone(),
Arc::clone(&stream.tls_config),
stream.table_properties.clone(),
stream.options.clone(),
Arc::clone(&stream.headers_provider),
Arc::clone(&stream.batch_tx),
Arc::clone(&stream.is_closed),
stream.last_ack_tx.clone(),
Arc::clone(&stream.pending_batches),
Arc::clone(&stream.failed_batches),
Arc::clone(&stream.recovery_attempts),
stream.server_error_tx.clone(),
Arc::clone(&stream.cumulative_records_sent),
Arc::clone(&stream.last_acked_records),
response_stream,
);
{
let mut receiver_task = stream.receiver_task.lock().await;
*receiver_task = Some(task);
}
info!(
table_name = %stream.table_properties.table_name,
"Arrow Flight stream created successfully"
);
Ok(stream)
}
async fn try_connect(
endpoint: &str,
tls_config: &Arc<dyn TlsConfig>,
table_properties: &ArrowTableProperties,
options: &ArrowStreamConfigurationOptions,
headers_provider: &Arc<dyn HeadersProvider>,
) -> ZerobusResult<(
Pin<Box<dyn Stream<Item = Result<PutResult, FlightError>> + Send>>,
mpsc::Sender<Result<FlightData, FlightError>>,
)> {
let client = Self::create_flight_client(
endpoint,
tls_config,
table_properties,
options,
headers_provider,
)
.await?;
Self::start_stream_connection(client, table_properties, options).await
}
async fn create_flight_client(
endpoint: &str,
tls_config: &Arc<dyn TlsConfig>,
table_properties: &ArrowTableProperties,
options: &ArrowStreamConfigurationOptions,
headers_provider: &Arc<dyn HeadersProvider>,
) -> ZerobusResult<FlightClient> {
let connection_timeout = Duration::from_millis(options.connection_timeout_ms);
let base_endpoint = Channel::from_shared(endpoint.to_string())
.map_err(|e| ZerobusError::ChannelCreationError(e.to_string()))?
.connect_timeout(connection_timeout)
.timeout(connection_timeout);
let channel = tls_config.configure_endpoint(base_endpoint)?.connect_lazy();
let mut client = FlightClient::new(channel);
const TABLE_NAME_HEADER: &str = "x-databricks-zerobus-table-name";
let headers = headers_provider.get_headers().await?;
for (key, value) in headers {
if key.eq_ignore_ascii_case(TABLE_NAME_HEADER) {
warn!(
"HeadersProvider attempted to set reserved header '{}', ignoring",
TABLE_NAME_HEADER
);
continue;
}
client.add_header(key, &value).map_err(|e| {
ZerobusError::InvalidArgument(format!("Failed to add header '{}': {}", key, e))
})?;
}
client
.add_header(TABLE_NAME_HEADER, &table_properties.table_name)
.map_err(|e| {
ZerobusError::InvalidArgument(format!("Failed to add table name header: {}", e))
})?;
Ok(client)
}
async fn start_stream_connection(
mut client: FlightClient,
table_properties: &ArrowTableProperties,
options: &ArrowStreamConfigurationOptions,
) -> ZerobusResult<(
Pin<Box<dyn Stream<Item = Result<PutResult, FlightError>> + Send>>,
mpsc::Sender<Result<FlightData, FlightError>>,
)> {
let (batch_tx, batch_rx) =
mpsc::channel::<Result<FlightData, FlightError>>(options.max_inflight_batches);
let ipc_write_options = make_ipc_write_options(options.ipc_compression)?;
let schema_fd = schema_to_flight_data(&table_properties.schema, &ipc_write_options);
let data_stream = tokio_stream::wrappers::ReceiverStream::new(batch_rx);
let flight_data_stream =
futures::stream::once(futures::future::ready(Ok(schema_fd))).chain(data_stream);
let mut response_stream = client
.do_put(flight_data_stream)
.await
.map_err(|e| ZerobusError::CreateStreamError(tonic::Status::from_error(Box::new(e))))?;
let setup_timeout = Duration::from_millis(options.connection_timeout_ms);
match tokio::time::timeout(setup_timeout, response_stream.next()).await {
Ok(Some(Ok(put_result))) => {
match FlightAckMetadata::from_bytes(&put_result.app_metadata) {
Ok(metadata) if metadata.is_stream_ready() => {
info!("Stream setup confirmed by server (ready signal received)");
}
Ok(metadata) => {
error!(
"Unexpected ack during setup (offset {}), expected ready signal",
metadata.ack_up_to_offset
);
return Err(ZerobusError::UnexpectedStreamResponseError(format!(
"Expected ready signal, got ack for offset {}",
metadata.ack_up_to_offset
)));
}
Err(e) => {
error!("Failed to parse setup response metadata: {}", e);
return Err(ZerobusError::UnexpectedStreamResponseError(format!(
"Malformed setup response metadata: {}",
e
)));
}
}
}
Ok(Some(Err(flight_error))) => {
error!("Stream setup failed: {:?}", flight_error);
return Err(ZerobusError::CreateStreamError(tonic::Status::from_error(
Box::new(flight_error),
)));
}
Ok(None) => {
error!("Server closed stream during setup without response");
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Server closed stream during setup",
)));
}
Err(_timeout) => {
error!(
"Timed out waiting for server setup confirmation ({}ms)",
options.connection_timeout_ms
);
return Err(ZerobusError::ConnectionTimeout(format!(
"Timed out waiting for server setup confirmation ({}ms)",
options.connection_timeout_ms
)));
}
}
Ok((response_stream, batch_tx))
}
#[allow(clippy::too_many_arguments)]
fn spawn_supervisor_task(
endpoint: String,
tls_config: Arc<dyn TlsConfig>,
table_properties: ArrowTableProperties,
options: ArrowStreamConfigurationOptions,
headers_provider: Arc<dyn HeadersProvider>,
batch_tx: BatchSender,
is_closed: Arc<AtomicBool>,
last_ack_tx: tokio::sync::watch::Sender<Option<OffsetId>>,
pending_batches: Arc<Mutex<Vec<PendingBatch>>>,
failed_batches: Arc<Mutex<Vec<ArrowPayload>>>,
recovery_attempts: Arc<AtomicU32>,
server_error_tx: watch::Sender<Option<ZerobusError>>,
cumulative_records_sent: Arc<AtomicU64>,
last_acked_records: Arc<AtomicU64>,
initial_response_stream: Pin<Box<dyn Stream<Item = Result<PutResult, FlightError>> + Send>>,
) -> tokio::task::JoinHandle<ZerobusResult<()>> {
tokio::spawn(async move {
let ack_timeout = Duration::from_millis(options.server_lack_of_ack_timeout_ms);
let mut response_stream = initial_response_stream;
loop {
if is_closed.load(Ordering::Relaxed) {
debug!("Supervisor: Stream closed, exiting");
return Ok(());
}
let result = Self::process_acks(
response_stream,
Arc::clone(&is_closed),
last_ack_tx.clone(),
Arc::clone(&pending_batches),
ack_timeout,
server_error_tx.clone(),
Arc::clone(&last_acked_records),
)
.await;
if is_closed.load(Ordering::Relaxed) {
debug!("Supervisor: Stream closed after process_acks, exiting");
return result;
}
match result {
Ok(()) => {
debug!("Supervisor: process_acks completed successfully");
return Ok(());
}
Err(ref error) if error.is_retryable() && options.recovery => {
let attempts = recovery_attempts.fetch_add(1, Ordering::Relaxed);
if attempts >= options.recovery_retries {
error!(
attempts = attempts,
max_retries = options.recovery_retries,
"Supervisor: Max recovery retries exceeded"
);
is_closed.store(true, Ordering::Relaxed);
Self::move_pending_to_failed(&pending_batches, &failed_batches).await;
return result;
}
info!(
attempt = attempts + 1,
max_retries = options.recovery_retries,
error = %error,
"Supervisor: Attempting recovery after retriable error"
);
sleep(Duration::from_millis(options.recovery_backoff_ms)).await;
let _ = server_error_tx.send(None);
{
let mut tx_guard = batch_tx.lock().await;
*tx_guard = None;
}
let reconnect_result = tokio::time::timeout(
Duration::from_millis(options.recovery_timeout_ms),
Self::reconnect(
&endpoint,
&tls_config,
&table_properties,
&options,
&headers_provider,
&batch_tx,
&pending_batches,
&cumulative_records_sent,
&last_acked_records,
),
)
.await;
match reconnect_result {
Ok(Ok(new_response_stream)) => {
info!("Supervisor: Recovery successful, resuming");
recovery_attempts.store(0, Ordering::Relaxed);
response_stream = new_response_stream;
}
Ok(Err(e)) => {
warn!("Supervisor: Reconnection failed: {}", e);
response_stream = Box::pin(futures::stream::once(async move {
Err(FlightError::Tonic(Box::new(tonic::Status::unavailable(
"Reconnection failed",
))))
}));
}
Err(_timeout) => {
warn!("Supervisor: Reconnection timed out");
response_stream = Box::pin(futures::stream::once(async move {
Err(FlightError::Tonic(Box::new(
tonic::Status::deadline_exceeded("Reconnection timed out"),
)))
}));
}
}
}
Err(error) => {
error!("Supervisor: Non-retriable error, closing stream: {}", error);
is_closed.store(true, Ordering::Relaxed);
Self::move_pending_to_failed(&pending_batches, &failed_batches).await;
return Err(error);
}
}
}
})
}
#[allow(clippy::too_many_arguments)]
async fn reconnect(
endpoint: &str,
tls_config: &Arc<dyn TlsConfig>,
table_properties: &ArrowTableProperties,
options: &ArrowStreamConfigurationOptions,
headers_provider: &Arc<dyn HeadersProvider>,
batch_tx: &BatchSender,
pending_batches: &Arc<Mutex<Vec<PendingBatch>>>,
cumulative_records_sent: &Arc<AtomicU64>,
last_acked_records: &Arc<AtomicU64>,
) -> ZerobusResult<Pin<Box<dyn Stream<Item = Result<PutResult, FlightError>> + Send>>> {
let client = Self::create_flight_client(
endpoint,
tls_config,
table_properties,
options,
headers_provider,
)
.await?;
let (tx, batch_rx) =
mpsc::channel::<Result<FlightData, FlightError>>(options.max_inflight_batches);
let ipc_write_options = make_ipc_write_options(options.ipc_compression)?;
let schema_fd = schema_to_flight_data(&table_properties.schema, &ipc_write_options);
let data_stream = tokio_stream::wrappers::ReceiverStream::new(batch_rx);
let flight_data_stream =
futures::stream::once(futures::future::ready(Ok(schema_fd))).chain(data_stream);
let mut flight_client = client;
let mut response_stream = flight_client
.do_put(flight_data_stream)
.await
.map_err(|e| ZerobusError::CreateStreamError(tonic::Status::from_error(Box::new(e))))?;
let setup_timeout = Duration::from_millis(options.connection_timeout_ms);
match tokio::time::timeout(setup_timeout, response_stream.next()).await {
Ok(Some(Ok(put_result))) => {
match FlightAckMetadata::from_bytes(&put_result.app_metadata) {
Ok(metadata) if metadata.is_stream_ready() => {
info!("Reconnection confirmed by server (ready signal received)");
}
Ok(metadata) => {
error!(
"Unexpected ack during reconnect (offset {}), expected ready signal",
metadata.ack_up_to_offset
);
return Err(ZerobusError::UnexpectedStreamResponseError(format!(
"Expected ready signal, got ack for offset {}",
metadata.ack_up_to_offset
)));
}
Err(e) => {
error!("Failed to parse reconnect response metadata: {}", e);
return Err(ZerobusError::UnexpectedStreamResponseError(format!(
"Malformed reconnect response metadata: {}",
e
)));
}
}
}
Ok(Some(Err(flight_error))) => {
error!("Reconnection setup failed: {:?}", flight_error);
return Err(ZerobusError::CreateStreamError(tonic::Status::from_error(
Box::new(flight_error),
)));
}
Ok(None) => {
error!("Server closed stream during reconnect without response");
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Server closed stream during reconnect",
)));
}
Err(_timeout) => {
error!(
"Timed out waiting for server reconnect confirmation ({}ms)",
options.connection_timeout_ms
);
return Err(ZerobusError::ConnectionTimeout(format!(
"Timed out waiting for server reconnect confirmation ({}ms)",
options.connection_timeout_ms
)));
}
}
{
let mut tx_guard = batch_tx.lock().await;
*tx_guard = Some(tx.clone());
}
let acked_before_disconnect = last_acked_records.load(Ordering::Relaxed);
last_acked_records.store(0, Ordering::Relaxed);
cumulative_records_sent.store(0, Ordering::Relaxed);
{
let mut pending = pending_batches.lock().await;
if !pending.is_empty() {
info!(
batch_count = pending.len(),
acked_records = acked_before_disconnect,
"Replaying pending batches after recovery"
);
let mut new_pending = Vec::with_capacity(pending.len());
let mut new_cumulative: u64 = 0;
let mut replay_offset: i64 = 0;
for pb in pending.drain(..) {
let payload = match slice_batch_for_recovery(&pb, acked_before_disconnect)? {
None => {
debug!(offset_id = pb.offset_id, "Skipping fully-acked batch");
continue;
}
Some(p) => p,
};
let (flight_data_messages, num_records) = match &payload {
ArrowPayload::Batch(b) => (
record_batch_to_flight_data(b, &ipc_write_options).map_err(|e| {
ZerobusError::InvalidArgument(format!(
"Failed to encode batch for replay: {e}"
))
})?,
b.num_rows() as u64,
),
ArrowPayload::Ipc(bytes) => {
let parsed = ipc_bytes_to_flight_data(bytes).map_err(|e| {
ZerobusError::InvalidArgument(format!(
"Failed to encode batch for replay: {e}"
))
})?;
(parsed.flight_data, parsed.num_rows)
}
};
let fd_count = flight_data_messages.len();
for (i, mut fd) in flight_data_messages.into_iter().enumerate() {
if i == fd_count - 1 {
let metadata = FlightBatchMetadata::new(replay_offset);
replay_offset += 1;
if let Ok(bytes) = metadata.to_bytes() {
fd.app_metadata = bytes.into();
}
}
if tx.send(Ok(fd)).await.is_err() {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Failed to replay batch during recovery",
)));
}
}
let start_record = new_cumulative;
let end_record = new_cumulative + num_records;
new_cumulative = end_record;
new_pending.push(PendingBatch {
payload,
offset_id: pb.offset_id,
start_record,
end_record,
});
}
*pending = new_pending;
cumulative_records_sent.store(new_cumulative, Ordering::Relaxed);
}
}
Ok(response_stream)
}
async fn move_pending_to_failed(
pending_batches: &Arc<Mutex<Vec<PendingBatch>>>,
failed_batches: &Arc<Mutex<Vec<ArrowPayload>>>,
) {
let pending: Vec<PendingBatch> = {
let mut pending_guard = pending_batches.lock().await;
std::mem::take(&mut *pending_guard)
};
let mut failed = failed_batches.lock().await;
for pb in pending {
failed.push(pb.payload);
}
}
#[allow(clippy::too_many_arguments)]
async fn process_acks(
mut response_stream: Pin<Box<dyn Stream<Item = Result<PutResult, FlightError>> + Send>>,
is_closed: Arc<AtomicBool>,
last_ack_tx: tokio::sync::watch::Sender<Option<OffsetId>>,
pending_batches: Arc<Mutex<Vec<PendingBatch>>>,
ack_timeout: Duration,
server_error_tx: watch::Sender<Option<ZerobusError>>,
last_acked_records: Arc<AtomicU64>,
) -> ZerobusResult<()> {
loop {
if is_closed.load(Ordering::Relaxed) {
debug!("Stream closed, stopping ack processor");
return Ok(());
}
let result = tokio::time::timeout(ack_timeout, response_stream.next()).await;
match result {
Ok(Some(Ok(put_result))) => {
match FlightAckMetadata::from_bytes(&put_result.app_metadata) {
Ok(ack) => {
let acked_records = ack.ack_up_to_records;
debug!(
ack_up_to_offset = ack.ack_up_to_offset,
ack_up_to_records = acked_records,
"Received acknowledgment"
);
last_acked_records.store(acked_records, Ordering::Relaxed);
let mut max_acked_offset: Option<OffsetId> = None;
{
let mut pending = pending_batches.lock().await;
pending.retain(|pb| {
if acked_records >= pb.end_record {
max_acked_offset = Some(
max_acked_offset
.map_or(pb.offset_id, |o| o.max(pb.offset_id)),
);
false } else {
true }
});
}
if let Some(offset) = max_acked_offset {
let _ = last_ack_tx.send(Some(offset));
}
}
Err(e) => {
warn!("Failed to parse ack metadata: {}", e);
}
}
}
Ok(Some(Err(e))) => {
error!("Flight stream error: {}", e);
let status: tonic::Status = e.into();
let error = ZerobusError::StreamClosedError(status);
let _ = server_error_tx.send(Some(error.clone()));
return Err(error);
}
Ok(None) => {
debug!("Server closed the stream");
let error = ZerobusError::StreamClosedError(tonic::Status::unknown(
"Server closed the stream",
));
return Err(error);
}
Err(_timeout) => {
let pending = pending_batches.lock().await;
if !pending.is_empty() {
error!(
pending_count = pending.len(),
"Server ack timeout with pending batches"
);
let error = ZerobusError::StreamClosedError(
tonic::Status::deadline_exceeded("Server ack timeout"),
);
return Err(error);
}
}
}
}
}
async fn send_flight_data_internal(
&self,
payload: ArrowPayload,
flight_data_messages: Vec<FlightData>,
offset_id: OffsetId,
start_record: u64,
end_record: u64,
) -> ZerobusResult<OffsetId> {
{
let mut pending = self.pending_batches.lock().await;
pending.push(PendingBatch {
payload,
offset_id,
start_record,
end_record,
});
}
let sender = {
let guard = self.batch_tx.lock().await;
guard.clone()
};
let sender = match sender {
Some(s) => s,
None => {
if let Some(server_error) = self.server_error_rx.borrow().clone() {
return Err(server_error);
}
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Stream sender is closed",
)));
}
};
let msg_count = flight_data_messages.len();
for (i, mut flight_data) in flight_data_messages.into_iter().enumerate() {
if i == msg_count - 1 {
let metadata = FlightBatchMetadata::new(offset_id);
if let Ok(bytes) = metadata.to_bytes() {
flight_data.app_metadata = bytes.into();
}
}
if let Err(e) = sender.send(Ok(flight_data)).await {
warn!("Send failed: {}", e);
if self.options.recovery {
debug!(
offset_id = offset_id,
"Send failed but recovery enabled - supervisor will handle recovery"
);
return Ok(offset_id);
} else {
{
let mut pending = self.pending_batches.lock().await;
pending.retain(|pb| pb.offset_id != offset_id);
}
let _ = tokio::time::timeout(
Duration::from_millis(100),
self.server_error_rx.clone().changed(),
)
.await;
if let Some(server_error) = self.server_error_rx.borrow().clone() {
return Err(server_error);
}
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Failed to send batch",
)));
}
}
}
Ok(offset_id)
}
#[instrument(level = "debug", skip_all, fields(table_name = %self.table_properties.table_name))]
pub async fn ingest_batch(&self, batch: RecordBatch) -> ZerobusResult<OffsetId> {
if self.is_closed.load(Ordering::Relaxed) {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Stream is closed",
)));
}
if batch.schema() != self.table_properties.schema {
return Err(ZerobusError::InvalidArgument(format!(
"RecordBatch schema does not match stream schema. Expected: {:?}, Got: {:?}",
self.table_properties.schema,
batch.schema()
)));
}
let _guard = self.ingest_mutex.lock().await;
let record_count = batch.num_rows() as u64;
let offset_id = self.offset_generator.next();
let start_record = self
.cumulative_records_sent
.fetch_add(record_count, Ordering::Relaxed);
let end_record = start_record + record_count;
let flight_data_messages = record_batch_to_flight_data(
&batch,
&make_ipc_write_options(self.options.ipc_compression)?,
)?;
debug!(offset_id = offset_id, "Batch queued for ingestion");
self.send_flight_data_internal(
ArrowPayload::Batch(batch),
flight_data_messages,
offset_id,
start_record,
end_record,
)
.await
}
#[instrument(level = "debug", skip_all, fields(table_name = %self.table_properties.table_name))]
pub async fn ingest_ipc_batch(&self, ipc_bytes: Bytes) -> ZerobusResult<OffsetId> {
if self.is_closed.load(Ordering::Relaxed) {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Stream is closed",
)));
}
if let Some(codec) = self.options.ipc_compression {
return Err(ZerobusError::InvalidArgument(format!(
"ingest_ipc_batch cannot be used when ipc_compression is enabled ({codec:?}). \
Use ingest_batch instead, or disable compression for this stream."
)));
}
let parsed = ipc_bytes_to_flight_data(&ipc_bytes)
.map_err(|e| ZerobusError::InvalidArgument(format!("Invalid Arrow IPC bytes: {e}")))?;
if parsed.schema != *self.table_properties.schema {
return Err(ZerobusError::InvalidArgument(format!(
"IPC batch schema does not match stream schema. Expected: {:?}, Got: {:?}",
self.table_properties.schema, parsed.schema
)));
}
let _guard = self.ingest_mutex.lock().await;
let offset_id = self.offset_generator.next();
let start_record = self
.cumulative_records_sent
.fetch_add(parsed.num_rows, Ordering::Relaxed);
let end_record = start_record + parsed.num_rows;
debug!(offset_id = offset_id, "IPC batch queued for ingestion");
self.send_flight_data_internal(
ArrowPayload::Ipc(ipc_bytes),
parsed.flight_data,
offset_id,
start_record,
end_record,
)
.await
}
async fn wait_for_offset_internal(
&self,
offset_to_wait: OffsetId,
operation_name: &str,
) -> ZerobusResult<()> {
let flush_timeout = Duration::from_millis(self.options.flush_timeout_ms);
let mut offset_rx = self.last_ack_tx.subscribe();
let mut error_rx = self.server_error_rx.clone();
let wait_future = async {
loop {
if self.is_closed.load(Ordering::Relaxed) {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
format!("Stream closed during {}", operation_name.to_lowercase()),
)));
}
let current_ack = *offset_rx.borrow_and_update();
if let Some(ack_offset) = current_ack {
if ack_offset >= offset_to_wait {
info!(
ack_offset = ack_offset,
target_offset = offset_to_wait,
"{} completed",
operation_name
);
return Ok(());
}
debug!(
current_ack = ack_offset,
target_offset = offset_to_wait,
"Waiting for more acks"
);
}
tokio::select! {
result = offset_rx.changed() => {
if result.is_err() {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
format!(
"Ack channel closed during {}",
operation_name.to_lowercase()
),
)));
}
}
_ = error_rx.changed() => {
if let Some(server_error) = error_rx.borrow().clone() {
if self.is_closed.load(Ordering::Relaxed) {
return Err(server_error);
}
}
}
}
}
};
tokio::time::timeout(flush_timeout, wait_future)
.await
.map_err(|_| {
error!("{} timed out", operation_name);
ZerobusError::StreamClosedError(tonic::Status::deadline_exceeded(format!(
"{} timed out",
operation_name
)))
})?
}
#[instrument(level = "debug", skip_all, fields(table_name = %self.table_properties.table_name))]
pub async fn flush(&self) -> ZerobusResult<()> {
if self.is_closed.load(Ordering::Relaxed) {
return Err(ZerobusError::StreamClosedError(tonic::Status::internal(
"Cannot flush: stream is closed",
)));
}
let target_offset = match self.offset_generator.last() {
Some(offset) => offset,
None => {
debug!("No batches to flush");
return Ok(());
}
};
self.wait_for_offset_internal(target_offset, "Flush").await
}
pub async fn wait_for_offset(&self, offset: OffsetId) -> ZerobusResult<()> {
self.wait_for_offset_internal(offset, "Waiting for acknowledgement")
.await
}
#[instrument(level = "debug", skip_all, fields(table_name = %self.table_properties.table_name))]
pub async fn close(&mut self) -> ZerobusResult<()> {
if self.is_closed.load(Ordering::Relaxed) {
return Ok(());
}
info!(
table_name = %self.table_properties.table_name,
"Closing Arrow Flight stream"
);
if let Err(e) = self.flush().await {
warn!(
"Flush failed during close: {}. Moving pending batches to failed.",
e
);
Self::move_pending_to_failed(&self.pending_batches, &self.failed_batches).await;
}
self.is_closed.store(true, Ordering::Relaxed);
{
let mut tx = self.batch_tx.lock().await;
*tx = None;
}
{
let mut task = self.receiver_task.lock().await;
if let Some(t) = task.take() {
t.abort();
}
}
Ok(())
}
pub async fn get_unacked_batches(&self) -> ZerobusResult<Vec<RecordBatch>> {
if !self.is_closed.load(Ordering::Relaxed) {
error!(
table_name = %self.table_properties.table_name,
"Cannot get unacked batches from an active stream. Stream must be closed first."
);
return Err(ZerobusError::InvalidStateError(
"Cannot get unacked batches from an active stream. Stream must be closed first."
.to_string(),
));
}
let mut result = Vec::new();
{
let pending = self.pending_batches.lock().await;
for pb in pending.iter() {
result.push(pb.payload.materialize().map_err(|e| {
ZerobusError::InvalidArgument(format!(
"unacked batch at offset_id {} could not be materialised: {e}",
pb.offset_id
))
})?);
}
}
{
let failed = self.failed_batches.lock().await;
for (i, payload) in failed.iter().enumerate() {
result.push(payload.materialize().map_err(|e| {
ZerobusError::InvalidArgument(format!(
"failed batch at index {i} could not be materialised: {e}"
))
})?);
}
}
Ok(result)
}
pub fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::Relaxed)
}
pub fn table_name(&self) -> &str {
&self.table_properties.table_name
}
pub fn schema(&self) -> &Arc<ArrowSchema> {
&self.table_properties.schema
}
pub fn table_properties(&self) -> &ArrowTableProperties {
&self.table_properties
}
pub fn options(&self) -> &ArrowStreamConfigurationOptions {
&self.options
}
pub(crate) fn headers_provider(&self) -> Arc<dyn HeadersProvider> {
Arc::clone(&self.headers_provider)
}
}
impl Drop for ZerobusArrowStream {
fn drop(&mut self) {
self.is_closed.store(true, Ordering::Relaxed);
if let Ok(mut guard) = self.receiver_task.try_lock() {
if let Some(handle) = guard.take() {
handle.abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::{DataType, Field};
#[test]
fn test_arrow_table_properties() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]));
let props = ArrowTableProperties {
table_name: "catalog.schema.table".to_string(),
schema,
};
assert_eq!(props.table_name, "catalog.schema.table");
assert_eq!(props.schema.fields().len(), 2);
}
}