use std::{
ops::RangeInclusive,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use anyhow::Result;
use async_trait::async_trait;
use crossbeam::atomic::AtomicCell;
use futures::{
channel::mpsc::{self, Sender},
try_join, SinkExt, StreamExt, TryFutureExt, TryStreamExt,
};
use serde::{Deserialize, Serialize};
use tokio::{select, sync::Notify};
use super::IndexedStream;
use crate::{
contiguous::{Contiguous, ContiguousQueue},
directive::Foldable,
operation::{Monoid, Operation},
queue::PublisherExt,
runtime::Runtime,
task::{Task, TaskOutput},
};
#[derive(Serialize, Deserialize, Debug)]
struct Metadata {
range: RangeInclusive<usize>,
}
impl<Op: Operation> TaskOutput<Op, Metadata> {
fn is_final(&self, job_size: usize) -> bool {
self.metadata.range.end() - self.metadata.range.start() + 1 == job_size
}
}
impl<Op: Operation> Contiguous for TaskOutput<Op, Metadata> {
type Key = usize;
fn is_contiguous(&self, other: &Self) -> bool {
self.metadata.range.end() + 1 == *other.metadata.range.start()
|| *other.metadata.range.end() + 1 == *self.metadata.range.start()
}
fn key(&self) -> &Self::Key {
self.metadata.range.start()
}
}
struct Dispatcher<'a, M: Monoid> {
m: &'a M,
assembler: Arc<ContiguousQueue<TaskOutput<M, Metadata>>>,
tx: Sender<Task<'a, M, Metadata>>,
channel_identifier: String,
}
impl<'a, Op: Monoid + 'static> Dispatcher<'a, Op> {
async fn queue(&self, result: TaskOutput<Op, Metadata>) {
self.assembler.queue(result);
}
async fn dequeue(&self, idx: &usize) -> Option<TaskOutput<Op, Metadata>> {
self.assembler.dequeue(idx)
}
async fn try_dispatch(&self, result: TaskOutput<Op, Metadata>) -> Result<()> {
if let Some((lhs, rhs)) = self.assembler.acquire_contiguous_pair_or_queue(result) {
let task = Task {
routing_key: self.channel_identifier.clone(),
metadata: Metadata {
range: *lhs.metadata.range.start()..=*rhs.metadata.range.end(),
},
op: self.m,
input: (lhs.output, rhs.output),
};
let mut tx = self.tx.clone();
tx.send(task).await?;
}
Ok(())
}
}
const MAX_CONCURRENCY_PER_TASK: usize = 10;
#[async_trait]
impl<'a, A: Send + Sync + 'a, B: Send + 'a> Foldable<'a, B> for IndexedStream<'a, A> {
async fn f_fold<M: Monoid<Elem = A>>(self, m: &'a M, runtime: &Runtime) -> Result<A>
where
M: 'static,
{
let (channel_identifier, sender, receiver) =
runtime.lease_coordinated_task_channel().await?;
let (mut tx, rx) = mpsc::channel::<Task<'a, M, Metadata>>(MAX_CONCURRENCY_PER_TASK);
let assembler = Arc::new(ContiguousQueue::new());
let sender = Arc::new(sender);
let dispatcher = Arc::new(Dispatcher {
m,
assembler: assembler.clone(),
tx: tx.clone(),
channel_identifier,
});
let should_dispatch = Arc::new(Notify::new());
let count = Arc::new(AtomicUsize::new(0));
let resolved_input_size = Arc::new(AtomicUsize::new(usize::MAX));
let init = self
.try_for_each_concurrent(MAX_CONCURRENCY_PER_TASK, |(idx, item)| {
let dispatcher = dispatcher.clone();
let should_dispatch = should_dispatch.clone();
let count = count.clone();
async move {
let item_result = TaskOutput {
metadata: Metadata { range: idx..=idx },
output: item,
};
let next_sum = count.fetch_add(1, Ordering::Relaxed) + 1;
if next_sum < 2 {
dispatcher.queue(item_result).await;
} else {
should_dispatch.notify_one();
dispatcher.try_dispatch(item_result).await?;
}
Ok::<_, anyhow::Error>(())
}
})
.and_then({
let count = count.clone();
let dispatcher = dispatcher.clone();
let resolved_input_size = resolved_input_size.clone();
|_| {
async move {
let size = count.load(Ordering::Relaxed);
resolved_input_size.store(size, Ordering::Release);
match size {
0 => Ok(m.empty()),
1 => Ok(dispatcher
.dequeue(&0)
.await
.expect("Expected dispatcher to have a single element at index 0")
.output),
_ => futures::future::pending().await,
}
}
}
});
let (final_result_tx, final_result_rx) = futures::channel::oneshot::channel::<M::Elem>();
let result_processor = {
let dispatcher = dispatcher.clone();
let should_dispatch = should_dispatch.clone();
let final_result_tx = Arc::new(AtomicCell::new(Some(final_result_tx)));
async move {
should_dispatch.notified().await;
receiver
.map(Ok)
.try_for_each_concurrent(MAX_CONCURRENCY_PER_TASK, |(result, acker)| {
let resolved_input_size = resolved_input_size.clone();
let dispatcher = dispatcher.clone();
let final_result_tx = final_result_tx.clone();
async move {
let result = result?;
let resolved_size = resolved_input_size.load(Ordering::Acquire);
if usize::MAX != resolved_size && result.is_final(resolved_size) {
acker.ack().await?;
final_result_tx
.take()
.ok_or_else(|| anyhow::anyhow!("final result tx taken"))?
.send(result.output)
.map_err(|_| anyhow::anyhow!("final result already sent"))?;
return Ok::<_, anyhow::Error>(());
}
try_join!(dispatcher.try_dispatch(result), acker.ack())?;
Ok(())
}
})
.await?;
unreachable!("Result stream should never complete")
}
};
let task_handler = {
let sender = sender.clone();
async move {
sender
.publish_all(rx.map(Ok), MAX_CONCURRENCY_PER_TASK)
.await?;
futures::future::pending().await
}
};
select! {
exit_early = init => exit_early,
task_handler = task_handler => task_handler,
result_processor = result_processor => result_processor,
final_result = final_result_rx => {
tx.close().await?;
sender.close().await?;
Ok(final_result?)
},
}
}
}