1#![recursion_limit = "512"]
2
3mod requests;
4mod subscribes;
5
6#[macro_use]
7extern crate log;
8
9use crate::requests::Requests;
10use crate::subscribes::Subscribes;
11use anyhow::Result;
12use async_std::net::{TcpListener, TcpStream, ToSocketAddrs};
13use async_std::stream;
14use async_std::task;
15use futures::channel::mpsc::{channel, Sender};
16use futures::lock::Mutex;
17use futures::prelude::*;
18use futures::select;
19use potatonet_common::{bus_message, LocalServiceId, NodeId, ServiceId};
20use slab::Slab;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::{Duration, Instant};
24
25struct Node {
27 services: HashMap<LocalServiceId, String>,
29
30 hb: Instant,
32
33 tx_close: Sender<()>,
35
36 tx: Sender<bus_message::Message>,
38}
39
40#[derive(Default)]
42struct Bus {
43 nodes: Slab<Node>,
45
46 services: HashMap<String, Vec<ServiceId>>,
48
49 pending_requests: Requests,
52
53 subscribes: Subscribes,
55}
56
57impl Bus {
58 fn find_service(&self, name: &str) -> Option<ServiceId> {
59 match self.services.get(name) {
60 Some(nodes) if !nodes.is_empty() => {
61 nodes.get(rand::random::<usize>() % nodes.len()).copied()
62 }
63 _ => None,
64 }
65 }
66
67 fn create_node(&mut self, tx: Sender<bus_message::Message>, tx_close: Sender<()>) -> NodeId {
68 let id = self.nodes.insert(Node {
69 services: Default::default(),
70 hb: Instant::now(),
71 tx_close,
72 tx,
73 });
74 NodeId(id as u32)
75 }
76}
77
78pub async fn run<A: ToSocketAddrs>(addr: A) -> Result<()> {
79 let bus: Arc<Mutex<Bus>> = Default::default();
80 let listener = TcpListener::bind(addr).await?;
81
82 let mut incoming = listener.incoming();
83 while let Some(stream) = incoming.next().await {
84 if let Ok(stream) = stream {
85 task::spawn(client_process(bus.clone(), stream));
86 }
87 }
88
89 Ok(())
90}
91
92async fn process_incoming_msg(bus: Arc<Mutex<Bus>>, node_id: NodeId, msg: bus_message::Message) {
93 match msg {
94 bus_message::Message::Bye => {
96 trace!("[{}/MSG:BYE]", node_id);
97 }
98
99 bus_message::Message::Ping => {
101 trace!("[{}/MSG:PING]", node_id);
102 if let Some(node) = bus.lock().await.nodes.get_mut(node_id.0 as usize) {
103 node.hb = Instant::now();
104 }
105 }
106
107 bus_message::Message::RegisterService { name, id } => {
109 trace!("[{}/MSG:REGISTER_SERVICE] name={} id={}", node_id, name, id);
110 let mut bus = bus.lock().await;
111 if let Some(node) = bus.nodes.get_mut(node_id.0 as usize) {
112 let service_id = id.to_global(node_id);
113 node.services.insert(id, name.clone());
114 bus.services
115 .entry(name)
116 .and_modify(|ids| ids.push(service_id))
117 .or_insert_with(|| vec![service_id]);
118 }
119 }
120
121 bus_message::Message::UnregisterService { id } => {
123 trace!("[{}/MSG:UNREGISTER_SERVICE] id={}", node_id, id);
124 let mut bus = bus.lock().await;
125 let service_id = id.to_global(node_id);
126 for (_, ids) in &mut bus.services {
127 if let Some(pos) = ids.iter().position(|x| *x == service_id) {
128 ids.remove(pos);
129 break;
130 }
131 }
132 if let Some(node) = bus.nodes.get_mut(node_id.0 as usize) {
133 node.services.remove(&id);
134 }
135 }
136
137 bus_message::Message::Req {
139 seq,
140 from,
141 to_service,
142 method,
143 data,
144 } => {
145 trace!(
146 "[{}/MSG:REQUEST] seq={} from={} to_service={}, method={}",
147 node_id,
148 seq,
149 from,
150 to_service,
151 method
152 );
153 let from = from.to_global(node_id);
154 let mut bus_inner = bus.lock().await;
155 let to = match bus_inner.find_service(&to_service) {
156 Some(to) => to,
157 None => {
158 let err_msg = format!("service '{}' not exists", to_service);
160 if let Some(node) = bus_inner.nodes.get_mut(node_id.0 as usize) {
161 if let Err(_) = node.tx.try_send(bus_message::Message::Rep {
162 seq,
163 result: Err(err_msg),
164 }) {
165 node.tx_close.try_send(()).ok();
167 }
168 }
169 return;
170 }
171 };
172 let new_seq = bus_inner.pending_requests.add(seq, node_id);
173 if let Some(to_node) = bus_inner.nodes.get_mut(to.node_id.0 as usize) {
174 if let Err(_) = to_node.tx.try_send(bus_message::Message::XReq {
175 from,
176 to: to.local_service_id,
177 seq: new_seq as u32,
178 method,
179 data,
180 }) {
181 to_node.tx_close.try_send(()).ok();
183 }
184 }
185
186 task::spawn({
188 let bus = bus.clone();
189 async move {
190 task::sleep(Duration::from_secs(5)).await;
191 let mut bus = bus.lock().await;
192 bus.pending_requests.remove(new_seq);
193 }
194 });
195 }
196
197 bus_message::Message::Rep { seq, result } => {
199 trace!("[{}/MSG:RESPONSE] seq={}", node_id, seq);
200 let mut bus = bus.lock().await;
201 if let Some((origin_seq, to_node_id)) = bus.pending_requests.remove(seq) {
202 if let Some(node) = bus.nodes.get_mut(to_node_id.0 as usize) {
203 if let Err(_) = node.tx.try_send(bus_message::Message::Rep {
204 seq: origin_seq,
205 result,
206 }) {
207 node.tx_close.try_send(()).ok();
209 }
210 }
211 };
212 }
213
214 bus_message::Message::Notify {
216 from,
217 to_service,
218 method,
219 data,
220 } => {
221 trace!(
222 "[{}/MSG:SEND_NOTIFY] from={} to_service={} method={}",
223 node_id,
224 from,
225 to_service,
226 method
227 );
228
229 let mut bus = bus.lock().await;
231 let bus = &mut *bus;
232
233 if let Some(services) = bus.services.get(&to_service) {
234 for service_id in services {
235 if node_id == service_id.node_id {
236 continue;
238 }
239
240 let to_node = bus.nodes.get_mut(service_id.node_id.0 as usize).unwrap();
241 if let Err(_) = to_node.tx.try_send(bus_message::Message::XNotify {
242 from: from.to_global(node_id),
243 to_service: to_service.clone(),
244 method,
245 data: data.clone(),
246 }) {
247 to_node.tx_close.try_send(()).ok();
249 }
250 }
251 }
252 }
253
254 bus_message::Message::NotifyTo {
256 from,
257 to,
258 method,
259 data,
260 } => {
261 trace!(
262 "[{}/MSG:SEND_NOTIFY_TO] from={} to={} method={}",
263 node_id,
264 from,
265 to,
266 method
267 );
268
269 let mut bus = bus.lock().await;
271 if let Some(node) = bus.nodes.get_mut(to.node_id.0 as usize) {
272 if let Err(_) = node.tx.try_send(bus_message::Message::XNotifyTo {
273 from: from.to_global(node_id),
274 to: to.local_service_id,
275 method: method,
276 data: data.clone(),
277 }) {
278 node.tx_close.try_send(()).ok();
280 }
281 }
282 }
283
284 bus_message::Message::Subscribe { topic } => {
286 trace!("[{}/MSG:SUBSCRIBE] topic={}", node_id, topic);
287 let mut bus = bus.lock().await;
288 bus.subscribes.subscribe(topic, node_id);
289 }
290
291 bus_message::Message::Unsubscribe { topic } => {
293 trace!("[{}/MSG:UNSUBSCRIBE] topic={}", node_id, topic);
294 let mut bus = bus.lock().await;
295 bus.subscribes.unsubscribe(topic, node_id);
296 }
297
298 bus_message::Message::Publish { topic, data } => {
300 trace!("[{}/MSG:PUBLISH] topic={}", node_id, topic);
301 let mut bus = bus.lock().await;
302 let bus = &mut *bus;
303 for to_node_id in bus.subscribes.query(&topic) {
304 if let Some(to_node) = bus.nodes.get_mut(to_node_id.0 as usize) {
305 if let Err(_) = to_node.tx.try_send(bus_message::Message::XPublish {
306 topic: topic.clone(),
307 data: data.clone(),
308 }) {
309 to_node.tx_close.try_send(()).ok();
311 }
312 }
313 }
314 }
315
316 _ => {}
317 }
318}
319
320async fn client_process(bus: Arc<Mutex<Bus>>, stream: TcpStream) {
322 let stream = Arc::new(stream);
323 let (tx_close, mut rx_close) = channel(1);
324 let (tx_incoming_msg, mut rx_incoming_msg) = channel(64);
325 let (mut tx_outgoing_msg, rx_outgoing_msg) = channel(64);
326 let node_id = bus
327 .lock()
328 .await
329 .create_node(tx_outgoing_msg.clone(), tx_close);
330
331 let (reader_task, abort_reader) =
334 future::abortable(bus_message::read_messages(stream.clone(), tx_incoming_msg));
335 let reader_handle = task::spawn(reader_task);
336
337 let (writer_task, abort_writer) =
339 future::abortable(bus_message::write_messages(stream.clone(), rx_outgoing_msg));
340 let writer_handle = task::spawn(writer_task);
341 trace!("[{}/CONNECTED]", node_id);
342
343 tx_outgoing_msg
345 .try_send(bus_message::Message::Hello(node_id))
346 .ok();
347 drop(tx_outgoing_msg);
348
349 let mut check_hb = stream::interval(Duration::from_secs(1)).fuse();
351
352 loop {
353 select! {
354 _ = rx_close.next() => {
355 break;
357 }
358 _ = check_hb.next() => {
359 if let Some(node) = bus.lock().await.nodes.get(node_id.0 as usize) {
360 if node.hb.elapsed() > Duration::from_secs(30) {
361 trace!("[{}/MSG:HEARTBEAT_TIMEOUT]", node_id);
363 break;
364 }
365 }
366 }
367 msg = rx_incoming_msg.next() => {
368 if let Some(msg) = msg {
369 let mut exit = false;
370 if let bus_message::Message::Bye = &msg {
371 exit = true;
372 }
373 process_incoming_msg(bus.clone(), node_id, msg).await;
374 if exit {
375 break;
377 }
378 } else {
379 trace!("client connection close. node_id={}", node_id);
381 break;
382 }
383 }
384 }
385 }
386
387 let mut bus = bus.lock().await;
389 bus.subscribes.remove_node(node_id);
390 for (_, ids) in &mut bus.services {
391 ids.retain(|id| id.node_id != node_id);
392 }
393 bus.nodes.remove(node_id.0 as usize);
394
395 abort_reader.abort();
397 abort_writer.abort();
398 reader_handle.await.ok();
399 writer_handle.await.ok();
400
401 trace!("[{}/DISCONNECTED]", node_id);
402}