potatonet_common/
bus_message.rs

1use crate::{LocalServiceId, NodeId, ServiceId};
2use anyhow::Result;
3use async_std::net::TcpStream;
4use bytes::Bytes;
5use futures::channel::mpsc::{Receiver, Sender};
6use futures::prelude::*;
7use std::io::Cursor;
8use std::sync::Arc;
9
10const MAX_DATA_SIZE: usize = 1024 * 1024;
11
12#[derive(Serialize, Deserialize, Debug)]
13pub enum Message {
14    /// 客户端发送ping
15    Ping,
16
17    /// 客户端断开连接
18    Bye,
19
20    /// 服务端发送欢迎消息
21    Hello(NodeId),
22
23    /// 注册服务
24    RegisterService { name: String, id: LocalServiceId },
25
26    /// 注销服务
27    UnregisterService { id: LocalServiceId },
28
29    /// 客户端发送请求
30    Req {
31        seq: u32,
32        from: LocalServiceId,
33        to_service: String,
34        method: u32,
35        data: Bytes,
36    },
37
38    /// 服务端发送请求
39    XReq {
40        from: ServiceId,
41        to: LocalServiceId,
42        seq: u32,
43        method: u32,
44        data: Bytes,
45    },
46
47    /// 服务端发送响应
48    Rep {
49        seq: u32,
50        result: Result<Bytes, String>,
51    },
52
53    /// 客户端发送通知
54    Notify {
55        from: LocalServiceId,
56        to_service: String,
57        method: u32,
58        data: Bytes,
59    },
60
61    /// 客户端给指定服务发送通知
62    NotifyTo {
63        from: LocalServiceId,
64        to: ServiceId,
65        method: u32,
66        data: Bytes,
67    },
68
69    /// 服务端发送通知
70    XNotify {
71        from: ServiceId,
72        to_service: String,
73        method: u32,
74        data: Bytes,
75    },
76
77    /// 服务端给指定服务发送通知
78    XNotifyTo {
79        from: ServiceId,
80        to: LocalServiceId,
81        method: u32,
82        data: Bytes,
83    },
84
85    /// 客户端订阅消息请求
86    Subscribe { topic: String },
87
88    /// 客户端取消订阅消息请求
89    Unsubscribe { topic: String },
90
91    /// 客户端发布消息
92    Publish { topic: String, data: Bytes },
93
94    /// 服务器发布消息
95    XPublish { topic: String, data: Bytes },
96}
97
98async fn read_message<R: AsyncRead + Unpin>(mut r: R, buf: &mut Vec<u8>) -> Result<Message> {
99    let mut len = [0u8; 4];
100    r.read_exact(&mut len).await?;
101    let data_size = u32::from_le_bytes(len) as usize;
102    if data_size > MAX_DATA_SIZE {
103        bail!("data length exceeding the limit");
104    }
105    buf.resize(data_size, 0);
106    r.read_exact(buf).await?;
107    let msg: Message = rmp_serde::from_read(Cursor::new(&buf))?;
108    Ok(msg)
109}
110
111async fn write_message<W: AsyncWrite + Unpin>(
112    mut w: W,
113    msg: &Message,
114    buf: &mut Vec<u8>,
115) -> Result<()> {
116    buf.clear();
117    rmp_serde::encode::write(buf, &msg)?;
118    w.write(&(buf.len() as u32).to_le_bytes()).await?;
119    w.write(&buf).await?;
120    Ok(())
121}
122
123pub async fn read_one_message<R: AsyncRead + Unpin>(r: R) -> Result<Message> {
124    let mut buf = Vec::new();
125    read_message(r, &mut buf).await
126}
127
128pub async fn read_messages(stream: Arc<TcpStream>, mut tx: Sender<Message>) {
129    let mut buf = Vec::with_capacity(1024);
130    while let Ok(msg) = read_message(&*stream, &mut buf).await {
131        if let Err(_) = tx.send(msg).await {
132            // 连接已断开
133            break;
134        }
135    }
136}
137
138pub async fn write_messages(stream: Arc<TcpStream>, mut rx: Receiver<Message>) {
139    let mut buf = Vec::with_capacity(1024);
140
141    while let Some(msg) = rx.next().await {
142        if let Err(_) = write_message(&*stream, &msg, &mut buf).await {
143            // 连接已断开
144            break;
145        }
146    }
147}