broadcast/
broadcast.rs

1/// ```bash
2/// $ cargo build --examples
3/// $ RUST_LOG=debug maelstrom test -w broadcast --bin ./target/debug/examples/broadcast --node-count 2 --time-limit 20 --rate 10 --log-stderr
4/// ````
5use async_trait::async_trait;
6use log::info;
7use maelstrom::protocol::Message;
8use maelstrom::{done, Node, Result, Runtime};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::sync::{Arc, Mutex};
12
13pub(crate) fn main() -> Result<()> {
14    Runtime::init(try_main())
15}
16
17async fn try_main() -> Result<()> {
18    let handler = Arc::new(Handler::default());
19    Runtime::new().with_handler(handler).run().await
20}
21
22#[derive(Clone, Default)]
23struct Handler {
24    inner: Arc<Mutex<Inner>>,
25}
26
27#[derive(Clone, Default)]
28struct Inner {
29    s: HashSet<u64>,
30    t: Vec<String>,
31}
32
33#[async_trait]
34impl Node for Handler {
35    async fn process(&self, runtime: Runtime, req: Message) -> Result<()> {
36        let msg: Result<Request> = req.body.as_obj();
37        match msg {
38            Ok(Request::Read {}) => {
39                let data = self.snapshot();
40                let msg = Request::ReadOk { messages: data };
41                return runtime.reply(req, msg).await;
42            }
43            Ok(Request::Broadcast { message: element }) => {
44                if self.try_add(element) {
45                    info!("messages now {}", element);
46                    for node in runtime.neighbours() {
47                        runtime.call_async(node, Request::Broadcast { message: element });
48                    }
49                }
50
51                return runtime.reply_ok(req).await;
52            }
53            Ok(Request::Topology { topology }) => {
54                let neighbours = topology.get(runtime.node_id()).unwrap();
55                self.inner.lock().unwrap().t = neighbours.clone();
56                info!("My neighbors are {:?}", neighbours);
57                return runtime.reply_ok(req).await;
58            }
59            _ => done(runtime, req),
60        }
61    }
62}
63
64impl Handler {
65    fn snapshot(&self) -> Vec<u64> {
66        self.inner.lock().unwrap().s.iter().copied().collect()
67    }
68
69    fn try_add(&self, val: u64) -> bool {
70        let mut g = self.inner.lock().unwrap();
71        if !g.s.contains(&val) {
72            g.s.insert(val);
73            return true;
74        }
75        false
76    }
77}
78
79#[derive(Serialize, Deserialize)]
80#[serde(rename_all = "snake_case", tag = "type")]
81enum Request {
82    Init {},
83    Read {},
84    ReadOk {
85        messages: Vec<u64>,
86    },
87    Broadcast {
88        message: u64,
89    },
90    Topology {
91        topology: HashMap<String, Vec<String>>,
92    },
93}