simulon/
state.rs

1use std::{
2    cell::RefCell,
3    collections::{BinaryHeap, HashMap, VecDeque},
4    sync::Arc,
5};
6
7use futures::executor::LocalPool;
8use fxhash::FxHashMap;
9
10use crate::{
11    api::{ConnectError, RemoteAddr},
12    future::{DeferredFuture, DeferredFutureWaker},
13    message::{Ignored, Message, MessageDetail},
14    report::{Metrics, NodeMetrics},
15    storage::TypedStorage,
16};
17
18thread_local! {
19    static NODES: RefCell<*mut NodeState> = RefCell::new(std::ptr::null_mut());
20}
21
22/// Provide the node that is being executed on the current thread.
23pub fn hook_node(node: *mut NodeState) {
24    NODES.with(|cell| {
25        cell.replace(node);
26    })
27}
28
29/// Provide the [`NodeState`] of the node that is currently being executed to the passed
30/// closure.
31pub fn with_node<F, T>(f: F) -> T
32where
33    F: FnOnce(&mut NodeState) -> T,
34{
35    NODES.with(|cell| {
36        let ptr = *cell.borrow();
37
38        if ptr.is_null() {
39            panic!("Simulon API function used outside of executor.");
40        }
41
42        let mut_ref = unsafe { &mut *ptr };
43
44        f(mut_ref)
45    })
46}
47
48/// The state of a single node.
49pub struct NodeState {
50    /// The global index of this node.
51    pub node_id: usize,
52    /// The total number of nodes in the simulation.
53    pub count_nodes: usize,
54    /// The current epoch time in the simulation.
55    pub time: u128,
56    /// The pool of futures that we need to execute.
57    pub spawn_pool: LocalPool,
58    /// The resource table that holds the mapping from a resource id to the resource.
59    pub resources: HashMap<ResourceId, Resource>,
60    /// Ports that we're listening on.
61    pub listening: FxHashMap<u16, ListenerState>,
62    /// The array of pending outgoing requests.
63    pub outgoing: Vec<Message>,
64    /// The messages which we have received and should execute when the time comes.
65    pub received: BinaryHeap<Message>,
66    /// The collected metrics during the entire execution.
67    pub metrics: NodeMetrics,
68    /// The current metrics being collected for the current frame.
69    pub current_metrics: Metrics,
70    /// The shared storage across all nodes.
71    pub storage: Arc<TypedStorage>,
72    /// The already emitted events.
73    pub emitted: FxHashMap<String, u128>,
74    next_rid: usize,
75    _clean_up: WithCleanUpDrop,
76}
77
78// When we're getting dropped some futures may still be pending and hold state which
79// may contain a connection or listener, and their Drop will use the `with_node` function
80// we will have an error if the current thread doesn't have a reference to the node that
81// is being executed. So here we set the current node as the node that is getting executed.
82impl Drop for NodeState {
83    fn drop(&mut self) {
84        hook_node(self as *mut NodeState);
85    }
86}
87
88/// This is the last field of the [`NodeState`] that allows us to run a function
89/// after other fields are dropped. So that we can remote the pointer to this node
90/// from the thread_local and remote the dangling pointer which would remain otherwise.
91struct WithCleanUpDrop;
92
93impl Drop for WithCleanUpDrop {
94    #[inline(always)]
95    fn drop(&mut self) {
96        hook_node(std::ptr::null_mut());
97    }
98}
99
100// We're cool with moving `LocalPool` across different threads.
101unsafe impl Sync for NodeState {}
102unsafe impl Send for NodeState {}
103
104#[derive(Default)]
105pub struct ListenerState {
106    /// The current ongoing accept future.
107    pub accept: Option<DeferredFutureWaker<Option<AcceptResponse>>>,
108    /// The queued connection requests. We enqueue these when there is
109    /// not an active 'accept'.
110    pub queue: VecDeque<(RemoteAddr, ResourceId)>,
111}
112
113pub struct AcceptResponse {
114    pub local_rid: ResourceId,
115    pub remote: RemoteAddr,
116    pub remote_rid: ResourceId,
117}
118
119#[derive(Debug, Clone, Copy, Hash, PartialEq, PartialOrd, Ord, Eq)]
120pub struct ResourceId(pub(crate) usize);
121
122pub enum Resource {
123    PendingConnection {
124        waker: DeferredFutureWaker<Result<ResourceId, ConnectError>>,
125        queue: VecDeque<Vec<u8>>,
126    },
127    EstablishedConnection {
128        recv: Option<DeferredFutureWaker<Option<Vec<u8>>>>,
129        queue: VecDeque<Vec<u8>>,
130    },
131}
132
133impl NodeState {
134    /// Create the empty state of a node.
135    pub fn new(storage: Arc<TypedStorage>, count_nodes: usize, node_id: usize) -> Self {
136        Self {
137            node_id,
138            count_nodes,
139            time: 0,
140            spawn_pool: LocalPool::new(),
141            resources: HashMap::with_capacity(16),
142            listening: HashMap::default(),
143            outgoing: Vec::new(),
144            received: BinaryHeap::new(),
145            metrics: NodeMetrics::default(),
146            current_metrics: Metrics::default(),
147            storage,
148            emitted: FxHashMap::default(),
149            next_rid: 0,
150            _clean_up: WithCleanUpDrop,
151        }
152    }
153
154    /// Consumes and returns the next resource id.
155    #[inline(always)]
156    fn get_rid(&mut self) -> ResourceId {
157        let rid = self.next_rid;
158        self.next_rid += 1;
159        ResourceId(rid)
160    }
161
162    /// Returns the current time on the node.
163    pub fn now(&mut self) -> u128 {
164        let now = self.time;
165        self.time += 1;
166        now
167    }
168
169    /// Send a request to establish a connection with the given peer on the provided port number.
170    ///
171    /// Returns a future that will be resolved when the connection is established.
172    pub fn connect(
173        &mut self,
174        remote: RemoteAddr,
175        port: u16,
176    ) -> (ResourceId, DeferredFuture<Result<ResourceId, ConnectError>>) {
177        let rid = self.get_rid();
178        let future = DeferredFuture::new();
179        let resource = Resource::PendingConnection {
180            waker: future.waker(),
181            queue: VecDeque::default(),
182        };
183
184        let message = Message {
185            sender: RemoteAddr(self.node_id),
186            receiver: remote,
187            time: std::cmp::Reverse(self.now()),
188            detail: MessageDetail::Connect { port, rid },
189        };
190
191        self.resources.insert(rid, resource);
192        self.outgoing.push(message);
193        self.current_metrics.connections_requested += 1;
194
195        (rid, future)
196    }
197
198    pub fn listen(&mut self, port: u16) {
199        if self.listening.contains_key(&port) {
200            panic!("Port {port} is already in used.");
201        }
202
203        self.listening.insert(port, ListenerState::default());
204    }
205
206    pub fn close_listener(&mut self, port: u16) {
207        let mut state = self
208            .listening
209            .remove(&port)
210            .unwrap_or_else(|| panic!("Port {port} is not being listened on"));
211
212        if let Some(waker) = state.accept.take() {
213            waker.wake(None);
214        }
215
216        for (addr, rid) in state.queue {
217            self.refuse_connection(addr, rid);
218        }
219    }
220
221    pub fn accept(&mut self, port: u16) -> DeferredFuture<Option<AcceptResponse>> {
222        let listener_state = self.listening.get_mut(&port).expect("Illegal accept call.");
223
224        assert!(
225            listener_state.accept.is_none(),
226            "Another accept call is still on-going."
227        );
228
229        if let Some((addr, rid)) = listener_state.queue.pop_front() {
230            self.current_metrics.connections_accepted += 1;
231            let res = self.accepted(addr, rid);
232            DeferredFuture::resolved(Some(res))
233        } else {
234            let future = DeferredFuture::new();
235            listener_state.accept = Some(future.waker());
236            future
237        }
238    }
239
240    pub fn send(&mut self, remote: RemoteAddr, rid: ResourceId, data: Vec<u8>) {
241        self.current_metrics.msg_sent += 1;
242        self.current_metrics.bytes_sent += data.len() as u64;
243
244        let message = Message {
245            sender: RemoteAddr(self.node_id),
246            receiver: remote,
247            time: std::cmp::Reverse(self.now()),
248            detail: MessageDetail::Data {
249                receiver_rid: rid,
250                data,
251            },
252        };
253
254        self.outgoing.push(message);
255    }
256
257    pub fn recv(&mut self, rid: ResourceId) -> DeferredFuture<Option<Vec<u8>>> {
258        let resource = self
259            .resources
260            .get_mut(&rid)
261            .expect("recv: Resource not found.");
262
263        let (recv, queue) = if let Resource::EstablishedConnection { recv, queue } = resource {
264            (recv, queue)
265        } else {
266            panic!("Invalid resource type.");
267        };
268
269        if let Some(msg) = queue.pop_front() {
270            self.current_metrics.msg_processed += 1;
271            self.current_metrics.bytes_processed += msg.len() as u64;
272            DeferredFuture::resolved(Some(msg))
273        } else {
274            assert!(recv.is_none(), "Another recv is already in progress.");
275            let future = DeferredFuture::<Option<Vec<u8>>>::new();
276            *recv = Some(future.waker());
277            future
278        }
279    }
280
281    pub fn close_connection(&mut self, local_rid: ResourceId, addr: RemoteAddr, rid: ResourceId) {
282        self.close_local_connection(local_rid);
283
284        let message = Message {
285            sender: RemoteAddr(self.node_id),
286            receiver: addr,
287            time: std::cmp::Reverse(self.now()),
288            detail: MessageDetail::ConnectionClosed { receiver_rid: rid },
289        };
290
291        self.outgoing.push(message);
292    }
293
294    pub fn is_connection_open(&mut self, rid: ResourceId) -> bool {
295        self.resources.contains_key(&rid)
296    }
297
298    pub fn sleep(&mut self, duration: u128) -> DeferredFuture<()> {
299        let future = DeferredFuture::new();
300        let waker = future.waker();
301
302        let message = Message {
303            sender: RemoteAddr(self.node_id),
304            receiver: RemoteAddr(self.node_id),
305            time: std::cmp::Reverse(self.now() + duration),
306            detail: MessageDetail::WakeUp {
307                waker: Ignored(waker),
308            },
309        };
310        self.received.push(message);
311
312        future
313    }
314
315    pub fn emit(&mut self, key: String) {
316        assert!(
317            self.emitted.insert(key, self.time).is_none(),
318            "Node has already emitted the event."
319        );
320    }
321
322    fn close_local_connection(&mut self, rid: ResourceId) {
323        let resource = if let Some(resource) = self.resources.remove(&rid) {
324            resource
325        } else {
326            return;
327        };
328
329        self.current_metrics.connections_closed += 1;
330
331        let (mut recv, _queue) = if let Resource::EstablishedConnection { recv, queue } = resource {
332            (recv, queue)
333        } else {
334            panic!("Invalid resource type.");
335        };
336
337        if let Some(waker) = recv.take() {
338            waker.wake(None);
339        }
340    }
341
342    fn process_message(&mut self, our_rid: ResourceId, data: Vec<u8>) {
343        let resource = self
344            .resources
345            .get_mut(&our_rid)
346            .expect("process_message: Resource not found.");
347
348        let (recv, queue) = match resource {
349            Resource::EstablishedConnection { recv, queue } => (recv, queue),
350            Resource::PendingConnection { queue, .. } => {
351                queue.push_back(data);
352                return;
353            },
354        };
355
356        if let Some(waker) = recv.take() {
357            self.current_metrics.msg_processed += 1;
358            self.current_metrics.bytes_processed += data.len() as u64;
359            waker.wake(Some(data));
360        } else {
361            queue.push_back(data);
362        }
363    }
364
365    fn accepted(&mut self, addr: RemoteAddr, remote_rid: ResourceId) -> AcceptResponse {
366        let rid = self.get_rid();
367
368        let resource = Resource::EstablishedConnection {
369            recv: None,
370            queue: VecDeque::new(),
371        };
372
373        let message = Message {
374            sender: RemoteAddr(self.node_id),
375            receiver: addr,
376            time: std::cmp::Reverse(self.now()),
377            detail: MessageDetail::ConnectionAccepted {
378                sender_rid: rid,
379                receiver_rid: remote_rid,
380            },
381        };
382
383        self.resources.insert(rid, resource);
384        self.outgoing.push(message);
385
386        AcceptResponse {
387            remote_rid,
388            local_rid: rid,
389            remote: addr,
390        }
391    }
392
393    fn refuse_connection(&mut self, addr: RemoteAddr, remote_rid: ResourceId) {
394        let message = Message {
395            sender: RemoteAddr(self.node_id),
396            receiver: addr,
397            time: std::cmp::Reverse(self.now()),
398            detail: MessageDetail::ConnectionRefused {
399                receiver_rid: remote_rid,
400            },
401        };
402
403        self.outgoing.push(message);
404    }
405
406    fn maybe_accept_new_connection(&mut self, port: u16, addr: RemoteAddr, rid: ResourceId) {
407        let listener_state = if let Some(s) = self.listening.get_mut(&port) {
408            s
409        } else {
410            self.current_metrics.connections_refused += 1;
411            return self.refuse_connection(addr, rid);
412        };
413
414        if let Some(waker) = listener_state.accept.take() {
415            self.current_metrics.connections_accepted += 1;
416            let res = self.accepted(addr, rid);
417            waker.wake(Some(res));
418        } else {
419            listener_state.queue.push_back((addr, rid));
420        }
421    }
422
423    fn resolve_connection(
424        &mut self,
425        our_rid: ResourceId,
426        result: Result<ResourceId, ConnectError>,
427    ) {
428        let resource = self.resources.remove(&our_rid).unwrap();
429
430        let queue = if let Resource::PendingConnection { waker, queue } = resource {
431            waker.wake(result);
432            queue
433        } else {
434            VecDeque::new()
435        };
436
437        if result.is_ok() {
438            self.resources.insert(
439                our_rid,
440                Resource::EstablishedConnection { recv: None, queue },
441            );
442        }
443    }
444
445    pub fn is_stalled(&self) -> bool {
446        !matches!(self.received.peek(), Some(msg) if msg.time.0 <= self.time)
447    }
448
449    pub fn run_until_stalled(&mut self) {
450        while !self.is_stalled() {
451            let msg = self.received.pop().unwrap();
452
453            // eprintln!("\t>current {}: {:?}", self.node_id, msg);
454
455            match msg.detail {
456                MessageDetail::Connect { port, rid } => {
457                    self.maybe_accept_new_connection(port, msg.sender, rid);
458                },
459                MessageDetail::ConnectionAccepted {
460                    sender_rid,
461                    receiver_rid,
462                } => {
463                    self.resolve_connection(receiver_rid, Ok(sender_rid));
464                },
465                MessageDetail::ConnectionRefused { receiver_rid } => {
466                    self.current_metrics.connections_failed += 1;
467                    self.resolve_connection(receiver_rid, Err(ConnectError::RemoteIsDown));
468                },
469                MessageDetail::ConnectionClosed { receiver_rid: rid } => {
470                    self.close_local_connection(rid);
471                },
472                MessageDetail::Data { receiver_rid, data } => {
473                    self.current_metrics.msg_received += 1;
474                    self.current_metrics.bytes_received += data.len() as u64;
475                    self.process_message(receiver_rid, data);
476                },
477                MessageDetail::WakeUp { waker } => {
478                    waker.wake(());
479                },
480            }
481        }
482
483        self.spawn_pool.run_until_stalled();
484    }
485}