use arrow_flight::FlightData;
use datafusion::arrow::array::RecordBatch;
use datafusion::common::runtime::SpawnedTask;
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool};
use futures::{Stream, StreamExt};
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
pub(crate) fn spawn_select_all<T, El, Err>(
inner: Vec<T>,
pool: Arc<dyn MemoryPool>,
queue_size: usize,
) -> impl Stream<Item = Result<El, Err>>
where
T: Stream<Item = Result<El, Err>> + Send + Unpin + 'static,
El: MemoryFootPrint + Send + 'static,
Err: Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(queue_size);
let mut tasks = vec![];
for mut t in inner {
let tx = tx.clone();
let pool = Arc::clone(&pool);
let consumer = MemoryConsumer::new("NetworkBoundary");
tasks.push(SpawnedTask::spawn(async move {
loop {
let msg = tokio::select! {
biased;
_ = tx.closed() => return,
msg = t.next() => msg
};
let Some(msg) = msg else { return };
let reservation = consumer.clone_with_new_id().register(&pool);
if let Ok(msg) = &msg {
reservation.grow(msg.get_memory_size());
}
if tx.send((msg, reservation)).await.is_err() {
return;
};
}
}))
}
ReceiverStream::new(rx).map(move |(msg, _reservation)| {
let _ = &tasks;
msg
})
}
pub(crate) trait MemoryFootPrint {
fn get_memory_size(&self) -> usize;
}
impl MemoryFootPrint for RecordBatch {
fn get_memory_size(&self) -> usize {
self.get_array_memory_size()
}
}
impl MemoryFootPrint for FlightData {
fn get_memory_size(&self) -> usize {
self.data_header.len() + self.data_body.len() + self.app_metadata.len()
}
}
#[cfg(test)]
mod tests {
use super::{MemoryFootPrint, spawn_select_all};
use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool};
use std::error::Error;
use std::sync::Arc;
use tokio_stream::StreamExt;
#[tokio::test]
async fn memory_reservation() -> Result<(), Box<dyn Error>> {
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
let mut stream = spawn_select_all(
vec![
futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]),
futures::stream::iter(vec![Ok(4), Ok(5)]),
],
Arc::clone(&pool),
5,
);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 15);
for i in [1, 2, 3] {
let n = stream.next().await.unwrap()?;
assert_eq!(i, n)
}
let reserved = pool.reserved();
assert_eq!(reserved, 9);
drop(stream);
let reserved = pool.reserved();
assert_eq!(reserved, 0);
Ok(())
}
#[tokio::test]
async fn memory_reservation_backpressure() -> Result<(), Box<dyn Error>> {
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
let mut stream = spawn_select_all(
vec![futures::stream::iter(vec![
Ok::<_, String>(1),
Ok(2),
Ok(3),
])],
Arc::clone(&pool),
1,
);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 3);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 1);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 5);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 2);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 3);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 3);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 0);
Ok(())
}
impl MemoryFootPrint for usize {
fn get_memory_size(&self) -> usize {
*self
}
}
}