use std::cmp::Ordering;
use std::collections::BTreeSet;
use std::fmt;
use std::hash::Hash;
use std::hash::Hasher;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use futures::Stream;
use futures::StreamExt;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use vortex_array::ArrayRef;
use vortex_array::dtype::DType;
use vortex_array::stream::ArrayStream;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_utils::aliases::hash_map::HashMap;
pub struct SequenceId {
id: Vec<usize>,
universe: Arc<Mutex<SequenceUniverse>>,
}
impl PartialEq for SequenceId {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for SequenceId {}
impl PartialOrd for SequenceId {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SequenceId {
fn cmp(&self, other: &Self) -> Ordering {
self.id.cmp(&other.id)
}
}
impl Hash for SequenceId {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
}
}
impl fmt::Debug for SequenceId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SequenceId").field("id", &self.id).finish()
}
}
impl SequenceId {
pub fn root() -> SequencePointer {
SequencePointer(SequenceId::new(vec![0], Default::default()))
}
pub fn descend(self) -> SequencePointer {
let mut id = self.id.clone();
id.push(0);
SequencePointer(SequenceId::new(id, Arc::clone(&self.universe)))
}
pub async fn collapse(&mut self) {
WaitSequenceFuture(self).await;
}
fn new(id: Vec<usize>, universe: Arc<Mutex<SequenceUniverse>>) -> Self {
let res = Self { id, universe };
res.universe.lock().add(&res);
res
}
}
impl Drop for SequenceId {
fn drop(&mut self) {
let waker = self.universe.lock().remove(self);
if let Some(w) = waker {
w.wake();
}
}
}
#[derive(Debug)]
pub struct SequencePointer(SequenceId);
impl SequencePointer {
pub fn split(mut self) -> (SequencePointer, SequencePointer) {
(self.split_off(), self)
}
pub fn split_off(&mut self) -> SequencePointer {
self.advance().descend()
}
pub fn advance(&mut self) -> SequenceId {
let mut next_id = self.0.id.clone();
let last = next_id.last_mut();
let last = last.vortex_expect("must have at least one element");
*last += 1;
let next_sibling = SequenceId::new(next_id, Arc::clone(&self.0.universe));
std::mem::replace(&mut self.0, next_sibling)
}
pub fn downgrade(self) -> SequenceId {
self.0
}
}
#[derive(Default)]
struct SequenceUniverse {
active: BTreeSet<Vec<usize>>,
wakers: HashMap<Vec<usize>, Waker>,
}
impl SequenceUniverse {
fn add(&mut self, sequence_id: &SequenceId) {
self.active.insert(sequence_id.id.clone());
}
fn remove(&mut self, sequence_id: &SequenceId) -> Option<Waker> {
self.active.remove(&sequence_id.id);
let Some(first) = self.active.first() else {
assert!(self.wakers.is_empty(), "all wakers must have been removed");
return None;
};
self.wakers.remove(first)
}
}
struct WaitSequenceFuture<'a>(&'a mut SequenceId);
impl Future for WaitSequenceFuture<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut guard = self.0.universe.lock();
let current_first = guard
.active
.first()
.cloned()
.vortex_expect("if we have a future, we must have at least one active sequence");
if self.0.id == current_first {
guard.wakers.remove(&self.0.id);
return Poll::Ready(());
}
guard.wakers.insert(self.0.id.clone(), cx.waker().clone());
Poll::Pending
}
}
impl Drop for WaitSequenceFuture<'_> {
fn drop(&mut self) {
self.0.universe.lock().wakers.remove(&self.0.id);
}
}
pub trait SequentialStream: Stream<Item = VortexResult<(SequenceId, ArrayRef)>> {
fn dtype(&self) -> &DType;
}
pub type SendableSequentialStream = Pin<Box<dyn SequentialStream + Send>>;
impl SequentialStream for SendableSequentialStream {
fn dtype(&self) -> &DType {
(**self).dtype()
}
}
pub trait SequentialStreamExt: SequentialStream {
fn sendable(self) -> SendableSequentialStream
where
Self: Sized + Send + 'static,
{
Box::pin(self)
}
}
impl<S: SequentialStream> SequentialStreamExt for S {}
pin_project! {
pub struct SequentialStreamAdapter<S> {
dtype: DType,
#[pin]
inner: S,
}
}
impl<S> SequentialStreamAdapter<S> {
pub fn new(dtype: DType, inner: S) -> Self {
Self { dtype, inner }
}
}
impl<S> SequentialStream for SequentialStreamAdapter<S>
where
S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
{
fn dtype(&self) -> &DType {
&self.dtype
}
}
impl<S> Stream for SequentialStreamAdapter<S>
where
S: Stream<Item = VortexResult<(SequenceId, ArrayRef)>>,
{
type Item = VortexResult<(SequenceId, ArrayRef)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let array = futures::ready!(this.inner.poll_next(cx));
if let Some(Ok((_, array))) = array.as_ref() {
assert_eq!(
array.dtype(),
this.dtype,
"Sequential stream of {} got chunk of {}.",
array.dtype(),
this.dtype
);
}
Poll::Ready(array)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pub trait SequentialArrayStreamExt: ArrayStream {
fn sequenced(self, mut pointer: SequencePointer) -> SendableSequentialStream
where
Self: Sized + Send + 'static,
{
Box::pin(SequentialStreamAdapter::new(
self.dtype().clone(),
StreamExt::map(self, move |item| {
item.map(|array| (pointer.advance(), array))
}),
))
}
}
impl<S: ArrayStream> SequentialArrayStreamExt for S {}