1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;

use futures::channel::oneshot::{self, Sender};
use tokio::sync::mpsc;
use tokio::task::{JoinSet, LocalSet};

use crate::shard::InternalJoinSetResult;
use crate::{LoadFromUpstream, ServiceData, UpstreamError};

/// The data load request is passed to [`LoadFromUpstream::load`] handler.
///
/// It contains methods to notify the shard that the data has finished loading,
/// or has failed to load.
///
/// If the DataLoadRequest is dropped, without sending a result, the shard will
/// receive an `UpstreamError::OperationAborted` for the given key being loaded.
/// This prevents leaking load requests due to code error, where a code-path did
/// not use the data load request.
#[must_use = "the data load request must be resolved or rejected, otherwise the operation will be considered aborted."]
pub struct DataLoadRequest<'a, Key: Send + 'static, Data: ServiceData> {
    key: Option<Key>,
    internal_join_set: &'a mut JoinSet<InternalJoinSetResult<Key, Data>>,
}

impl<'a, Key: Send, Data: ServiceData> DataLoadRequest<'a, Key, Data> {
    pub(crate) fn new(
        key: Key,
        internal_join_set: &'a mut JoinSet<InternalJoinSetResult<Key, Data>>,
    ) -> Self {
        Self {
            key: Some(key),
            internal_join_set,
        }
    }

    /// Returns a reference to the key that is currently being loaded.
    pub fn key(&self) -> &Key {
        self.key
            .as_ref()
            .expect("invariant: key must be present, unless dropped.")
    }

    pub fn take_key(&mut self) -> Key {
        self.key
            .take()
            .expect("invariant: key must be present, unless dropped.")
    }

    // Emits a task to the internal join set with the provided result
    fn emit_result_async(&mut self, result: InternalJoinSetResult<Key, Data>) {
        self.internal_join_set.spawn(async move { result });
    }

    /// Resolves the data load request with the given data.
    pub fn resolve(mut self, data: Data) {
        let key = self.take_key();
        self.emit_result_async(InternalJoinSetResult::DataLoadResult(key, Ok(data)));
    }

    /// Rejects the data load request with a given error. The error must be able to
    /// be converted into an [`UpstreamError`].
    pub fn reject<E: Into<UpstreamError>>(mut self, error: E) {
        let key = self.take_key();
        self.emit_result_async(InternalJoinSetResult::DataLoadResult(
            key,
            Err(error.into()),
        ));
    }
}

impl<'a, Data: ServiceData, Key: Send + 'static> DataLoadRequest<'a, Key, Data> {
    /// Convenience method to spawn a task to drive a future to completion, and capture that future's result in order
    /// to reject or resolve the data load request.
    pub fn spawn<F: Future<Output = Result<Data, UpstreamError>> + Send + 'static>(
        mut self,
        fut: F,
    ) {
        let key = self.take_key();
        self.internal_join_set.spawn(async move {
            match fut.await {
                Ok(data) => InternalJoinSetResult::DataLoadResult(key, Ok(data)),
                Err(err) => InternalJoinSetResult::DataLoadResult(key, Err(err)),
            }
        });
    }
}

impl<'a, Key: Send, Data: ServiceData + Default> DataLoadRequest<'a, Key, Data> {
    /// Similar to `spawn`, however, in event of encountering a [`UpstreamError::KeyNotFound`], error
    /// will resolve with the value of `Data::default`.
    ///
    /// [`spawn`]: `DataLoadRequest::spawn`
    pub fn spawn_default<F: Future<Output = Result<Data, UpstreamError>> + Send + 'static>(
        mut self,
        fut: F,
    ) {
        let key = self.take_key();
        self.internal_join_set.spawn(async move {
            match fut.await {
                Ok(data) => InternalJoinSetResult::DataLoadResult(key, Ok(data)),
                Err(UpstreamError::KeyNotFound) => {
                    InternalJoinSetResult::DataLoadResult(key, Ok(Data::default()))
                }
                Err(err) => InternalJoinSetResult::DataLoadResult(key, Err(err)),
            }
        });
    }
}

impl<'a, Key: Send, Data: ServiceData> Drop for DataLoadRequest<'a, Key, Data> {
    fn drop(&mut self) {
        if let Some(key) = self.key.take() {
            self.emit_result_async(InternalJoinSetResult::DataLoadResult(
                key,
                Err(UpstreamError::OperationAborted),
            ));
        }
    }
}