1use crate::pack::{Pack, PackError};
2use anyhow::{self, Result};
3use bytes::{Bytes, BytesMut};
4use digest::Digest;
5use futures::{
6 channel::mpsc,
7 prelude::*,
8 sink::Sink,
9 stream::FusedStream,
10 task::{Context, Poll},
11};
12use sha3::Sha3_512;
13use std::{
14 borrow::Borrow,
15 cell::RefCell,
16 cmp::{Ord, Ordering, PartialOrd},
17 hash::Hash,
18 iter::{IntoIterator, Iterator},
19 net::{IpAddr, SocketAddr},
20 pin::Pin,
21};
22
23#[macro_export]
24macro_rules! try_cf {
25 ($msg:expr, continue, $lbl:tt, $e:expr) => {
26 match $e {
27 Ok(x) => x,
28 Err(e) => {
29 log::info!($msg, e);
30 continue $lbl;
31 }
32 }
33 };
34 ($msg:expr, break, $lbl:tt, $e:expr) => {
35 match $e {
36 Ok(x) => x,
37 Err(e) => {
38 log::info!($msg, e);
39 break $lbl Err(Error::from(e));
40 }
41 }
42 };
43 ($msg:expr, continue, $e:expr) => {
44 match $e {
45 Ok(x) => x,
46 Err(e) => {
47 log::info!("{}: {}", $msg, e);
48 continue;
49 }
50 }
51 };
52 ($msg:expr, break, $e:expr) => {
53 match $e {
54 Ok(x) => x,
55 Err(e) => {
56 log::info!("{}: {}", $msg, e);
57 break Err(Error::from(e));
58 }
59 }
60 };
61 (continue, $lbl:tt, $e:expr) => {
62 match $e {
63 Ok(x) => x,
64 Err(e) => {
65 continue $lbl;
66 }
67 }
68 };
69 (break, $lbl:tt, $e:expr) => {
70 match $e {
71 Ok(x) => x,
72 Err(e) => {
73 break $lbl Err(Error::from(e));
74 }
75 }
76 };
77 ($msg:expr, $e:expr) => {
78 match $e {
79 Ok(x) => x,
80 Err(e) => {
81 log::info!("{}: {}", $msg, e);
82 break Err(Error::from(e));
83 }
84 }
85 };
86 (continue, $e:expr) => {
87 match $e {
88 Ok(x) => x,
89 Err(e) => {
90 continue;
91 }
92 }
93 };
94 (break, $e:expr) => {
95 match $e {
96 Ok(x) => x,
97 Err(e) => {
98 break Err(Error::from(e));
99 }
100 }
101 };
102 ($e:expr) => {
103 match $e {
104 Ok(x) => x,
105 Err(e) => {
106 break Err(Error::from(e));
107 }
108 }
109 };
110}
111
112#[macro_export]
113macro_rules! atomic_id {
114 ($name:ident) => {
115 #[derive(
116 Debug,
117 Clone,
118 Copy,
119 PartialEq,
120 Eq,
121 PartialOrd,
122 Ord,
123 Hash,
124 Serialize,
125 Deserialize,
126 )]
127 pub struct $name(u64);
128
129 impl $name {
130 pub fn new() -> Self {
131 use std::sync::atomic::{AtomicU64, Ordering};
132 static NEXT: AtomicU64 = AtomicU64::new(0);
133 $name(NEXT.fetch_add(1, Ordering::Relaxed))
134 }
135
136 pub fn inner(&self) -> u64 {
137 self.0
138 }
139
140 #[cfg(test)]
141 #[allow(dead_code)]
142 pub fn mk(i: u64) -> Self {
143 $name(i)
144 }
145 }
146
147 impl netidx_core::pack::Pack for $name {
148 fn encoded_len(&self) -> usize {
149 netidx_core::pack::varint_len(self.0)
150 }
151
152 fn encode(
153 &self,
154 buf: &mut impl bytes::BufMut,
155 ) -> std::result::Result<(), netidx_core::pack::PackError> {
156 Ok(netidx_core::pack::encode_varint(self.0, buf))
157 }
158
159 fn decode(
160 buf: &mut impl bytes::Buf,
161 ) -> std::result::Result<Self, netidx_core::pack::PackError> {
162 Ok(Self(netidx_core::pack::decode_varint(buf)?))
163 }
164 }
165 };
166}
167
168pub fn check_addr<A>(ip: IpAddr, resolvers: &[(SocketAddr, A)]) -> Result<()> {
169 match ip {
170 IpAddr::V4(ip) if ip.is_link_local() => {
171 bail!("addr is a link local address");
172 }
173 IpAddr::V4(ip) if ip.is_broadcast() => {
174 bail!("addr is a broadcast address");
175 }
176 IpAddr::V4(ip) if ip.is_private() => {
177 let ok = resolvers.iter().all(|(a, _)| match a.ip() {
178 IpAddr::V4(ip) if ip.is_private() || ip.is_loopback() => true,
179 IpAddr::V6(_) => true,
180 _ => false,
181 });
182 if !ok {
183 bail!("addr is a private address, and the resolver is not")
184 }
185 }
186 _ => (),
187 }
188 if ip.is_unspecified() {
189 bail!("addr is an unspecified address");
190 }
191 if ip.is_multicast() {
192 bail!("addr is a multicast address");
193 }
194 if ip.is_loopback() && !resolvers.iter().all(|(a, _)| a.ip().is_loopback()) {
195 bail!("addr is a loopback address and the resolver is not");
196 }
197 Ok(())
198}
199
200thread_local! {
201 static BUF: RefCell<BytesMut> = RefCell::new(BytesMut::with_capacity(512));
202}
203
204pub fn make_sha3_token<'a>(data: impl IntoIterator<Item = &'a [u8]> + 'a) -> Bytes {
205 let mut hash = Sha3_512::new();
206 for v in data.into_iter() {
207 hash.update(v);
208 }
209 BUF.with(|buf| {
210 let mut b = buf.borrow_mut();
211 b.extend(hash.finalize().into_iter());
212 b.split().freeze()
213 })
214}
215
216pub fn pack<T: Pack>(t: &T) -> Result<BytesMut, PackError> {
218 BUF.with(|buf| {
219 let mut b = buf.borrow_mut();
220 t.encode(&mut *b)?;
221 Ok(b.split())
222 })
223}
224
225pub fn bytesmut(t: &[u8]) -> BytesMut {
226 BUF.with(|buf| {
227 let mut b = buf.borrow_mut();
228 b.extend_from_slice(t);
229 b.split()
230 })
231}
232
233pub fn bytes(t: &[u8]) -> Bytes {
234 bytesmut(t).freeze()
235}
236
237#[derive(Clone, Debug)]
238pub struct ChanWrap<T>(pub mpsc::Sender<T>);
239
240impl<T> PartialEq for ChanWrap<T> {
241 fn eq(&self, other: &ChanWrap<T>) -> bool {
242 self.0.same_receiver(&other.0)
243 }
244}
245
246impl<T> Eq for ChanWrap<T> {}
247
248impl<T> Hash for ChanWrap<T> {
249 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
250 self.0.hash_receiver(state)
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
255pub struct ChanId(u64);
256
257impl ChanId {
258 pub fn new() -> Self {
259 use std::sync::atomic::{AtomicU64, Ordering};
260 static NEXT: AtomicU64 = AtomicU64::new(0);
261 ChanId(NEXT.fetch_add(1, Ordering::Relaxed))
262 }
263}
264
265#[derive(Debug, Clone)]
266pub enum BatchItem<T> {
267 InBatch(T),
268 EndBatch,
269}
270
271#[must_use = "streams do nothing unless polled"]
272pub struct Batched<S: Stream> {
273 stream: S,
274 ended: bool,
275 max: usize,
276 current: usize,
277}
278
279impl<S: Stream> Batched<S> {
280 unsafe_pinned!(stream: S);
285
286 unsafe_unpinned!(ended: bool);
288 unsafe_unpinned!(current: usize);
289
290 pub fn new(stream: S, max: usize) -> Batched<S> {
291 Batched { stream, max, ended: false, current: 0 }
292 }
293
294 pub fn inner(&self) -> &S {
295 &self.stream
296 }
297
298 pub fn inner_mut(&mut self) -> &mut S {
299 &mut self.stream
300 }
301
302 pub fn into_inner(self) -> S {
303 self.stream
304 }
305}
306
307impl<S: Stream> Stream for Batched<S> {
308 type Item = BatchItem<<S as Stream>::Item>;
309
310 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
311 if self.ended {
312 Poll::Ready(None)
313 } else if self.current >= self.max {
314 *self.current() = 0;
315 Poll::Ready(Some(BatchItem::EndBatch))
316 } else {
317 match self.as_mut().stream().poll_next(cx) {
318 Poll::Ready(Some(v)) => {
319 *self.as_mut().current() += 1;
320 Poll::Ready(Some(BatchItem::InBatch(v)))
321 }
322 Poll::Ready(None) => {
323 *self.as_mut().ended() = true;
324 if self.current == 0 {
325 Poll::Ready(None)
326 } else {
327 *self.current() = 0;
328 Poll::Ready(Some(BatchItem::EndBatch))
329 }
330 }
331 Poll::Pending => {
332 if self.current == 0 {
333 Poll::Pending
334 } else {
335 *self.current() = 0;
336 Poll::Ready(Some(BatchItem::EndBatch))
337 }
338 }
339 }
340 }
341 }
342
343 fn size_hint(&self) -> (usize, Option<usize>) {
344 self.stream.size_hint()
345 }
346}
347
348impl<S: Stream> FusedStream for Batched<S> {
349 fn is_terminated(&self) -> bool {
350 self.ended
351 }
352}
353
354impl<Item, S: Stream + Sink<Item>> Sink<Item> for Batched<S> {
355 type Error = <S as Sink<Item>>::Error;
356
357 fn poll_ready(
358 self: Pin<&mut Self>,
359 cx: &mut Context,
360 ) -> Poll<Result<(), Self::Error>> {
361 self.stream().poll_ready(cx)
362 }
363
364 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
365 self.stream().start_send(item)
366 }
367
368 fn poll_flush(
369 self: Pin<&mut Self>,
370 cx: &mut Context,
371 ) -> Poll<Result<(), Self::Error>> {
372 self.stream().poll_flush(cx)
373 }
374
375 fn poll_close(
376 self: Pin<&mut Self>,
377 cx: &mut Context,
378 ) -> Poll<Result<(), Self::Error>> {
379 self.stream().poll_close(cx)
380 }
381}
382
383#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
386pub struct Addr(pub SocketAddr);
387
388impl From<SocketAddr> for Addr {
389 fn from(addr: SocketAddr) -> Self {
390 Addr(addr)
391 }
392}
393
394impl Borrow<SocketAddr> for Addr {
395 fn borrow(&self) -> &SocketAddr {
396 &self.0
397 }
398}
399
400impl PartialOrd for Addr {
401 fn partial_cmp(&self, other: &Addr) -> Option<Ordering> {
402 match (self.0, other.0) {
403 (SocketAddr::V4(v0), SocketAddr::V4(v1)) => {
404 match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
405 None => None,
406 Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
407 Some(o) => Some(o),
408 }
409 }
410 (SocketAddr::V6(v0), SocketAddr::V6(v1)) => {
411 match v0.ip().octets().partial_cmp(&v1.ip().octets()) {
412 None => None,
413 Some(Ordering::Equal) => v0.port().partial_cmp(&v1.port()),
414 Some(o) => Some(o),
415 }
416 }
417 (SocketAddr::V4(_), SocketAddr::V6(_)) => Some(Ordering::Less),
418 (SocketAddr::V6(_), SocketAddr::V4(_)) => Some(Ordering::Greater),
419 }
420 }
421}
422
423impl Ord for Addr {
424 fn cmp(&self, other: &Self) -> Ordering {
425 self.partial_cmp(other).unwrap()
426 }
427}
428
429#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)]
430pub enum Either<T, U> {
431 Left(T),
432 Right(U),
433}
434
435impl<T, U> Either<T, U> {
436 pub fn is_left(&self) -> bool {
437 match self {
438 Self::Left(_) => true,
439 Self::Right(_) => false,
440 }
441 }
442
443 pub fn is_right(&self) -> bool {
444 match self {
445 Self::Left(_) => false,
446 Self::Right(_) => true,
447 }
448 }
449
450 pub fn left(self) -> Option<T> {
451 match self {
452 Either::Left(t) => Some(t),
453 Either::Right(_) => None,
454 }
455 }
456
457 pub fn right(self) -> Option<U> {
458 match self {
459 Either::Right(t) => Some(t),
460 Either::Left(_) => None,
461 }
462 }
463}
464
465impl<I, T: Iterator<Item = I>, U: Iterator<Item = I>> Iterator for Either<T, U> {
466 type Item = I;
467 fn next(&mut self) -> Option<I> {
468 match self {
469 Either::Left(t) => t.next(),
470 Either::Right(t) => t.next(),
471 }
472 }
473}