use async_trait::async_trait;
use futures::TryStreamExt;
use mongodb::bson::Document;
use mongodb::Cursor;
use tracing::{debug, info};
use crate::error::Result;
#[async_trait]
pub trait StreamingQuery: Send {
async fn next_batch(&mut self) -> Result<Option<Vec<Document>>>;
async fn close(&mut self) -> Result<()>;
}
pub struct CursorStreamingQuery {
cursor: Option<Cursor<Document>>,
batch_size: u32,
total_fetched: u64,
query_type: &'static str,
closed: bool,
}
impl CursorStreamingQuery {
pub fn new(cursor: Cursor<Document>, batch_size: u32, query_type: &'static str) -> Self {
Self {
cursor: Some(cursor),
batch_size,
total_fetched: 0,
query_type,
closed: false,
}
}
}
#[async_trait]
impl StreamingQuery for CursorStreamingQuery {
async fn next_batch(&mut self) -> Result<Option<Vec<Document>>> {
if self.closed {
return Ok(None);
}
let cursor = match self.cursor.as_mut() {
Some(c) => c,
None => return Ok(None),
};
let mut batch = Vec::with_capacity(self.batch_size as usize);
for _ in 0..self.batch_size {
match cursor.try_next().await {
Ok(Some(doc)) => batch.push(doc),
Ok(None) => break,
Err(e) => {
self.cursor = None;
self.closed = true;
return Err(e.into());
}
}
}
if batch.is_empty() {
debug!(
"{} streaming query exhausted after {} documents",
self.query_type, self.total_fetched
);
self.cursor = None;
self.closed = true;
Ok(None)
} else {
self.total_fetched += batch.len() as u64;
debug!(
"Fetched batch of {} documents (total: {})",
batch.len(),
self.total_fetched
);
Ok(Some(batch))
}
}
async fn close(&mut self) -> Result<()> {
if !self.closed {
self.cursor = None;
self.closed = true;
info!(
"Closed {} streaming query after fetching {} documents",
self.query_type, self.total_fetched
);
}
Ok(())
}
}
impl Drop for CursorStreamingQuery {
fn drop(&mut self) {
if !self.closed {
debug!("CursorStreamingQuery dropped without explicit close");
self.cursor = None;
}
}
}
pub type FindStreamingQuery = CursorStreamingQuery;
pub type AggregateStreamingQuery = CursorStreamingQuery;
impl FindStreamingQuery {
pub fn new_find(cursor: Cursor<Document>, batch_size: u32) -> Self {
Self::new(cursor, batch_size, "Find")
}
}
impl AggregateStreamingQuery {
pub fn new_aggregate(cursor: Cursor<Document>, batch_size: u32) -> Self {
Self::new(cursor, batch_size, "Aggregate")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_streaming_query_trait_object() {
fn _accepts_streaming_query(_query: Box<dyn StreamingQuery>) {}
}
}