atm0s_sdn_rpc/
rpc_queue.rs

1use std::{
2    collections::{HashMap, VecDeque},
3    sync::Arc,
4};
5
6use atm0s_sdn_identity::NodeId;
7use atm0s_sdn_network::msg::{MsgHeader, TransportMsg};
8use atm0s_sdn_router::RouteRule;
9use atm0s_sdn_utils::awaker::Awaker;
10
11use crate::{
12    rpc_id_gen::RpcIdGenerate,
13    rpc_msg::{RpcError, RpcMsg},
14    rpc_reliable::{
15        msg::{MSG_ACK, MSG_DATA},
16        recv::RpcReliableReceiver,
17        send::RpcReliableSender,
18    },
19};
20
21pub struct RpcQueue<LD> {
22    node_id: NodeId,
23    service_id: u8,
24    id_gen: RpcIdGenerate,
25    reqs: HashMap<u64, (u64, LD)>,
26    reliable_receiver: RpcReliableReceiver,
27    reliable_sender: RpcReliableSender,
28    outs: VecDeque<TransportMsg>,
29    awaker: Option<Arc<dyn Awaker>>,
30    // we should set should_awake to true if outs is empty, then should_awake is set to false when called awake_if_need
31    should_awake: bool,
32}
33
34impl<LD> RpcQueue<LD> {
35    pub fn new(node_id: NodeId, service_id: u8) -> Self {
36        Self {
37            node_id,
38            service_id,
39            id_gen: Default::default(),
40            reqs: HashMap::new(),
41            reliable_receiver: RpcReliableReceiver::new(node_id),
42            reliable_sender: RpcReliableSender::new(node_id),
43            outs: VecDeque::new(),
44            awaker: None,
45            should_awake: true,
46        }
47    }
48
49    pub fn set_awaker(&mut self, awaker: Arc<dyn Awaker>) {
50        self.awaker = Some(awaker);
51    }
52
53    pub fn add_request<Req: Into<Vec<u8>>>(&mut self, now_ms: u64, service_id: u8, rule: RouteRule, cmd: &str, param: Req, local_data: LD, timeout_after_ms: u64) {
54        log::info!("[RpcQueue] add request {}", cmd);
55        let req_id = self.id_gen.generate();
56        let rpc = RpcMsg::create_request(self.node_id, self.service_id, cmd, req_id, param);
57
58        let mut header = MsgHeader::build(self.service_id, service_id, rule);
59        header.from_node = Some(self.node_id);
60        let payload = bincode::serialize(&rpc).expect("Should ok");
61
62        if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
63            self.reqs.insert(req_id, (now_ms + timeout_after_ms, local_data));
64            while let Some(msg) = self.reliable_sender.pop_transport_msg() {
65                self.outs.push_back(msg);
66            }
67        }
68        self.awake_if_need();
69    }
70
71    pub fn add_event<E: Into<Vec<u8>>>(&mut self, now_ms: u64, service_id: u8, rule: RouteRule, cmd: &str, event: E) {
72        log::info!("[RpcQueue] add event {}", cmd);
73        let rpc = RpcMsg::create_event(self.node_id, self.service_id, cmd, event);
74        let mut header = MsgHeader::build(self.service_id, service_id, rule);
75        header.from_node = Some(self.node_id);
76        let payload = bincode::serialize(&rpc).expect("Should ok");
77
78        if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
79            while let Some(msg) = self.reliable_sender.pop_transport_msg() {
80                self.outs.push_back(msg);
81            }
82        }
83
84        self.awake_if_need();
85    }
86
87    pub fn answer_for<Res: Into<Vec<u8>>>(&mut self, now_ms: u64, req: &RpcMsg, param: Result<Res, RpcError>) {
88        log::info!("[RpcQueue] answer {}", req.cmd);
89        let answer = req.answer(self.node_id, self.service_id, param);
90        let header = MsgHeader::build(self.service_id, req.from_service_id, RouteRule::ToNode(req.from_node_id)).set_from_node(Some(self.node_id));
91        let payload = bincode::serialize(&answer).expect("Should ok");
92
93        if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
94            while let Some(msg) = self.reliable_sender.pop_transport_msg() {
95                self.outs.push_back(msg);
96            }
97        }
98        self.awake_if_need();
99    }
100
101    pub fn on_msg(&mut self, now_ms: u64, msg: TransportMsg) -> Option<RpcMsg> {
102        match msg.header.meta {
103            MSG_ACK => {
104                self.reliable_sender.on_ack(msg.header.stream_id);
105                None
106            }
107            MSG_DATA => {
108                let res = self.reliable_receiver.on_msg(now_ms, msg);
109                while let Some(msg) = self.reliable_receiver.pop_msg() {
110                    self.outs.push_back(msg);
111                }
112                res.map(|(header, payload)| RpcMsg::from_header_payload(&header, &payload)).flatten()
113            }
114            _ => None,
115        }
116    }
117
118    pub fn take_request(&mut self, req_id: u64) -> Option<LD> {
119        self.reqs.remove(&req_id).map(|(_, ld)| ld)
120    }
121
122    pub fn pop_timeout(&mut self, now_ms: u64) -> Option<(u64, LD)> {
123        self.reliable_sender.on_tick(now_ms);
124        self.reliable_receiver.on_tick(now_ms);
125        while let Some(msg) = self.reliable_sender.pop_transport_msg() {
126            self.outs.push_back(msg);
127        }
128        while let Some(msg) = self.reliable_receiver.pop_msg() {
129            self.outs.push_back(msg);
130        }
131
132        let mut timeout = None;
133        for (req_id, (timeout_at, _ld)) in &self.reqs {
134            if now_ms >= *timeout_at {
135                timeout = Some(*req_id);
136                break;
137            }
138        }
139
140        timeout.map(|req_id| {
141            let ld = self.reqs.remove(&req_id).expect("Should has").1;
142            (req_id, ld)
143        })
144    }
145
146    pub fn pop_transmit(&mut self) -> Option<TransportMsg> {
147        if self.outs.len() == 1 {
148            self.should_awake = true;
149        }
150        self.outs.pop_front()
151    }
152
153    fn awake_if_need(&mut self) {
154        if self.should_awake && !self.outs.is_empty() {
155            self.should_awake = false;
156            if let Some(awaker) = &self.awaker {
157                awaker.notify();
158            }
159        }
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use std::sync::Arc;
166
167    use atm0s_sdn_network::msg::{MsgHeader, TransportMsg};
168    use atm0s_sdn_router::RouteRule;
169    use atm0s_sdn_utils::awaker::{Awaker, MockAwaker};
170
171    use crate::{
172        rpc_reliable::msg::{build_stream_id, MSG_ACK, MSG_DATA},
173        RpcMsg, RpcMsgParam, RpcQueue,
174    };
175
176    #[test]
177    fn create_event() {
178        let node_id = 1;
179        let service_id = 100;
180        let to_service_id = 200;
181        let awaker = Arc::new(MockAwaker::default());
182        let mut queue = RpcQueue::<u32>::new(node_id, service_id);
183        queue.set_awaker(awaker.clone());
184
185        queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3]);
186        assert_eq!(awaker.pop_awake_count(), 1);
187        let transmit = queue.pop_transmit().unwrap();
188        let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
189        assert_eq!(
190            rpc_msg,
191            RpcMsg {
192                cmd: "cmd1".to_string(),
193                from_node_id: node_id,
194                from_service_id: service_id,
195                param: RpcMsgParam::Event(vec![1, 2, 3]),
196            }
197        );
198    }
199
200    #[test]
201    fn create_big_event_should_fire_awake() {
202        let node_id = 1;
203        let service_id = 100;
204        let to_service_id = 200;
205        let awaker = Arc::new(MockAwaker::default());
206        let mut queue = RpcQueue::<u32>::new(node_id, service_id);
207        queue.set_awaker(awaker.clone());
208
209        queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1; 20000]);
210        assert_eq!(awaker.pop_awake_count(), 1);
211
212        while let Some(_) = queue.pop_transmit() {}
213
214        queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd2", vec![2; 20000]);
215        assert_eq!(awaker.pop_awake_count(), 1);
216    }
217
218    #[test]
219    fn create_request() {
220        let node_id = 1;
221        let service_id = 100;
222        let to_service_id = 200;
223        let mut queue = RpcQueue::<u32>::new(node_id, service_id);
224
225        queue.add_request(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3], 12345, 1000);
226        let transmit = queue.pop_transmit().unwrap();
227        let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
228        assert_eq!(
229            rpc_msg,
230            RpcMsg {
231                cmd: "cmd1".to_string(),
232                from_node_id: node_id,
233                from_service_id: service_id,
234                param: RpcMsgParam::Request { req_id: 0, param: vec![1, 2, 3] },
235            }
236        );
237
238        assert_eq!(queue.take_request(0), Some(12345));
239    }
240
241    #[test]
242    fn create_request_timeout() {
243        let node_id = 1;
244        let service_id = 100;
245        let to_service_id = 200;
246        let mut queue = RpcQueue::<u32>::new(node_id, service_id);
247
248        queue.add_request(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3], 12345, 1000);
249        let transmit = queue.pop_transmit().unwrap();
250        let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
251        assert_eq!(
252            rpc_msg,
253            RpcMsg {
254                cmd: "cmd1".to_string(),
255                from_node_id: node_id,
256                from_service_id: service_id,
257                param: RpcMsgParam::Request { req_id: 0, param: vec![1, 2, 3] },
258            }
259        );
260
261        assert_eq!(queue.pop_timeout(999), None);
262        assert_eq!(queue.pop_timeout(1000), Some((0, 12345)));
263    }
264
265    #[test]
266    fn create_answer() {
267        let node_id = 1;
268        let service_id = 100;
269        let from_node_id = 2;
270        let from_service_id = 200;
271        let mut queue = RpcQueue::<u32>::new(node_id, service_id);
272
273        let incomming_req = RpcMsg {
274            cmd: "cmd1".to_string(),
275            from_node_id,
276            from_service_id,
277            param: RpcMsgParam::Request { req_id: 123, param: vec![1, 2, 3] },
278        };
279
280        queue.answer_for(0, &incomming_req, Ok(vec![3, 4, 5]));
281        let transmit = queue.pop_transmit().unwrap();
282        let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
283        assert_eq!(
284            rpc_msg,
285            RpcMsg {
286                cmd: "cmd1".to_string(),
287                from_node_id: node_id,
288                from_service_id: service_id,
289                param: RpcMsgParam::Answer {
290                    req_id: 123,
291                    param: Ok(vec![3, 4, 5])
292                },
293            }
294        );
295    }
296
297    #[test]
298    fn queue_handle_incoming_event() {
299        let mut queue = RpcQueue::<u32>::new(10, 100);
300
301        let expected_req = RpcMsg {
302            cmd: "cmd1".to_string(),
303            from_node_id: 11,
304            from_service_id: 101,
305            param: RpcMsgParam::Request { req_id: 123, param: vec![1, 2, 3] },
306        };
307
308        let header = MsgHeader::build(101, 100, RouteRule::Direct)
309            .set_from_node(Some(11))
310            .set_stream_id(build_stream_id(0, 0, 0))
311            .set_meta(MSG_DATA);
312
313        let received_req = queue
314            .on_msg(0, TransportMsg::build_raw(header, &bincode::serialize(&expected_req).expect("")))
315            .expect("Should finish req");
316        assert_eq!(received_req, expected_req);
317
318        let ack_msg = queue.pop_transmit().expect("Should has");
319        assert_eq!(ack_msg.header.from_node, Some(10));
320        assert_eq!(ack_msg.header.route, RouteRule::ToNode(11));
321        assert_eq!(ack_msg.header.from_service_id, 100);
322        assert_eq!(ack_msg.header.to_service_id, 101);
323        assert_eq!(ack_msg.header.meta, MSG_ACK);
324        assert_eq!(ack_msg.payload(), &[]);
325    }
326}