use arrow_flight::FlightData;
use datafusion::arrow::array::RecordBatch;
use datafusion::common::runtime::SpawnedTask;
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryPool};
use futures::{Stream, StreamExt};
use std::sync::Arc;
use tokio_stream::wrappers::ReceiverStream;
pub(crate) fn spawn_select_all<T, El, Err>(
inner: Vec<T>,
pool: Arc<dyn MemoryPool>,
queue_size: usize,
) -> impl Stream<Item = Result<El, Err>>
where
T: Stream<Item = Result<El, Err>> + Send + Unpin + 'static,
El: MemoryFootPrint + Send + 'static,
Err: Send + 'static,
{
let reservation = Arc::new(MemoryConsumer::new("NetworkBoundary").register(&pool));
let mut tasks = Vec::with_capacity(inner.len());
let mut in_rxs = Vec::with_capacity(inner.len());
for mut t in inner {
let (in_tx, in_rx) = tokio::sync::mpsc::channel(queue_size);
in_rxs.push(ReceiverStream::new(in_rx));
let reservation = Arc::clone(&reservation);
tasks.push(SpawnedTask::spawn(async move {
loop {
let msg = tokio::select! {
biased;
_ = in_tx.closed() => return,
msg = t.next() => msg
};
let Some(msg) = msg else { return };
if let Ok(msg) = &msg {
reservation.grow(msg.get_memory_size());
}
if in_tx.send(msg).await.is_err() {
return;
};
}
}))
}
futures::stream::select_all(in_rxs).map(move |msg| {
if let Ok(msg) = &msg {
reservation.shrink(msg.get_memory_size());
}
let _ = &tasks;
msg
})
}
pub(crate) trait MemoryFootPrint {
fn get_memory_size(&self) -> usize;
}
impl MemoryFootPrint for RecordBatch {
fn get_memory_size(&self) -> usize {
self.get_array_memory_size()
}
}
impl MemoryFootPrint for FlightData {
fn get_memory_size(&self) -> usize {
self.data_header.len() + self.data_body.len() + self.app_metadata.len()
}
}
#[cfg(test)]
mod tests {
use super::{MemoryFootPrint, spawn_select_all};
use datafusion::execution::memory_pool::{MemoryPool, UnboundedMemoryPool};
use std::error::Error;
use std::sync::Arc;
use tokio_stream::StreamExt;
#[tokio::test]
async fn memory_reservation() -> Result<(), Box<dyn Error>> {
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
let mut stream = spawn_select_all(
vec![
futures::stream::iter(vec![Ok::<_, String>(1), Ok(2), Ok(3)]),
futures::stream::iter(vec![Ok(4), Ok(5)]),
],
Arc::clone(&pool),
5,
);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 15);
let mut consumed = 0;
for _ in 0..3 {
consumed += stream.next().await.unwrap()?;
}
let reserved = pool.reserved();
assert_eq!(reserved, 15 - consumed);
drop(stream);
let reserved = pool.reserved();
assert_eq!(reserved, 0);
Ok(())
}
#[tokio::test]
async fn memory_reservation_backpressure() -> Result<(), Box<dyn Error>> {
let pool: Arc<dyn MemoryPool> = Arc::new(UnboundedMemoryPool::default());
let mut stream = spawn_select_all(
vec![futures::stream::iter(vec![
Ok::<_, String>(1),
Ok(2),
Ok(3),
])],
Arc::clone(&pool),
1,
);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 3);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 1);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 5);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 2);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 3);
let n = stream.next().await.unwrap()?;
assert_eq!(n, 3);
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
let reserved = pool.reserved();
assert_eq!(reserved, 0);
Ok(())
}
impl MemoryFootPrint for usize {
fn get_memory_size(&self) -> usize {
*self
}
}
}