1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
    future::Future,
    pin::Pin,
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc, RwLock,
    },
    task::{Context, Poll},
};

#[derive(Clone, Debug)]
pub struct Sender<T: Clone> {
    is_sent: Arc<AtomicBool>,
    value: Arc<RwLock<Option<T>>>,
}

impl<T: Clone> Sender<T> {
    pub fn send(&self, value: T) {
        if !self.is_sent.load(Ordering::Relaxed) {
            self.value.write().unwrap().replace(value);
            self.is_sent.store(true, Ordering::Relaxed);
        }
    }

    pub fn subscribe(&self) -> Receiver<T> {
        Receiver {
            is_sent: self.is_sent.clone(),
            value: self.value.clone(),
        }
    }
}

#[derive(Clone, Debug)]
pub struct Receiver<T: Clone> {
    is_sent: Arc<AtomicBool>,
    value: Arc<RwLock<Option<T>>>,
}

impl<T: Clone> Receiver<T> {
    pub fn recv(&self) -> ReceiverWaiter<T> {
        ReceiverWaiter {
            is_sent: self.is_sent.clone(),
            value: self.value.clone(),
        }
    }
}

pub struct ReceiverWaiter<T: Clone> {
    is_sent: Arc<AtomicBool>,
    value: Arc<RwLock<Option<T>>>,
}

impl<T: Clone> Future for ReceiverWaiter<T> {
    type Output = T;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.is_sent.load(Ordering::Relaxed) {
            let a = self.get_mut().value.read().unwrap().clone().unwrap();
            return Poll::Ready(a);
        }

        cx.waker().wake_by_ref();

        Poll::Pending
    }
}

pub fn channel<T: Clone>() -> (Sender<T>, Receiver<T>) {
    let is_sent = Arc::new(AtomicBool::new(false));
    let value = Arc::new(RwLock::new(None));

    (
        Sender {
            is_sent: is_sent.clone(),
            value: value.clone(),
        },
        Receiver {
            is_sent: is_sent.clone(),
            value: value.clone(),
        },
    )
}

#[cfg(test)]
mod tests {
    use tokio::join;

    #[tokio::test]
    async fn test() {
        let (tx, rx1) = super::channel::<bool>();

        let rx2 = tx.subscribe();
        let rx3 = tx.subscribe();

        let a = tokio::spawn(async move {
            let val = rx1.recv().await; //wait for value to become false
            assert!(val)
        });

        let b = tokio::spawn(async move {
            let val = rx2.recv().await; //wait for value to become false
            assert!(val)
        });

        tokio::time::sleep(std::time::Duration::from_millis(10)).await;

        tx.send(true);

        let val = rx3.recv().await;
        assert!(val);

        let (a, b) = join!(a, b);
        a.unwrap();
        b.unwrap();
    }
}