1#![allow(dead_code)]
2
3use crate::error::Error;
4use crate::protocol::{ErrorMessageBody, InitMessageBody, Message};
5use crate::waitgroup::WaitGroup;
6use crate::{rpc_err_to_response, RPCResult};
7use async_trait::async_trait;
8use futures::FutureExt;
9use log::{debug, error, info, warn};
10use serde::Serialize;
11use serde_json::Value;
12use simple_error::bail;
13use std::collections::HashMap;
14use std::future::Future;
15use std::sync::atomic::AtomicU64;
16use std::sync::atomic::Ordering::{AcqRel, Release};
17use std::sync::Arc;
18use tokio::io::{stdin, stdout, AsyncBufReadExt, AsyncRead, AsyncWriteExt, BufReader, Stdout};
19use tokio::select;
20use tokio::sync::oneshot::Sender;
21use tokio::sync::{mpsc, Mutex, OnceCell};
22use tokio::task::JoinHandle;
23use tokio_context::context::Context;
24
25pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
26
27pub struct Runtime {
28 inter: Arc<Inter>,
31}
32
33struct Inter {
34 msg_id: AtomicU64,
35
36 membership: OnceCell<MembershipState>,
41
42 handler: OnceCell<Arc<dyn Node>>,
43
44 rpc: Mutex<HashMap<u64, Sender<Message>>>,
45
46 out: Mutex<Stdout>,
47
48 serving: WaitGroup,
49}
50
51#[async_trait]
53pub trait Node: Sync + Send {
54 async fn process(&self, runtime: Runtime, request: Message) -> Result<()>;
79}
80
81#[allow(clippy::needless_pass_by_value)]
102pub fn done(runtime: Runtime, message: Message) -> Result<()> {
103 if message.get_type() == "init" {
104 return Ok(());
105 }
106
107 let err = Error::NotSupported(message.body.typ.clone());
108 let msg: ErrorMessageBody = err.clone().into();
109
110 let runtime0 = runtime.clone();
111 runtime.spawn(async move {
112 let _ = runtime0.reply(message, msg).await;
113 });
114
115 Err(Box::new(err))
116}
117
118#[derive(Clone, Debug, Eq, PartialEq, Default)]
119pub struct MembershipState {
120 pub node_id: String,
121 pub nodes: Vec<String>,
122}
123
124impl Runtime {
125 pub fn init<F: Future>(future: F) -> F::Output {
126 let runtime = tokio::runtime::Runtime::new().unwrap();
127 let _guard = runtime.enter();
128
129 crate::log::builder().init();
130 debug!("inited");
131
132 runtime.block_on(future)
133 }
134}
135
136impl Runtime {
137 #[must_use]
138 pub fn new() -> Self {
139 Runtime::default()
140 }
141
142 #[must_use]
143 pub fn with_handler(self, handler: Arc<dyn Node + Send + Sync>) -> Self {
144 assert!(
145 self.inter.handler.set(handler).is_ok(),
146 "runtime handler is already initialized"
147 );
148 self
149 }
150
151 pub async fn send_raw(&self, msg: &str) -> Result<()> {
152 {
153 let mut out = self.inter.out.lock().await;
154 out.write_all(msg.as_bytes()).await?;
155 out.write_all(b"\n").await?;
156 }
157 info!("Sent {}", msg);
158 Ok(())
159 }
160
161 pub fn send_async<T>(&self, to: impl Into<String>, message: T) -> Result<()>
162 where
163 T: Serialize + Send,
164 {
165 let runtime = self.clone();
166 let msg = crate::protocol::message(self.node_id(), to, message)?;
167 let ans = serde_json::to_string(&msg)?;
168 self.spawn(async move {
169 if let Err(err) = runtime.send_raw(ans.as_str()).await {
170 error!("send error: {}", err);
171 }
172 });
173 Ok(())
174 }
175
176 pub async fn send<T>(&self, to: impl Into<String>, message: T) -> Result<()>
177 where
178 T: Serialize,
179 {
180 let msg = crate::protocol::message(self.node_id(), to, message)?;
181 let ans = serde_json::to_string(&msg)?;
182 self.send_raw(ans.as_str()).await
183 }
184
185 pub async fn send_back<T>(&self, req: Message, resp: T) -> Result<()>
186 where
187 T: Serialize,
188 {
189 self.send(req.src, resp).await
190 }
191
192 pub async fn reply<T>(&self, req: Message, resp: T) -> Result<()>
193 where
194 T: Serialize,
195 {
196 let mut msg = crate::protocol::message(self.node_id(), req.src, resp)?;
197 msg.body.in_reply_to = req.body.msg_id;
198
199 if !msg.body.extra.contains_key("type") && !req.body.typ.is_empty() {
200 let key = "type".to_string();
201 let value = Value::String(req.body.typ + "_ok");
202 msg.body.extra.insert(key, value);
203 }
204
205 let answer = serde_json::to_string(&msg)?;
206 self.send_raw(answer.as_str()).await
207 }
208
209 pub async fn reply_ok(&self, req: Message) -> Result<()> {
210 self.reply(req, Runtime::empty_response()).await
211 }
212
213 #[track_caller]
214 pub fn spawn<T>(&self, future: T) -> JoinHandle<T::Output>
215 where
216 T: Future + Send + 'static,
217 T::Output: Send + 'static,
218 {
219 let h = self.inter.serving.clone();
220 tokio::spawn(future.then(|x| async move {
221 drop(h);
222 x
223 }))
224 }
225
226 pub fn rpc<T>(
274 &self,
275 to: impl Into<String>,
276 request: T,
277 ) -> impl Future<Output = Result<RPCResult>>
278 where
279 T: Serialize,
280 {
281 let msg = crate::protocol::message(self.node_id(), to, request);
282
283 let req_msg_id = self.next_msg_id();
284 let req_res: Result<String> = match msg {
285 Ok(mut t) => {
286 t.body.msg_id = req_msg_id;
287 match serde_json::to_string(&t) {
288 Ok(s) => Ok(s),
289 Err(e) => Err(Box::new(e)),
290 }
291 }
292 Err(e) => Err(e),
293 };
294
295 crate::rpc(self.clone(), req_msg_id, req_res)
296 }
297
298 pub async fn call<T>(&self, ctx: Context, to: impl Into<String>, request: T) -> Result<Message>
305 where
306 T: Serialize,
307 {
308 let mut call = self.rpc(to, request).await?;
309 call.done_with(ctx).await
310 }
311
312 pub fn call_async<T>(&self, to: impl Into<String>, request: T)
315 where
316 T: Serialize + 'static,
317 {
318 self.spawn(self.rpc(to.into(), request));
319 }
320
321 #[must_use]
322 pub fn node_id(&self) -> &str {
323 if let Some(v) = self.inter.membership.get() {
324 return v.node_id.as_str();
325 }
326 ""
327 }
328
329 #[must_use]
330 pub fn nodes(&self) -> &[String] {
331 if let Some(v) = self.inter.membership.get() {
332 return v.nodes.as_slice();
333 }
334 &[]
335 }
336
337 pub fn set_membership_state(&self, state: MembershipState) -> Result<()> {
338 debug!("new {:?}", state);
339
340 if let Err(e) = self.inter.membership.set(state) {
341 bail!("membership is inited: {}", e);
342 }
343
344 self.inter.msg_id.store(1, Release);
346
347 Ok(())
348 }
349
350 pub async fn done(&self) {
351 self.inter.serving.wait().await;
352 }
353
354 pub async fn run(&self) -> Result<()> {
355 self.run_with(BufReader::new(stdin())).await
356 }
357
358 pub async fn run_with<R>(&self, input: BufReader<R>) -> Result<()>
359 where
360 R: AsyncRead + Unpin,
361 {
362 let stdin = input;
363
364 let (tx_err, mut rx_err) = mpsc::channel::<Result<()>>(1);
365 let mut tx_out: Result<()> = Ok(());
366
367 let mut lines_from_stdin = stdin.lines();
368 loop {
369 select! {
370 Ok(read) = lines_from_stdin.next_line().fuse() => {
371 match read {
372 Some(line) =>{
373 if line.trim().is_empty() {
374 continue;
375 }
376
377 info!("Received {}", line);
378
379 let tx_err0 = tx_err.clone();
380 self.spawn(Self::process_request(self.clone(), line).then(|result| async move {
381 if let Err(e) = result {
382 if let Some(Error::NotSupported(t)) = e.downcast_ref::<Error>() {
383 warn!("message type not supported: {}", t);
384 } else {
385 error!("process_request error: {}", e);
386 let _ = tx_err0.send(Err(e)).await;
387 }
388 }
389 }));
390 }
391 None => break
392 }
393 },
394 Some(e) = rx_err.recv() => { tx_out = e; break },
395 else => break
396 }
397 }
398
399 select! {
400 _ = self.done() => {},
401 Some(e) = rx_err.recv() => tx_out = e,
402 }
403
404 if tx_out.is_ok() {
405 if let Ok(err) = rx_err.try_recv() {
406 tx_out = err;
407 }
408 }
409
410 rx_err.close();
411
412 if let Err(e) = tx_out {
413 debug!("node error: {}", e);
414 return Err(e);
415 }
416
417 debug!("node done");
419
420 Ok(())
421 }
422
423 async fn process_request(runtime: Runtime, line: String) -> Result<()> {
424 let msg = match serde_json::from_str::<Message>(line.as_str()) {
425 Ok(v) => v,
426 Err(err) => return Err(Box::new(err)),
427 };
428
429 if msg.body.in_reply_to > 0 {
431 let mut guard = runtime.inter.rpc.lock().await;
432 if let Some(tx) = guard.remove(&msg.body.in_reply_to) {
433 drop(guard);
435 drop(tx.send(msg));
438 }
439 return Ok(());
440 }
441
442 let mut init_source: Option<(String, u64)> = None;
443 let is_init = msg.get_type() == "init";
444 if is_init {
445 init_source = Some((msg.src.clone(), msg.body.msg_id));
446 runtime.process_init(&msg)?;
447 }
448
449 if let Some(handler) = runtime.inter.handler.get() {
450 let res = handler.process(runtime.clone(), msg.clone()).await;
452 if res.is_err() {
453 if let Some(user_err) = rpc_err_to_response(&res) {
455 runtime.reply(msg, user_err).await?;
456 } else {
457 return res;
458 }
459 }
460 }
461
462 if is_init {
463 let (dst, msg_id) = init_source.unwrap();
464 let init_resp: Value = serde_json::from_str(
465 format!(r#"{{"in_reply_to":{msg_id},"type":"init_ok"}}"#).as_str(),
466 )?;
467 return runtime.send(dst, init_resp).await;
468 }
469
470 Ok(())
471 }
472
473 fn process_init(&self, message: &Message) -> Result<()> {
474 let raw = message.body.extra.clone();
475 let init = serde_json::from_value::<InitMessageBody>(Value::Object(raw))?;
476 self.set_membership_state(MembershipState {
477 node_id: init.node_id,
478 nodes: init.nodes,
479 })
480 }
481
482 #[inline]
483 #[must_use]
484 pub fn next_msg_id(&self) -> u64 {
485 self.inter.msg_id.fetch_add(1, AcqRel)
486 }
487
488 #[inline]
489 #[must_use]
490 pub fn empty_response() -> Value {
491 Value::Object(serde_json::Map::default())
492 }
493
494 #[inline]
495 pub(crate) async fn insert_rpc_sender(
496 &self,
497 id: u64,
498 tx: Sender<Message>,
499 ) -> Option<Sender<Message>> {
500 self.inter.rpc.lock().await.insert(id, tx)
501 }
502
503 #[inline]
504 pub(crate) async fn release_rpc_sender(&self, id: u64) -> Option<Sender<Message>> {
505 self.inter.rpc.lock().await.remove(&id)
506 }
507
508 #[inline]
509 #[must_use]
510 pub fn is_client(&self, src: &String) -> bool {
511 !src.is_empty() && src.starts_with('c')
512 }
513
514 #[inline]
515 #[must_use]
516 pub fn is_from_cluster(&self, src: &String) -> bool {
517 !src.is_empty() && src.starts_with('n')
519 }
520
521 #[inline]
523 pub fn neighbours(&self) -> impl Iterator<Item = &String> {
524 let n = self.node_id();
525 self.nodes()
526 .iter()
527 .filter(move |t: &&String| t.as_str() != n)
528 }
529}
530
531impl Default for Runtime {
532 fn default() -> Self {
533 Runtime {
534 inter: Arc::new(Inter {
535 msg_id: AtomicU64::new(1),
536 membership: OnceCell::new(),
537 handler: OnceCell::new(),
538 rpc: Mutex::default(),
539 out: Mutex::new(stdout()),
540 serving: WaitGroup::new(),
541 }),
542 }
543 }
544}
545
546impl Clone for Runtime {
547 fn clone(&self) -> Self {
548 Runtime {
549 inter: self.inter.clone(),
550 }
551 }
552}
553
554#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
555pub struct BlackHoleNode {}
556
557#[async_trait]
558impl Node for BlackHoleNode {
559 async fn process(&self, _: Runtime, _: Message) -> Result<()> {
560 Ok(())
561 }
562}
563
564#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
566pub struct IOFailingNode {}
567
568#[async_trait]
569impl Node for IOFailingNode {
570 async fn process(&self, _: Runtime, _: Message) -> Result<()> {
571 bail!("IOFailingNode: process failed")
572 }
573}
574
575#[derive(Default, Copy, Clone, PartialEq, Eq, Debug)]
576pub struct EchoNode {}
577
578#[async_trait]
579impl Node for EchoNode {
580 async fn process(&self, runtime: Runtime, req: Message) -> Result<()> {
581 let resp = Value::Object(serde_json::Map::default());
582 runtime.reply(req, resp).await
583 }
584}
585
586#[cfg(test)]
587mod test {
588 use crate::{MembershipState, Result, Runtime};
589 use tokio::io::BufReader;
590 use tokio_util::sync::CancellationToken;
591
592 #[test]
593 fn membership() -> Result<()> {
594 let tokio_runtime = tokio::runtime::Runtime::new()?;
595 tokio_runtime.block_on(async move {
596 let runtime = Runtime::new();
597 let runtime0 = runtime.clone();
598 let s1 = MembershipState::example("n0", &["n0", "n1"]);
599 let s2 = MembershipState::example("n1", &["n0", "n1"]);
600 runtime.spawn(async move {
601 runtime0.set_membership_state(s1).unwrap();
602 async move {
603 assert!(matches!(runtime0.set_membership_state(s2), Err(_)));
604 }
605 .await;
606 });
607 runtime.done().await;
608 assert_eq!(
609 runtime.node_id(),
610 "n0",
611 "invalid node id, can't be anything else"
612 );
613 });
614 Ok(())
615 }
616
617 impl MembershipState {
618 fn example(n: &str, s: &[&str]) -> Self {
619 return MembershipState {
620 node_id: n.to_string(),
621 nodes: s.iter().map(|x| x.to_string()).collect(),
622 };
623 }
624 }
625
626 #[tokio::test]
627 async fn io_failure() {
628 let handler = std::sync::Arc::new(crate::IOFailingNode::default());
629 let runtime = Runtime::new().with_handler(handler);
630 let cursor = std::io::Cursor::new(
631 r#"
632
633 {"src":"c0","dest":"n0","body":{"type":"echo","msg_id":1}}
634 "#,
635 );
636 let token = CancellationToken::new();
637 runtime.spawn(async move { token.cancelled().await });
638 let run = runtime.run_with(BufReader::new(cursor));
639 assert!(matches!(run.await, Err(_)));
640 }
641}