use crate::collection::types::Collection;
use crate::point::Point;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Notify;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub buffer_size: usize,
pub batch_size: usize,
pub flush_interval_ms: u64,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
buffer_size: 10_000,
batch_size: 128,
flush_interval_ms: 50,
}
}
}
#[allow(dead_code)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WriteMode {
Api,
Streaming,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum BackpressureError {
#[error("streaming buffer is full (backpressure)")]
BufferFull,
#[error("streaming is not configured on this collection")]
NotConfigured,
#[error("streaming drain task has exited; the ingestion pipeline is dead")]
DrainTaskDead,
}
pub struct StreamIngester {
sender: mpsc::Sender<Point>,
config: StreamingConfig,
drain_handle: Option<tokio::task::JoinHandle<()>>,
shutdown: Arc<Notify>,
}
impl StreamIngester {
#[must_use]
pub(crate) fn new(collection: Collection, config: StreamingConfig) -> Self {
let (tx, rx) = mpsc::channel(config.buffer_size);
let shutdown = Arc::new(Notify::new());
let drain_handle = tokio::spawn(drain_loop(
collection,
rx,
config.batch_size,
config.flush_interval_ms,
Arc::clone(&shutdown),
));
Self {
sender: tx,
config,
drain_handle: Some(drain_handle),
shutdown,
}
}
pub fn try_send(&self, point: Point) -> Result<(), BackpressureError> {
self.sender.try_send(point).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => BackpressureError::BufferFull,
mpsc::error::TrySendError::Closed(_) => BackpressureError::DrainTaskDead,
})
}
pub fn try_send_batch(
&self,
points: Vec<crate::point::Point>,
) -> Result<usize, BackpressureError> {
let count = points.len();
for point in points {
self.sender.try_send(point).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => BackpressureError::BufferFull,
mpsc::error::TrySendError::Closed(_) => BackpressureError::DrainTaskDead,
})?;
}
Ok(count)
}
#[must_use]
pub fn config(&self) -> &StreamingConfig {
&self.config
}
pub async fn shutdown(mut self) {
self.shutdown.notify_one();
if let Some(handle) = self.drain_handle.take() {
let _ = handle.await;
}
}
}
impl Drop for StreamIngester {
fn drop(&mut self) {
if let Some(handle) = self.drain_handle.take() {
handle.abort();
}
}
}
#[allow(clippy::cognitive_complexity)]
async fn drain_loop(
collection: Collection,
mut rx: mpsc::Receiver<Point>,
batch_size: usize,
flush_interval_ms: u64,
shutdown: Arc<Notify>,
) {
let mut batch: Vec<Point> = Vec::with_capacity(batch_size);
let mut interval = tokio::time::interval(std::time::Duration::from_millis(flush_interval_ms));
interval.tick().await;
loop {
tokio::select! {
() = shutdown.notified() => {
drain_on_shutdown(&collection, &mut rx, &mut batch, batch_size).await;
break;
}
_ = interval.tick() => {
flush_if_non_empty(&collection, &mut batch).await;
}
msg = rx.recv() => {
if !handle_received_point(&collection, &mut batch, batch_size, &mut interval, msg).await {
break;
}
}
}
}
}
async fn drain_on_shutdown(
collection: &Collection,
rx: &mut mpsc::Receiver<Point>,
batch: &mut Vec<Point>,
batch_size: usize,
) {
while let Ok(point) = rx.try_recv() {
batch.push(point);
if batch.len() >= batch_size {
flush_batch(collection, batch).await;
}
}
flush_if_non_empty(collection, batch).await;
}
async fn flush_if_non_empty(collection: &Collection, batch: &mut Vec<Point>) {
if !batch.is_empty() {
flush_batch(collection, batch).await;
}
}
async fn handle_received_point(
collection: &Collection,
batch: &mut Vec<Point>,
batch_size: usize,
interval: &mut tokio::time::Interval,
msg: Option<Point>,
) -> bool {
if let Some(point) = msg {
batch.push(point);
if batch.len() >= batch_size {
flush_batch(collection, batch).await;
interval.reset();
}
true
} else {
flush_if_non_empty(collection, batch).await;
false
}
}
async fn flush_batch(collection: &Collection, batch: &mut Vec<Point>) {
let points: Vec<Point> = std::mem::take(batch);
let delta_entries: Vec<(u64, Vec<f32>)> = if collection.delta_buffer.is_active() {
points.iter().map(|p| (p.id, p.vector.clone())).collect()
} else {
Vec::new()
};
let coll = collection.clone();
let result = tokio::task::spawn_blocking(move || coll.upsert(points)).await;
match result {
Ok(Ok(())) => {
if !delta_entries.is_empty() {
collection.delta_buffer.extend(delta_entries);
}
}
Ok(Err(e)) => {
tracing::error!("Streaming drain flush failed: {e}");
}
Err(e) => {
tracing::error!("Streaming drain task panicked: {e}");
}
}
}