maelstrom/
runtime.rs

1#![allow(dead_code)]
2
3use crate::error::Error;
4use crate::protocol::{ErrorMessageBody, InitMessageBody, Message};
5use crate::waitgroup::WaitGroup;
6use crate::{rpc_err_to_response, RPCResult};
7use async_trait::async_trait;
8use futures::FutureExt;
9use log::{debug, error, info, warn};
10use serde::Serialize;
11use serde_json::Value;
12use simple_error::bail;
13use std::collections::HashMap;
14use std::future::Future;
15use std::sync::atomic::AtomicU64;
16use std::sync::atomic::Ordering::{AcqRel, Release};
17use std::sync::Arc;
18use tokio::io::{stdin, stdout, AsyncBufReadExt, AsyncRead, AsyncWriteExt, BufReader, Stdout};
19use tokio::select;
20use tokio::sync::oneshot::Sender;
21use tokio::sync::{mpsc, Mutex, OnceCell};
22use tokio::task::JoinHandle;
23use tokio_context::context::Context;
24
25pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
26
27pub struct Runtime {
28    // we need an arc<> here to be able to pass runtime.clone() from &Self run() further
29    // to the handler.
30    inter: Arc<Inter>,
31}
32
33struct Inter {
34    msg_id: AtomicU64,
35
36    // OnceCell seems works better here than RwLock, but what do we think of cluster membership change?
37    // How the API should behave if the Maelstrom will send the second init message?
38    // Let's pick the approach when it is possible to only once initialize a node and a cluster
39    // membership change must start a new node and stop the old ones.
40    membership: OnceCell<MembershipState>,
41
42    handler: OnceCell<Arc<dyn Node>>,
43
44    rpc: Mutex<HashMap<u64, Sender<Message>>>,
45
46    out: Mutex<Stdout>,
47
48    serving: WaitGroup,
49}
50
51// Handler is the trait that implements message handling.
52#[async_trait]
53pub trait Node: Sync + Send {
54    /// Main handler function that processes incoming requests.
55    ///
56    /// Example:
57    ///
58    /// ```
59    /// use async_trait::async_trait;
60    /// use maelstrom::protocol::Message;
61    /// use maelstrom::{Node, Result, Runtime, done};
62    ///
63    /// struct Handler {}
64    ///
65    /// #[async_trait]
66    /// impl Node for Handler {
67    ///     async fn process(&self, runtime: Runtime, req: Message) -> Result<()> {
68    ///         if req.get_type() == "echo" {
69    ///             let echo = req.body.clone().with_type("echo_ok");
70    ///             return runtime.reply(req, echo).await;
71    ///         }
72    ///
73    ///         // all other types are unsupported
74    ///         done(runtime, req)
75    ///     }
76    /// }
77    /// ```
78    async fn process(&self, runtime: Runtime, request: Message) -> Result<()>;
79}
80
81/// Returns a result with `NotSupported` error meaning that Node.process()
82/// is not aware of specific message type or Ok(()) for init.
83///
84/// Example:
85///
86/// ```
87/// use async_trait::async_trait;
88/// use maelstrom::{Node, Runtime, Result, done};
89/// use maelstrom::protocol::Message;
90///
91/// struct Handler {}
92///
93/// #[async_trait]
94/// impl Node for Handler {
95///     async fn process(&self, runtime: Runtime, req: Message) -> Result<()> {
96///         // would skip init and respond with Code == 10 for any other type.
97///         done(runtime, req)
98///     }
99/// }
100/// ```
101#[allow(clippy::needless_pass_by_value)]
102pub fn done(runtime: Runtime, message: Message) -> Result<()> {
103    if message.get_type() == "init" {
104        return Ok(());
105    }
106
107    let err = Error::NotSupported(message.body.typ.clone());
108    let msg: ErrorMessageBody = err.clone().into();
109
110    let runtime0 = runtime.clone();
111    runtime.spawn(async move {
112        let _ = runtime0.reply(message, msg).await;
113    });
114
115    Err(Box::new(err))
116}
117
118#[derive(Clone, Debug, Eq, PartialEq, Default)]
119pub struct MembershipState {
120    pub node_id: String,
121    pub nodes: Vec<String>,
122}
123
124impl Runtime {
125    pub fn init<F: Future>(future: F) -> F::Output {
126        let runtime = tokio::runtime::Runtime::new().unwrap();
127        let _guard = runtime.enter();
128
129        crate::log::builder().init();
130        debug!("inited");
131
132        runtime.block_on(future)
133    }
134}
135
136impl Runtime {
137    #[must_use]
138    pub fn new() -> Self {
139        Runtime::default()
140    }
141
142    #[must_use]
143    pub fn with_handler(self, handler: Arc<dyn Node + Send + Sync>) -> Self {
144        assert!(
145            self.inter.handler.set(handler).is_ok(),
146            "runtime handler is already initialized"
147        );
148        self
149    }
150
151    pub async fn send_raw(&self, msg: &str) -> Result<()> {
152        {
153            let mut out = self.inter.out.lock().await;
154            out.write_all(msg.as_bytes()).await?;
155            out.write_all(b"\n").await?;
156        }
157        info!("Sent {}", msg);
158        Ok(())
159    }
160
161    pub fn send_async<T>(&self, to: impl Into<String>, message: T) -> Result<()>
162    where
163        T: Serialize + Send,
164    {
165        let runtime = self.clone();
166        let msg = crate::protocol::message(self.node_id(), to, message)?;
167        let ans = serde_json::to_string(&msg)?;
168        self.spawn(async move {
169            if let Err(err) = runtime.send_raw(ans.as_str()).await {
170                error!("send error: {}", err);
171            }
172        });
173        Ok(())
174    }
175
176    pub async fn send<T>(&self, to: impl Into<String>, message: T) -> Result<()>
177    where
178        T: Serialize,
179    {
180        let msg = crate::protocol::message(self.node_id(), to, message)?;
181        let ans = serde_json::to_string(&msg)?;
182        self.send_raw(ans.as_str()).await
183    }
184
185    pub async fn send_back<T>(&self, req: Message, resp: T) -> Result<()>
186    where
187        T: Serialize,
188    {
189        self.send(req.src, resp).await
190    }
191
192    pub async fn reply<T>(&self, req: Message, resp: T) -> Result<()>
193    where
194        T: Serialize,
195    {
196        let mut msg = crate::protocol::message(self.node_id(), req.src, resp)?;
197        msg.body.in_reply_to = req.body.msg_id;
198
199        if !msg.body.extra.contains_key("type") && !req.body.typ.is_empty() {
200            let key = "type".to_string();
201            let value = Value::String(req.body.typ + "_ok");
202            msg.body.extra.insert(key, value);
203        }
204
205        let answer = serde_json::to_string(&msg)?;
206        self.send_raw(answer.as_str()).await
207    }
208
209    pub async fn reply_ok(&self, req: Message) -> Result<()> {
210        self.reply(req, Runtime::empty_response()).await
211    }
212
213    #[track_caller]
214    pub fn spawn<T>(&self, future: T) -> JoinHandle<T::Output>
215    where
216        T: Future + Send + 'static,
217        T::Output: Send + 'static,
218    {
219        let h = self.inter.serving.clone();
220        tokio::spawn(future.then(|x| async move {
221            drop(h);
222            x
223        }))
224    }
225
226    /// rpc() makes a remote call to another node via message passing interface.
227    /// Provided context may serve as a timeout limiter.
228    /// `RPCResult` is immediately canceled on drop.
229    ///
230    /// Example:
231    /// ```
232    /// use maelstrom::{Error, Result, Runtime};
233    /// use std::fmt::{Display, Formatter};
234    /// use serde::Serialize;
235    /// use serde::Deserialize;
236    /// use tokio_context::context::Context;
237    ///
238    /// pub struct Storage {
239    ///     typ: &'static str,
240    ///     runtime: Runtime,
241    /// }
242    ///
243    /// impl Storage {
244    ///     async fn get<T>(&self, ctx: Context, key: String) -> Result<T>
245    ///         where
246    ///             T: Deserialize<'static> + Send,
247    ///     {
248    ///         let req = Message::Read::<String> { key };
249    ///         let mut call = self.runtime.rpc(self.typ, req).await?;
250    ///         let msg = call.done_with(ctx).await?;
251    ///         let data = msg.body.as_obj::<Message<T>>()?;
252    ///         match data {
253    ///             Message::ReadOk { value } => Ok(value),
254    ///             _ => Err(Box::new(Error::Custom(
255    ///                 -1,
256    ///                 "kv: protocol violated".to_string(),
257    ///             ))),
258    ///         }
259    ///     }
260    /// }
261    ///
262    /// #[derive(Serialize, Deserialize)]
263    /// #[serde(rename_all = "snake_case", tag = "type")]
264    /// enum Message<T> {
265    ///     Read {
266    ///         key: String,
267    ///     },
268    ///     ReadOk {
269    ///         value: T,
270    ///     },
271    /// }
272    /// ```
273    pub fn rpc<T>(
274        &self,
275        to: impl Into<String>,
276        request: T,
277    ) -> impl Future<Output = Result<RPCResult>>
278    where
279        T: Serialize,
280    {
281        let msg = crate::protocol::message(self.node_id(), to, request);
282
283        let req_msg_id = self.next_msg_id();
284        let req_res: Result<String> = match msg {
285            Ok(mut t) => {
286                t.body.msg_id = req_msg_id;
287                match serde_json::to_string(&t) {
288                    Ok(s) => Ok(s),
289                    Err(e) => Err(Box::new(e)),
290                }
291            }
292            Err(e) => Err(e),
293        };
294
295        crate::rpc(self.clone(), req_msg_id, req_res)
296    }
297
298    /// call() is the same as `let _: Result<Message> = rpc().await?.done_with(ctx).await;`.
299    /// for examples see [`Runtime::rpc`] and [`RPCResult`].
300    ///
301    /// rpc() makes a remote call to another node via message passing interface.
302    /// Provided context may serve as a timeout limiter.
303    /// `RPCResult` is immediately canceled on drop.
304    pub async fn call<T>(&self, ctx: Context, to: impl Into<String>, request: T) -> Result<Message>
305    where
306        T: Serialize,
307    {
308        let mut call = self.rpc(to, request).await?;
309        call.done_with(ctx).await
310    }
311
312    /// `call_async`() is equivalent to `runtime.spawn(runtime.call(...))`.
313    /// see [`Runtime::call`], [`Runtime::rpc`].
314    pub fn call_async<T>(&self, to: impl Into<String>, request: T)
315    where
316        T: Serialize + 'static,
317    {
318        self.spawn(self.rpc(to.into(), request));
319    }
320
321    #[must_use]
322    pub fn node_id(&self) -> &str {
323        if let Some(v) = self.inter.membership.get() {
324            return v.node_id.as_str();
325        }
326        ""
327    }
328
329    #[must_use]
330    pub fn nodes(&self) -> &[String] {
331        if let Some(v) = self.inter.membership.get() {
332            return v.nodes.as_slice();
333        }
334        &[]
335    }
336
337    pub fn set_membership_state(&self, state: MembershipState) -> Result<()> {
338        debug!("new {:?}", state);
339
340        if let Err(e) = self.inter.membership.set(state) {
341            bail!("membership is inited: {}", e);
342        }
343
344        // new node = new message sequence
345        self.inter.msg_id.store(1, Release);
346
347        Ok(())
348    }
349
350    pub async fn done(&self) {
351        self.inter.serving.wait().await;
352    }
353
354    pub async fn run(&self) -> Result<()> {
355        self.run_with(BufReader::new(stdin())).await
356    }
357
358    pub async fn run_with<R>(&self, input: BufReader<R>) -> Result<()>
359    where
360        R: AsyncRead + Unpin,
361    {
362        let stdin = input;
363
364        let (tx_err, mut rx_err) = mpsc::channel::<Result<()>>(1);
365        let mut tx_out: Result<()> = Ok(());
366
367        let mut lines_from_stdin = stdin.lines();
368        loop {
369            select! {
370                Ok(read) = lines_from_stdin.next_line().fuse() => {
371                    match read {
372                        Some(line) =>{
373                            if line.trim().is_empty() {
374                                continue;
375                            }
376
377                            info!("Received {}", line);
378
379                            let tx_err0 = tx_err.clone();
380                            self.spawn(Self::process_request(self.clone(), line).then(|result| async move  {
381                                if let Err(e) = result {
382                                    if let Some(Error::NotSupported(t)) = e.downcast_ref::<Error>() {
383                                        warn!("message type not supported: {}", t);
384                                    } else {
385                                        error!("process_request error: {}", e);
386                                        let _ = tx_err0.send(Err(e)).await;
387                                    }
388                                }
389                            }));
390                        }
391                        None => break
392                    }
393                },
394                Some(e) = rx_err.recv() => { tx_out = e; break },
395                else => break
396            }
397        }
398
399        select! {
400            _ = self.done() => {},
401            Some(e) = rx_err.recv() => tx_out = e,
402        }
403
404        if tx_out.is_ok() {
405            if let Ok(err) = rx_err.try_recv() {
406                tx_out = err;
407            }
408        }
409
410        rx_err.close();
411
412        if let Err(e) = tx_out {
413            debug!("node error: {}", e);
414            return Err(e);
415        }
416
417        // TODO: print stats?
418        debug!("node done");
419
420        Ok(())
421    }
422
423    async fn process_request(runtime: Runtime, line: String) -> Result<()> {
424        let msg = match serde_json::from_str::<Message>(line.as_str()) {
425            Ok(v) => v,
426            Err(err) => return Err(Box::new(err)),
427        };
428
429        // rpc call
430        if msg.body.in_reply_to > 0 {
431            let mut guard = runtime.inter.rpc.lock().await;
432            if let Some(tx) = guard.remove(&msg.body.in_reply_to) {
433                // we don't need to hold mutex for doing long blocking tx.send.
434                drop(guard);
435                // at the moment we don't care of the send err because I expect
436                // it will fail only if the rx end is closed.
437                drop(tx.send(msg));
438            }
439            return Ok(());
440        }
441
442        let mut init_source: Option<(String, u64)> = None;
443        let is_init = msg.get_type() == "init";
444        if is_init {
445            init_source = Some((msg.src.clone(), msg.body.msg_id));
446            runtime.process_init(&msg)?;
447        }
448
449        if let Some(handler) = runtime.inter.handler.get() {
450            // I am not happy we are cloning a msg here, but let it go this time.
451            let res = handler.process(runtime.clone(), msg.clone()).await;
452            if res.is_err() {
453                // rpc error is user level error
454                if let Some(user_err) = rpc_err_to_response(&res) {
455                    runtime.reply(msg, user_err).await?;
456                } else {
457                    return res;
458                }
459            }
460        }
461
462        if is_init {
463            let (dst, msg_id) = init_source.unwrap();
464            let init_resp: Value = serde_json::from_str(
465                format!(r#"{{"in_reply_to":{msg_id},"type":"init_ok"}}"#).as_str(),
466            )?;
467            return runtime.send(dst, init_resp).await;
468        }
469
470        Ok(())
471    }
472
473    fn process_init(&self, message: &Message) -> Result<()> {
474        let raw = message.body.extra.clone();
475        let init = serde_json::from_value::<InitMessageBody>(Value::Object(raw))?;
476        self.set_membership_state(MembershipState {
477            node_id: init.node_id,
478            nodes: init.nodes,
479        })
480    }
481
482    #[inline]
483    #[must_use]
484    pub fn next_msg_id(&self) -> u64 {
485        self.inter.msg_id.fetch_add(1, AcqRel)
486    }
487
488    #[inline]
489    #[must_use]
490    pub fn empty_response() -> Value {
491        Value::Object(serde_json::Map::default())
492    }
493
494    #[inline]
495    pub(crate) async fn insert_rpc_sender(
496        &self,
497        id: u64,
498        tx: Sender<Message>,
499    ) -> Option<Sender<Message>> {
500        self.inter.rpc.lock().await.insert(id, tx)
501    }
502
503    #[inline]
504    pub(crate) async fn release_rpc_sender(&self, id: u64) -> Option<Sender<Message>> {
505        self.inter.rpc.lock().await.remove(&id)
506    }
507
508    #[inline]
509    #[must_use]
510    pub fn is_client(&self, src: &String) -> bool {
511        !src.is_empty() && src.starts_with('c')
512    }
513
514    #[inline]
515    #[must_use]
516    pub fn is_from_cluster(&self, src: &String) -> bool {
517        // alternative implementation: self.nodes().contains(src)
518        !src.is_empty() && src.starts_with('n')
519    }
520
521    /// All nodes that are not this node.
522    #[inline]
523    pub fn neighbours(&self) -> impl Iterator<Item = &String> {
524        let n = self.node_id();
525        self.nodes()
526            .iter()
527            .filter(move |t: &&String| t.as_str() != n)
528    }
529}
530
531impl Default for Runtime {
532    fn default() -> Self {
533        Runtime {
534            inter: Arc::new(Inter {
535                msg_id: AtomicU64::new(1),
536                membership: OnceCell::new(),
537                handler: OnceCell::new(),
538                rpc: Mutex::default(),
539                out: Mutex::new(stdout()),
540                serving: WaitGroup::new(),
541            }),
542        }
543    }
544}
545
546impl Clone for Runtime {
547    fn clone(&self) -> Self {
548        Runtime {
549            inter: self.inter.clone(),
550        }
551    }
552}
553
554#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
555pub struct BlackHoleNode {}
556
557#[async_trait]
558impl Node for BlackHoleNode {
559    async fn process(&self, _: Runtime, _: Message) -> Result<()> {
560        Ok(())
561    }
562}
563
564// TODO: make err customizable
565#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
566pub struct IOFailingNode {}
567
568#[async_trait]
569impl Node for IOFailingNode {
570    async fn process(&self, _: Runtime, _: Message) -> Result<()> {
571        bail!("IOFailingNode: process failed")
572    }
573}
574
575#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
576pub struct EchoNode {}
577
578#[async_trait]
579impl Node for EchoNode {
580    async fn process(&self, runtime: Runtime, req: Message) -> Result<()> {
581        let resp = Value::Object(serde_json::Map::default());
582        runtime.reply(req, resp).await
583    }
584}
585
586#[cfg(test)]
587mod test {
588    use crate::{MembershipState, Result, Runtime};
589    use tokio::io::BufReader;
590    use tokio_util::sync::CancellationToken;
591
592    #[test]
593    fn membership() -> Result<()> {
594        let tokio_runtime = tokio::runtime::Runtime::new()?;
595        tokio_runtime.block_on(async move {
596            let runtime = Runtime::new();
597            let runtime0 = runtime.clone();
598            let s1 = MembershipState::example("n0", &["n0", "n1"]);
599            let s2 = MembershipState::example("n1", &["n0", "n1"]);
600            runtime.spawn(async move {
601                runtime0.set_membership_state(s1).unwrap();
602                async move {
603                    assert!(matches!(runtime0.set_membership_state(s2), Err(_)));
604                }
605                .await;
606            });
607            runtime.done().await;
608            assert_eq!(
609                runtime.node_id(),
610                "n0",
611                "invalid node id, can't be anything else"
612            );
613        });
614        Ok(())
615    }
616
617    impl MembershipState {
618        fn example(n: &str, s: &[&str]) -> Self {
619            return MembershipState {
620                node_id: n.to_string(),
621                nodes: s.iter().map(|x| x.to_string()).collect(),
622            };
623        }
624    }
625
626    #[tokio::test]
627    async fn io_failure() {
628        let handler = std::sync::Arc::new(crate::IOFailingNode::default());
629        let runtime = Runtime::new().with_handler(handler);
630        let cursor = std::io::Cursor::new(
631            r#"
632            
633            {"src":"c0","dest":"n0","body":{"type":"echo","msg_id":1}}
634            "#,
635        );
636        let token = CancellationToken::new();
637        runtime.spawn(async move { token.cancelled().await });
638        let run = runtime.run_with(BufReader::new(cursor));
639        assert!(matches!(run.await, Err(_)));
640    }
641}