Skip to main content

netidx_core/
utils.rs

1//! Utility types and functions.
2use crate::pack::{Pack, PackError};
3use anyhow::{self, Result};
4use bytes::{Bytes, BytesMut};
5use futures::{
6    channel::mpsc,
7    prelude::*,
8    sink::Sink,
9    stream::FusedStream,
10    task::{Context, Poll},
11};
12use sha3::{Digest, Sha3_512};
13use std::{
14    borrow::Borrow,
15    cell::RefCell,
16    cmp::{Ord, Ordering, PartialOrd},
17    hash::Hash,
18    iter::{IntoIterator, Iterator},
19    net::{IpAddr, SocketAddr},
20    pin::Pin,
21};
22
23#[macro_export]
24macro_rules! try_cf {
25    ($msg:expr, continue, $lbl:tt, $e:expr) => {
26        match $e {
27            Ok(x) => x,
28            Err(e) => {
29                log::info!($msg, e);
30                continue $lbl;
31            }
32        }
33    };
34    ($msg:expr, break, $lbl:tt, $e:expr) => {
35        match $e {
36            Ok(x) => x,
37            Err(e) => {
38                log::info!($msg, e);
39                break $lbl Err(Error::from(e));
40            }
41        }
42    };
43    ($msg:expr, continue, $e:expr) => {
44        match $e {
45            Ok(x) => x,
46            Err(e) => {
47                log::info!("{}: {}", $msg, e);
48                continue;
49            }
50        }
51    };
52    ($msg:expr, break, $e:expr) => {
53        match $e {
54            Ok(x) => x,
55            Err(e) => {
56                log::info!("{}: {}", $msg, e);
57                break Err(Error::from(e));
58            }
59        }
60    };
61    (continue, $lbl:tt, $e:expr) => {
62        match $e {
63            Ok(x) => x,
64            Err(e) => {
65                continue $lbl;
66            }
67        }
68    };
69    (break, $lbl:tt, $e:expr) => {
70        match $e {
71            Ok(x) => x,
72            Err(e) => {
73                break $lbl Err(Error::from(e));
74            }
75        }
76    };
77    ($msg:expr, $e:expr) => {
78        match $e {
79            Ok(x) => x,
80            Err(e) => {
81                log::info!("{}: {}", $msg, e);
82                break Err(Error::from(e));
83            }
84        }
85    };
86    (continue, $e:expr) => {
87        match $e {
88            Ok(x) => x,
89            Err(e) => {
90                continue;
91            }
92        }
93    };
94    (break, $e:expr) => {
95        match $e {
96            Ok(x) => x,
97            Err(e) => {
98                break Err(Error::from(e));
99            }
100        }
101    };
102    ($e:expr) => {
103        match $e {
104            Ok(x) => x,
105            Err(e) => {
106                break Err(Error::from(e));
107            }
108        }
109    };
110}
111
112#[macro_export]
113macro_rules! atomic_id {
114    ($name:ident) => {
115        #[derive(
116            Debug,
117            Clone,
118            Copy,
119            PartialEq,
120            Eq,
121            PartialOrd,
122            Ord,
123            Hash,
124            Serialize,
125            Deserialize,
126        )]
127        pub struct $name(u64);
128
129        impl nohash::IsEnabled for $name {}
130
131        impl $name {
132            pub fn new() -> Self {
133                use std::sync::atomic::{AtomicU64, Ordering};
134                static NEXT: AtomicU64 = AtomicU64::new(0);
135                $name(NEXT.fetch_add(1, Ordering::Relaxed))
136            }
137
138            pub fn inner(&self) -> u64 {
139                self.0
140            }
141
142            #[cfg(test)]
143            #[allow(dead_code)]
144            pub fn mk(i: u64) -> Self {
145                $name(i)
146            }
147        }
148
149        impl netidx_core::pack::Pack for $name {
150            fn encoded_len(&self) -> usize {
151                netidx_core::pack::varint_len(self.0)
152            }
153
154            fn encode(
155                &self,
156                buf: &mut impl bytes::BufMut,
157            ) -> std::result::Result<(), netidx_core::pack::PackError> {
158                Ok(netidx_core::pack::encode_varint(self.0, buf))
159            }
160
161            fn decode(
162                buf: &mut impl bytes::Buf,
163            ) -> std::result::Result<Self, netidx_core::pack::PackError> {
164                Ok(Self(netidx_core::pack::decode_varint(buf)?))
165            }
166        }
167    };
168}
169
170pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
171    match ip {
172        IpAddr::V4(ip) if ip.is_link_local() => {
173            bail!("addr is a link local address");
174        }
175        IpAddr::V4(ip) if ip.is_broadcast() => {
176            bail!("addr is a broadcast address");
177        }
178        IpAddr::V4(ip) if ip.is_private() => {
179            let ok = resolvers.iter().all(|(a, _)| match a.ip() {
180                IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
181                IpAddr::V6(_) => true,
182                _ => false,
183            });
184            if !ok {
185                bail!("addr is a private address, and the resolver is not")
186            }
187        }
188        _ => (),
189    }
190    if ip.is_unspecified() {
191        bail!("addr is an unspecified address");
192    }
193    if ip.is_multicast() {
194        bail!("addr is a multicast address");
195    }
196    if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
197        bail!("addr is a loopback address and the resolver is not");
198    }
199    Ok(())
200}
201
202thread_local! {
203    static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
204}
205
206pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
207    let mut hash = Sha3_512::new();
208    for v in data.into_iter() {
209        hash.update(v);
210    }
211    BUF.with(|buf| {
212        let mut b = buf.borrow_mut();
213        b.extend(hash.finalize().into_iter());
214        b.split().freeze()
215    })
216}
217
218/// pack T and return a bytesmut from the global thread local buffer
219pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
220    BUF.with(|buf| {
221        let mut b = buf.borrow_mut();
222        t.encode(&mut *b)?;
223        Ok(b.split())
224    })
225}
226
227pub fn bytesmut(t: &[u8]) -> BytesMut {
228    BUF.with(|buf| {
229        let mut b = buf.borrow_mut();
230        b.extend_from_slice(t);
231        b.split()
232    })
233}
234
235pub fn bytes(t: &[u8]) -> Bytes {
236    bytesmut(t).freeze()
237}
238
239#[derive(Clone, Debug)]
240pub struct ChanWrap<T>(pub mpsc::Sender<T>);
241
242impl<T> PartialEq for ChanWrap<T> {
243    fn eq(&self, other: &ChanWrap<T>) -> bool {
244        self.0.same_receiver(&other.0)
245    }
246}
247
248impl<T> Eq for ChanWrap<T> {}
249
250impl<T> Hash for ChanWrap<T> {
251    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
252        self.0.hash_receiver(state)
253    }
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
257pub struct ChanId(u64);
258
259impl nohash::IsEnabled for ChanId {}
260
261impl ChanId {
262    pub fn new() -> Self {
263        use std::sync::atomic::{AtomicU64, Ordering};
264        static NEXT: AtomicU64 = AtomicU64::new(0);
265        ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
266    }
267}
268
269#[derive(Debug, Clone)]
270pub enum BatchItem<T> {
271    InBatch(T),
272    EndBatch,
273}
274
275#[must_use = "streams do nothing unless polled"]
276pub struct Batched<S: Stream> {
277    stream: S,
278    ended: bool,
279    max: usize,
280    current: usize,
281}
282
283impl<S: Stream> Batched<S> {
284    // this is safe because,
285    // - Batched doesn't implement Drop
286    // - Batched doesn't implement Unpin
287    // - Batched isn't #[repr(packed)]
288    unsafe_pinned!(stream: S);
289
290    // these are safe because both types are copy
291    unsafe_unpinned!(ended: bool);
292    unsafe_unpinned!(current: usize);
293
294    pub fn new(stream: S, max: usize) -> Batched<S> {
295        Batched { stream, max, ended: false, current: 0 }
296    }
297
298    pub fn inner(&self) -> &S {
299        &self.stream
300    }
301
302    pub fn inner_mut(&mut self) -> &mut S {
303        &mut self.stream
304    }
305
306    pub fn into_inner(self) -> S {
307        self.stream
308    }
309}
310
311impl<S: Stream> Stream for Batched<S> {
312    type Item = BatchItem<<S as Stream>::Item>;
313
314    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
315        if self.ended {
316            Poll::Ready(None)
317        } else if self.current >= self.max {
318            *self.current() = 0;
319            Poll::Ready(Some(BatchItem::EndBatch))
320        } else {
321            match self.as_mut().stream().poll_next(cx) {
322                Poll::Ready(Some(v)) => {
323                    *self.as_mut().current() += 1;
324                    Poll::Ready(Some(BatchItem::InBatch(v)))
325                }
326                Poll::Ready(None) => {
327                    *self.as_mut().ended() = true;
328                    if self.current == 0 {
329                        Poll::Ready(None)
330                    } else {
331                        *self.current() = 0;
332                        Poll::Ready(Some(BatchItem::EndBatch))
333                    }
334                }
335                Poll::Pending => {
336                    if self.current == 0 {
337                        Poll::Pending
338                    } else {
339                        *self.current() = 0;
340                        Poll::Ready(Some(BatchItem::EndBatch))
341                    }
342                }
343            }
344        }
345    }
346
347    fn size_hint(&self) -> (usize, Option<usize>) {
348        self.stream.size_hint()
349    }
350}
351
352impl<S: Stream> FusedStream for Batched<S> {
353    fn is_terminated(&self) -> bool {
354        self.ended
355    }
356}
357
358impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
359    type Error = <S as Sink<Item>>::Error;
360
361    fn poll_ready(
362        self: Pin<&mut Self>,
363        cx: &mut Context,
364    ) -> Poll<Result<(), Self::Error>> {
365        self.stream().poll_ready(cx)
366    }
367
368    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
369        self.stream().start_send(item)
370    }
371
372    fn poll_flush(
373        self: Pin<&mut Self>,
374        cx: &mut Context,
375    ) -> Poll<Result<(), Self::Error>> {
376        self.stream().poll_flush(cx)
377    }
378
379    fn poll_close(
380        self: Pin<&mut Self>,
381        cx: &mut Context,
382    ) -> Poll<Result<(), Self::Error>> {
383        self.stream().poll_close(cx)
384    }
385}
386
387// a socketaddr wrapper that implements Ord so we can put clients in a
388// set.
389#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
390pub struct Addr(pub SocketAddr);
391
392impl From<SocketAddr> for Addr {
393    fn from(addr: SocketAddr) -> Self {
394        Addr(addr)
395    }
396}
397
398impl Borrow<SocketAddr> for Addr {
399    fn borrow(&self) -> &SocketAddr {
400        &self.0
401    }
402}
403
404impl PartialOrd for Addr {
405    fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
406        match (self.0, other.0) {
407            (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
408                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
409                    None => None,
410                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
411                    Some(o) => Some(o),
412                }
413            }
414            (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
415                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
416                    None => None,
417                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
418                    Some(o) => Some(o),
419                }
420            }
421            (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
422            (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
423        }
424    }
425}
426
427impl Ord for Addr {
428    fn cmp(&self, other: &Self) -> Ordering {
429        self.partial_cmp(other).unwrap()
430    }
431}
432
433#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
434pub enum Either<T, U> {
435    Left(T),
436    Right(U),
437}
438
439impl<T, U> Either<T, U> {
440    pub fn is_left(&self) -> bool {
441        match self {
442            Self::Left(_) => true,
443            Self::Right(_) => false,
444        }
445    }
446
447    pub fn is_right(&self) -> bool {
448        match self {
449            Self::Left(_) => false,
450            Self::Right(_) => true,
451        }
452    }
453
454    pub fn left(self) -> Option<T> {
455        match self {
456            Either::Left(t) => Some(t),
457            Either::Right(_) => None,
458        }
459    }
460
461    pub fn right(self) -> Option<U> {
462        match self {
463            Either::Right(t) => Some(t),
464            Either::Left(_) => None,
465        }
466    }
467}
468
469impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
470    type Item = I;
471    fn next(&mut self) -> Option<I> {
472        match self {
473            Either::Left(t) => t.next(),
474            Either::Right(t) => t.next(),
475        }
476    }
477}