1use std::{
2 collections::{HashMap, VecDeque},
3 sync::Arc,
4};
5
6use atm0s_sdn_identity::NodeId;
7use atm0s_sdn_network::msg::{MsgHeader, TransportMsg};
8use atm0s_sdn_router::RouteRule;
9use atm0s_sdn_utils::awaker::Awaker;
10
11use crate::{
12 rpc_id_gen::RpcIdGenerate,
13 rpc_msg::{RpcError, RpcMsg},
14 rpc_reliable::{
15 msg::{MSG_ACK, MSG_DATA},
16 recv::RpcReliableReceiver,
17 send::RpcReliableSender,
18 },
19};
20
21pub struct RpcQueue<LD> {
22 node_id: NodeId,
23 service_id: u8,
24 id_gen: RpcIdGenerate,
25 reqs: HashMap<u64, (u64, LD)>,
26 reliable_receiver: RpcReliableReceiver,
27 reliable_sender: RpcReliableSender,
28 outs: VecDeque<TransportMsg>,
29 awaker: Option<Arc<dyn Awaker>>,
30 should_awake: bool,
32}
33
34impl<LD> RpcQueue<LD> {
35 pub fn new(node_id: NodeId, service_id: u8) -> Self {
36 Self {
37 node_id,
38 service_id,
39 id_gen: Default::default(),
40 reqs: HashMap::new(),
41 reliable_receiver: RpcReliableReceiver::new(node_id),
42 reliable_sender: RpcReliableSender::new(node_id),
43 outs: VecDeque::new(),
44 awaker: None,
45 should_awake: true,
46 }
47 }
48
49 pub fn set_awaker(&mut self, awaker: Arc<dyn Awaker>) {
50 self.awaker = Some(awaker);
51 }
52
53 pub fn add_request<Req: Into<Vec<u8>>>(&mut self, now_ms: u64, service_id: u8, rule: RouteRule, cmd: &str, param: Req, local_data: LD, timeout_after_ms: u64) {
54 log::info!("[RpcQueue] add request {}", cmd);
55 let req_id = self.id_gen.generate();
56 let rpc = RpcMsg::create_request(self.node_id, self.service_id, cmd, req_id, param);
57
58 let mut header = MsgHeader::build(self.service_id, service_id, rule);
59 header.from_node = Some(self.node_id);
60 let payload = bincode::serialize(&rpc).expect("Should ok");
61
62 if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
63 self.reqs.insert(req_id, (now_ms + timeout_after_ms, local_data));
64 while let Some(msg) = self.reliable_sender.pop_transport_msg() {
65 self.outs.push_back(msg);
66 }
67 }
68 self.awake_if_need();
69 }
70
71 pub fn add_event<E: Into<Vec<u8>>>(&mut self, now_ms: u64, service_id: u8, rule: RouteRule, cmd: &str, event: E) {
72 log::info!("[RpcQueue] add event {}", cmd);
73 let rpc = RpcMsg::create_event(self.node_id, self.service_id, cmd, event);
74 let mut header = MsgHeader::build(self.service_id, service_id, rule);
75 header.from_node = Some(self.node_id);
76 let payload = bincode::serialize(&rpc).expect("Should ok");
77
78 if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
79 while let Some(msg) = self.reliable_sender.pop_transport_msg() {
80 self.outs.push_back(msg);
81 }
82 }
83
84 self.awake_if_need();
85 }
86
87 pub fn answer_for<Res: Into<Vec<u8>>>(&mut self, now_ms: u64, req: &RpcMsg, param: Result<Res, RpcError>) {
88 log::info!("[RpcQueue] answer {}", req.cmd);
89 let answer = req.answer(self.node_id, self.service_id, param);
90 let header = MsgHeader::build(self.service_id, req.from_service_id, RouteRule::ToNode(req.from_node_id)).set_from_node(Some(self.node_id));
91 let payload = bincode::serialize(&answer).expect("Should ok");
92
93 if self.reliable_sender.add_msg(now_ms, header, &payload).is_some() {
94 while let Some(msg) = self.reliable_sender.pop_transport_msg() {
95 self.outs.push_back(msg);
96 }
97 }
98 self.awake_if_need();
99 }
100
101 pub fn on_msg(&mut self, now_ms: u64, msg: TransportMsg) -> Option<RpcMsg> {
102 match msg.header.meta {
103 MSG_ACK => {
104 self.reliable_sender.on_ack(msg.header.stream_id);
105 None
106 }
107 MSG_DATA => {
108 let res = self.reliable_receiver.on_msg(now_ms, msg);
109 while let Some(msg) = self.reliable_receiver.pop_msg() {
110 self.outs.push_back(msg);
111 }
112 res.map(|(header, payload)| RpcMsg::from_header_payload(&header, &payload)).flatten()
113 }
114 _ => None,
115 }
116 }
117
118 pub fn take_request(&mut self, req_id: u64) -> Option<LD> {
119 self.reqs.remove(&req_id).map(|(_, ld)| ld)
120 }
121
122 pub fn pop_timeout(&mut self, now_ms: u64) -> Option<(u64, LD)> {
123 self.reliable_sender.on_tick(now_ms);
124 self.reliable_receiver.on_tick(now_ms);
125 while let Some(msg) = self.reliable_sender.pop_transport_msg() {
126 self.outs.push_back(msg);
127 }
128 while let Some(msg) = self.reliable_receiver.pop_msg() {
129 self.outs.push_back(msg);
130 }
131
132 let mut timeout = None;
133 for (req_id, (timeout_at, _ld)) in &self.reqs {
134 if now_ms >= *timeout_at {
135 timeout = Some(*req_id);
136 break;
137 }
138 }
139
140 timeout.map(|req_id| {
141 let ld = self.reqs.remove(&req_id).expect("Should has").1;
142 (req_id, ld)
143 })
144 }
145
146 pub fn pop_transmit(&mut self) -> Option<TransportMsg> {
147 if self.outs.len() == 1 {
148 self.should_awake = true;
149 }
150 self.outs.pop_front()
151 }
152
153 fn awake_if_need(&mut self) {
154 if self.should_awake && !self.outs.is_empty() {
155 self.should_awake = false;
156 if let Some(awaker) = &self.awaker {
157 awaker.notify();
158 }
159 }
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use std::sync::Arc;
166
167 use atm0s_sdn_network::msg::{MsgHeader, TransportMsg};
168 use atm0s_sdn_router::RouteRule;
169 use atm0s_sdn_utils::awaker::{Awaker, MockAwaker};
170
171 use crate::{
172 rpc_reliable::msg::{build_stream_id, MSG_ACK, MSG_DATA},
173 RpcMsg, RpcMsgParam, RpcQueue,
174 };
175
176 #[test]
177 fn create_event() {
178 let node_id = 1;
179 let service_id = 100;
180 let to_service_id = 200;
181 let awaker = Arc::new(MockAwaker::default());
182 let mut queue = RpcQueue::<u32>::new(node_id, service_id);
183 queue.set_awaker(awaker.clone());
184
185 queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3]);
186 assert_eq!(awaker.pop_awake_count(), 1);
187 let transmit = queue.pop_transmit().unwrap();
188 let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
189 assert_eq!(
190 rpc_msg,
191 RpcMsg {
192 cmd: "cmd1".to_string(),
193 from_node_id: node_id,
194 from_service_id: service_id,
195 param: RpcMsgParam::Event(vec![1, 2, 3]),
196 }
197 );
198 }
199
200 #[test]
201 fn create_big_event_should_fire_awake() {
202 let node_id = 1;
203 let service_id = 100;
204 let to_service_id = 200;
205 let awaker = Arc::new(MockAwaker::default());
206 let mut queue = RpcQueue::<u32>::new(node_id, service_id);
207 queue.set_awaker(awaker.clone());
208
209 queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1; 20000]);
210 assert_eq!(awaker.pop_awake_count(), 1);
211
212 while let Some(_) = queue.pop_transmit() {}
213
214 queue.add_event(0, to_service_id, RouteRule::ToService(0), "cmd2", vec![2; 20000]);
215 assert_eq!(awaker.pop_awake_count(), 1);
216 }
217
218 #[test]
219 fn create_request() {
220 let node_id = 1;
221 let service_id = 100;
222 let to_service_id = 200;
223 let mut queue = RpcQueue::<u32>::new(node_id, service_id);
224
225 queue.add_request(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3], 12345, 1000);
226 let transmit = queue.pop_transmit().unwrap();
227 let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
228 assert_eq!(
229 rpc_msg,
230 RpcMsg {
231 cmd: "cmd1".to_string(),
232 from_node_id: node_id,
233 from_service_id: service_id,
234 param: RpcMsgParam::Request { req_id: 0, param: vec![1, 2, 3] },
235 }
236 );
237
238 assert_eq!(queue.take_request(0), Some(12345));
239 }
240
241 #[test]
242 fn create_request_timeout() {
243 let node_id = 1;
244 let service_id = 100;
245 let to_service_id = 200;
246 let mut queue = RpcQueue::<u32>::new(node_id, service_id);
247
248 queue.add_request(0, to_service_id, RouteRule::ToService(0), "cmd1", vec![1, 2, 3], 12345, 1000);
249 let transmit = queue.pop_transmit().unwrap();
250 let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
251 assert_eq!(
252 rpc_msg,
253 RpcMsg {
254 cmd: "cmd1".to_string(),
255 from_node_id: node_id,
256 from_service_id: service_id,
257 param: RpcMsgParam::Request { req_id: 0, param: vec![1, 2, 3] },
258 }
259 );
260
261 assert_eq!(queue.pop_timeout(999), None);
262 assert_eq!(queue.pop_timeout(1000), Some((0, 12345)));
263 }
264
265 #[test]
266 fn create_answer() {
267 let node_id = 1;
268 let service_id = 100;
269 let from_node_id = 2;
270 let from_service_id = 200;
271 let mut queue = RpcQueue::<u32>::new(node_id, service_id);
272
273 let incomming_req = RpcMsg {
274 cmd: "cmd1".to_string(),
275 from_node_id,
276 from_service_id,
277 param: RpcMsgParam::Request { req_id: 123, param: vec![1, 2, 3] },
278 };
279
280 queue.answer_for(0, &incomming_req, Ok(vec![3, 4, 5]));
281 let transmit = queue.pop_transmit().unwrap();
282 let rpc_msg = RpcMsg::from_header_payload(&transmit.header, transmit.payload()).unwrap();
283 assert_eq!(
284 rpc_msg,
285 RpcMsg {
286 cmd: "cmd1".to_string(),
287 from_node_id: node_id,
288 from_service_id: service_id,
289 param: RpcMsgParam::Answer {
290 req_id: 123,
291 param: Ok(vec![3, 4, 5])
292 },
293 }
294 );
295 }
296
297 #[test]
298 fn queue_handle_incoming_event() {
299 let mut queue = RpcQueue::<u32>::new(10, 100);
300
301 let expected_req = RpcMsg {
302 cmd: "cmd1".to_string(),
303 from_node_id: 11,
304 from_service_id: 101,
305 param: RpcMsgParam::Request { req_id: 123, param: vec![1, 2, 3] },
306 };
307
308 let header = MsgHeader::build(101, 100, RouteRule::Direct)
309 .set_from_node(Some(11))
310 .set_stream_id(build_stream_id(0, 0, 0))
311 .set_meta(MSG_DATA);
312
313 let received_req = queue
314 .on_msg(0, TransportMsg::build_raw(header, &bincode::serialize(&expected_req).expect("")))
315 .expect("Should finish req");
316 assert_eq!(received_req, expected_req);
317
318 let ack_msg = queue.pop_transmit().expect("Should has");
319 assert_eq!(ack_msg.header.from_node, Some(10));
320 assert_eq!(ack_msg.header.route, RouteRule::ToNode(11));
321 assert_eq!(ack_msg.header.from_service_id, 100);
322 assert_eq!(ack_msg.header.to_service_id, 101);
323 assert_eq!(ack_msg.header.meta, MSG_ACK);
324 assert_eq!(ack_msg.payload(), &[]);
325 }
326}