1use crate::pack::{Pack, PackError};
3use anyhow::{self, Result};
4use bytes::{Bytes, BytesMut};
5use futures::{
6 channel::mpsc,
7 prelude::*,
8 sink::Sink,
9 stream::FusedStream,
10 task::{Context, Poll},
11};
12use sha3::{Digest, Sha3_512};
13use std::{
14 borrow::Borrow,
15 cell::RefCell,
16 cmp::{Ord, Ordering, PartialOrd},
17 hash::Hash,
18 iter::{IntoIterator, Iterator},
19 net::{IpAddr, SocketAddr},
20 pin::Pin,
21};
22
23#[macro_export]
24macro_rules! try_cf {
25 ($msg:expr, continue, $lbl:tt, $e:expr) => {
26 match $e {
27 Ok(x) => x,
28 Err(e) => {
29 log::info!($msg, e);
30 continue $lbl;
31 }
32 }
33 };
34 ($msg:expr, break, $lbl:tt, $e:expr) => {
35 match $e {
36 Ok(x) => x,
37 Err(e) => {
38 log::info!($msg, e);
39 break $lbl Err(Error::from(e));
40 }
41 }
42 };
43 ($msg:expr, continue, $e:expr) => {
44 match $e {
45 Ok(x) => x,
46 Err(e) => {
47 log::info!("{}: {}", $msg, e);
48 continue;
49 }
50 }
51 };
52 ($msg:expr, break, $e:expr) => {
53 match $e {
54 Ok(x) => x,
55 Err(e) => {
56 log::info!("{}: {}", $msg, e);
57 break Err(Error::from(e));
58 }
59 }
60 };
61 (continue, $lbl:tt, $e:expr) => {
62 match $e {
63 Ok(x) => x,
64 Err(e) => {
65 continue $lbl;
66 }
67 }
68 };
69 (break, $lbl:tt, $e:expr) => {
70 match $e {
71 Ok(x) => x,
72 Err(e) => {
73 break $lbl Err(Error::from(e));
74 }
75 }
76 };
77 ($msg:expr, $e:expr) => {
78 match $e {
79 Ok(x) => x,
80 Err(e) => {
81 log::info!("{}: {}", $msg, e);
82 break Err(Error::from(e));
83 }
84 }
85 };
86 (continue, $e:expr) => {
87 match $e {
88 Ok(x) => x,
89 Err(e) => {
90 continue;
91 }
92 }
93 };
94 (break, $e:expr) => {
95 match $e {
96 Ok(x) => x,
97 Err(e) => {
98 break Err(Error::from(e));
99 }
100 }
101 };
102 ($e:expr) => {
103 match $e {
104 Ok(x) => x,
105 Err(e) => {
106 break Err(Error::from(e));
107 }
108 }
109 };
110}
111
112#[macro_export]
113macro_rules! atomic_id {
114 ($name:ident) => {
115 #[derive(
116 Debug,
117 Clone,
118 Copy,
119 PartialEq,
120 Eq,
121 PartialOrd,
122 Ord,
123 Hash,
124 Serialize,
125 Deserialize,
126 )]
127 pub struct $name(u64);
128
129 impl nohash::IsEnabled for $name {}
130
131 impl $name {
132 pub fn new() -> Self {
133 use std::sync::atomic::{AtomicU64, Ordering};
134 static NEXT: AtomicU64 = AtomicU64::new(0);
135 $name(NEXT.fetch_add(1, Ordering::Relaxed))
136 }
137
138 pub fn inner(&self) -> u64 {
139 self.0
140 }
141
142 #[cfg(test)]
143 #[allow(dead_code)]
144 pub fn mk(i: u64) -> Self {
145 $name(i)
146 }
147 }
148
149 impl netidx_core::pack::Pack for $name {
150 fn encoded_len(&self) -> usize {
151 netidx_core::pack::varint_len(self.0)
152 }
153
154 fn encode(
155 &self,
156 buf: &mut impl bytes::BufMut,
157 ) -> std::result::Result<(), netidx_core::pack::PackError> {
158 Ok(netidx_core::pack::encode_varint(self.0, buf))
159 }
160
161 fn decode(
162 buf: &mut impl bytes::Buf,
163 ) -> std::result::Result<Self, netidx_core::pack::PackError> {
164 Ok(Self(netidx_core::pack::decode_varint(buf)?))
165 }
166 }
167 };
168}
169
170pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
171 match ip {
172 IpAddr::V4(ip) if ip.is_link_local() => {
173 bail!("addr is a link local address");
174 }
175 IpAddr::V4(ip) if ip.is_broadcast() => {
176 bail!("addr is a broadcast address");
177 }
178 IpAddr::V4(ip) if ip.is_private() => {
179 let ok = resolvers.iter().all(|(a, _)| match a.ip() {
180 IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
181 IpAddr::V6(_) => true,
182 _ => false,
183 });
184 if !ok {
185 bail!("addr is a private address, and the resolver is not")
186 }
187 }
188 _ => (),
189 }
190 if ip.is_unspecified() {
191 bail!("addr is an unspecified address");
192 }
193 if ip.is_multicast() {
194 bail!("addr is a multicast address");
195 }
196 if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
197 bail!("addr is a loopback address and the resolver is not");
198 }
199 Ok(())
200}
201
202thread_local! {
203 static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
204}
205
206pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
207 let mut hash = Sha3_512::new();
208 for v in data.into_iter() {
209 hash.update(v);
210 }
211 BUF.with(|buf| {
212 let mut b = buf.borrow_mut();
213 b.extend(hash.finalize().into_iter());
214 b.split().freeze()
215 })
216}
217
218pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
220 BUF.with(|buf| {
221 let mut b = buf.borrow_mut();
222 t.encode(&mut *b)?;
223 Ok(b.split())
224 })
225}
226
227pub fn bytesmut(t: &[u8]) -> BytesMut {
228 BUF.with(|buf| {
229 let mut b = buf.borrow_mut();
230 b.extend_from_slice(t);
231 b.split()
232 })
233}
234
235pub fn bytes(t: &[u8]) -> Bytes {
236 bytesmut(t).freeze()
237}
238
239#[derive(Clone, Debug)]
240pub struct ChanWrap<T>(pub mpsc::Sender<T>);
241
242impl<T> PartialEq for ChanWrap<T> {
243 fn eq(&self, other: &ChanWrap<T>) -> bool {
244 self.0.same_receiver(&other.0)
245 }
246}
247
248impl<T> Eq for ChanWrap<T> {}
249
250impl<T> Hash for ChanWrap<T> {
251 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
252 self.0.hash_receiver(state)
253 }
254}
255
256#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
257pub struct ChanId(u64);
258
259impl nohash::IsEnabled for ChanId {}
260
261impl ChanId {
262 pub fn new() -> Self {
263 use std::sync::atomic::{AtomicU64, Ordering};
264 static NEXT: AtomicU64 = AtomicU64::new(0);
265 ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
266 }
267}
268
269#[derive(Debug, Clone)]
270pub enum BatchItem<T> {
271 InBatch(T),
272 EndBatch,
273}
274
275#[must_use = "streams do nothing unless polled"]
276pub struct Batched<S: Stream> {
277 stream: S,
278 ended: bool,
279 max: usize,
280 current: usize,
281}
282
283impl<S: Stream> Batched<S> {
284 unsafe_pinned!(stream: S);
289
290 unsafe_unpinned!(ended: bool);
292 unsafe_unpinned!(current: usize);
293
294 pub fn new(stream: S, max: usize) -> Batched<S> {
295 Batched { stream, max, ended: false, current: 0 }
296 }
297
298 pub fn inner(&self) -> &S {
299 &self.stream
300 }
301
302 pub fn inner_mut(&mut self) -> &mut S {
303 &mut self.stream
304 }
305
306 pub fn into_inner(self) -> S {
307 self.stream
308 }
309}
310
311impl<S: Stream> Stream for Batched<S> {
312 type Item = BatchItem<<S as Stream>::Item>;
313
314 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
315 if self.ended {
316 Poll::Ready(None)
317 } else if self.current >= self.max {
318 *self.current() = 0;
319 Poll::Ready(Some(BatchItem::EndBatch))
320 } else {
321 match self.as_mut().stream().poll_next(cx) {
322 Poll::Ready(Some(v)) => {
323 *self.as_mut().current() += 1;
324 Poll::Ready(Some(BatchItem::InBatch(v)))
325 }
326 Poll::Ready(None) => {
327 *self.as_mut().ended() = true;
328 if self.current == 0 {
329 Poll::Ready(None)
330 } else {
331 *self.current() = 0;
332 Poll::Ready(Some(BatchItem::EndBatch))
333 }
334 }
335 Poll::Pending => {
336 if self.current == 0 {
337 Poll::Pending
338 } else {
339 *self.current() = 0;
340 Poll::Ready(Some(BatchItem::EndBatch))
341 }
342 }
343 }
344 }
345 }
346
347 fn size_hint(&self) -> (usize, Option<usize>) {
348 self.stream.size_hint()
349 }
350}
351
352impl<S: Stream> FusedStream for Batched<S> {
353 fn is_terminated(&self) -> bool {
354 self.ended
355 }
356}
357
358impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
359 type Error = <S as Sink<Item>>::Error;
360
361 fn poll_ready(
362 self: Pin<&mut Self>,
363 cx: &mut Context,
364 ) -> Poll<Result<(), Self::Error>> {
365 self.stream().poll_ready(cx)
366 }
367
368 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
369 self.stream().start_send(item)
370 }
371
372 fn poll_flush(
373 self: Pin<&mut Self>,
374 cx: &mut Context,
375 ) -> Poll<Result<(), Self::Error>> {
376 self.stream().poll_flush(cx)
377 }
378
379 fn poll_close(
380 self: Pin<&mut Self>,
381 cx: &mut Context,
382 ) -> Poll<Result<(), Self::Error>> {
383 self.stream().poll_close(cx)
384 }
385}
386
387#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
390pub struct Addr(pub SocketAddr);
391
392impl From<SocketAddr> for Addr {
393 fn from(addr: SocketAddr) -> Self {
394 Addr(addr)
395 }
396}
397
398impl Borrow<SocketAddr> for Addr {
399 fn borrow(&self) -> &SocketAddr {
400 &self.0
401 }
402}
403
404impl PartialOrd for Addr {
405 fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
406 match (self.0, other.0) {
407 (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
408 match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
409 None => None,
410 Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
411 Some(o) => Some(o),
412 }
413 }
414 (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
415 match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
416 None => None,
417 Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
418 Some(o) => Some(o),
419 }
420 }
421 (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
422 (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
423 }
424 }
425}
426
427impl Ord for Addr {
428 fn cmp(&self, other: &Self) -> Ordering {
429 self.partial_cmp(other).unwrap()
430 }
431}
432
433#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
434pub enum Either<T, U> {
435 Left(T),
436 Right(U),
437}
438
439impl<T, U> Either<T, U> {
440 pub fn is_left(&self) -> bool {
441 match self {
442 Self::Left(_) => true,
443 Self::Right(_) => false,
444 }
445 }
446
447 pub fn is_right(&self) -> bool {
448 match self {
449 Self::Left(_) => false,
450 Self::Right(_) => true,
451 }
452 }
453
454 pub fn left(self) -> Option<T> {
455 match self {
456 Either::Left(t) => Some(t),
457 Either::Right(_) => None,
458 }
459 }
460
461 pub fn right(self) -> Option<U> {
462 match self {
463 Either::Right(t) => Some(t),
464 Either::Left(_) => None,
465 }
466 }
467}
468
469impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
470 type Item = I;
471 fn next(&mut self) -> Option<I> {
472 match self {
473 Either::Left(t) => t.next(),
474 Either::Right(t) => t.next(),
475 }
476 }
477}