netidx_core/
utils.rs

1use crate::{
2    pack::{Pack, PackError},
3    pool::{Pool, Poolable, Pooled},
4};
5use anyhow::{self, Result};
6use bytes::{Bytes, BytesMut};
7use digest::Digest;
8use futures::{
9    channel::mpsc,
10    prelude::*,
11    sink::Sink,
12    stream::FusedStream,
13    task::{Context, Poll},
14};
15use fxhash::FxHashMap;
16use sha3::Sha3_512;
17use std::{
18    any::{Any, TypeId},
19    borrow::Borrow,
20    borrow::Cow,
21    cell::RefCell,
22    cmp::{Ord, Ordering, PartialOrd},
23    collections::HashMap,
24    hash::Hash,
25    iter::{IntoIterator, Iterator},
26    net::{IpAddr, SocketAddr},
27    pin::Pin,
28    str,
29};
30
31#[macro_export]
32macro_rules! try_cf {
33    ($msg:expr, continue, $lbl:tt, $e:expr) => {
34        match $e {
35            Ok(x) => x,
36            Err(e) => {
37                log::info!($msg, e);
38                continue $lbl;
39            }
40        }
41    };
42    ($msg:expr, break, $lbl:tt, $e:expr) => {
43        match $e {
44            Ok(x) => x,
45            Err(e) => {
46                log::info!($msg, e);
47                break $lbl Err(Error::from(e));
48            }
49        }
50    };
51    ($msg:expr, continue, $e:expr) => {
52        match $e {
53            Ok(x) => x,
54            Err(e) => {
55                log::info!("{}: {}", $msg, e);
56                continue;
57            }
58        }
59    };
60    ($msg:expr, break, $e:expr) => {
61        match $e {
62            Ok(x) => x,
63            Err(e) => {
64                log::info!("{}: {}", $msg, e);
65                break Err(Error::from(e));
66            }
67        }
68    };
69    (continue, $lbl:tt, $e:expr) => {
70        match $e {
71            Ok(x) => x,
72            Err(e) => {
73                continue $lbl;
74            }
75        }
76    };
77    (break, $lbl:tt, $e:expr) => {
78        match $e {
79            Ok(x) => x,
80            Err(e) => {
81                break $lbl Err(Error::from(e));
82            }
83        }
84    };
85    ($msg:expr, $e:expr) => {
86        match $e {
87            Ok(x) => x,
88            Err(e) => {
89                log::info!("{}: {}", $msg, e);
90                break Err(Error::from(e));
91            }
92        }
93    };
94    (continue, $e:expr) => {
95        match $e {
96            Ok(x) => x,
97            Err(e) => {
98                continue;
99            }
100        }
101    };
102    (break, $e:expr) => {
103        match $e {
104            Ok(x) => x,
105            Err(e) => {
106                break Err(Error::from(e));
107            }
108        }
109    };
110    ($e:expr) => {
111        match $e {
112            Ok(x) => x,
113            Err(e) => {
114                break Err(Error::from(e));
115            }
116        }
117    };
118}
119
120#[macro_export]
121macro_rules! atomic_id {
122    ($name:ident) => {
123        #[derive(
124            Debug,
125            Clone,
126            Copy,
127            PartialEq,
128            Eq,
129            PartialOrd,
130            Ord,
131            Hash,
132            Serialize,
133            Deserialize,
134        )]
135        pub struct $name(u64);
136
137        impl $name {
138            pub fn new() -> Self {
139                use std::sync::atomic::{AtomicU64, Ordering};
140                static NEXT: AtomicU64 = AtomicU64::new(0);
141                $name(NEXT.fetch_add(1, Ordering::Relaxed))
142            }
143
144            pub fn inner(&self) -> u64 {
145                self.0
146            }
147
148            #[cfg(test)]
149            #[allow(dead_code)]
150            pub fn mk(i: u64) -> Self {
151                $name(i)
152            }
153        }
154
155        impl netidx_core::pack::Pack for $name {
156            fn encoded_len(&self) -> usize {
157                netidx_core::pack::varint_len(self.0)
158            }
159
160            fn encode(
161                &self,
162                buf: &mut impl bytes::BufMut,
163            ) -> std::result::Result<(), netidx_core::pack::PackError> {
164                Ok(netidx_core::pack::encode_varint(self.0, buf))
165            }
166
167            fn decode(
168                buf: &mut impl bytes::Buf,
169            ) -> std::result::Result<Self, netidx_core::pack::PackError> {
170                Ok(Self(netidx_core::pack::decode_varint(buf)?))
171            }
172        }
173    };
174}
175
176pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
177    match ip {
178        IpAddr::V4(ip) if ip.is_link_local() => {
179            bail!("addr is a link local address");
180        }
181        IpAddr::V4(ip) if ip.is_broadcast() => {
182            bail!("addr is a broadcast address");
183        }
184        IpAddr::V4(ip) if ip.is_private() => {
185            let ok = resolvers.iter().all(|(a, _)| match a.ip() {
186                IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
187                IpAddr::V6(_) => true,
188                _ => false,
189            });
190            if !ok {
191                bail!("addr is a private address, and the resolver is not")
192            }
193        }
194        _ => (),
195    }
196    if ip.is_unspecified() {
197        bail!("addr is an unspecified address");
198    }
199    if ip.is_multicast() {
200        bail!("addr is a multicast address");
201    }
202    if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
203        bail!("addr is a loopback address and the resolver is not");
204    }
205    Ok(())
206}
207
208pub fn is_sep(esc: &mut bool, c: char, escape: char, sep: char) -> bool {
209    if c == sep {
210        !*esc
211    } else {
212        *esc = c == escape && !*esc;
213        false
214    }
215}
216
217/// escape the specified string using the specified escape character
218/// and a slice of special characters that need escaping. Place the
219/// results into the specified buffer
220pub fn escape_to<T>(s: &T, buf: &mut String, esc: char, spec: &[char])
221where
222    T: AsRef<str> + ?Sized,
223{
224    for c in s.as_ref().chars() {
225        if spec.contains(&c) {
226            buf.push(esc);
227            buf.push(c);
228        } else if c == esc {
229            buf.push(esc);
230            buf.push(c);
231        } else {
232            buf.push(c);
233        }
234    }
235}
236
237/// escape the specified string using the specified escape character
238/// and a slice of special characters that need escaping.
239pub fn escape<'a, 'b, T>(s: &'a T, esc: char, spec: &'b [char]) -> Cow<'a, str>
240where
241    T: AsRef<str> + ?Sized,
242    'a: 'b,
243{
244    let s = s.as_ref();
245    if s.find(|c: char| spec.contains(&c) || c == esc).is_none() {
246        Cow::Borrowed(s.as_ref())
247    } else {
248        let mut out = String::with_capacity(s.len());
249        escape_to(s, &mut out, esc, spec);
250        Cow::Owned(out)
251    }
252}
253
254/// unescape the specified string using the specified escape
255/// character. Place the result in the specified buffer.
256pub fn unescape_to<T>(s: &T, buf: &mut String, esc: char)
257where
258    T: AsRef<str> + ?Sized,
259{
260    let mut escaped = false;
261    buf.extend(s.as_ref().chars().filter_map(|c| {
262        if c == esc && !escaped {
263            escaped = true;
264            None
265        } else {
266            escaped = false;
267            Some(c)
268        }
269    }))
270}
271
272/// unescape the specified string using the specified escape character
273pub fn unescape<T>(s: &T, esc: char) -> Cow<str>
274where
275    T: AsRef<str> + ?Sized,
276{
277    let s = s.as_ref();
278    if !s.contains(esc) {
279        Cow::Borrowed(s.as_ref())
280    } else {
281        let mut res = String::with_capacity(s.len());
282        unescape_to(s, &mut res, esc);
283        Cow::Owned(res)
284    }
285}
286
287pub fn is_escaped(s: &str, esc: char, i: usize) -> bool {
288    let b = s.as_bytes();
289    !s.is_char_boundary(i) || {
290        let mut res = false;
291        for j in (0..i).rev() {
292            if s.is_char_boundary(j) && b[j] == (esc as u8) {
293                res = !res;
294            } else {
295                break;
296            }
297        }
298        res
299    }
300}
301
302pub fn splitn_escaped(
303    s: &str,
304    n: usize,
305    escape: char,
306    sep: char,
307) -> impl Iterator<Item = &str> {
308    s.splitn(n, {
309        let mut esc = false;
310        move |c| is_sep(&mut esc, c, escape, sep)
311    })
312}
313
314pub fn split_escaped(s: &str, escape: char, sep: char) -> impl Iterator<Item = &str> {
315    s.split({
316        let mut esc = false;
317        move |c| is_sep(&mut esc, c, escape, sep)
318    })
319}
320
321pub fn rsplit_escaped(s: &str, escape: char, sep: char) -> impl Iterator<Item = &str> {
322    s.rsplit({
323        let mut esc = false;
324        move |c| is_sep(&mut esc, c, escape, sep)
325    })
326}
327
328thread_local! {
329    static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
330}
331
332pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
333    let mut hash = Sha3_512::new();
334    for v in data.into_iter() {
335        hash.update(v);
336    }
337    BUF.with(|buf| {
338        let mut b = buf.borrow_mut();
339        b.extend(hash.finalize().into_iter());
340        b.split().freeze()
341    })
342}
343
344/// pack T and return a bytesmut from the global thread local buffer
345pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
346    BUF.with(|buf| {
347        let mut b = buf.borrow_mut();
348        t.encode(&mut *b)?;
349        Ok(b.split())
350    })
351}
352
353thread_local! {
354    static POOLS: RefCell<FxHashMap<TypeId, Box<dyn Any>>> =
355        RefCell::new(HashMap::default());
356}
357
358/// Take a poolable type T from the generic thread local pool set.
359/// Note it is much more efficient to construct your own pools.
360/// size and max are the pool parameters used if the pool doesn't
361/// already exist.
362pub fn take_t<T: Any + Poolable + Send + 'static>(size: usize, max: usize) -> Pooled<T> {
363    POOLS.with(|pools| {
364        let mut pools = pools.borrow_mut();
365        let pool: &mut Pool<T> = pools
366            .entry(TypeId::of::<T>())
367            .or_insert_with(|| Box::new(Pool::<T>::new(size, max)))
368            .downcast_mut()
369            .unwrap();
370        pool.take()
371    })
372}
373
374pub fn bytesmut(t: &[u8]) -> BytesMut {
375    BUF.with(|buf| {
376        let mut b = buf.borrow_mut();
377        b.extend_from_slice(t);
378        b.split()
379    })
380}
381
382pub fn bytes(t: &[u8]) -> Bytes {
383    bytesmut(t).freeze()
384}
385
386#[derive(Clone, Debug)]
387pub struct ChanWrap<T>(pub mpsc::Sender<T>);
388
389impl<T> PartialEq for ChanWrap<T> {
390    fn eq(&self, other: &ChanWrap<T>) -> bool {
391        self.0.same_receiver(&other.0)
392    }
393}
394
395impl<T> Eq for ChanWrap<T> {}
396
397impl<T> Hash for ChanWrap<T> {
398    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
399        self.0.hash_receiver(state)
400    }
401}
402
403#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
404pub struct ChanId(u64);
405
406impl ChanId {
407    pub fn new() -> Self {
408        use std::sync::atomic::{AtomicU64, Ordering};
409        static NEXT: AtomicU64 = AtomicU64::new(0);
410        ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
411    }
412}
413
414#[derive(Debug, Clone)]
415pub enum BatchItem<T> {
416    InBatch(T),
417    EndBatch,
418}
419
420#[must_use = "streams do nothing unless polled"]
421pub struct Batched<S: Stream> {
422    stream: S,
423    ended: bool,
424    max: usize,
425    current: usize,
426}
427
428impl<S: Stream> Batched<S> {
429    // this is safe because,
430    // - Batched doesn't implement Drop
431    // - Batched doesn't implement Unpin
432    // - Batched isn't #[repr(packed)]
433    unsafe_pinned!(stream: S);
434
435    // these are safe because both types are copy
436    unsafe_unpinned!(ended: bool);
437    unsafe_unpinned!(current: usize);
438
439    pub fn new(stream: S, max: usize) -> Batched<S> {
440        Batched { stream, max, ended: false, current: 0 }
441    }
442
443    pub fn inner(&self) -> &S {
444        &self.stream
445    }
446
447    pub fn inner_mut(&mut self) -> &mut S {
448        &mut self.stream
449    }
450
451    pub fn into_inner(self) -> S {
452        self.stream
453    }
454}
455
456impl<S: Stream> Stream for Batched<S> {
457    type Item = BatchItem<<S as Stream>::Item>;
458
459    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
460        if self.ended {
461            Poll::Ready(None)
462        } else if self.current >= self.max {
463            *self.current() = 0;
464            Poll::Ready(Some(BatchItem::EndBatch))
465        } else {
466            match self.as_mut().stream().poll_next(cx) {
467                Poll::Ready(Some(v)) => {
468                    *self.as_mut().current() += 1;
469                    Poll::Ready(Some(BatchItem::InBatch(v)))
470                }
471                Poll::Ready(None) => {
472                    *self.as_mut().ended() = true;
473                    if self.current == 0 {
474                        Poll::Ready(None)
475                    } else {
476                        *self.current() = 0;
477                        Poll::Ready(Some(BatchItem::EndBatch))
478                    }
479                }
480                Poll::Pending => {
481                    if self.current == 0 {
482                        Poll::Pending
483                    } else {
484                        *self.current() = 0;
485                        Poll::Ready(Some(BatchItem::EndBatch))
486                    }
487                }
488            }
489        }
490    }
491
492    fn size_hint(&self) -> (usize, Option<usize>) {
493        self.stream.size_hint()
494    }
495}
496
497impl<S: Stream> FusedStream for Batched<S> {
498    fn is_terminated(&self) -> bool {
499        self.ended
500    }
501}
502
503impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
504    type Error = <S as Sink<Item>>::Error;
505
506    fn poll_ready(
507        self: Pin<&mut Self>,
508        cx: &mut Context,
509    ) -> Poll<Result<(), Self::Error>> {
510        self.stream().poll_ready(cx)
511    }
512
513    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
514        self.stream().start_send(item)
515    }
516
517    fn poll_flush(
518        self: Pin<&mut Self>,
519        cx: &mut Context,
520    ) -> Poll<Result<(), Self::Error>> {
521        self.stream().poll_flush(cx)
522    }
523
524    fn poll_close(
525        self: Pin<&mut Self>,
526        cx: &mut Context,
527    ) -> Poll<Result<(), Self::Error>> {
528        self.stream().poll_close(cx)
529    }
530}
531
532// a socketaddr wrapper that implements Ord so we can put clients in a
533// set.
534#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
535pub struct Addr(pub SocketAddr);
536
537impl From<SocketAddr> for Addr {
538    fn from(addr: SocketAddr) -> Self {
539        Addr(addr)
540    }
541}
542
543impl Borrow<SocketAddr> for Addr {
544    fn borrow(&self) -> &SocketAddr {
545        &self.0
546    }
547}
548
549impl PartialOrd for Addr {
550    fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
551        match (self.0, other.0) {
552            (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
553                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
554                    None => None,
555                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
556                    Some(o) => Some(o),
557                }
558            }
559            (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
560                match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
561                    None => None,
562                    Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
563                    Some(o) => Some(o),
564                }
565            }
566            (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
567            (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
568        }
569    }
570}
571
572impl Ord for Addr {
573    fn cmp(&self, other: &Self) -> Ordering {
574        self.partial_cmp(other).unwrap()
575    }
576}
577
578#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
579pub enum Either<T, U> {
580    Left(T),
581    Right(U),
582}
583
584impl<T, U> Either<T, U> {
585    pub fn is_left(&self) -> bool {
586        match self {
587            Self::Left(_) => true,
588            Self::Right(_) => false,
589        }
590    }
591
592    pub fn is_right(&self) -> bool {
593        match self {
594            Self::Left(_) => false,
595            Self::Right(_) => true,
596        }
597    }
598
599    pub fn left(self) -> Option<T> {
600        match self {
601            Either::Left(t) => Some(t),
602            Either::Right(_) => None,
603        }
604    }
605
606    pub fn right(self) -> Option<U> {
607        match self {
608            Either::Right(t) => Some(t),
609            Either::Left(_) => None,
610        }
611    }
612}
613
614impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
615    type Item = I;
616    fn next(&mut self) -> Option<I> {
617        match self {
618            Either::Left(t) => t.next(),
619            Either::Right(t) => t.next(),
620        }
621    }
622}