sbd_server/
cslot.rs

1//! Attempt to pre-allocate as much as possible, including our tokio tasks.
2//! Ideally this would include a frame buffer that we could fill on ws
3//! recv and use ase a reference for ws send, but alas, fastwebsockets
4//! doesn't seem up to the task. tungstenite will willy-nilly allocate
5//! buffers for us, but at least we should only be dealing with one at a
6//! time per connection.
7
8use super::*;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, Weak};
11
12static U: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
13
14enum TaskMsg {
15    NewWs {
16        uniq: u64,
17        index: usize,
18        ws: Arc<dyn SbdWebsocket>,
19        ip: Arc<Ipv6Addr>,
20        pk: PubKey,
21        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
22    },
23    Close,
24}
25
26struct SlotEntry {
27    send: tokio::sync::mpsc::UnboundedSender<TaskMsg>,
28}
29
30struct SlabEntry {
31    uniq: u64,
32    handshake_complete: bool,
33    weak_ws: Weak<dyn SbdWebsocket>,
34}
35
36struct CSlotInner {
37    max_count: usize,
38    slots: Vec<SlotEntry>,
39    slab: slab::Slab<SlabEntry>,
40    pk_to_index: HashMap<PubKey, usize>,
41    ip_to_index: HashMap<Arc<Ipv6Addr>, Vec<usize>>,
42    task_list: Vec<tokio::task::JoinHandle<()>>,
43    open_connections: opentelemetry::metrics::UpDownCounter<i64>,
44}
45
46impl Drop for CSlotInner {
47    fn drop(&mut self) {
48        for task in self.task_list.iter() {
49            task.abort();
50        }
51    }
52}
53
54/// A weak reference to a connection slot container.
55#[derive(Clone)]
56pub struct WeakCSlot(Weak<Mutex<CSlotInner>>);
57
58impl WeakCSlot {
59    /// Upgrade this weak reference to a strong reference.
60    pub fn upgrade(&self) -> Option<CSlot> {
61        self.0.upgrade().map(CSlot)
62    }
63}
64
65/// A connection slot container.
66///
67/// Note this is not clone to ensure that when the single top-level handle
68/// is dropped, that everything is shutdown properly.
69pub struct CSlot(Arc<Mutex<CSlotInner>>);
70
71impl CSlot {
72    /// Create a new connection slot container.
73    pub fn new(
74        config: Arc<Config>,
75        ip_rate: Arc<IpRate>,
76        meter: opentelemetry::metrics::Meter,
77    ) -> Self {
78        let count = config.limit_clients as usize;
79
80        let ip_rate_counter = meter
81            .u64_counter("sbd.server.ip_rate_limited")
82            .with_description("Total number of IP rate limited events")
83            .with_unit("count")
84            .build();
85
86        Self(Arc::new_cyclic(|this| {
87            let mut slots = Vec::with_capacity(count);
88            let mut task_list = Vec::with_capacity(count);
89            for _ in 0..count {
90                let (send, recv) = tokio::sync::mpsc::unbounded_channel();
91                slots.push(SlotEntry { send });
92                task_list.push(tokio::task::spawn(top_task(
93                    config.clone(),
94                    ip_rate.clone(),
95                    WeakCSlot(this.clone()),
96                    recv,
97                    ip_rate_counter.clone(),
98                )));
99            }
100
101            let open_connections = meter
102                .i64_up_down_counter("sbd.server.open_connections")
103                .with_description("Number of open client connections")
104                .build();
105
106            Mutex::new(CSlotInner {
107                max_count: count,
108                slots,
109                slab: slab::Slab::with_capacity(count),
110                pk_to_index: HashMap::with_capacity(count),
111                ip_to_index: HashMap::with_capacity(count),
112                task_list,
113                open_connections,
114            })
115        }))
116    }
117
118    /// Get a weak reference to this connection slot container.
119    pub fn weak(&self) -> WeakCSlot {
120        WeakCSlot(Arc::downgrade(&self.0))
121    }
122
123    /// Remove a websocket from its slot.
124    fn remove(&self, uniq: u64, index: usize) {
125        let mut lock = self.0.lock().unwrap();
126
127        match lock.slab.get(index) {
128            None => return,
129            Some(s) => {
130                if s.uniq != uniq {
131                    return;
132                }
133            }
134        }
135
136        let _ = lock.slots.get(index).unwrap().send.send(TaskMsg::Close);
137        lock.slab.remove(index);
138        lock.pk_to_index.retain(|_, i| *i != index);
139        lock.ip_to_index.retain(|_, v| {
140            v.retain(|i| *i != index);
141            !v.is_empty()
142        });
143
144        // Decrement the open connections metric
145        lock.open_connections.add(-1, &[])
146    }
147
148    /// Inner helper for inserting a websocket into an available slot.
149    // oi clippy, this is super straight forward...
150    #[allow(clippy::type_complexity)]
151    fn insert_and_get_rate_send_list(
152        &self,
153        ip: Arc<Ipv6Addr>,
154        pk: PubKey,
155        ws: Arc<dyn SbdWebsocket>,
156        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
157    ) -> std::result::Result<
158        Vec<(u64, usize, Weak<dyn SbdWebsocket>)>,
159        Arc<dyn SbdWebsocket>,
160    > {
161        let mut lock = self.0.lock().unwrap();
162
163        if lock.slab.len() >= lock.max_count {
164            return Err(ws);
165        }
166
167        let weak_ws = Arc::downgrade(&ws);
168
169        let uniq = U.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
170
171        let index = lock.slab.insert(SlabEntry {
172            uniq,
173            weak_ws,
174            handshake_complete: false,
175        });
176
177        lock.pk_to_index.insert(pk.clone(), index);
178
179        let rate_send_list = {
180            let list = {
181                // WARN - allocation here!
182                // Also, do we want to limit the max connections from same ip?
183
184                let e = lock
185                    .ip_to_index
186                    .entry(ip.clone())
187                    .or_insert_with(|| Vec::with_capacity(1024));
188
189                e.push(index);
190
191                e.clone()
192            };
193
194            let mut rate_send_list = Vec::with_capacity(list.len());
195
196            for index in list.iter() {
197                if let Some(slab) = lock.slab.get(*index) {
198                    rate_send_list.push((
199                        slab.uniq,
200                        *index,
201                        slab.weak_ws.clone(),
202                    ));
203                }
204            }
205
206            rate_send_list
207        };
208
209        let send = lock.slots.get(index).unwrap().send.clone();
210        let _ = send.send(TaskMsg::NewWs {
211            uniq,
212            index,
213            ws,
214            ip,
215            pk,
216            maybe_auth,
217        });
218
219        // Increment the open connections metric
220        lock.open_connections.add(1, &[]);
221
222        Ok(rate_send_list)
223    }
224
225    /// Insert a connection to be managed by this container.
226    pub async fn insert(
227        &self,
228        config: &Config,
229        ip: Arc<Ipv6Addr>,
230        pk: PubKey,
231        ws: Arc<impl SbdWebsocket>,
232        maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
233    ) {
234        let rate_send_list =
235            self.insert_and_get_rate_send_list(ip, pk, ws, maybe_auth);
236
237        match rate_send_list {
238            Ok(rate_send_list) => {
239                let rate = if config.disable_rate_limiting {
240                    1
241                } else {
242                    let mut rate = config.limit_ip_byte_nanos() as u64
243                        * rate_send_list.len() as u64;
244                    if rate > i32::MAX as u64 {
245                        rate = i32::MAX as u64;
246                    }
247                    rate as i32
248                };
249
250                for (uniq, index, weak_ws) in rate_send_list {
251                    if let Some(ws) = weak_ws.upgrade() {
252                        if ws
253                            .send(cmd::SbdCmd::limit_byte_nanos(rate))
254                            .await
255                            .is_err()
256                        {
257                            self.remove(uniq, index);
258                        }
259                    }
260                }
261            }
262            Err(ws) => {
263                ws.close().await;
264                drop(ws);
265            }
266        }
267    }
268
269    /// Mark a slotted websocket as ready.
270    fn mark_ready(&self, uniq: u64, index: usize) {
271        let mut lock = self.0.lock().unwrap();
272        if let Some(slab) = lock.slab.get_mut(index) {
273            if slab.uniq == uniq {
274                slab.handshake_complete = true;
275            }
276        }
277    }
278
279    /// Get a websocket from its slot.
280    fn get_sender(
281        &self,
282        pk: &PubKey,
283    ) -> Result<(u64, usize, Arc<dyn SbdWebsocket>)> {
284        let lock = self.0.lock().unwrap();
285
286        let index = match lock.pk_to_index.get(pk) {
287            None => return Err(Error::other("no such peer")),
288            Some(index) => *index,
289        };
290
291        let slab = lock.slab.get(index).unwrap();
292
293        if !slab.handshake_complete {
294            return Err(Error::other("no such peer"));
295        }
296
297        let uniq = slab.uniq;
298        let ws = match slab.weak_ws.upgrade() {
299            None => return Err(Error::other("no such peer")),
300            Some(ws) => ws,
301        };
302
303        Ok((uniq, index, ws))
304    }
305
306    /// Send via a slotted websocket.
307    async fn send(&self, pk: &PubKey, payload: Payload) -> Result<()> {
308        let (uniq, index, ws) = self.get_sender(pk)?;
309
310        match ws.send(payload).await {
311            Err(err) => {
312                self.remove(uniq, index);
313                Err(err)
314            }
315            Ok(_) => Ok(()),
316        }
317    }
318}
319
320/// This top-task waits for incoming websockets, processes them until
321/// completion, and then waits for a new incoming websocket.
322async fn top_task(
323    config: Arc<Config>,
324    ip_rate: Arc<IpRate>,
325    weak: WeakCSlot,
326    mut recv: tokio::sync::mpsc::UnboundedReceiver<TaskMsg>,
327    ip_rate_counter: opentelemetry::metrics::Counter<u64>,
328) {
329    let mut item = recv.recv().await;
330    loop {
331        let uitem = match item {
332            None => break,
333            Some(uitem) => uitem,
334        };
335
336        item = if let TaskMsg::NewWs {
337            uniq,
338            index,
339            ws,
340            ip,
341            pk,
342            maybe_auth,
343        } = uitem
344        {
345            // we have a websocket! process to completion
346            let next_i = tokio::select! {
347                i = recv.recv() => Some(i),
348                _ = ws_task(
349                    &config,
350                    &ip_rate,
351                    &weak,
352                    &ws,
353                    ip,
354                    pk,
355                    uniq,
356                    index,
357                    maybe_auth,
358                    &ip_rate_counter,
359                ) => None,
360            };
361
362            // our websocket task ended, clean up
363            ws.close().await;
364            drop(ws);
365            if let Some(cslot) = weak.upgrade() {
366                cslot.remove(uniq, index);
367            }
368
369            match next_i {
370                Some(i) => i,
371                None => recv.recv().await,
372            }
373        } else {
374            recv.recv().await
375        };
376    }
377}
378
379/// Process a single websocket until completion.
380#[allow(clippy::too_many_arguments)]
381async fn ws_task(
382    config: &Arc<Config>,
383    ip_rate: &IpRate,
384    weak_cslot: &WeakCSlot,
385    ws: &Arc<dyn SbdWebsocket>,
386    ip: Arc<Ipv6Addr>,
387    pk: PubKey,
388    uniq: u64,
389    index: usize,
390    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
391    ip_rate_counter: &opentelemetry::metrics::Counter<u64>,
392) {
393    let pub_key =
394        base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(*pk.0);
395    let auth_res = tokio::time::timeout(config.idle_dur(), async {
396        use rand::Rng;
397        let mut nonce = [0xdb; 32];
398        rand::thread_rng().fill(&mut nonce[..]);
399
400        // send them a nonce to prove they can sign with private key
401        ws.send(cmd::SbdCmd::auth_req(&nonce)).await?;
402
403        loop {
404            let auth_res = ws.recv().await?;
405
406            if !ip_rate.is_ok(&ip, auth_res.as_ref().len()).await {
407                ip_rate_counter.add(
408                    1,
409                    &[
410                        opentelemetry::KeyValue::new(
411                            "pub_key",
412                            pub_key.clone(),
413                        ),
414                        opentelemetry::KeyValue::new("kind", "auth"),
415                    ],
416                );
417
418                return Err(Error::other("ip rate limited"));
419            }
420
421            if let Some((token, token_tracker)) = &maybe_auth {
422                // we already know they had a valid token
423                // when they opened this connection.
424                // just using this for side-effect marking token use time
425                let _ =
426                    token_tracker.check_is_token_valid(config, token.clone());
427            }
428
429            match cmd::SbdCmd::parse(auth_res)? {
430                cmd::SbdCmd::AuthRes(sig) => {
431                    if !pk.verify(&sig, &nonce) {
432                        return Err(Error::other("invalid sig"));
433                    }
434                    break;
435                }
436                cmd::SbdCmd::Message(_) => {
437                    return Err(Error::other(
438                        "invalid forward before handshake",
439                    ));
440                }
441                _ => continue,
442            }
443        }
444
445        // NOTE: the byte_nanos limit is sent during the cslot insert
446
447        ws.send(cmd::SbdCmd::limit_idle_millis(config.limit_idle_millis))
448            .await?;
449
450        if let Some(cslot) = weak_cslot.upgrade() {
451            cslot.mark_ready(uniq, index);
452        } else {
453            return Err(Error::other("closed"));
454        }
455
456        ws.send(cmd::SbdCmd::ready()).await?;
457
458        Ok(())
459    })
460    .await;
461
462    if auth_res.is_err() {
463        return;
464    }
465
466    // auth/init complete, now loop over incoming data
467
468    while let Ok(Ok(payload)) =
469        tokio::time::timeout(config.idle_dur(), ws.recv()).await
470    {
471        if !ip_rate.is_ok(&ip, payload.len()).await {
472            ip_rate_counter.add(
473                1,
474                &[
475                    opentelemetry::KeyValue::new("pub_key", pub_key),
476                    opentelemetry::KeyValue::new("kind", "msg"),
477                ],
478            );
479
480            break;
481        }
482
483        if let Some((token, token_tracker)) = &maybe_auth {
484            // we already know they had a valid token
485            // when they opened this connection.
486            // just using this for side-effect marking token use time
487            let _ = token_tracker.check_is_token_valid(config, token.clone());
488        }
489
490        let cmd = match cmd::SbdCmd::parse(payload) {
491            Err(_) => break,
492            Ok(cmd) => cmd,
493        };
494
495        match cmd {
496            // don't need to do anything... we just get a new timeout above
497            cmd::SbdCmd::Keepalive => (),
498            // auth responses are invalid at this stage
499            cmd::SbdCmd::AuthRes(_) => break,
500            // ignore unknown messages
501            cmd::SbdCmd::Unknown => (),
502            // forward an actual message to a peer
503            cmd::SbdCmd::Message(mut payload) => {
504                let dest = {
505                    let payload = payload.to_mut();
506
507                    let mut dest = [0; 32];
508                    dest.copy_from_slice(&payload[..32]);
509                    let dest = PubKey(Arc::new(dest));
510
511                    payload[..32].copy_from_slice(&pk.0[..]);
512
513                    dest
514                };
515
516                if let Some(cslot) = weak_cslot.upgrade() {
517                    let _ = cslot.send(&dest, payload).await;
518                } else {
519                    break;
520                }
521            }
522        }
523    }
524
525    tracing::debug!("Closed connection for {ip}");
526}