netidx_core/
utils.rs

1use crate::pack::{Pack, PackError};
2use anyhow::{self, Result};
3use bytes::{Bytes, BytesMut};
4use digest::Digest;
5use futures::{
6    channel::mpsc,
7    prelude::*,
8    sink::Sink,
9    stream::FusedStream,
10    task::{Context, Poll},
11};
12use sha3::Sha3_512;
13use std::{
14    borrow::Borrow,
15    cell::RefCell,
16    cmp::{Ord, Ordering, PartialOrd},
17    hash::Hash,
18    iter::{IntoIterator, Iterator},
19    net::{IpAddr, SocketAddr},
20    pin::Pin,
21};
22
23#[macro_export]
24macro_rules! try_cf {
25    ($msg:expr, continue, $lbl:tt, $e:expr) => {
26        match $e {
27            Ok(x) => x,
28            Err(e) => {
29                log::info!($msg, e);
30                continue $lbl;
31            }
32        }
33    };
34    ($msg:expr, break, $lbl:tt, $e:expr) => {
35        match $e {
36            Ok(x) => x,
37            Err(e) => {
38                log::info!($msg, e);
39                break $lbl Err(Error::from(e));
40            }
41        }
42    };
43    ($msg:expr, continue, $e:expr) => {
44        match $e {
45            Ok(x) => x,
46            Err(e) => {
47                log::info!("{}: {}", $msg, e);
48                continue;
49            }
50        }
51    };
52    ($msg:expr, break, $e:expr) => {
53        match $e {
54            Ok(x) => x,
55            Err(e) => {
56                log::info!("{}: {}", $msg, e);
57                break Err(Error::from(e));
58            }
59        }
60    };
61    (continue, $lbl:tt, $e:expr) => {
62        match $e {
63            Ok(x) => x,
64            Err(e) => {
65                continue $lbl;
66            }
67        }
68    };
69    (break, $lbl:tt, $e:expr) => {
70        match $e {
71            Ok(x) => x,
72            Err(e) => {
73                break $lbl Err(Error::from(e));
74            }
75        }
76    };
77    ($msg:expr, $e:expr) => {
78        match $e {
79            Ok(x) => x,
80            Err(e) => {
81                log::info!("{}: {}", $msg, e);
82                break Err(Error::from(e));
83            }
84        }
85    };
86    (continue, $e:expr) => {
87        match $e {
88            Ok(x) => x,
89            Err(e) => {
90                continue;
91            }
92        }
93    };
94    (break, $e:expr) => {
95        match $e {
96            Ok(x) => x,
97            Err(e) => {
98                break Err(Error::from(e));
99            }
100        }
101    };
102    ($e:expr) => {
103        match $e {
104            Ok(x) => x,
105            Err(e) => {
106                break Err(Error::from(e));
107            }
108        }
109    };
110}
111
112#[macro_export]
113macro_rules! atomic_id {
114    ($name:ident) => {
115        #[derive(
116            Debug,
117            Clone,
118            Copy,
119            PartialEq,
120            Eq,
121            PartialOrd,
122            Ord,
123            Hash,
124            Serialize,
125            Deserialize,
126        )]
127        pub struct $name(u64);
128
129        impl $name {
130            pub fn new() -> Self {
131                use std::sync::atomic::{AtomicU64, Ordering};
132                static NEXT: AtomicU64 = AtomicU64::new(0);
133                $name(NEXT.fetch_add(1, Ordering::Relaxed))
134            }
135
136            pub fn inner(&self) -> u64 {
137                self.0
138            }
139
140            #[cfg(test)]
141            #[allow(dead_code)]
142            pub fn mk(i: u64) -> Self {
143                $name(i)
144            }
145        }
146
147        impl netidx_core::pack::Pack for $name {
148            fn encoded_len(&self) -> usize {
149                netidx_core::pack::varint_len(self.0)
150            }
151
152            fn encode(
153                &self,
154                buf: &mut impl bytes::BufMut,
155            ) -> std::result::Result<(), netidx_core::pack::PackError> {
156                Ok(netidx_core::pack::encode_varint(self.0, buf))
157            }
158
159            fn decode(
160                buf: &mut impl bytes::Buf,
161            ) -> std::result::Result<Self, netidx_core::pack::PackError> {
162                Ok(Self(netidx_core::pack::decode_varint(buf)?))
163            }
164        }
165    };
166}
167
168pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
169    match ip {
170        IpAddr::V4(ip) if ip.is_link_local() => {
171            bail!("addr is a link local address");
172        }
173        IpAddr::V4(ip) if ip.is_broadcast() => {
174            bail!("addr is a broadcast address");
175        }
176        IpAddr::V4(ip) if ip.is_private() => {
177            let ok = resolvers.iter().all(|(a, _)| match a.ip() {
178                IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
179                IpAddr::V6(_) => true,
180                _ => false,
181            });
182            if !ok {
183                bail!("addr is a private address, and the resolver is not")
184            }
185        }
186        _ => (),
187    }
188    if ip.is_unspecified() {
189        bail!("addr is an unspecified address");
190    }
191    if ip.is_multicast() {
192        bail!("addr is a multicast address");
193    }
194    if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
195        bail!("addr is a loopback address and the resolver is not");
196    }
197    Ok(())
198}
199
200thread_local! {
201    static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
202}
203
204pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
205    let mut hash = Sha3_512::new();
206    for v in data.into_iter() {
207        hash.update(v);
208    }
209    BUF.with(|buf| {
210        let mut b = buf.borrow_mut();
211        b.extend(hash.finalize().into_iter());
212        b.split().freeze()
213    })
214}
215
216/// pack T and return a bytesmut from the global thread local buffer
217pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
218    BUF.with(|buf| {
219        let mut b = buf.borrow_mut();
220        t.encode(&mut *b)?;
221        Ok(b.split())
222    })
223}
224
225pub fn bytesmut(t: &[u8]) -> BytesMut {
226    BUF.with(|buf| {
227        let mut b = buf.borrow_mut();
228        b.extend_from_slice(t);
229        b.split()
230    })
231}
232
233pub fn bytes(t: &[u8]) -> Bytes {
234    bytesmut(t).freeze()
235}
236
237#[derive(Clone, Debug)]
238pub struct ChanWrap<T>(pub mpsc::Sender<T>);
239
240impl<T> PartialEq for ChanWrap<T> {
241    fn eq(&self, other: &ChanWrap<T>) -> bool {
242        self.0.same_receiver(&other.0)
243    }
244}
245
246impl<T> Eq for ChanWrap<T> {}
247
248impl<T> Hash for ChanWrap<T> {
249    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
250        self.0.hash_receiver(state)
251    }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
255pub struct ChanId(u64);
256
257impl ChanId {
258    pub fn new() -> Self {
259        use std::sync::atomic::{AtomicU64, Ordering};
260        static NEXT: AtomicU64 = AtomicU64::new(0);
261        ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
262    }
263}
264
265#[derive(Debug, Clone)]
266pub enum BatchItem<T> {
267    InBatch(T),
268    EndBatch,
269}
270
271#[must_use = "streams do nothing unless polled"]
272pub struct Batched<S: Stream> {
273    stream: S,
274    ended: bool,
275    max: usize,
276    current: usize,
277}
278
279impl<S: Stream> Batched<S> {
280    // this is safe because,
281    // - Batched doesn't implement Drop
282    // - Batched doesn't implement Unpin
283    // - Batched isn't #[repr(packed)]
284    unsafe_pinned!(stream: S);
285
286    // these are safe because both types are copy
287    unsafe_unpinned!(ended: bool);
288    unsafe_unpinned!(current: usize);
289
290    pub fn new(stream: S, max: usize) -> Batched<S> {
291        Batched { stream, max, ended: false, current: 0 }
292    }
293
294    pub fn inner(&self) -> &S {
295        &self.stream
296    }
297
298    pub fn inner_mut(&mut self) -> &mut S {
299        &mut self.stream
300    }
301
302    pub fn into_inner(self) -> S {
303        self.stream
304    }
305}
306
307impl<S: Stream> Stream for Batched<S> {
308    type Item = BatchItem<<S as Stream>::Item>;
309
310    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
311        if self.ended {
312            Poll::Ready(None)
313        } else if self.current >= self.max {
314            *self.current() = 0;
315            Poll::Ready(Some(BatchItem::EndBatch))
316        } else {
317            match self.as_mut().stream().poll_next(cx) {
318                Poll::Ready(Some(v)) => {
319                    *self.as_mut().current() += 1;
320                    Poll::Ready(Some(BatchItem::InBatch(v)))
321                }
322                Poll::Ready(None) => {
323                    *self.as_mut().ended() = true;
324                    if self.current == 0 {
325                        Poll::Ready(None)
326                    } else {
327                        *self.current() = 0;
328                        Poll::Ready(Some(BatchItem::EndBatch))
329                    }
330                }
331                Poll::Pending => {
332                    if self.current == 0 {
333                        Poll::Pending
334                    } else {
335                        *self.current() = 0;
336                        Poll::Ready(Some(BatchItem::EndBatch))
337                    }
338                }
339            }
340        }
341    }
342
343    fn size_hint(&self) -> (usize, Option<usize>) {
344        self.stream.size_hint()
345    }
346}
347
348impl<S: Stream> FusedStream for Batched<S> {
349    fn is_terminated(&self) -> bool {
350        self.ended
351    }
352}
353
354impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
355    type Error = <S as Sink<Item>>::Error;
356
357    fn poll_ready(
358        self: Pin<&mut Self>,
359        cx: &mut Context,
360    ) -> Poll<Result<(), Self::Error>> {
361        self.stream().poll_ready(cx)
362    }
363
364    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
365        self.stream().start_send(item)
366    }
367
368    fn poll_flush(
369        self: Pin<&mut Self>,
370        cx: &mut Context,
371    ) -> Poll<Result<(), Self::Error>> {
372        self.stream().poll_flush(cx)
373    }
374
375    fn poll_close(
376        self: Pin<&mut Self>,
377        cx: &mut Context,
378    ) -> Poll<Result<(), Self::Error>> {
379        self.stream().poll_close(cx)
380    }
381}
382
383// a socketaddr wrapper that implements Ord so we can put clients in a
384// set.
385#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
386pub struct Addr(pub SocketAddr);
387
388impl From<SocketAddr> for Addr {
389    fn from(addr: SocketAddr) -> Self {
390        Addr(addr)
391    }
392}
393
394impl Borrow<SocketAddr> for Addr {
395    fn borrow(&self) -> &SocketAddr {
396        &self.0
397    }
398}
399
400impl PartialOrd for Addr {
401    fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
402        match (self.0, other.0) {
403            (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
404                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
405                    None => None,
406                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
407                    Some(o) => Some(o),
408                }
409            }
410            (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
411                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
412                    None => None,
413                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
414                    Some(o) => Some(o),
415                }
416            }
417            (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
418            (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
419        }
420    }
421}
422
423impl Ord for Addr {
424    fn cmp(&self, other: &Self) -> Ordering {
425        self.partial_cmp(other).unwrap()
426    }
427}
428
429#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
430pub enum Either<T, U> {
431    Left(T),
432    Right(U),
433}
434
435impl<T, U> Either<T, U> {
436    pub fn is_left(&self) -> bool {
437        match self {
438            Self::Left(_) => true,
439            Self::Right(_) => false,
440        }
441    }
442
443    pub fn is_right(&self) -> bool {
444        match self {
445            Self::Left(_) => false,
446            Self::Right(_) => true,
447        }
448    }
449
450    pub fn left(self) -> Option<T> {
451        match self {
452            Either::Left(t) => Some(t),
453            Either::Right(_) => None,
454        }
455    }
456
457    pub fn right(self) -> Option<U> {
458        match self {
459            Either::Right(t) => Some(t),
460            Either::Left(_) => None,
461        }
462    }
463}
464
465impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
466    type Item = I;
467    fn next(&mut self) -> Option<I> {
468        match self {
469            Either::Left(t) => t.next(),
470            Either::Right(t) => t.next(),
471        }
472    }
473}