netidx_core/
utils.rs

1use crate::{
2    pack::{Pack, PackError},
3    pool::{Pool, Poolable, Pooled},
4};
5use anyhow::{self, Result};
6use bytes::{Bytes, BytesMut};
7use digest::Digest;
8use futures::{
9    channel::mpsc,
10    prelude::*,
11    sink::Sink,
12    stream::FusedStream,
13    task::{Context, Poll},
14};
15use fxhash::FxHashMap;
16use sha3::Sha3_512;
17use std::{
18    any::{Any, TypeId},
19    borrow::Borrow,
20    cell::RefCell,
21    cmp::{Ord, Ordering, PartialOrd},
22    collections::HashMap,
23    hash::Hash,
24    iter::{IntoIterator, Iterator},
25    net::{IpAddr, SocketAddr},
26    pin::Pin,
27};
28
29#[macro_export]
30macro_rules! try_cf {
31    ($msg:expr, continue, $lbl:tt, $e:expr) => {
32        match $e {
33            Ok(x) => x,
34            Err(e) => {
35                log::info!($msg, e);
36                continue $lbl;
37            }
38        }
39    };
40    ($msg:expr, break, $lbl:tt, $e:expr) => {
41        match $e {
42            Ok(x) => x,
43            Err(e) => {
44                log::info!($msg, e);
45                break $lbl Err(Error::from(e));
46            }
47        }
48    };
49    ($msg:expr, continue, $e:expr) => {
50        match $e {
51            Ok(x) => x,
52            Err(e) => {
53                log::info!("{}: {}", $msg, e);
54                continue;
55            }
56        }
57    };
58    ($msg:expr, break, $e:expr) => {
59        match $e {
60            Ok(x) => x,
61            Err(e) => {
62                log::info!("{}: {}", $msg, e);
63                break Err(Error::from(e));
64            }
65        }
66    };
67    (continue, $lbl:tt, $e:expr) => {
68        match $e {
69            Ok(x) => x,
70            Err(e) => {
71                continue $lbl;
72            }
73        }
74    };
75    (break, $lbl:tt, $e:expr) => {
76        match $e {
77            Ok(x) => x,
78            Err(e) => {
79                break $lbl Err(Error::from(e));
80            }
81        }
82    };
83    ($msg:expr, $e:expr) => {
84        match $e {
85            Ok(x) => x,
86            Err(e) => {
87                log::info!("{}: {}", $msg, e);
88                break Err(Error::from(e));
89            }
90        }
91    };
92    (continue, $e:expr) => {
93        match $e {
94            Ok(x) => x,
95            Err(e) => {
96                continue;
97            }
98        }
99    };
100    (break, $e:expr) => {
101        match $e {
102            Ok(x) => x,
103            Err(e) => {
104                break Err(Error::from(e));
105            }
106        }
107    };
108    ($e:expr) => {
109        match $e {
110            Ok(x) => x,
111            Err(e) => {
112                break Err(Error::from(e));
113            }
114        }
115    };
116}
117
118#[macro_export]
119macro_rules! atomic_id {
120    ($name:ident) => {
121        #[derive(
122            Debug,
123            Clone,
124            Copy,
125            PartialEq,
126            Eq,
127            PartialOrd,
128            Ord,
129            Hash,
130            Serialize,
131            Deserialize,
132        )]
133        pub struct $name(u64);
134
135        impl $name {
136            pub fn new() -> Self {
137                use std::sync::atomic::{AtomicU64, Ordering};
138                static NEXT: AtomicU64 = AtomicU64::new(0);
139                $name(NEXT.fetch_add(1, Ordering::Relaxed))
140            }
141
142            pub fn inner(&self) -> u64 {
143                self.0
144            }
145
146            #[cfg(test)]
147            #[allow(dead_code)]
148            pub fn mk(i: u64) -> Self {
149                $name(i)
150            }
151        }
152
153        impl netidx_core::pack::Pack for $name {
154            fn encoded_len(&self) -> usize {
155                netidx_core::pack::varint_len(self.0)
156            }
157
158            fn encode(
159                &self,
160                buf: &mut impl bytes::BufMut,
161            ) -> std::result::Result<(), netidx_core::pack::PackError> {
162                Ok(netidx_core::pack::encode_varint(self.0, buf))
163            }
164
165            fn decode(
166                buf: &mut impl bytes::Buf,
167            ) -> std::result::Result<Self, netidx_core::pack::PackError> {
168                Ok(Self(netidx_core::pack::decode_varint(buf)?))
169            }
170        }
171    };
172}
173
174pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
175    match ip {
176        IpAddr::V4(ip) if ip.is_link_local() => {
177            bail!("addr is a link local address");
178        }
179        IpAddr::V4(ip) if ip.is_broadcast() => {
180            bail!("addr is a broadcast address");
181        }
182        IpAddr::V4(ip) if ip.is_private() => {
183            let ok = resolvers.iter().all(|(a, _)| match a.ip() {
184                IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
185                IpAddr::V6(_) => true,
186                _ => false,
187            });
188            if !ok {
189                bail!("addr is a private address, and the resolver is not")
190            }
191        }
192        _ => (),
193    }
194    if ip.is_unspecified() {
195        bail!("addr is an unspecified address");
196    }
197    if ip.is_multicast() {
198        bail!("addr is a multicast address");
199    }
200    if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
201        bail!("addr is a loopback address and the resolver is not");
202    }
203    Ok(())
204}
205
206thread_local! {
207    static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
208}
209
210pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
211    let mut hash = Sha3_512::new();
212    for v in data.into_iter() {
213        hash.update(v);
214    }
215    BUF.with(|buf| {
216        let mut b = buf.borrow_mut();
217        b.extend(hash.finalize().into_iter());
218        b.split().freeze()
219    })
220}
221
222/// pack T and return a bytesmut from the global thread local buffer
223pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
224    BUF.with(|buf| {
225        let mut b = buf.borrow_mut();
226        t.encode(&mut *b)?;
227        Ok(b.split())
228    })
229}
230
231thread_local! {
232    static POOLS: RefCell<FxHashMap<TypeId, Box<dyn Any>>> =
233        RefCell::new(HashMap::default());
234}
235
236/// Take a poolable type T from the generic thread local pool set.
237/// Note it is much more efficient to construct your own pools.
238/// size and max are the pool parameters used if the pool doesn't
239/// already exist.
240pub fn take_t<T: Any + Poolable + Send + 'static>(size: usize, max: usize) -> Pooled<T> {
241    POOLS.with(|pools| {
242        let mut pools = pools.borrow_mut();
243        let pool: &mut Pool<T> = pools
244            .entry(TypeId::of::<T>())
245            .or_insert_with(|| Box::new(Pool::<T>::new(size, max)))
246            .downcast_mut()
247            .unwrap();
248        pool.take()
249    })
250}
251
252pub fn bytesmut(t: &[u8]) -> BytesMut {
253    BUF.with(|buf| {
254        let mut b = buf.borrow_mut();
255        b.extend_from_slice(t);
256        b.split()
257    })
258}
259
260pub fn bytes(t: &[u8]) -> Bytes {
261    bytesmut(t).freeze()
262}
263
264#[derive(Clone, Debug)]
265pub struct ChanWrap<T>(pub mpsc::Sender<T>);
266
267impl<T> PartialEq for ChanWrap<T> {
268    fn eq(&self, other: &ChanWrap<T>) -> bool {
269        self.0.same_receiver(&other.0)
270    }
271}
272
273impl<T> Eq for ChanWrap<T> {}
274
275impl<T> Hash for ChanWrap<T> {
276    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
277        self.0.hash_receiver(state)
278    }
279}
280
281#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
282pub struct ChanId(u64);
283
284impl ChanId {
285    pub fn new() -> Self {
286        use std::sync::atomic::{AtomicU64, Ordering};
287        static NEXT: AtomicU64 = AtomicU64::new(0);
288        ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
289    }
290}
291
292#[derive(Debug, Clone)]
293pub enum BatchItem<T> {
294    InBatch(T),
295    EndBatch,
296}
297
298#[must_use = "streams do nothing unless polled"]
299pub struct Batched<S: Stream> {
300    stream: S,
301    ended: bool,
302    max: usize,
303    current: usize,
304}
305
306impl<S: Stream> Batched<S> {
307    // this is safe because,
308    // - Batched doesn't implement Drop
309    // - Batched doesn't implement Unpin
310    // - Batched isn't #[repr(packed)]
311    unsafe_pinned!(stream: S);
312
313    // these are safe because both types are copy
314    unsafe_unpinned!(ended: bool);
315    unsafe_unpinned!(current: usize);
316
317    pub fn new(stream: S, max: usize) -> Batched<S> {
318        Batched { stream, max, ended: false, current: 0 }
319    }
320
321    pub fn inner(&self) -> &S {
322        &self.stream
323    }
324
325    pub fn inner_mut(&mut self) -> &mut S {
326        &mut self.stream
327    }
328
329    pub fn into_inner(self) -> S {
330        self.stream
331    }
332}
333
334impl<S: Stream> Stream for Batched<S> {
335    type Item = BatchItem<<S as Stream>::Item>;
336
337    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
338        if self.ended {
339            Poll::Ready(None)
340        } else if self.current >= self.max {
341            *self.current() = 0;
342            Poll::Ready(Some(BatchItem::EndBatch))
343        } else {
344            match self.as_mut().stream().poll_next(cx) {
345                Poll::Ready(Some(v)) => {
346                    *self.as_mut().current() += 1;
347                    Poll::Ready(Some(BatchItem::InBatch(v)))
348                }
349                Poll::Ready(None) => {
350                    *self.as_mut().ended() = true;
351                    if self.current == 0 {
352                        Poll::Ready(None)
353                    } else {
354                        *self.current() = 0;
355                        Poll::Ready(Some(BatchItem::EndBatch))
356                    }
357                }
358                Poll::Pending => {
359                    if self.current == 0 {
360                        Poll::Pending
361                    } else {
362                        *self.current() = 0;
363                        Poll::Ready(Some(BatchItem::EndBatch))
364                    }
365                }
366            }
367        }
368    }
369
370    fn size_hint(&self) -> (usize, Option<usize>) {
371        self.stream.size_hint()
372    }
373}
374
375impl<S: Stream> FusedStream for Batched<S> {
376    fn is_terminated(&self) -> bool {
377        self.ended
378    }
379}
380
381impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
382    type Error = <S as Sink<Item>>::Error;
383
384    fn poll_ready(
385        self: Pin<&mut Self>,
386        cx: &mut Context,
387    ) -> Poll<Result<(), Self::Error>> {
388        self.stream().poll_ready(cx)
389    }
390
391    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
392        self.stream().start_send(item)
393    }
394
395    fn poll_flush(
396        self: Pin<&mut Self>,
397        cx: &mut Context,
398    ) -> Poll<Result<(), Self::Error>> {
399        self.stream().poll_flush(cx)
400    }
401
402    fn poll_close(
403        self: Pin<&mut Self>,
404        cx: &mut Context,
405    ) -> Poll<Result<(), Self::Error>> {
406        self.stream().poll_close(cx)
407    }
408}
409
410// a socketaddr wrapper that implements Ord so we can put clients in a
411// set.
412#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
413pub struct Addr(pub SocketAddr);
414
415impl From<SocketAddr> for Addr {
416    fn from(addr: SocketAddr) -> Self {
417        Addr(addr)
418    }
419}
420
421impl Borrow<SocketAddr> for Addr {
422    fn borrow(&self) -> &SocketAddr {
423        &self.0
424    }
425}
426
427impl PartialOrd for Addr {
428    fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
429        match (self.0, other.0) {
430            (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
431                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
432                    None => None,
433                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
434                    Some(o) => Some(o),
435                }
436            }
437            (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
438                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
439                    None => None,
440                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
441                    Some(o) => Some(o),
442                }
443            }
444            (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
445            (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
446        }
447    }
448}
449
450impl Ord for Addr {
451    fn cmp(&self, other: &Self) -> Ordering {
452        self.partial_cmp(other).unwrap()
453    }
454}
455
456#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
457pub enum Either<T, U> {
458    Left(T),
459    Right(U),
460}
461
462impl<T, U> Either<T, U> {
463    pub fn is_left(&self) -> bool {
464        match self {
465            Self::Left(_) => true,
466            Self::Right(_) => false,
467        }
468    }
469
470    pub fn is_right(&self) -> bool {
471        match self {
472            Self::Left(_) => false,
473            Self::Right(_) => true,
474        }
475    }
476
477    pub fn left(self) -> Option<T> {
478        match self {
479            Either::Left(t) => Some(t),
480            Either::Right(_) => None,
481        }
482    }
483
484    pub fn right(self) -> Option<U> {
485        match self {
486            Either::Right(t) => Some(t),
487            Either::Left(_) => None,
488        }
489    }
490}
491
492impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
493    type Item = I;
494    fn next(&mut self) -> Option<I> {
495        match self {
496            Either::Left(t) => t.next(),
497            Either::Right(t) => t.next(),
498        }
499    }
500}