use std::sync::Arc;
use futures::FutureExt;
use futures::TryFutureExt;
use futures::future::BoxFuture;
use futures::future::WeakShared;
use vortex_array::buffer::BufferHandle;
use vortex_error::SharedVortexResult;
use vortex_error::VortexError;
use vortex_error::VortexExpect;
use vortex_utils::aliases::dash_map::DashMap;
use vortex_utils::aliases::dash_map::Entry;
use crate::segments::SegmentFuture;
use crate::segments::SegmentId;
use crate::segments::SegmentSource;
pub struct SharedSegmentSource<S> {
inner: S,
in_flight: DashMap<SegmentId, WeakShared<SharedSegmentFuture>>,
}
type SharedSegmentFuture = BoxFuture<'static, SharedVortexResult<BufferHandle>>;
impl<S: SegmentSource> SharedSegmentSource<S> {
pub fn new(inner: S) -> Self {
Self {
inner,
in_flight: DashMap::default(),
}
}
}
impl<S: SegmentSource> SegmentSource for SharedSegmentSource<S> {
fn request(&self, id: SegmentId) -> SegmentFuture {
loop {
match self.in_flight.entry(id) {
Entry::Occupied(e) => {
if let Some(shared_future) = e.get().upgrade() {
return shared_future.map_err(VortexError::from).boxed();
} else {
e.remove();
}
}
Entry::Vacant(e) => {
let future = self.inner.request(id).map_err(Arc::new).boxed().shared();
e.insert(
future
.downgrade()
.vortex_expect("just created, cannot be polled to completion"),
);
return future.map_err(VortexError::from).boxed();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use vortex_buffer::ByteBuffer;
use super::*;
use crate::segments::SegmentSink;
use crate::segments::TestSegments;
use crate::sequence::SequenceId;
#[derive(Default, Clone)]
struct CountingSegmentSource {
segments: TestSegments,
request_count: Arc<AtomicUsize>,
}
impl SegmentSource for CountingSegmentSource {
fn request(&self, id: SegmentId) -> SegmentFuture {
self.request_count.fetch_add(1, Ordering::SeqCst);
self.segments.request(id)
}
}
#[tokio::test]
async fn test_shared_source_deduplicates_concurrent_requests() {
let source = CountingSegmentSource::default();
let data = ByteBuffer::from(vec![1, 2, 3, 4]);
let seq_id = SequenceId::root().downgrade();
source
.segments
.write(seq_id, vec![data.clone()])
.await
.unwrap();
let shared_source = SharedSegmentSource::new(source.clone());
let id = SegmentId::from(0);
let future1 = shared_source.request(id);
let future2 = shared_source.request(id);
let (result1, result2) = futures::join!(future1, future2);
assert_eq!(result1.unwrap().unwrap_host(), data);
assert_eq!(result2.unwrap().unwrap_host(), data);
assert_eq!(source.request_count.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_shared_source_handles_dropped_futures() {
let source = CountingSegmentSource::default();
let data = ByteBuffer::from(vec![5, 6, 7, 8]);
let seq_id = SequenceId::root().downgrade();
source
.segments
.write(seq_id, vec![data.clone()])
.await
.unwrap();
let shared_source = SharedSegmentSource::new(source.clone());
let id = SegmentId::from(0);
{
let _future = shared_source.request(id);
}
let result = shared_source.request(id).await;
assert_eq!(result.unwrap().unwrap_host(), data);
assert_eq!(source.request_count.load(Ordering::Relaxed), 2);
}
}