1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
//! Core channel driver--wraps the underlying netidx pack_channel with
//! useful specialized functions.

use anyhow::{anyhow, bail, Result};
use api::{
    utils::messaging::MaybeRequest, Address, ComponentId, Envelope, MaybeSplit, Stamp,
    TypedMessage, UserId,
};
use futures_util::{select_biased, FutureExt};
use log::{debug, error};
use netidx::{path::Path, subscriber::Subscriber};
use netidx_protocols::pack_channel;
use std::sync::{Arc, RwLock};
use tokio::{
    sync::{broadcast, oneshot, watch},
    task,
};

static DEFAULT_CHANNEL_ID: u32 = 1;

struct Channel {
    channel: Arc<pack_channel::client::Connection>,
    src: Address,
}

pub struct ChannelDriver {
    channel: Arc<RwLock<Option<Channel>>>,
    channel_ready: watch::Receiver<bool>,
    channel_path: Path,
    tx: broadcast::Sender<Arc<Vec<Envelope<TypedMessage>>>>,
    _rx: broadcast::Receiver<Arc<Vec<Envelope<TypedMessage>>>>,
    close: Option<(oneshot::Sender<()>, task::JoinHandle<()>)>,
}

impl ChannelDriver {
    pub fn new(
        subscriber: &Subscriber,
        channel_path: Path,
        channel_id: Option<u32>,
    ) -> Self {
        let channel = Arc::new(RwLock::new(None));
        let (mut channel_ready_tx, channel_ready_rx) = watch::channel(false);
        let (close_tx, mut close_rx) = oneshot::channel();
        let (tx, rx) = broadcast::channel(1000);
        let channel_task = {
            let subscriber = subscriber.clone();
            let channel_path = channel_path.clone();
            let channel = channel.clone();
            let tx = tx.clone();
            task::spawn({
                async move {
                    loop {
                        let res = Self::connect_inner(
                            &subscriber,
                            channel_path.clone(),
                            channel_id,
                            channel.clone(),
                            &mut channel_ready_tx,
                            &mut close_rx,
                            tx.clone(),
                        )
                        .await;
                        channel_ready_tx.send_replace(false);
                        if let Err(e) = res {
                            error!("channel driver error, reconnecting in 1s: {}", e);
                            let delay = std::time::Duration::from_secs(1);
                            tokio::time::sleep(delay).await;
                        } else {
                            // graceful shutdown
                            break;
                        }
                    }
                }
            })
        };
        Self {
            channel,
            channel_ready: channel_ready_rx,
            channel_path,
            tx,
            _rx: rx,
            close: Some((close_tx, channel_task)),
        }
    }

    async fn connect_inner(
        subscriber: &Subscriber,
        channel_path: Path,
        channel_id: Option<u32>,
        channel: Arc<RwLock<Option<Channel>>>,
        channel_ready_tx: &mut watch::Sender<bool>,
        close_rx: &mut oneshot::Receiver<()>,
        tx: broadcast::Sender<Arc<Vec<Envelope<TypedMessage>>>>,
    ) -> Result<()> {
        let channel_id = channel_id.unwrap_or(DEFAULT_CHANNEL_ID);
        let conn = Arc::new(
            pack_channel::client::Connection::connect(subscriber, channel_path.clone())
                .await?,
        );
        debug!("beginning channel handshake, channel_id = {}", channel_id);
        conn.send_one(&channel_id)?;
        let src: Address = conn.recv_one().await?;
        {
            if let Ok(mut channel) = channel.write() {
                *channel = Some(Channel { channel: conn.clone(), src: src.clone() });
            } else {
                bail!("BUG: channel ready lock poisoned");
            }
        }
        channel_ready_tx.send_replace(true);
        debug!("channel handshake complete, channel = {}", src);
        let mut messages: Vec<Envelope<TypedMessage>> = vec![];
        let mut close_rx = close_rx.fuse();
        loop {
            let mut closed = false;
            select_biased! {
                _ = &mut close_rx => { closed = true; },
                _ = conn.recv(|m| { messages.push(m); true }).fuse() => {}
            }
            let buf = std::mem::replace(&mut messages, Vec::new());
            if !buf.is_empty() {
                if let Err(e) = tx.send(Arc::new(buf)) {
                    error!("channel driver send error, dropping: {}", e);
                }
            }
            if closed {
                break Ok(());
            }
        }
    }

    pub async fn wait_connected(&self) -> Result<()> {
        let mut channel_ready = self.channel_ready.clone();
        let _ = channel_ready.wait_for(|ready| *ready).await?;
        Ok(())
    }

    /// Close the channel, waiting for all queued messages to send
    pub async fn close(&mut self) -> Result<()> {
        if let Some((close_tx, join)) = self.close.take() {
            close_tx.send(()).map_err(|_| anyhow!("channel already closed"))?;
            join.await?;
            Ok(())
        } else {
            bail!("channel already closed")
        }
    }

    pub fn path(&self) -> &Path {
        &self.channel_path
    }

    pub fn user_id(&self) -> Result<UserId> {
        let channel =
            self.channel.read().map_err(|_| anyhow!("channel ready lock poisoned"))?;
        if let Some(channel) = &*channel {
            match channel.src {
                Address::Channel(user_id, _) => Ok(user_id),
                _ => bail!("channel not a user channel"),
            }
        } else {
            bail!("channel not ready")
        }
    }

    /// Access to the underlying netidx pack_channel connection
    pub fn with_channel<R>(
        &self,
        f: impl FnOnce(&pack_channel::client::Connection, Address) -> R,
    ) -> Result<R> {
        if let Ok(cr) = self.channel.read() {
            if let Some(cr) = &*cr {
                Ok(f(&cr.channel, cr.src.clone()))
            } else {
                bail!("channel not ready")
            }
        } else {
            bail!("channel ready lock poisoned")
        }
    }

    // CR alee: probably want to give these type signatures some more thought;
    // one disadvantage to using Into<TypedMessage> as a bound is how to support
    // custom builds without having to make a new api/sdk;
    pub fn send_to<M>(&self, dst: ComponentId, msg: M) -> Result<()>
    where
        M: Into<TypedMessage>,
    {
        self.with_channel(|conn, src| {
            conn.send_one(&Envelope {
                src,
                dst: Address::Component(dst),
                stamp: Stamp::default(),
                msg: msg.into(),
            })
        })?
    }

    pub fn subscribe(&self) -> broadcast::Receiver<Arc<Vec<Envelope<TypedMessage>>>> {
        self.tx.subscribe()
    }

    /// Wait for a message that satisfies predicate `f`.
    /// The dumber version of [wait_for].
    pub async fn wait_until<R>(&self, mut f: impl FnMut(R) -> bool) -> Result<()>
    where
        TypedMessage: TryInto<MaybeSplit<TypedMessage, R>>,
    {
        self.wait_for(|msg| if f(msg) { Some(()) } else { None }).await
    }

    /// Wait for a response that satisfies `f`.
    /// Ignores and discards any intervening messages.
    pub async fn wait_for<R, T>(&self, mut f: impl FnMut(R) -> Option<T>) -> Result<T>
    where
        TypedMessage: TryInto<MaybeSplit<TypedMessage, R>>,
    {
        let mut rx = self.tx.subscribe();
        while let Ok(envs) = rx.recv().await {
            for env in envs.iter() {
                if let Ok((_orig, msg)) =
                    env.msg.clone().try_into().map(MaybeSplit::parts)
                {
                    if let Some(t) = f(msg) {
                        return Ok(t);
                    }
                }
            }
        }
        Err(anyhow!("lost connection to component channel"))
    }

    /// Send message to a component and wait for a response that satisfies `f`.
    /// Ignores and discards any intervening messages.
    pub async fn send_to_and_wait_for<M, R, T>(
        &self,
        dst: ComponentId,
        msg: M,
        f: impl FnMut(R) -> Option<T>,
    ) -> Result<T>
    where
        M: Into<TypedMessage>,
        TypedMessage: TryInto<MaybeSplit<TypedMessage, R>>,
    {
        let waiter = self.wait_for(f);
        self.send_to(dst, msg)?;
        waiter.await
    }

    /// Send a request to a component and wait for the corresponding response.  Calls the
    /// provided `unwrap` function on the response and returns the result.
    ///
    /// Ignores and discards any intervening messages.
    pub async fn request_and_wait_for<M, R, T>(
        &self,
        dst: ComponentId,
        msg: M,
        unwrap: impl Fn(R) -> Result<T>,
    ) -> Result<T>
    where
        M: MaybeRequest + Into<TypedMessage>,
        R: MaybeRequest,
        TypedMessage: TryInto<MaybeSplit<TypedMessage, R>>,
    {
        let req_id = msg.request_id();
        self.send_to_and_wait_for(dst, msg, |res| {
            if res.response_id() == req_id {
                Some(unwrap(res))
            } else {
                None
            }
        })
        .await?
    }
}