1use crate::reliable_conn::{ReliableOrderedStreamToTarget, ReliableOrderedStreamToTargetExt};
36use crate::sync::network_application::{PostActionChannel, PreActionChannel, INITIAL_CAPACITY};
37use crate::sync::subscription::{
38 close_sequence_for_multiplexed_bistream, Subscribable, SubscriptionBiStream,
39};
40use crate::sync::{RelativeNodeType, SymmetricConvID};
41use anyhow::Error;
42use async_trait::async_trait;
43use citadel_io::tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
44use citadel_io::tokio::sync::Mutex;
45use citadel_io::RwLock;
46use serde::de::DeserializeOwned;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::fmt::Debug;
50use std::hash::Hash;
51use std::ops::Deref;
52use std::sync::atomic::{AtomicU64, Ordering};
53use std::sync::Arc;
54
55pub trait MultiplexedConnKey:
57 Debug + Eq + Hash + Copy + Send + Sync + Serialize + DeserializeOwned + IDGen<Self>
58{
59}
60impl<T: Debug + Eq + Hash + Copy + Send + Sync + Serialize + DeserializeOwned + IDGen<Self>>
61 MultiplexedConnKey for T
62{
63}
64
65pub trait IDGen<Key: MultiplexedConnKey> {
67 type Container: Send + Sync;
69 fn generate_container() -> Self::Container;
71 fn generate_next(container: &Self::Container) -> Self;
73 fn get_proposed_next(container: &Self::Container) -> Key;
75}
76
77impl IDGen<SymmetricConvID> for SymmetricConvID {
78 type Container = Arc<AtomicU64>;
79
80 fn generate_container() -> Self::Container {
81 Arc::new(AtomicU64::new(0))
82 }
83
84 fn generate_next(container: &Self::Container) -> SymmetricConvID {
85 (1 + container.fetch_add(1, Ordering::Relaxed)).into()
86 }
87
88 fn get_proposed_next(container: &Self::Container) -> SymmetricConvID {
89 (1 + container.load(Ordering::Relaxed)).into()
90 }
91}
92
93pub struct MultiplexedConn<K: MultiplexedConnKey = SymmetricConvID> {
95 inner: Arc<MultiplexedConnInner<K>>,
96}
97
98pub struct MultiplexedConnInner<K: MultiplexedConnKey> {
100 pub(crate) conn: Arc<dyn ReliableOrderedStreamToTarget>,
102 subscribers: RwLock<HashMap<K, MemorySender>>,
104 pre_open_container: PreActionChannel<K>,
106 post_close_container: PostActionChannel<K>,
108 id_gen: K::Container,
110 current_latest_subscribed: K::Container,
112 node_type: RelativeNodeType,
114}
115
116pub struct MemorySender {
118 tx: UnboundedSender<Vec<u8>>,
120 pre_reserved_rx: Option<UnboundedReceiver<Vec<u8>>>,
122}
123
124impl Deref for MemorySender {
125 type Target = UnboundedSender<Vec<u8>>;
126
127 fn deref(&self) -> &Self::Target {
128 &self.tx
129 }
130}
131
132impl<K: MultiplexedConnKey> Deref for MultiplexedConn<K> {
133 type Target = MultiplexedConnInner<K>;
134
135 fn deref(&self) -> &Self::Target {
136 self.inner.deref()
137 }
138}
139
140#[derive(Serialize, Deserialize)]
142#[serde(bound = "")]
143pub(crate) enum MultiplexedPacket<K: MultiplexedConnKey> {
144 ApplicationLayer { id: K, payload: Vec<u8> },
146 PostDrop { id: K },
148 PreCreate { id: K },
150 Greeter,
152}
153
154impl<K: MultiplexedConnKey> MultiplexedConn<K> {
155 pub fn new<T: ReliableOrderedStreamToTarget + 'static>(
157 node_type: RelativeNodeType,
158 conn: T,
159 ) -> Self {
160 let id_gen = K::generate_container();
161 let ids: Vec<K> = (0..INITIAL_CAPACITY)
162 .map(|_| <K as IDGen<K>>::generate_next(&id_gen))
163 .collect();
164 let post_close_container = PostActionChannel::new(&ids);
166 let mut subscribers = HashMap::new();
167
168 for id in ids {
169 let (tx, pre_reserved_rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
170 subscribers.insert(
171 id,
172 MemorySender {
173 tx,
174 pre_reserved_rx: Some(pre_reserved_rx),
175 },
176 );
177 }
178
179 let current_latest_subscribed = K::generate_container();
180
181 Self {
182 inner: Arc::new(MultiplexedConnInner {
183 conn: Arc::new(conn),
184 subscribers: RwLock::new(subscribers),
185 pre_open_container: PreActionChannel::new(),
186 post_close_container,
187 current_latest_subscribed,
188 id_gen,
189 node_type,
190 }),
191 }
192 }
193}
194
195impl<K: MultiplexedConnKey> Clone for MultiplexedConn<K> {
196 fn clone(&self) -> Self {
197 Self {
198 inner: self.inner.clone(),
199 }
200 }
201}
202
203pub struct MultiplexedSubscription<'a, K: MultiplexedConnKey = SymmetricConvID> {
205 ptr: &'a MultiplexedConn<K>,
207 receiver: Option<Mutex<UnboundedReceiver<Vec<u8>>>>,
209 id: K,
211}
212
213impl<K: MultiplexedConnKey> SubscriptionBiStream for MultiplexedSubscription<'_, K> {
214 type Conn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
215 type ID = K;
216
217 fn conn(&self) -> &Self::Conn {
218 &self.ptr.conn
219 }
220
221 fn receiver(&self) -> &Mutex<UnboundedReceiver<Vec<u8>>> {
222 self.receiver.as_ref().unwrap()
223 }
224
225 fn id(&self) -> Self::ID {
226 self.id
227 }
228
229 fn node_type(&self) -> RelativeNodeType {
230 self.ptr.node_type
231 }
232}
233
234impl<K: MultiplexedConnKey> From<MultiplexedSubscription<'_, K>>
235 for OwnedMultiplexedSubscription<K>
236{
237 fn from(mut this: MultiplexedSubscription<'_, K>) -> Self {
238 let ret = Self {
239 ptr: this.ptr.clone(),
240 receiver: this.receiver.take().unwrap(),
241 id: this.id,
242 };
243
244 std::mem::forget(this);
246 ret
247 }
248}
249
250pub struct OwnedMultiplexedSubscription<K: MultiplexedConnKey + 'static = SymmetricConvID> {
252 ptr: MultiplexedConn<K>,
254 receiver: Mutex<UnboundedReceiver<Vec<u8>>>,
256 id: K,
258}
259
260impl<K: MultiplexedConnKey> SubscriptionBiStream for OwnedMultiplexedSubscription<K> {
261 type Conn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
262 type ID = K;
263
264 fn conn(&self) -> &Self::Conn {
265 &self.ptr.conn
266 }
267
268 fn receiver(&self) -> &Mutex<UnboundedReceiver<Vec<u8>>> {
269 &self.receiver
270 }
271
272 fn id(&self) -> Self::ID {
273 self.id
274 }
275
276 fn node_type(&self) -> RelativeNodeType {
277 self.ptr.node_type
278 }
279}
280
281#[async_trait]
282impl<K: MultiplexedConnKey + 'static> Subscribable for MultiplexedConn<K> {
283 type ID = K;
284 type UnderlyingConn = Arc<dyn ReliableOrderedStreamToTarget + 'static>;
285 type SubscriptionType = OwnedMultiplexedSubscription<K>;
286 type BorrowedSubscriptionType = OwnedMultiplexedSubscription<K>;
288
289 fn underlying_conn(&self) -> &Self::UnderlyingConn {
290 &self.conn
291 }
292
293 fn subscriptions(&self) -> &RwLock<HashMap<Self::ID, MemorySender>> {
294 &self.subscribers
295 }
296
297 fn post_close_container(&self) -> &PostActionChannel<Self::ID> {
298 &self.post_close_container
299 }
300
301 fn pre_action_container(&self) -> &PreActionChannel<Self::ID> {
302 &self.pre_open_container
303 }
304
305 async fn recv_post_close_signal_from_stream(&self, id: Self::ID) -> Result<(), Error> {
306 self.post_close_container.recv(id).await
307 }
308
309 async fn send_post_close_signal(&self, id: Self::ID) -> Result<(), Error> {
310 Ok(self
311 .conn
312 .send_serialized(MultiplexedPacket::PostDrop { id })
313 .await?)
314 }
315
316 async fn send_pre_open_signal(&self, id: Self::ID) -> Result<(), Error> {
317 Ok(self
318 .conn
319 .send_serialized(MultiplexedPacket::PreCreate { id })
320 .await?)
321 }
322
323 fn node_type(&self) -> RelativeNodeType {
324 self.node_type
325 }
326
327 fn get_next_prereserved(&self) -> Option<Self::BorrowedSubscriptionType> {
328 let mut lock = self.subscribers.write();
329 let next_key = K::get_proposed_next(&self.current_latest_subscribed);
330 let pre_reserved_stream = lock.get_mut(&next_key)?;
331 let sub = MultiplexedSubscription {
332 ptr: self,
333 receiver: Some(Mutex::new(pre_reserved_stream.pre_reserved_rx.take()?)),
334 id: next_key,
335 };
336 assert_eq!(K::generate_next(&self.current_latest_subscribed), next_key);
337 Some(sub.into())
338 }
339
340 fn subscribe(&self, id: Self::ID) -> Self::BorrowedSubscriptionType {
341 let mut lock = self.subscribers.write();
342 let (tx, receiver) = unbounded_channel();
343 let sub = MultiplexedSubscription {
344 ptr: self,
345 receiver: Some(Mutex::new(receiver)),
346 id,
347 };
348 assert!(lock
349 .insert(
350 id,
351 MemorySender {
352 tx,
353 pre_reserved_rx: None
354 }
355 )
356 .is_none());
357 assert_eq!(K::generate_next(&self.current_latest_subscribed), id);
358 sub.into()
360 }
361
362 fn owned_subscription(&self, id: Self::ID) -> Self::SubscriptionType {
363 self.subscribe(id)
364 }
365
366 fn get_next_id(&self) -> Self::ID {
367 <K as IDGen<K>>::generate_next(&self.id_gen)
368 }
369}
370
371impl<K: MultiplexedConnKey + 'static> Drop for OwnedMultiplexedSubscription<K> {
372 fn drop(&mut self) {
373 close_sequence_for_multiplexed_bistream(self.id, self.ptr.clone())
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use crate::multiplex::OwnedMultiplexedSubscription;
380 use crate::reliable_conn::ReliableOrderedStreamToTargetExt;
381 use crate::sync::network_application::NetworkApplication;
382 use crate::sync::subscription::{Subscribable, SubscriptionBiStreamExt};
383 use crate::sync::test_utils::create_streams;
384 use crate::sync::SymmetricConvID;
385 use async_recursion::async_recursion;
386 use citadel_io::tokio;
387 use serde::{Deserialize, Serialize};
388
389 #[derive(Serialize, Deserialize)]
390 struct Packet(usize);
391
392 #[tokio::test]
393 async fn nested_multiplexed_stream() {
394 let (outer_stream_server, outer_stream_client) = create_streams().await;
395 nested(0, 50, outer_stream_server, outer_stream_client).await;
397 }
398
399 #[async_recursion]
400 async fn nested(
401 idx: usize,
402 max: usize,
403 server_stream: NetworkApplication,
404 client_stream: NetworkApplication,
405 ) -> (NetworkApplication, NetworkApplication) {
406 if idx == max {
407 return (server_stream, client_stream);
408 }
409
410 let (tx, rx) = citadel_io::tokio::sync::oneshot::channel::<()>();
411 let (server_stream0, client_stream0) = (server_stream.clone(), client_stream.clone());
412
413 let server = citadel_io::tokio::spawn(async move {
414 let next_stream: OwnedMultiplexedSubscription =
416 server_stream.initiate_subscription().await.unwrap();
417 next_stream.send_serialized(Packet(idx)).await.unwrap();
418 rx.await.unwrap();
419 next_stream.multiplex::<SymmetricConvID>().await.unwrap()
420 });
421
422 let client = citadel_io::tokio::spawn(async move {
423 let next_stream: OwnedMultiplexedSubscription =
424 client_stream.initiate_subscription().await.unwrap();
425 let val = next_stream.recv_serialized::<Packet>().await.unwrap();
426 assert_eq!(val.0, idx);
427 tx.send(()).unwrap();
428 next_stream.multiplex::<SymmetricConvID>().await.unwrap()
429 });
430
431 let (tx1, rx1) = citadel_io::tokio::sync::oneshot::channel::<()>();
432
433 let server1 = citadel_io::tokio::spawn(async move {
434 let next_stream: OwnedMultiplexedSubscription =
436 server_stream0.initiate_subscription().await.unwrap();
437 next_stream.send_serialized(Packet(idx + 10)).await.unwrap();
438 rx1.await.unwrap();
439 next_stream.multiplex::<SymmetricConvID>().await.unwrap()
440 });
441
442 let client1 = citadel_io::tokio::spawn(async move {
443 let next_stream: OwnedMultiplexedSubscription =
444 client_stream0.initiate_subscription().await.unwrap();
445 let val = next_stream.recv_serialized::<Packet>().await.unwrap();
446 assert_eq!(val.0, idx + 10);
447 tx1.send(()).unwrap();
448 next_stream.multiplex::<SymmetricConvID>().await.unwrap()
449 });
450
451 let (next_server_stream, next_client_stream, _, _) =
452 citadel_io::tokio::join!(server, client, server1, client1);
453
454 nested(
455 idx + 1,
456 max,
457 next_server_stream.unwrap(),
458 next_client_stream.unwrap(),
459 )
460 .await
461 }
462}