1use crate::multiplex::MultiplexedConnKey;
50use crate::reliable_conn::{ReliableOrderedStreamToTarget, ReliableOrderedStreamToTargetExt};
51use crate::sync::subscription::{Subscribable, SubscriptionBiStream};
52use crate::sync::RelativeNodeType;
53use crate::ScopedFutureResult;
54use citadel_io::tokio::sync::{Mutex, MutexGuard};
55use serde::{Deserialize, Serialize};
56use std::future::Future;
57use std::pin::Pin;
58use std::task::{Context, Poll};
59
60pub struct NetTryJoin<'a, T, E> {
62 future: ScopedFutureResult<'a, NetTryJoinResult<T, E>>,
63}
64
65impl<'a, T: Send + 'a, E: Send + 'a> NetTryJoin<'a, T, E> {
66 pub fn new<
67 S: Subscribable<ID = K, UnderlyingConn = Conn>,
68 K: MultiplexedConnKey + 'a,
69 Conn: ReliableOrderedStreamToTarget + 'static,
70 F: Future<Output = Result<T, E>> + Send + 'a,
71 >(
72 conn: &'a S,
73 local_node_type: RelativeNodeType,
74 future: F,
75 ) -> NetTryJoin<'a, T, E> {
76 Self {
77 future: Box::pin(resolve(conn, local_node_type, future)),
78 }
79 }
80}
81
82impl<T, E> Future for NetTryJoin<'_, T, E> {
83 type Output = Result<NetTryJoinResult<T, E>, anyhow::Error>;
84
85 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86 self.future.as_mut().poll(cx)
87 }
88}
89
90#[derive(Debug)]
91pub struct NetTryJoinResult<T, E> {
92 pub value: Option<Result<T, E>>,
93}
94
95#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
96enum State {
97 Pending,
98 ObtainedValidResult,
99 Resolved,
100 ResolvedBothFail,
101 NonPreferredFinished,
102 Error,
103 Pinging(Option<bool>),
105}
106
107impl State {
108 fn implies_success(&self) -> bool {
110 matches!(self, Self::ObtainedValidResult | Self::Pinging(Some(true)))
111 }
112
113 fn implies_failure(&self) -> bool {
114 matches!(self, Self::Error | Self::Pinging(Some(false)))
115 }
116}
117
118async fn resolve<
119 S: Subscribable<ID = K, UnderlyingConn = Conn>,
120 K: MultiplexedConnKey,
121 Conn: ReliableOrderedStreamToTarget + 'static,
122 F,
123 T,
124 E,
125>(
126 conn: &S,
127 local_node_type: RelativeNodeType,
128 future: F,
129) -> Result<NetTryJoinResult<T, E>, anyhow::Error>
130where
131 F: Future<Output = Result<T, E>>,
132{
133 let conn = &(conn.initiate_subscription().await?);
134 log::trace!(target: "citadel", "NET_TRY_JOIN started conv={:?} for {:?}", conn.id(), local_node_type);
135 let (stopper_tx, stopper_rx) = citadel_io::tokio::sync::oneshot::channel::<()>();
136
137 struct LocalState<T, E> {
138 local_state: State,
139 ret_value: Option<Result<T, E>>,
140 }
141
142 let local_state = LocalState {
143 local_state: State::Pending,
144 ret_value: None,
145 };
146 let local_state_ref = &Mutex::new(local_state);
147
148 let has_preference = local_node_type == RelativeNodeType::Initiator;
149
150 let evaluator = async move {
152 let _stopper_tx = stopper_tx;
153
154 async fn return_sequence<Conn: ReliableOrderedStreamToTarget, T, E>(
155 conn: &Conn,
156 new_state: State,
157 mut state: MutexGuard<'_, LocalState<T, E>>,
158 ) -> Result<Option<Result<T, E>>, anyhow::Error> {
159 state.local_state = new_state.clone();
160 conn.send_serialized(new_state.clone()).await?;
161 Ok(state.ret_value.take())
162 }
163
164 loop {
165 let received_remote_state = conn.recv_serialized::<State>().await?;
166 let mut lock = local_state_ref.lock().await;
168 let local_state_info = lock.ret_value.as_ref().map(|r| r.is_ok());
169 log::trace!(target: "citadel", "[conv={:?} Node {:?} recv {:?} || Local state: {:?}", conn.id(), local_node_type, received_remote_state, lock.local_state);
170 if has_preference {
171 if received_remote_state.implies_failure() || lock.local_state.implies_failure() {
174 return return_sequence(conn, State::ResolvedBothFail, lock).await;
176 }
177
178 if received_remote_state.implies_success() && lock.local_state.implies_success() {
180 return return_sequence(conn, State::Resolved, lock).await;
181 }
182
183 conn.send_serialized(State::Pinging(local_state_info))
185 .await?;
186 } else {
187 match received_remote_state {
189 State::Resolved => {
190 lock.local_state = State::Resolved;
192 return Ok(lock.ret_value.take());
193 }
194
195 State::ResolvedBothFail => {
196 return Ok(lock.ret_value.take());
198 }
199
200 _ => {
201 conn.send_serialized(State::Pinging(local_state_info))
204 .await?;
205 }
206 }
207 }
208 }
209 };
210
211 let completer = async move {
213 let res = future.await;
215 let mut local_state = local_state_ref.lock().await;
216
217 let state = res
218 .as_ref()
219 .map(|_| State::ObtainedValidResult)
220 .unwrap_or(State::Error);
221
222 local_state.local_state = state.clone();
226 local_state.ret_value = Some(res);
227
228 conn.send_serialized(state).await?;
230 std::mem::drop(local_state);
231 stopper_rx.await?;
234 Err(anyhow::Error::msg("Stopped before the resolver"))
235 };
236
237 citadel_io::tokio::select! {
238 res0 = evaluator => {
239 log::trace!(target: "citadel", "NET_TRY_JOIN ending for {:?} (conv={:?})", local_node_type, conn.id());
240 let ret = res0?;
241 wrap_return(ret)
242 },
243
244 res1 = completer => res1
245 }
246}
247
248fn wrap_return<T, E>(value: Option<Result<T, E>>) -> Result<NetTryJoinResult<T, E>, anyhow::Error> {
249 Ok(NetTryJoinResult { value })
250}
251
252#[cfg(test)]
253mod tests {
254 use crate::sync::network_application::NetworkApplication;
255 use crate::sync::test_utils::create_streams;
256 use citadel_io::tokio;
257 use std::fmt::Debug;
258 use std::future::Future;
259 use std::time::Duration;
260
261 #[tokio::test]
262 async fn racer() {
263 citadel_logging::setup_log();
264
265 let (server_stream, client_stream) = create_streams().await;
266 const COUNT: i32 = 10;
267 for idx in 0..COUNT {
268 log::trace!(target: "citadel", "[Meta] ERR:ERR ({}/{})", idx, COUNT);
269 inner(
270 server_stream.clone(),
271 client_stream.clone(),
272 dummy_function_err(),
273 dummy_function_err(),
274 false,
275 )
276 .await;
277 }
278
279 for idx in 0..COUNT {
280 log::trace!(target: "citadel", "[Meta] OK:OK ({}/{})", idx, COUNT);
281 inner(
282 server_stream.clone(),
283 client_stream.clone(),
284 dummy_function(),
285 dummy_function(),
286 true,
287 )
288 .await;
289 }
290
291 for idx in 0..COUNT {
292 log::trace!(target: "citadel", "[Meta] ERR:OK ({}/{})", idx, COUNT);
293 inner(
294 server_stream.clone(),
295 client_stream.clone(),
296 dummy_function_err(),
297 dummy_function(),
298 false,
299 )
300 .await;
301 }
302
303 for idx in 0..COUNT {
304 log::trace!(target: "citadel", "[Meta] OK:ERR ({}/{})", idx, COUNT);
305 inner(
306 server_stream.clone(),
307 client_stream.clone(),
308 dummy_function(),
309 dummy_function_err(),
310 false,
311 )
312 .await;
313 }
314 }
315
316 async fn inner<
317 R: Send + Debug + 'static,
318 F: Future<Output = Result<R, &'static str>> + Send + 'static,
319 Y: Future<Output = Result<R, &'static str>> + Send + 'static,
320 >(
321 conn0: NetworkApplication,
322 conn1: NetworkApplication,
323 fx_1: F,
324 fx_2: Y,
325 success: bool,
326 ) {
327 let server = async move {
328 let res = conn0.net_try_join(fx_1).await.unwrap();
329 log::trace!(target: "citadel", "Server res: {:?}", res.value);
330 res
331 };
332
333 let client = async move {
334 let res = conn1.net_try_join(fx_2).await.unwrap();
335 log::trace!(target: "citadel", "Client res: {:?}", res);
336 res
337 };
338
339 let server = citadel_io::tokio::spawn(server);
340 let client = citadel_io::tokio::spawn(client);
341 let (res0, res1) = citadel_io::tokio::join!(server, client);
342
343 log::trace!(target: "citadel", "Unwrapping ....");
344
345 let (res0, res1) = (res0.unwrap(), res1.unwrap());
346
347 log::trace!(target: "citadel", "Done unwrapping");
348 if success {
349 assert!(res0.value.unwrap().is_ok() && res1.value.unwrap().is_ok())
350 } else {
351 assert!(
352 res0.value.map(|r| r.is_err()).unwrap_or(true)
353 || res1.value.map(|r| r.is_err()).unwrap_or(true)
354 );
355 }
356
357 log::trace!(target: "citadel", "DONE executing")
358 }
359
360 async fn dummy_function() -> Result<(), &'static str> {
361 citadel_io::tokio::time::sleep(Duration::from_millis(50)).await;
362 Ok(())
363 }
364
365 async fn dummy_function_err() -> Result<(), &'static str> {
366 Err("Error")
367 }
368}