use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use datafusion::catalog::TableProvider;
use datafusion::catalog::streaming::StreamingTable;
use std::sync::Arc;
use datafusion::error::{DataFusionError, Result as DataFusionResult};
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::streaming::PartitionStream;
use krishiv_connectors::Source;
use krishiv_connectors::kafka::{KafkaConfig, KafkaSource};
const STREAMING_AUTO_COMMIT_MS: u64 = 1_000;
pub(crate) fn kafka_auto_commit_interval_ms() -> Option<u64> {
let profile = std::env::var("KRISHIV_DURABILITY_PROFILE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(krishiv_common::DurabilityProfile::DevLocal);
if krishiv_common::requires_manual_kafka_commit(profile) {
None
} else {
Some(STREAMING_AUTO_COMMIT_MS)
}
}
pub(crate) struct KafkaPartitionStream {
schema: SchemaRef,
source: Arc<tokio::sync::Mutex<KafkaSource>>,
consumer_task: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl KafkaPartitionStream {
pub fn new(schema: SchemaRef, source: KafkaSource) -> Self {
Self {
schema,
source: Arc::new(tokio::sync::Mutex::new(source)),
consumer_task: std::sync::Mutex::new(None),
}
}
}
impl std::fmt::Debug for KafkaPartitionStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KafkaPartitionStream").finish()
}
}
impl PartitionStream for KafkaPartitionStream {
fn schema(&self) -> &SchemaRef {
&self.schema
}
fn execute(&self, _ctx: Arc<datafusion::execution::TaskContext>) -> SendableRecordBatchStream {
let source = self.source.clone();
let schema = self.schema.clone();
let manual_commit = kafka_auto_commit_interval_ms().is_none();
let (tx, rx) = tokio::sync::mpsc::channel::<Result<RecordBatch, DataFusionError>>(64);
let task = tokio::spawn(async move {
loop {
if tx.is_closed() {
break;
}
let res = {
let mut guard = source.lock().await;
guard.read_batch().await
};
match res {
Ok(Some(batch)) if batch.num_rows() == 0 => {
}
Ok(Some(batch)) => {
let send_result = match project_batch(&batch, &schema) {
Ok(projected) => tx.send(Ok(projected)).await,
Err(e) => {
tx.send(Err(DataFusionError::ArrowError(Box::new(e), None)))
.await
}
};
if send_result.is_err() {
break; }
if manual_commit {
let guard = source.lock().await;
guard.commit_current_offset();
}
}
Ok(None) => {
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
}
Err(e) => {
let _ = tx.send(Err(DataFusionError::External(Box::new(e)))).await;
break;
}
}
}
});
*self.consumer_task.lock().unwrap_or_else(|p| p.into_inner()) = Some(task);
let recv_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
Box::pin(RecordBatchStreamAdapter::new(
self.schema.clone(),
recv_stream,
))
}
}
pub(crate) fn project_batch(
batch: &RecordBatch,
schema: &SchemaRef,
) -> Result<RecordBatch, arrow::error::ArrowError> {
let mut cols = Vec::with_capacity(schema.fields().len());
for field in schema.fields() {
let col = if let Ok(idx) = batch.schema().index_of(field.name()) {
let src = batch.column(idx);
arrow::compute::cast(src, field.data_type()).map_err(|e| {
arrow::error::ArrowError::CastError(format!(
"Kafka column '{}': cast from {} to {} failed: {e}",
field.name(),
src.data_type(),
field.data_type(),
))
})?
} else {
arrow::array::new_null_array(field.data_type(), batch.num_rows())
};
cols.push(col);
}
RecordBatch::try_new(schema.clone(), cols)
}
pub fn create_kafka_streaming_table(
schema: SchemaRef,
config: KafkaConfig,
) -> DataFusionResult<Arc<dyn TableProvider>> {
let config = match kafka_auto_commit_interval_ms() {
Some(ms) => config.with_auto_commit(ms),
None => config,
};
let source = KafkaSource::new(config).map_err(|e| DataFusionError::External(Box::new(e)))?;
let partition = Arc::new(KafkaPartitionStream::new(schema.clone(), source));
let table = StreamingTable::try_new(schema, vec![partition])?;
Ok(Arc::new(table))
}