use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use futures::channel::oneshot::{self, Sender};
use tokio::sync::mpsc;
use tokio::task::{JoinSet, LocalSet};
use crate::shard::InternalJoinSetResult;
use crate::{LoadFromUpstream, ServiceData, UpstreamError};
#[must_use = "the data load request must be resolved or rejected, otherwise the operation will be considered aborted."]
pub struct DataLoadRequest<'a, Key: Send + 'static, Data: ServiceData> {
key: Option<Key>,
internal_join_set: &'a mut JoinSet<InternalJoinSetResult<Key, Data>>,
}
impl<'a, Key: Send, Data: ServiceData> DataLoadRequest<'a, Key, Data> {
pub(crate) fn new(
key: Key,
internal_join_set: &'a mut JoinSet<InternalJoinSetResult<Key, Data>>,
) -> Self {
Self {
key: Some(key),
internal_join_set,
}
}
pub fn key(&self) -> &Key {
self.key
.as_ref()
.expect("invariant: key must be present, unless dropped.")
}
pub fn take_key(&mut self) -> Key {
self.key
.take()
.expect("invariant: key must be present, unless dropped.")
}
fn emit_result_async(&mut self, result: InternalJoinSetResult<Key, Data>) {
self.internal_join_set.spawn(async move { result });
}
pub fn resolve(mut self, data: Data) {
let key = self.take_key();
self.emit_result_async(InternalJoinSetResult::DataLoadResult(key, Ok(data)));
}
pub fn reject<E: Into<UpstreamError>>(mut self, error: E) {
let key = self.take_key();
self.emit_result_async(InternalJoinSetResult::DataLoadResult(
key,
Err(error.into()),
));
}
}
impl<'a, Data: ServiceData, Key: Send + 'static> DataLoadRequest<'a, Key, Data> {
pub fn spawn<F: Future<Output = Result<Data, UpstreamError>> + Send + 'static>(
mut self,
fut: F,
) {
let key = self.take_key();
self.internal_join_set.spawn(async move {
match fut.await {
Ok(data) => InternalJoinSetResult::DataLoadResult(key, Ok(data)),
Err(err) => InternalJoinSetResult::DataLoadResult(key, Err(err)),
}
});
}
}
impl<'a, Key: Send, Data: ServiceData + Default> DataLoadRequest<'a, Key, Data> {
pub fn spawn_default<F: Future<Output = Result<Data, UpstreamError>> + Send + 'static>(
mut self,
fut: F,
) {
let key = self.take_key();
self.internal_join_set.spawn(async move {
match fut.await {
Ok(data) => InternalJoinSetResult::DataLoadResult(key, Ok(data)),
Err(UpstreamError::KeyNotFound) => {
InternalJoinSetResult::DataLoadResult(key, Ok(Data::default()))
}
Err(err) => InternalJoinSetResult::DataLoadResult(key, Err(err)),
}
});
}
}
impl<'a, Key: Send, Data: ServiceData> Drop for DataLoadRequest<'a, Key, Data> {
fn drop(&mut self) {
if let Some(key) = self.key.take() {
self.emit_result_async(InternalJoinSetResult::DataLoadResult(
key,
Err(UpstreamError::OperationAborted),
));
}
}
}