potatonet_bus/
lib.rs

1#![recursion_limit = "512"]
2
3mod requests;
4mod subscribes;
5
6#[macro_use]
7extern crate log;
8
9use crate::requests::Requests;
10use crate::subscribes::Subscribes;
11use anyhow::Result;
12use async_std::net::{TcpListener, TcpStream, ToSocketAddrs};
13use async_std::stream;
14use async_std::task;
15use futures::channel::mpsc::{channel, Sender};
16use futures::lock::Mutex;
17use futures::prelude::*;
18use futures::select;
19use potatonet_common::{bus_message, LocalServiceId, NodeId, ServiceId};
20use slab::Slab;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25/// 节点
26struct Node {
27    /// 节点提供的服务
28    services: HashMap<LocalServiceId, String>,
29
30    /// 上次心跳时间
31    hb: Instant,
32
33    /// 主动断开节点连接通知通道
34    tx_close: Sender<()>,
35
36    /// 发送数据通道
37    tx: Sender<bus_message::Message>,
38}
39
40/// 消息总线数据
41#[derive(Default)]
42struct Bus {
43    /// 节点集合
44    nodes: Slab<Node>,
45
46    /// 按服务名索引的服务id
47    services: HashMap<String, Vec<ServiceId>>,
48
49    /// 未完成的请求
50    /// 如果5秒内未收到节点发来的响应,则从该表删除
51    pending_requests: Requests,
52
53    /// 订阅信息
54    subscribes: Subscribes,
55}
56
57impl Bus {
58    fn find_service(&self, name: &str) -> Option<ServiceId> {
59        match self.services.get(name) {
60            Some(nodes) if !nodes.is_empty() => {
61                nodes.get(rand::random::<usize>() % nodes.len()).copied()
62            }
63            _ => None,
64        }
65    }
66
67    fn create_node(&mut self, tx: Sender<bus_message::Message>, tx_close: Sender<()>) -> NodeId {
68        let id = self.nodes.insert(Node {
69            services: Default::default(),
70            hb: Instant::now(),
71            tx_close,
72            tx,
73        });
74        NodeId(id as u32)
75    }
76}
77
78pub async fn run<A: ToSocketAddrs>(addr: A) -> Result<()> {
79    let bus: Arc<Mutex<Bus>> = Default::default();
80    let listener = TcpListener::bind(addr).await?;
81
82    let mut incoming = listener.incoming();
83    while let Some(stream) = incoming.next().await {
84        if let Ok(stream) = stream {
85            task::spawn(client_process(bus.clone(), stream));
86        }
87    }
88
89    Ok(())
90}
91
92async fn process_incoming_msg(bus: Arc<Mutex<Bus>>, node_id: NodeId, msg: bus_message::Message) {
93    match msg {
94        // 退出
95        bus_message::Message::Bye => {
96            trace!("[{}/MSG:BYE]", node_id);
97        }
98
99        // ping消息
100        bus_message::Message::Ping => {
101            trace!("[{}/MSG:PING]", node_id);
102            if let Some(node) = bus.lock().await.nodes.get_mut(node_id.0 as usize) {
103                node.hb = Instant::now();
104            }
105        }
106
107        // 注册服务
108        bus_message::Message::RegisterService { name, id } => {
109            trace!("[{}/MSG:REGISTER_SERVICE] name={} id={}", node_id, name, id);
110            let mut bus = bus.lock().await;
111            if let Some(node) = bus.nodes.get_mut(node_id.0 as usize) {
112                let service_id = id.to_global(node_id);
113                node.services.insert(id, name.clone());
114                bus.services
115                    .entry(name)
116                    .and_modify(|ids| ids.push(service_id))
117                    .or_insert_with(|| vec![service_id]);
118            }
119        }
120
121        // 注销服务
122        bus_message::Message::UnregisterService { id } => {
123            trace!("[{}/MSG:UNREGISTER_SERVICE] id={}", node_id, id);
124            let mut bus = bus.lock().await;
125            let service_id = id.to_global(node_id);
126            for (_, ids) in &mut bus.services {
127                if let Some(pos) = ids.iter().position(|x| *x == service_id) {
128                    ids.remove(pos);
129                    break;
130                }
131            }
132            if let Some(node) = bus.nodes.get_mut(node_id.0 as usize) {
133                node.services.remove(&id);
134            }
135        }
136
137        // 请求
138        bus_message::Message::Req {
139            seq,
140            from,
141            to_service,
142            method,
143            data,
144        } => {
145            trace!(
146                "[{}/MSG:REQUEST] seq={} from={} to_service={}, method={}",
147                node_id,
148                seq,
149                from,
150                to_service,
151                method
152            );
153            let from = from.to_global(node_id);
154            let mut bus_inner = bus.lock().await;
155            let to = match bus_inner.find_service(&to_service) {
156                Some(to) => to,
157                None => {
158                    // 服务不存在
159                    let err_msg = format!("service '{}' not exists", to_service);
160                    if let Some(node) = bus_inner.nodes.get_mut(node_id.0 as usize) {
161                        if let Err(_) = node.tx.try_send(bus_message::Message::Rep {
162                            seq,
163                            result: Err(err_msg),
164                        }) {
165                            // 数据发送失败,断开连接
166                            node.tx_close.try_send(()).ok();
167                        }
168                    }
169                    return;
170                }
171            };
172            let new_seq = bus_inner.pending_requests.add(seq, node_id);
173            if let Some(to_node) = bus_inner.nodes.get_mut(to.node_id.0 as usize) {
174                if let Err(_) = to_node.tx.try_send(bus_message::Message::XReq {
175                    from,
176                    to: to.local_service_id,
177                    seq: new_seq as u32,
178                    method,
179                    data,
180                }) {
181                    // 数据发送失败,断开连接
182                    to_node.tx_close.try_send(()).ok();
183                }
184            }
185
186            // 5秒未收到响应则删除
187            task::spawn({
188                let bus = bus.clone();
189                async move {
190                    task::sleep(Duration::from_secs(5)).await;
191                    let mut bus = bus.lock().await;
192                    bus.pending_requests.remove(new_seq);
193                }
194            });
195        }
196
197        // 响应
198        bus_message::Message::Rep { seq, result } => {
199            trace!("[{}/MSG:RESPONSE] seq={}", node_id, seq);
200            let mut bus = bus.lock().await;
201            if let Some((origin_seq, to_node_id)) = bus.pending_requests.remove(seq) {
202                if let Some(node) = bus.nodes.get_mut(to_node_id.0 as usize) {
203                    if let Err(_) = node.tx.try_send(bus_message::Message::Rep {
204                        seq: origin_seq,
205                        result,
206                    }) {
207                        // 数据发送失败,断开连接
208                        node.tx_close.try_send(()).ok();
209                    }
210                }
211            };
212        }
213
214        // 发送通知
215        bus_message::Message::Notify {
216            from,
217            to_service,
218            method,
219            data,
220        } => {
221            trace!(
222                "[{}/MSG:SEND_NOTIFY] from={} to_service={} method={}",
223                node_id,
224                from,
225                to_service,
226                method
227            );
228
229            // 通知其它节点的指定服务
230            let mut bus = bus.lock().await;
231            let bus = &mut *bus;
232
233            if let Some(services) = bus.services.get(&to_service) {
234                for service_id in services {
235                    if node_id == service_id.node_id {
236                        // 不通知来源节点
237                        continue;
238                    }
239
240                    let to_node = bus.nodes.get_mut(service_id.node_id.0 as usize).unwrap();
241                    if let Err(_) = to_node.tx.try_send(bus_message::Message::XNotify {
242                        from: from.to_global(node_id),
243                        to_service: to_service.clone(),
244                        method,
245                        data: data.clone(),
246                    }) {
247                        // 数据发送失败,断开连接
248                        to_node.tx_close.try_send(()).ok();
249                    }
250                }
251            }
252        }
253
254        // 给指定服务发送通知
255        bus_message::Message::NotifyTo {
256            from,
257            to,
258            method,
259            data,
260        } => {
261            trace!(
262                "[{}/MSG:SEND_NOTIFY_TO] from={} to={} method={}",
263                node_id,
264                from,
265                to,
266                method
267            );
268
269            // 通知其它节点的指定服务
270            let mut bus = bus.lock().await;
271            if let Some(node) = bus.nodes.get_mut(to.node_id.0 as usize) {
272                if let Err(_) = node.tx.try_send(bus_message::Message::XNotifyTo {
273                    from: from.to_global(node_id),
274                    to: to.local_service_id,
275                    method: method,
276                    data: data.clone(),
277                }) {
278                    // 数据发送失败,断开连接
279                    node.tx_close.try_send(()).ok();
280                }
281            }
282        }
283
284        // 订阅
285        bus_message::Message::Subscribe { topic } => {
286            trace!("[{}/MSG:SUBSCRIBE] topic={}", node_id, topic);
287            let mut bus = bus.lock().await;
288            bus.subscribes.subscribe(topic, node_id);
289        }
290
291        // 取消订阅
292        bus_message::Message::Unsubscribe { topic } => {
293            trace!("[{}/MSG:UNSUBSCRIBE] topic={}", node_id, topic);
294            let mut bus = bus.lock().await;
295            bus.subscribes.unsubscribe(topic, node_id);
296        }
297
298        // 发布消息
299        bus_message::Message::Publish { topic, data } => {
300            trace!("[{}/MSG:PUBLISH] topic={}", node_id, topic);
301            let mut bus = bus.lock().await;
302            let bus = &mut *bus;
303            for to_node_id in bus.subscribes.query(&topic) {
304                if let Some(to_node) = bus.nodes.get_mut(to_node_id.0 as usize) {
305                    if let Err(_) = to_node.tx.try_send(bus_message::Message::XPublish {
306                        topic: topic.clone(),
307                        data: data.clone(),
308                    }) {
309                        // 数据发送失败,断开连接
310                        to_node.tx_close.try_send(()).ok();
311                    }
312                }
313            }
314        }
315
316        _ => {}
317    }
318}
319
320/// 客户端消息处理
321async fn client_process(bus: Arc<Mutex<Bus>>, stream: TcpStream) {
322    let stream = Arc::new(stream);
323    let (tx_close, mut rx_close) = channel(1);
324    let (tx_incoming_msg, mut rx_incoming_msg) = channel(64);
325    let (mut tx_outgoing_msg, rx_outgoing_msg) = channel(64);
326    let node_id = bus
327        .lock()
328        .await
329        .create_node(tx_outgoing_msg.clone(), tx_close);
330
331    // 接收消息任务
332    // 当心跳超时后,通过abort_handle来关闭消息接收任务
333    let (reader_task, abort_reader) =
334        future::abortable(bus_message::read_messages(stream.clone(), tx_incoming_msg));
335    let reader_handle = task::spawn(reader_task);
336
337    // 写消息任务
338    let (writer_task, abort_writer) =
339        future::abortable(bus_message::write_messages(stream.clone(), rx_outgoing_msg));
340    let writer_handle = task::spawn(writer_task);
341    trace!("[{}/CONNECTED]", node_id);
342
343    // 发送欢迎消息
344    tx_outgoing_msg
345        .try_send(bus_message::Message::Hello(node_id))
346        .ok();
347    drop(tx_outgoing_msg);
348
349    // 心跳检测定时器
350    let mut check_hb = stream::interval(Duration::from_secs(1)).fuse();
351
352    loop {
353        select! {
354            _ = rx_close.next() => {
355                // 主动断开连接
356                break;
357            }
358            _ = check_hb.next() => {
359                if let Some(node) = bus.lock().await.nodes.get(node_id.0 as usize) {
360                    if node.hb.elapsed() > Duration::from_secs(30) {
361                        // 心跳超时
362                        trace!("[{}/MSG:HEARTBEAT_TIMEOUT]", node_id);
363                        break;
364                    }
365                }
366            }
367            msg = rx_incoming_msg.next() => {
368                if let Some(msg) = msg {
369                    let mut exit = false;
370                    if let bus_message::Message::Bye = &msg {
371                        exit = true;
372                    }
373                    process_incoming_msg(bus.clone(), node_id, msg).await;
374                    if exit {
375                        // 客户端退出
376                        break;
377                    }
378                } else {
379                    // 连接已断开
380                    trace!("client connection close. node_id={}", node_id);
381                    break;
382                }
383            }
384        }
385    }
386
387    // 节点下线
388    let mut bus = bus.lock().await;
389    bus.subscribes.remove_node(node_id);
390    for (_, ids) in &mut bus.services {
391        ids.retain(|id| id.node_id != node_id);
392    }
393    bus.nodes.remove(node_id.0 as usize);
394
395    // 等待读写任务关闭
396    abort_reader.abort();
397    abort_writer.abort();
398    reader_handle.await.ok();
399    writer_handle.await.ok();
400
401    trace!("[{}/DISCONNECTED]", node_id);
402}