netsim_embed_core/
lib.rs

1use async_io::Timer;
2use futures::channel::mpsc;
3use futures::future::FutureExt;
4use futures::stream::{Stream, StreamExt};
5use std::collections::VecDeque;
6use std::net::Ipv4Addr;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::{Duration, Instant};
10
11mod addr;
12mod packet;
13mod range;
14
15pub use packet::{Packet, Protocol};
16pub use range::Ipv4Range;
17
18#[derive(Clone, Copy, Debug)]
19pub struct Ipv4Route {
20    dest: Ipv4Range,
21    gateway: Option<Ipv4Addr>,
22}
23
24impl Ipv4Route {
25    /// Create a new route with the given destination and gateway.
26    pub fn new(dest: Ipv4Range, gateway: Option<Ipv4Addr>) -> Self {
27        Self { dest, gateway }
28    }
29
30    /// Returns the destination IP range of the route.
31    pub fn dest(&self) -> Ipv4Range {
32        self.dest
33    }
34
35    /// Returns the route's gateway (if any).
36    pub fn gateway(&self) -> Option<Ipv4Addr> {
37        self.gateway
38    }
39}
40
41impl From<Ipv4Range> for Ipv4Route {
42    fn from(range: Ipv4Range) -> Self {
43        Self::new(range, None)
44    }
45}
46
47impl From<Ipv4Addr> for Ipv4Route {
48    fn from(addr: Ipv4Addr) -> Self {
49        Self::new(addr.into(), None)
50    }
51}
52
53#[derive(Debug)]
54pub struct Plug {
55    tx: mpsc::UnboundedSender<Vec<u8>>,
56    rx: mpsc::UnboundedReceiver<Vec<u8>>,
57}
58
59impl Plug {
60    pub fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<Vec<u8>>> {
61        Pin::new(&mut self.rx).poll_next(cx)
62    }
63
64    pub async fn incoming(&mut self) -> Option<Vec<u8>> {
65        self.rx.next().await
66    }
67
68    pub fn unbounded_send(&mut self, packet: Vec<u8>) {
69        let _ = self.tx.unbounded_send(packet);
70    }
71
72    pub fn split(
73        self,
74    ) -> (
75        mpsc::UnboundedSender<Vec<u8>>,
76        mpsc::UnboundedReceiver<Vec<u8>>,
77    ) {
78        (self.tx, self.rx)
79    }
80}
81
82pub fn wire() -> (Plug, Plug) {
83    let (a_tx, b_rx) = mpsc::unbounded();
84    let (b_tx, a_rx) = mpsc::unbounded();
85    let a = Plug { tx: a_tx, rx: a_rx };
86    let b = Plug { tx: b_tx, rx: b_rx };
87    (a, b)
88}
89
90#[derive(Clone, Copy, Debug)]
91pub struct DelayBuffer {
92    delay: Duration,
93    buffer_size: usize,
94}
95
96impl Default for DelayBuffer {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl DelayBuffer {
103    pub fn new() -> Self {
104        Self {
105            delay: Duration::from_millis(0),
106            buffer_size: usize::MAX,
107        }
108    }
109
110    pub fn set_delay(&mut self, delay: Duration) {
111        self.delay = delay;
112    }
113
114    pub fn set_buffer_size(&mut self, buffer_size: usize) {
115        self.buffer_size = buffer_size;
116    }
117
118    pub fn spawn(self, mut b: Plug) -> Plug {
119        #[allow(non_snake_case)]
120        let DURATION_MAX: Duration = Duration::from_secs(10000);
121        let (mut c, d) = wire();
122        async_global_executor::spawn(async move {
123            let mut b_tx_buffer_size = 0;
124            let mut b_tx_buffer = VecDeque::new();
125            let mut c_tx_buffer_size = 0;
126            let mut c_tx_buffer = VecDeque::new();
127            let mut idle = true;
128            let mut timer = Timer::after(DURATION_MAX);
129            loop {
130                futures::select! {
131                    packet = b.incoming().fuse() => {
132                        if let Some(packet) = packet {
133                            if c_tx_buffer_size + packet.len() < self.buffer_size {
134                                c_tx_buffer_size += packet.len();
135                                let time = Instant::now();
136                                c_tx_buffer.push_back((packet, time + self.delay));
137                                if idle {
138                                    timer.set_after(self.delay);
139                                    idle = false;
140                                }
141                            }
142                        } else {
143                            break;
144                        }
145                    }
146                    packet = c.incoming().fuse() => {
147                        if let Some(packet) = packet {
148                            if b_tx_buffer_size + packet.len() < self.buffer_size {
149                                b_tx_buffer_size += packet.len();
150                                let time = Instant::now();
151                                b_tx_buffer.push_back((packet, time + self.delay));
152                                if idle {
153                                    timer.set_after(self.delay);
154                                    idle = false;
155                                }
156                            }
157                        } else {
158                            break;
159                        }
160                    }
161                    now = FutureExt::fuse(&mut timer) => {
162                        let mut wtime = DURATION_MAX;
163                        while let Some((packet, time)) = b_tx_buffer.front() {
164                            if *time <= now {
165                                b_tx_buffer_size -= packet.len();
166                                b.unbounded_send(b_tx_buffer.pop_front().unwrap().0);
167                            } else {
168                                let bwtime = time.duration_since(now);
169                                if wtime > bwtime {
170                                    wtime = bwtime;
171                                }
172                                break;
173                            }
174                        }
175                        while let Some((packet, time)) = c_tx_buffer.front() {
176                            if *time <= now {
177                                c_tx_buffer_size -= packet.len();
178                                c.unbounded_send(c_tx_buffer.pop_front().unwrap().0);
179                            } else {
180                                let cwtime = time.duration_since(now);
181                                if wtime > cwtime {
182                                    wtime = cwtime;
183                                }
184                                break;
185                            }
186                        }
187                        timer.set_after(wtime);
188                        idle = wtime == DURATION_MAX
189                    }
190                }
191            }
192        })
193        .detach();
194        d
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[async_std::test]
203    async fn test_delay() {
204        let (mut a, b) = wire();
205        let mut w = DelayBuffer::new();
206        w.set_delay(Duration::from_millis(100));
207        let mut b = w.spawn(b);
208        let now = Instant::now();
209        a.unbounded_send(vec![1]);
210        a.unbounded_send(vec![2]);
211        async_std::task::sleep(Duration::from_millis(10)).await;
212        a.unbounded_send(vec![3]);
213        a.unbounded_send(vec![4]);
214        b.incoming().await;
215        println!("{:?}", now.elapsed());
216        assert!(now.elapsed() >= Duration::from_millis(100));
217        assert!(now.elapsed() < Duration::from_millis(102));
218        b.incoming().await;
219        println!("{:?}", now.elapsed());
220        assert!(now.elapsed() >= Duration::from_millis(100));
221        assert!(now.elapsed() < Duration::from_millis(102));
222        b.incoming().await;
223        println!("{:?}", now.elapsed());
224        assert!(now.elapsed() >= Duration::from_millis(110));
225        assert!(now.elapsed() < Duration::from_millis(112));
226        b.incoming().await;
227        println!("{:?}", now.elapsed());
228        assert!(now.elapsed() >= Duration::from_millis(110));
229        assert!(now.elapsed() < Duration::from_millis(112));
230    }
231}