rocketmq_common/common/
future.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::pin::Pin;
19use std::sync::Arc;
20use std::task::Context;
21use std::task::Poll;
22use std::task::Waker;
23
24use tokio::sync::mpsc;
25use tokio::sync::mpsc::Sender;
26
27/// Enumeration representing the state of a CompletableFuture.
28#[derive(Copy, Clone, PartialEq)]
29enum State {
30    /// The default state, indicating that the future value is pending.
31    Pending,
32    /// Indicates that the future value is ready.
33    Ready,
34}
35
36/// The internal state of a CompletableFuture.
37struct CompletableFutureState<T> {
38    /// The current completion status.
39    completed: State,
40    /// An optional waker to be notified upon completion.
41    waker: Option<Waker>,
42    /// The data value contained within the CompletableFuture upon completion.
43    data: Option<T>,
44
45    /// An optional error value contained within the CompletableFuture upon completion.
46    error: Option<Box<dyn std::error::Error + Send + Sync>>,
47}
48
49/// A CompletableFuture represents a future value that may be completed or pending.
50pub struct CompletableFuture<T> {
51    /// The shared state of the CompletableFuture.
52    state: Arc<std::sync::Mutex<CompletableFutureState<T>>>,
53    /// The sender part of a channel used for communication with the CompletableFuture.
54    tx_rx: Sender<T>,
55}
56
57impl<T: Send + 'static> Default for CompletableFuture<T> {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<T: Send + 'static> CompletableFuture<T> {
64    /// Constructs a new CompletableFuture.
65    pub fn new() -> Self {
66        // Create the shared state.
67        let status = Arc::new(std::sync::Mutex::new(CompletableFutureState {
68            completed: State::Pending,
69            waker: None,
70            data: None,
71            error: None,
72        }));
73        let arc = status.clone();
74
75        // Spawn a Tokio task to handle completion.
76        let (tx, mut rx) = mpsc::channel::<T>(1);
77        tokio::spawn(async move {
78            if let Some(data) = rx.recv().await {
79                let mut state = arc.lock().unwrap();
80                state.data = Some(data);
81                state.completed = State::Ready;
82                if let Some(waker) = state.waker.take() {
83                    waker.wake();
84                }
85                rx.close();
86            }
87        });
88
89        Self {
90            state: status,
91            tx_rx: tx,
92        }
93    }
94
95    /// Returns the sender part of the channel used for communication.
96    pub fn get_sender(&self) -> Sender<T> {
97        self.tx_rx.clone()
98    }
99
100    // Rust code to complete a future task by updating the state and waking up the associated waker
101    pub fn complete(&mut self, result: T) {
102        let mut state = self.state.lock().unwrap();
103        state.completed = State::Ready;
104        state.data = Some(result);
105        if let Some(waker) = state.waker.take() {
106            waker.wake();
107        }
108    }
109
110    pub fn complete_exceptionally(&mut self, error: Box<dyn std::error::Error + Send + Sync>) {
111        let mut state = self.state.lock().unwrap();
112        state.completed = State::Ready;
113        state.error = Some(error);
114        if let Some(waker) = state.waker.take() {
115            waker.wake();
116        }
117    }
118}
119
120impl<T> std::future::Future for CompletableFuture<T> {
121    type Output = Option<T>;
122
123    /// Polls the CompletableFuture to determine if the future value is ready.
124    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
125        let mut shared_state = self.state.lock().unwrap();
126        if shared_state.completed == State::Ready {
127            // If the future value is ready, return it.
128            Poll::Ready(shared_state.data.take())
129        } else {
130            // Otherwise, set the waker to be notified upon completion and return pending.
131            shared_state.waker = Some(cx.waker().clone());
132            Poll::Pending
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use tokio::runtime::Runtime;
140
141    use super::CompletableFuture;
142
143    #[test]
144    fn test_completable_future() {
145        Runtime::new().unwrap().block_on(async move {
146            let cf = CompletableFuture::new();
147            let sender = cf.get_sender();
148
149            // Send data to the CompletableFuture
150            sender.send(42).await.expect("Failed to send data");
151
152            // Wait for the CompletableFuture to complete
153            let result = cf.await;
154
155            // Ensure that the result is Some(42)
156            assert_eq!(result, Some(42));
157        });
158    }
159}