1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
mod multiplex_state;
mod pipe_pool;
mod stream;
mod trace;
use std::{
    any::Any,
    sync::Arc,
    time::{Duration, Instant},
};

use concurrent_queue::ConcurrentQueue;

use futures_intrusive::sync::ManualResetEvent;
use parking_lot::Mutex;
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use smol::{
    channel::{Receiver, Sender},
    future::FutureExt,
};
use stdcode::StdcodeSerializeExt;
// pub use congestion::*;
use crate::Pipe;
pub use stream::MuxStream;

use self::{multiplex_state::MultiplexState, pipe_pool::PipePool};

/// A multiplex session over a sosistab session, implementing both reliable "streams" and unreliable messages.
pub struct Multiplex {
    pipe_pool: Arc<PipePool>,
    state: Arc<Mutex<MultiplexState>>,
    friends: ConcurrentQueue<Box<dyn Any + Send>>,
    recv_accepted: Receiver<MuxStream>,

    _task: smol::Task<()>,
}

fn to_ioerror<T: Into<Box<dyn std::error::Error + Send + Sync>>>(val: T) -> std::io::Error {
    std::io::Error::new(std::io::ErrorKind::ConnectionReset, val)
}

impl Multiplex {
    /// Creates a new multiplexed Pipe. If `their_long_pk` is given, verify that the other side has the given public key.
    pub fn new(local_sk: MuxSecret, preshared_peer_pk: Option<MuxPublic>) -> Self {
        let stream_update = Arc::new(ManualResetEvent::new(false));
        let state = Arc::new(Mutex::new(MultiplexState::new(
            stream_update.clone(),
            local_sk,
            preshared_peer_pk,
        )));
        let pipe_pool = Arc::new(PipePool::new(10, preshared_peer_pk.is_none()));
        let (send_accepted, recv_accepted) = smol::channel::unbounded();
        let _task = smolscale::spawn(multiplex_loop(
            state.clone(),
            stream_update,
            pipe_pool.clone(),
            send_accepted,
        ));
        Self {
            pipe_pool,
            state,
            friends: ConcurrentQueue::unbounded(),
            recv_accepted,
            _task,
        }
    }

    /// Returns this side's public key.
    pub fn local_pk(&self) -> MuxPublic {
        self.state.lock().local_lsk.to_public()
    }

    /// Returns the other side's public key. This is useful for "binding"-type authentication on the application layer, where the other end of the Multiplex does not have a preshared public key, but a public key that can be verified by e.g. a signature. Returns `None` if it's not yet known.
    pub fn peer_pk(&self) -> Option<MuxPublic> {
        self.state.lock().peer_lpk
    }

    /// Adds an arbitrary "friend" that will be dropped together with the multiplex. This is useful for managing RAII resources like tasks, tables etc that should live exactly as long as a particular multiplex.
    pub fn add_drop_friend(&self, friend: impl Any + Send) {
        self.friends.push(Box::new(friend)).unwrap()
    }

    /// Adds a Pipe to the Multiplex
    pub fn add_pipe(&self, pipe: impl Pipe) {
        self.pipe_pool.add_pipe(pipe)
    }

    /// Obtains the pipe last used by this multiplex for sending.
    pub fn last_send_pipe(&self) -> Option<impl Pipe> {
        self.pipe_pool.last_send_pipe()
    }

    /// Obtains the pipe last used by this multiplex for receiving.
    pub fn last_recv_pipe(&self) -> Option<impl Pipe> {
        self.pipe_pool.last_recv_pipe()
    }

    /// Iterates through *all* the underlying pipes.
    pub fn iter_pipes(&self) -> impl Iterator<Item = impl Pipe> + '_ {
        self.pipe_pool.all_pipes().into_iter()
    }

    /// Retain only the pipes that fit a certain criterion.
    pub fn retain(&self, f: impl FnMut(&dyn Pipe) -> bool) {
        self.pipe_pool.retain(f)
    }

    /// Open a reliable conn to the other end.
    pub async fn open_conn(&self, additional: &str) -> std::io::Result<MuxStream> {
        // create a pre-open stream, then wait until the ticking makes it open
        let stream = self
            .state
            .lock()
            .start_open_stream(additional)
            .map_err(to_ioerror)?;
        stream.wait_connected().await?;
        Ok(stream)
    }

    /// Accept a reliable conn from the other end.
    pub async fn accept_conn(&self) -> std::io::Result<MuxStream> {
        self.recv_accepted.recv().await.map_err(to_ioerror)
    }
}

/// The master loop that starts the other loops
async fn multiplex_loop(
    state: Arc<Mutex<MultiplexState>>,
    stream_update: Arc<ManualResetEvent>,
    pipe_pool: Arc<PipePool>,
    send_accepted: Sender<MuxStream>,
) {
    // we don't spawn more things to avoid unnecessary contention over mutexes etc
    let ticker = tick_loop(state.clone(), stream_update, pipe_pool.clone());
    let incomer = incoming_loop(state, pipe_pool, send_accepted);
    if let Err(err) = ticker.race(incomer).await {
        log::error!("BUG: ticker or incomer died: {:?}", err)
    }
}

/// Handle incoming messages
async fn incoming_loop(
    state: Arc<Mutex<MultiplexState>>,
    pipe_pool: Arc<PipePool>,
    send_accepted: Sender<MuxStream>,
) -> anyhow::Result<()> {
    let mut send_queue = vec![];
    loop {
        let incoming = pipe_pool.recv().await?;
        log::trace!("incoming {} bytes", incoming.len());
        if let Ok(incoming) = stdcode::deserialize(&incoming) {
            // have the state process the message
            state
                .lock()
                .recv_msg(
                    incoming,
                    |msg| send_queue.push(msg),
                    |stream| {
                        let _ = send_accepted.try_send(stream);
                    },
                )
                .unwrap_or_else(|e| {
                    log::trace!("could not process message: {:?}", e);
                });

            // send all possible replies
            for msg in send_queue.drain(..) {
                pipe_pool.send(msg.stdcode().into()).await;
            }
        }
    }
}

/// Handle "ticking" the streams
async fn tick_loop(
    state: Arc<Mutex<MultiplexState>>,
    stream_update: Arc<ManualResetEvent>,
    pipe_pool: Arc<PipePool>,
) -> anyhow::Result<()> {
    let mut timer = smol::Timer::after(Duration::from_secs(0));
    let mut next_tick;
    let mut send_queue = vec![];
    loop {
        let start = Instant::now();
        next_tick = state.lock().tick(|msg| send_queue.push(msg));
        log::trace!("tick took {:?}", start.elapsed());
        // transmit all the queue
        for msg in send_queue.drain(..) {
            pipe_pool.send(msg.stdcode().into()).await;
        }
        // sleep first to prevent too aggressively looping around
        // this is also the basis for the brand of delayed-ack handling we do
        timer.set_at(Instant::now() + Duration::from_millis(1));
        (&mut timer).await;
        timer.set_at(next_tick);
        // horrifying hax
        async {
            stream_update.wait().await;
            stream_update.reset();
            log::trace!("update woken");
        }
        .or(async {
            (&mut timer).await;
            log::trace!("timer woken");
        })
        .await;
    }
}

/// A server public key for the end-to-end multiplex.
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(transparent)]
pub struct MuxPublic(pub(crate) x25519_dalek::PublicKey);

impl MuxPublic {
    /// Returns the bytes representation.
    pub fn as_bytes(&self) -> &[u8; 32] {
        self.0.as_bytes()
    }

    /// Convert from bytes.
    pub fn from_bytes(b: [u8; 32]) -> Self {
        Self(x25519_dalek::PublicKey::from(b))
    }
}

/// A server secret key for the end-to-end multiplex.
#[derive(Clone, Serialize, Deserialize)]
pub struct MuxSecret(pub(crate) x25519_dalek::StaticSecret);

impl MuxSecret {
    /// Returns the bytes representation.
    pub fn to_bytes(&self) -> [u8; 32] {
        self.0.to_bytes()
    }

    /// Convert from bytes.
    pub fn from_bytes(b: [u8; 32]) -> Self {
        Self(x25519_dalek::StaticSecret::from(b))
    }

    /// Generate.
    pub fn generate() -> Self {
        Self(x25519_dalek::StaticSecret::new(OsRng {}))
    }

    /// Convert to a public key.
    pub fn to_public(&self) -> MuxPublic {
        MuxPublic((&self.0).into())
    }
}