netbeam/sync/operations/
net_try_join.rs

1/*!
2 * # Network Try-Join Operation
3 *
4 * Implements a network-aware try-join operation that synchronizes fallible futures
5 * across two network endpoints. Similar to `futures::try_join`, but operates over
6 * a network connection.
7 *
8 * ## Features
9 * - Synchronizes fallible futures between network endpoints
10 * - Returns when both endpoints complete successfully
11 * - Early termination on first error
12 * - Type-safe with generic value and error types
13 * - State synchronization between nodes
14 * - Network-aware relative node types
15 *
16 * ## Usage Example
17 * ```rust
18 * use netbeam::sync::operations::net_try_join::NetTryJoin;
19 * use netbeam::sync::RelativeNodeType;
20 * use netbeam::sync::subscription::Subscribable;
21 * use anyhow::Result;
22 *
23 * async fn example<S: Subscribable>(connection: &S) -> Result<()> {
24 *     // Create a try-join operation
25 *     let join = NetTryJoin::new(
26 *         connection,
27 *         RelativeNodeType::Initiator,
28 *         async { Ok::<_, anyhow::Error>(42) }
29 *     );
30 *
31 *     // Wait for both endpoints
32 *     let result = join.await?;
33 *     println!("Got result: {:?}", result);
34 *     Ok(())
35 * }
36 * ```
37 *
38 * ## Important Notes
39 * - Both endpoints must complete successfully
40 * - Handles errors with Result type
41 * - State is synchronized between nodes
42 * - Uses multiplexed connections
43 *
44 * ## Related Components
45 * - `net_join.rs`: Basic join operation without error handling
46 * - `net_select.rs`: Select operation for multiple futures
47 */
48
49use crate::multiplex::MultiplexedConnKey;
50use crate::reliable_conn::{ReliableOrderedStreamToTarget, ReliableOrderedStreamToTargetExt};
51use crate::sync::subscription::{Subscribable, SubscriptionBiStream};
52use crate::sync::RelativeNodeType;
53use crate::ScopedFutureResult;
54use citadel_io::tokio::sync::{Mutex, MutexGuard};
55use serde::{Deserialize, Serialize};
56use std::future::Future;
57use std::pin::Pin;
58use std::task::{Context, Poll};
59
60/// Two endpoints produce Ok(T). Returns when both endpoints produce Ok(T), or, when the first error occurs
61pub struct NetTryJoin<'a, T, E> {
62    future: ScopedFutureResult<'a, NetTryJoinResult<T, E>>,
63}
64
65impl<'a, T: Send + 'a, E: Send + 'a> NetTryJoin<'a, T, E> {
66    pub fn new<
67        S: Subscribable<ID = K, UnderlyingConn = Conn>,
68        K: MultiplexedConnKey + 'a,
69        Conn: ReliableOrderedStreamToTarget + 'static,
70        F: Future<Output = Result<T, E>> + Send + 'a,
71    >(
72        conn: &'a S,
73        local_node_type: RelativeNodeType,
74        future: F,
75    ) -> NetTryJoin<'a, T, E> {
76        Self {
77            future: Box::pin(resolve(conn, local_node_type, future)),
78        }
79    }
80}
81
82impl<T, E> Future for NetTryJoin<'_, T, E> {
83    type Output = Result<NetTryJoinResult<T, E>, anyhow::Error>;
84
85    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86        self.future.as_mut().poll(cx)
87    }
88}
89
90#[derive(Debug)]
91pub struct NetTryJoinResult<T, E> {
92    pub value: Option<Result<T, E>>,
93}
94
95#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
96enum State {
97    Pending,
98    ObtainedValidResult,
99    Resolved,
100    ResolvedBothFail,
101    NonPreferredFinished,
102    Error,
103    // None if not finished, false if errored, true if success
104    Pinging(Option<bool>),
105}
106
107impl State {
108    /// assumes this is called by the receiving node, not the node that creates the state
109    fn implies_success(&self) -> bool {
110        matches!(self, Self::ObtainedValidResult | Self::Pinging(Some(true)))
111    }
112
113    fn implies_failure(&self) -> bool {
114        matches!(self, Self::Error | Self::Pinging(Some(false)))
115    }
116}
117
118async fn resolve<
119    S: Subscribable<ID = K, UnderlyingConn = Conn>,
120    K: MultiplexedConnKey,
121    Conn: ReliableOrderedStreamToTarget + 'static,
122    F,
123    T,
124    E,
125>(
126    conn: &S,
127    local_node_type: RelativeNodeType,
128    future: F,
129) -> Result<NetTryJoinResult<T, E>, anyhow::Error>
130where
131    F: Future<Output = Result<T, E>>,
132{
133    let conn = &(conn.initiate_subscription().await?);
134    log::trace!(target: "citadel", "NET_TRY_JOIN started conv={:?} for {:?}", conn.id(), local_node_type);
135    let (stopper_tx, stopper_rx) = citadel_io::tokio::sync::oneshot::channel::<()>();
136
137    struct LocalState<T, E> {
138        local_state: State,
139        ret_value: Option<Result<T, E>>,
140    }
141
142    let local_state = LocalState {
143        local_state: State::Pending,
144        ret_value: None,
145    };
146    let local_state_ref = &Mutex::new(local_state);
147
148    let has_preference = local_node_type == RelativeNodeType::Initiator;
149
150    // the evaluator finishes before the "completer" if this goes successfully
151    let evaluator = async move {
152        let _stopper_tx = stopper_tx;
153
154        async fn return_sequence<Conn: ReliableOrderedStreamToTarget, T, E>(
155            conn: &Conn,
156            new_state: State,
157            mut state: MutexGuard<'_, LocalState<T, E>>,
158        ) -> Result<Option<Result<T, E>>, anyhow::Error> {
159            state.local_state = new_state.clone();
160            conn.send_serialized(new_state.clone()).await?;
161            Ok(state.ret_value.take())
162        }
163
164        loop {
165            let received_remote_state = conn.recv_serialized::<State>().await?;
166            //log::trace!(target: "citadel", "{:?} RECV'd {:?}", local_node_type, &received_remote_state);
167            let mut lock = local_state_ref.lock().await;
168            let local_state_info = lock.ret_value.as_ref().map(|r| r.is_ok());
169            log::trace!(target: "citadel", "[conv={:?} Node {:?} recv {:?} || Local state: {:?}", conn.id(), local_node_type, received_remote_state, lock.local_state);
170            if has_preference {
171                // if local has preference, we have the permission to evaluate
172                // first, check to make sure local hasn't already obtained a value
173                if received_remote_state.implies_failure() || lock.local_state.implies_failure() {
174                    // If ANY node fails in a TryJoin, we have a global failure
175                    return return_sequence(conn, State::ResolvedBothFail, lock).await;
176                }
177
178                // at this point, neither imply failure
179                if received_remote_state.implies_success() && lock.local_state.implies_success() {
180                    return return_sequence(conn, State::Resolved, lock).await;
181                }
182
183                // neither imply failure, AND, neither imply success. This means we need to ping until either one of those conditions becomes true
184                conn.send_serialized(State::Pinging(local_state_info))
185                    .await?;
186            } else {
187                // if not, we cannot evaluate UNLESS we are being told that we resolved
188                match received_remote_state {
189                    State::Resolved => {
190                        // remote is telling us we both won
191                        lock.local_state = State::Resolved;
192                        return Ok(lock.ret_value.take());
193                    }
194
195                    State::ResolvedBothFail => {
196                        // both nodes failed
197                        return Ok(lock.ret_value.take());
198                    }
199
200                    _ => {
201                        // even in the case of an error, or simply an acknowledgement that the adjacent side succeeded, we need to let remote determine what to do. Just ping
202                        //std::mem::drop(lock);
203                        conn.send_serialized(State::Pinging(local_state_info))
204                            .await?;
205                    }
206                }
207            }
208        }
209    };
210
211    // racer should never finish first
212    let completer = async move {
213        // both sides start this function
214        let res = future.await;
215        let mut local_state = local_state_ref.lock().await;
216
217        let state = res
218            .as_ref()
219            .map(|_| State::ObtainedValidResult)
220            .unwrap_or(State::Error);
221
222        // we don't check the local state because the resolution would terminate this task anyways
223        //log::trace!(target: "citadel", "[NetRacer] {:?} Old state: {:?} | New state: {:?}", local_node_type, &local_state.local_state, &state);
224
225        local_state.local_state = state.clone();
226        local_state.ret_value = Some(res);
227
228        // now, send a packet to the other side
229        conn.send_serialized(state).await?;
230        std::mem::drop(local_state);
231        //log::trace!(target: "citadel", "[NetRacer] {:?} completer done", local_node_type);
232
233        stopper_rx.await?;
234        Err(anyhow::Error::msg("Stopped before the resolver"))
235    };
236
237    citadel_io::tokio::select! {
238        res0 = evaluator => {
239            log::trace!(target: "citadel", "NET_TRY_JOIN ending for {:?} (conv={:?})", local_node_type, conn.id());
240            let ret = res0?;
241            wrap_return(ret)
242        },
243
244        res1 = completer => res1
245    }
246}
247
248fn wrap_return<T, E>(value: Option<Result<T, E>>) -> Result<NetTryJoinResult<T, E>, anyhow::Error> {
249    Ok(NetTryJoinResult { value })
250}
251
252#[cfg(test)]
253mod tests {
254    use crate::sync::network_application::NetworkApplication;
255    use crate::sync::test_utils::create_streams;
256    use citadel_io::tokio;
257    use std::fmt::Debug;
258    use std::future::Future;
259    use std::time::Duration;
260
261    #[tokio::test]
262    async fn racer() {
263        citadel_logging::setup_log();
264
265        let (server_stream, client_stream) = create_streams().await;
266        const COUNT: i32 = 10;
267        for idx in 0..COUNT {
268            log::trace!(target: "citadel", "[Meta] ERR:ERR ({}/{})", idx, COUNT);
269            inner(
270                server_stream.clone(),
271                client_stream.clone(),
272                dummy_function_err(),
273                dummy_function_err(),
274                false,
275            )
276            .await;
277        }
278
279        for idx in 0..COUNT {
280            log::trace!(target: "citadel", "[Meta] OK:OK ({}/{})", idx, COUNT);
281            inner(
282                server_stream.clone(),
283                client_stream.clone(),
284                dummy_function(),
285                dummy_function(),
286                true,
287            )
288            .await;
289        }
290
291        for idx in 0..COUNT {
292            log::trace!(target: "citadel", "[Meta] ERR:OK ({}/{})", idx, COUNT);
293            inner(
294                server_stream.clone(),
295                client_stream.clone(),
296                dummy_function_err(),
297                dummy_function(),
298                false,
299            )
300            .await;
301        }
302
303        for idx in 0..COUNT {
304            log::trace!(target: "citadel", "[Meta] OK:ERR ({}/{})", idx, COUNT);
305            inner(
306                server_stream.clone(),
307                client_stream.clone(),
308                dummy_function(),
309                dummy_function_err(),
310                false,
311            )
312            .await;
313        }
314    }
315
316    async fn inner<
317        R: Send + Debug + 'static,
318        F: Future<Output = Result<R, &'static str>> + Send + 'static,
319        Y: Future<Output = Result<R, &'static str>> + Send + 'static,
320    >(
321        conn0: NetworkApplication,
322        conn1: NetworkApplication,
323        fx_1: F,
324        fx_2: Y,
325        success: bool,
326    ) {
327        let server = async move {
328            let res = conn0.net_try_join(fx_1).await.unwrap();
329            log::trace!(target: "citadel", "Server res: {:?}", res.value);
330            res
331        };
332
333        let client = async move {
334            let res = conn1.net_try_join(fx_2).await.unwrap();
335            log::trace!(target: "citadel", "Client res: {:?}", res);
336            res
337        };
338
339        let server = citadel_io::tokio::spawn(server);
340        let client = citadel_io::tokio::spawn(client);
341        let (res0, res1) = citadel_io::tokio::join!(server, client);
342
343        log::trace!(target: "citadel", "Unwrapping ....");
344
345        let (res0, res1) = (res0.unwrap(), res1.unwrap());
346
347        log::trace!(target: "citadel", "Done unwrapping");
348        if success {
349            assert!(res0.value.unwrap().is_ok() && res1.value.unwrap().is_ok())
350        } else {
351            assert!(
352                res0.value.map(|r| r.is_err()).unwrap_or(true)
353                    || res1.value.map(|r| r.is_err()).unwrap_or(true)
354            );
355        }
356
357        log::trace!(target: "citadel", "DONE executing")
358    }
359
360    async fn dummy_function() -> Result<(), &'static str> {
361        citadel_io::tokio::time::sleep(Duration::from_millis(50)).await;
362        Ok(())
363    }
364
365    async fn dummy_function_err() -> Result<(), &'static str> {
366        Err("Error")
367    }
368}