use crate::backend::job::{Data, Input, InputKind, InputSource};
use crate::backend::overlap_checking::OverlapChecker;
use crate::sync::once_lock::OnceLock;
use crate::sync::promise::{BoxPromise, IntoPromise, MapPromise, PollPromise, Promise};
use crate::{InputFlavor, JobError};
use parking_lot::Mutex;
use std::any::type_name;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use std::ops::{Bound, RangeBounds};
use std::sync::Arc;
use thiserror::Error;
pub struct Disjointed<'lf, T, P: Promise<Output = Vec<T>> + 'lf = BoxPromise<'lf, Vec<T>>>
where
T: Send + 'lf,
{
inner: Arc<DisjointedInner<'lf, T, P>>,
}
impl<'lf, T, P: Promise<Output = Vec<T>> + 'lf> Disjointed<'lf, T, P>
where
T: Send + 'lf,
{
pub fn new(source: P) -> Self {
Self {
inner: Arc::new(DisjointedInner::new(source)),
}
}
pub fn get(&self, index: usize) -> Result<DisjointedIndex<'lf, T, P>, DisjointedError> {
if self.inner.checker.lock().insert(index..(index + 1)) {
Ok(DisjointedIndex {
index,
inner: self.inner.clone(),
})
} else {
Err(DisjointedError::IndexAlreadyUsed(index))
}
}
pub fn get_range<R: RangeBounds<usize> + Send + 'lf>(
&self,
index: R,
) -> Result<DisjointedRange<'lf, T, R, P>, DisjointedError> {
let (start, end) = (index.start_bound().cloned(), index.end_bound().cloned());
if self.inner.checker.lock().insert((start, end)) {
Ok(DisjointedRange {
range: index,
inner: self.inner.clone(),
})
} else {
Err(DisjointedError::RangeOverlaps(start, end))
}
}
}
impl<'lf, T> Debug for Disjointed<'lf, T>
where
T: Send + 'lf,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Disjointed").finish()
}
}
pub struct DisjointedIndex<'lf, T: Send + 'lf, P: Promise<Output = Vec<T>> + 'lf> {
index: usize,
inner: Arc<DisjointedInner<'lf, T, P>>,
}
impl<'lf, T: Send, P: Promise<Output = Vec<T>> + 'lf> Debug for DisjointedIndex<'lf, T, P> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DisjointedIndex")
.field("index", &self.index)
.finish()
}
}
impl<'lf, T: Send, P: Promise<Output = Vec<T>> + 'lf> Promise for DisjointedIndex<'lf, T, P> {
type Output = T;
fn poll(&mut self) -> PollPromise<Self::Output> {
match self.inner.get() {
None => PollPromise::Pending,
Some(ready) => match ready[self.index].lock().take() {
None => {
panic!("Tried to poll empty disjointed index");
}
Some(g) => PollPromise::Ready(g),
},
}
}
}
impl<P: Promise<Output = Vec<Data>> + 'static> InputSource<DisjointedIndex<'static, Data, P>>
for Input
{
type Data = Data;
fn use_as_input_source(
&mut self,
other: DisjointedIndex<'static, Data, P>,
) -> Result<(), JobError> {
let data_promise = other.into_promise();
match (self.flavor, &mut self.kind) {
(InputFlavor::None, _) => Err(JobError::UnexpectedInput),
(InputFlavor::Single, to_set @ InputKind::None) => {
*to_set = InputKind::Single(Box::new(data_promise));
Ok(())
}
(InputFlavor::Single, InputKind::Single(_)) => Err(JobError::InputAlreadySet),
(InputFlavor::Funnel, InputKind::Funnel(funnel)) => {
funnel.insert(data_promise);
Ok(())
}
_ => {
panic!("input kind and flavor mismatch")
}
}
}
}
pub struct DisjointedRange<
'lf,
T: Send + 'lf,
R: RangeBounds<usize> + Send + 'lf,
P: Promise<Output = Vec<T>> + 'lf,
> {
range: R,
inner: Arc<DisjointedInner<'lf, T, P>>,
}
impl<'lf, T: Send + 'lf, R: RangeBounds<usize> + Send + 'lf, P: Promise<Output = Vec<T>> + 'lf>
Debug for DisjointedRange<'lf, T, R, P>
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DisjointedRange")
.field("start", &self.range.start_bound())
.field("end", &self.range.end_bound())
.finish()
}
}
impl<'lf, T: Send, R: RangeBounds<usize> + Send + 'lf, P: Promise<Output = Vec<T>> + 'lf> Promise
for DisjointedRange<'lf, T, R, P>
{
type Output = Vec<T>;
fn poll(&mut self) -> PollPromise<Self::Output> {
match self.inner.get() {
None => PollPromise::Pending,
Some(ready) => {
let (start, end) = (
self.range.start_bound().cloned(),
self.range.end_bound().cloned(),
);
match ready[(start, end)]
.iter()
.try_fold(vec![], |mut accum, next| {
let t = next.lock().take()?;
accum.push(t);
Some(accum)
}) {
Some(vec) => PollPromise::Ready(vec),
None => panic!("trying to poll already executed promise"),
}
}
}
}
}
impl<R: RangeBounds<usize> + Send + 'static, P: Promise<Output = Vec<Data>> + 'static>
DisjointedRange<'static, Data, R, P>
{
pub fn downcast_elements<T: 'static>(self) -> CastedDisjointedRange<T, R, P> {
CastedDisjointedRange {
inner: self,
_marker: PhantomData,
}
}
}
pub struct CastedDisjointedRange<
T: 'static,
R: RangeBounds<usize> + Send + 'static,
P: Promise<Output = Vec<Data>> + 'static,
> {
inner: DisjointedRange<'static, Data, R, P>,
_marker: PhantomData<T>,
}
impl<
T: 'static + Send,
R: RangeBounds<usize> + Send + 'static,
P: Promise<Output = Vec<Data>> + 'static,
> InputSource<CastedDisjointedRange<T, R, P>> for Input
{
type Data = Vec<T>;
fn use_as_input_source(
&mut self,
other: CastedDisjointedRange<T, R, P>,
) -> Result<(), JobError> {
let inner = other
.inner
.map(|inner| {
inner
.into_iter()
.map(|data| match data.downcast::<T>() {
Ok(data) => *data,
Err(e) => {
panic!("failed to downcast {:?} to `{}`", e, type_name::<T>())
}
})
.collect::<Vec<T>>()
})
.map(|vec| Box::new(vec) as Data);
let promise: BoxPromise<'static, Data> = Box::new(inner);
match (self.flavor, &mut self.kind) {
(InputFlavor::None, _) => Err(JobError::UnexpectedInput),
(InputFlavor::Single, to_set @ InputKind::None) => {
*to_set = InputKind::Single(Box::new(promise));
Ok(())
}
(InputFlavor::Single, InputKind::Single(_)) => Err(JobError::InputAlreadySet),
(InputFlavor::Funnel, InputKind::Funnel(funnel)) => {
funnel.insert(promise);
Ok(())
}
_ => {
panic!("input kind and flavor mismatch")
}
}
}
}
type Buf<T> = Vec<Mutex<Option<T>>>;
struct DisjointedInner<'lf, T: Send + 'lf, P: Promise<Output = Vec<T>> + 'lf> {
checker: Mutex<OverlapChecker<usize>>,
promise: Arc<Mutex<P>>,
inner: OnceLock<Buf<T>>,
_marker: PhantomData<&'lf ()>,
}
impl<'lf, T: Send + 'lf, P: Promise<Output = Vec<T>> + 'lf> DisjointedInner<'lf, T, P> {
fn new(promise: P) -> Self {
Self {
checker: Default::default(),
promise: Arc::new(Mutex::new(promise)),
inner: OnceLock::new(),
_marker: Default::default(),
}
}
fn get(&self) -> Option<&Buf<T>> {
let promise = self.promise.clone();
let f = move || match promise.lock().poll() {
PollPromise::Ready(ready) => Some(into_buf(ready)),
PollPromise::Pending => None,
};
self.inner.get_or_try_init_opt(f)
}
}
fn into_buf<T: Send>(v: Vec<T>) -> Buf<T> {
v.into_iter().map(|t| Mutex::new(Some(t))).collect()
}
#[derive(Debug, Error)]
pub enum DisjointedError {
#[error("{0} is already is use")]
IndexAlreadyUsed(usize),
#[error("({0:?}, {1:?}) is already in use")]
RangeOverlaps(Bound<usize>, Bound<usize>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::disjointed::{Buf, Disjointed};
use crate::backend::job::Data;
use crate::backend::recv_promise::RecvPromise;
use crate::sync::promise::GetPromise;
use crossbeam::channel::bounded;
use static_assertions::assert_impl_all;
use std::sync::Arc;
assert_impl_all!(Disjointed<'static, Data>: Send);
assert_impl_all!(Buf<Data>: Sync);
#[test]
fn test_disjointed_inner() {
let (tx, rx) = bounded::<Vec<usize>>(1);
let receiver_promise = RecvPromise::new(rx);
let inner = Arc::new(DisjointedInner::new(receiver_promise));
assert!(inner.get().is_none());
tx.send(vec![1, 2, 3]).unwrap();
assert!(inner.get().is_some());
let Some(buf) = inner.get() else {
panic!(
"Should be ready, and calling get multiple times shouldn't cause problems once promise is ready"
)
};
assert_eq!(buf.len(), 3);
}
#[test]
fn test_disjointed_indices() {
let (tx, rx) = bounded::<Vec<usize>>(1);
let receiver_promise = RecvPromise::new(rx);
let inner = Disjointed::new(receiver_promise);
let a = inner.get(0).unwrap();
let b = inner.get(1).unwrap();
inner.get(0).expect_err("should result in error");
tx.send(vec![1, 2, 3]).unwrap();
let a = a.get();
let b = b.get();
assert_eq!(a, 1);
assert_eq!(b, 2);
}
#[test]
fn test_disjointed_ranges() {
let (tx, rx) = bounded::<Vec<usize>>(1);
let receiver_promise = RecvPromise::new(rx);
let inner = Disjointed::new(receiver_promise);
let a = inner.get_range(0..3).unwrap();
let b = inner.get_range(3..).unwrap();
inner.get(0).expect_err("should result in error");
inner.get_range(2..=4).expect_err("should result in error");
tx.send(vec![1, 2, 3, 4, 5, 6]).unwrap();
let a = a.get();
let b = b.get();
assert_eq!(a, [1, 2, 3]);
assert_eq!(b, [4, 5, 6]);
}
#[test]
fn test_disjointed_mixed() {
let (tx, rx) = bounded::<Vec<usize>>(1);
let receiver_promise = RecvPromise::new(rx);
let inner = Disjointed::new(receiver_promise);
let a = inner.get_range(0..3).unwrap();
let b = inner.get(3).unwrap();
let c = inner.get(4).unwrap();
let d = inner.get_range(5..).unwrap();
tx.send(vec![1, 2, 3, 4, 5, 6, 7]).unwrap();
let a = a.get();
let b = b.get();
let c = c.get();
let d = d.get();
assert_eq!(a, [1, 2, 3]);
assert_eq!(b, 4);
assert_eq!(c, 5);
assert_eq!(d, [6, 7]);
}
}