1use crate::{
6 rpc_protocol::{
7 fill_remote_error,
8 parse::{
9 build_message_identifier, parse_header, parse_message_identifier,
10 parse_protocol_message, ParseErrors,
11 },
12 RemoteError, Response, RpcMessageTypes, StreamMessage,
13 },
14 server::{ServerError, ServerInternalError, ServerResult, ServerResultError},
15 service_module_definition::{
16 BiStreamsResponse, ClientStreamsResponse, ServerStreamsResponse, UnaryResponse,
17 },
18 stream_protocol::Generator,
19 transports::{Transport, TransportError},
20 CommonError,
21};
22use async_channel::Sender as AsyncChannelSender;
23use log::{debug, error, info};
24use prost::Message;
25use std::{collections::HashMap, sync::Arc};
26use tokio::{
27 select,
28 sync::{
29 oneshot::{
30 channel as oneshot_channel, Receiver as OneShotReceiver, Sender as OneShotSender,
31 },
32 Mutex,
33 },
34};
35use tokio_util::sync::CancellationToken;
36
37#[derive(Default)]
42#[cfg(feature = "server")]
43pub struct ServerMessagesHandler {
44 pub streams_handler: Arc<StreamsHandler>,
46 listeners: Mutex<HashMap<u32, AsyncChannelSender<StreamPackage>>>,
48}
49
50#[cfg(feature = "server")]
51impl ServerMessagesHandler {
52 pub fn new() -> Self {
53 Self {
54 streams_handler: Arc::new(StreamsHandler::new()),
55 listeners: Mutex::new(HashMap::new()),
56 }
57 }
58
59 pub fn process_unary_request<T: Transport + ?Sized + 'static>(
63 &self,
64 transport: Arc<T>,
65 message_number: u32,
66 procedure_handler: UnaryResponse,
67 ) {
68 tokio::spawn(async move {
69 match procedure_handler.await {
70 Ok(procedure_response) => {
71 let response = Response {
72 message_identifier: build_message_identifier(
73 RpcMessageTypes::Response as u32,
74 message_number,
75 ),
76 payload: procedure_response,
77 };
78
79 if let Err(err) = transport.send(response.encode_to_vec()).await {
80 error!("> ServerMessagesHandler > Error on sending procedure response through a transport - message num: {message_number} - error: {err:?}");
81 if !matches!(err, TransportError::Closed) {
82 send_remote_error(
84 transport,
85 message_number,
86 ServerError::UnexpectedErrorOnTransport.into(),
87 )
88 .await;
89 }
90 }
91 }
92 Err(procedure_remote_error) => {
93 send_remote_error(transport, message_number, procedure_remote_error).await;
94 }
95 };
96 });
97 }
98
99 pub fn process_server_streams_request<T: Transport + ?Sized + 'static>(
103 self: Arc<Self>,
104 transport: Arc<T>,
105 message_number: u32,
106 port_id: u32,
107 procedure_handler: ServerStreamsResponse,
108 ) {
109 tokio::spawn(async move {
110 match self
111 .open_server_stream(transport.clone(), message_number, port_id)
112 .await
113 {
114 Ok(open_ack_listener) => match open_ack_listener.await {
115 Ok(_) => match procedure_handler.await {
116 Ok(generator) => {
117 if let Err(error) = self
118 .streams_handler
119 .send_streams_through_transport(
120 transport.clone(),
121 generator,
122 port_id,
123 message_number,
124 )
125 .await
126 {
127 error!("> ServerMessagesHandler > process_server_streams_request > Error while executing StreamsHandler::send_streams_through_transport - Error: {error:?}");
128 if !matches!(error, CommonError::TransportWasClosed) {
129 send_remote_error(
131 transport,
132 message_number,
133 ServerError::UnexpectedErrorOnTransport.into(),
134 )
135 .await;
136 }
137 }
138 }
139 Err(procedure_remote_error) => {
140 send_remote_error(transport, message_number, procedure_remote_error)
141 .await
142 }
143 },
144 Err(e) => {
145 error!("> ServerMessagesHandler > process_server_streams_request > Error on receiving on a open_ack_listener, sender seems to be dropped: {e:?}");
146 }
147 },
148 Err(error) => {
149 error!("> ServerMessagesHandler > process_server_streams_request > Erron on opening a server stream: {error:?}");
150 if !matches!(
151 error,
152 ServerResultError::Internal(ServerInternalError::TransportWasClosed)
153 ) {
154 send_remote_error(
156 transport,
157 message_number,
158 ServerError::UnexpectedErrorOnTransport.into(),
159 )
160 .await;
161 }
162 }
163 }
164 });
165 }
166
167 pub fn process_client_streams_request<T: Transport + ?Sized + 'static>(
171 self: Arc<Self>,
172 transport: Arc<T>,
173 message_number: u32,
174 client_stream_id: u32,
175 procedure_handler: ClientStreamsResponse,
176 listener: AsyncChannelSender<(RpcMessageTypes, u32, StreamMessage)>,
177 ) {
178 tokio::spawn(async move {
179 self.register_listener(client_stream_id, listener).await;
180 match procedure_handler.await {
181 Ok(procedure_response) => {
182 self.send_response(transport, message_number, procedure_response)
183 .await;
184 }
185 Err(procedure_remote_err) => {
186 send_remote_error(transport, message_number, procedure_remote_err).await;
187 }
188 }
189 });
190 }
191
192 pub fn process_bidir_streams_request<T: Transport + ?Sized + 'static>(
196 self: Arc<Self>,
197 transport: Arc<T>,
198 message_number: u32,
199 port_id: u32,
200 client_stream_id: u32,
201 listener: AsyncChannelSender<(RpcMessageTypes, u32, StreamMessage)>,
202 procedure_handler: BiStreamsResponse,
203 ) {
204 tokio::spawn(async move {
205 self.register_listener(client_stream_id, listener).await;
206 match self
207 .open_server_stream(transport.clone(), message_number, port_id)
208 .await
209 {
210 Ok(open_ack_listener) => {
211 match open_ack_listener.await {
212 Ok(_) => {
213 match procedure_handler.await {
214 Ok(generator) => {
215 if let Err(err) = self
216 .streams_handler
217 .send_streams_through_transport(
218 transport.clone(),
219 generator,
220 port_id,
221 message_number,
222 )
223 .await
224 {
225 error!("> ServerMessagesHandler > process_bidir_streams_request > Error while executing StreamsHandler::send_streams_through_transport - Error: {err:?}");
226 if !matches!(err, CommonError::TransportWasClosed) {
227 send_remote_error(
229 transport,
230 message_number,
231 ServerError::UnexpectedErrorOnTransport.into(),
232 )
233 .await;
234 }
235 }
236 }
237 Err(remote_error) => {
238 send_remote_error(transport, message_number, remote_error)
239 .await;
240 }
241 }
242 }
243 Err(_) => {
244 error!("> ServerMessagesHandler > process_bidir_streams_request > Error on receiving on a open_ack_listener, sender seems to be dropped");
245 }
246 }
247 }
248 Err(err) => {
249 if !matches!(
250 err,
251 ServerResultError::Internal(ServerInternalError::TransportWasClosed)
252 ) {
253 send_remote_error(
255 transport,
256 message_number,
257 ServerError::UnexpectedErrorOnTransport.into(),
258 )
259 .await;
260 }
261 }
262 }
263 });
264 }
265
266 pub fn notify_new_client_stream(self: Arc<Self>, message_number: u32, payload: Vec<u8>) {
270 tokio::spawn(async move {
271 let lock = self.listeners.lock().await;
272 let listener = lock.get(&message_number);
273 if let Some(listener) = listener {
274 if let Ok(stream_message) = StreamMessage::decode(payload.as_slice()) {
275 if listener
276 .send((
277 RpcMessageTypes::StreamMessage,
278 message_number,
279 stream_message,
280 ))
281 .await
282 .is_err()
283 {
284 error!("> ServerMessagesHandler > notify_new_client_stream > Error while sending through the listener, channel seems to be closed ");
285 }
286 } else {
287 error!("> ServerMessagesHandler > notify_new_client_stream > Error while decoding payload into StreamMessage, something is corrupted or bad implemented");
288 }
289 }
290 });
291 }
292
293 pub async fn send_response<T: Transport + ?Sized>(
298 &self,
299 transport: Arc<T>,
300 message_number: u32,
301 payload: Vec<u8>,
302 ) {
303 let response = Response {
304 message_identifier: build_message_identifier(
305 RpcMessageTypes::Response as u32,
306 message_number,
307 ),
308 payload,
309 };
310
311 if let Err(err) = transport.send(response.encode_to_vec()).await {
312 if !matches!(err, TransportError::Closed) {
313 error!("> ServerMessagesHandler > send_response > Error while sending the original response through transport but it seems not to be closed");
314 send_remote_error(
315 transport,
316 message_number,
317 ServerError::UnexpectedErrorOnTransport.into(),
318 )
319 .await;
320 } else {
321 error!("> ServerMessagesHandler > send_response > Error while sending response through transport, it seems to be clsoed");
322 }
323 }
324 }
325
326 async fn open_server_stream<T: Transport + ?Sized>(
328 &self,
329 transport: Arc<T>,
330 message_number: u32,
331 port_id: u32,
332 ) -> ServerResult<OneShotReceiver<Vec<u8>>> {
333 let opening_message = StreamMessage {
334 closed: false,
335 ack: false,
336 sequence_id: 0,
337 message_identifier: build_message_identifier(
338 RpcMessageTypes::StreamMessage as u32,
339 message_number,
340 ),
341 port_id,
342 payload: vec![],
343 };
344
345 let receiver = self
346 .streams_handler
347 .send_stream(transport, opening_message)
348 .await
349 .map_err(|err| {
350 if matches!(err, CommonError::TransportWasClosed) {
351 return ServerResultError::Internal(ServerInternalError::TransportWasClosed);
352 }
353 ServerResultError::Internal(ServerInternalError::TransportError)
354 })?;
355
356 Ok(receiver)
357 }
358
359 pub async fn register_listener(
361 &self,
362 message_id: u32,
363 callback: AsyncChannelSender<(RpcMessageTypes, u32, StreamMessage)>,
364 ) {
365 let mut lock = self.listeners.lock().await;
366 lock.insert(message_id, callback);
367 }
368
369 pub async fn unregister_listener(&self, message_id: u32) {
371 let mut lock = self.listeners.lock().await;
372 lock.remove(&message_id);
373 }
374}
375
376type StreamPackage = (RpcMessageTypes, u32, StreamMessage);
377
378#[cfg(feature = "client")]
386pub struct ClientMessagesHandler<T: Transport + ?Sized> {
387 pub transport: Arc<T>,
389 pub streams_handler: Arc<StreamsHandler>,
391 one_time_listeners: Mutex<HashMap<u32, OneShotSender<Vec<u8>>>>,
401 listeners: Mutex<HashMap<u32, AsyncChannelSender<StreamPackage>>>,
411 process_cancellation_token: CancellationToken,
416}
417
418#[cfg(feature = "client")]
419impl<T: Transport + ?Sized + 'static> ClientMessagesHandler<T> {
420 pub fn new(transport: Arc<T>) -> Self {
421 Self {
422 transport,
423 one_time_listeners: Mutex::new(HashMap::new()),
424 process_cancellation_token: CancellationToken::new(),
425 listeners: Mutex::new(HashMap::new()),
426 streams_handler: Arc::new(StreamsHandler::new()),
427 }
428 }
429
430 pub fn start(self: Arc<Self>) {
435 let token = self.process_cancellation_token.clone();
436 tokio::spawn(async move {
437 select! {
438 _ = token.cancelled() => {
439 debug!("> ClientMessagesHandler > cancelled!");
440 self.transport.close().await;
441 },
442 _ = self.process() => {
443
444 }
445 }
446 });
447 }
448
449 pub fn stop(&self) {
451 self.process_cancellation_token.cancel();
452 }
453
454 async fn process(&self) {
456 loop {
457 match self.transport.receive().await {
458 Ok(data) => {
459 let message_header = parse_header(&data);
460 match message_header {
461 Some((message_type, message_number)) => {
462 let mut read_callbacks = self.one_time_listeners.lock().await;
463 let sender = read_callbacks.remove(&message_number);
465 if let Some(sender) = sender {
466 match sender.send(data) {
467 Ok(()) => {}
468 Err(_) => {
469 error!(
470 "> ClientMessagesHandler > process > error while sending {} response",
471 message_number
472 );
473 continue;
474 }
475 }
476 } else {
477 let listeners = self.listeners.lock().await;
479 let listener = listeners.get(&message_number);
480
481 if let Some(listener) = listener {
482 if let Ok(stream_message) =
483 StreamMessage::decode(data.as_slice())
484 {
485 if let Err(error) = listener
486 .send((message_type, message_number, stream_message))
487 .await
488 {
489 error!(
490 "> ClientMessagesHandler > process > Error while sending StreamMessage to a listener {error:?}")
491 }
492 } else {
493 error!("> ClientMessagesHandler > process > Error while decoding bytes into a StreamMessage, something seems to be bad implemented")
494 }
495 } else {
496 self.streams_handler
498 .clone()
499 .message_acknowledged_by_peer(message_number, data)
500 }
501 }
502 }
503 None => {
504 error!("> ClientMessagesHandler > process > Error on parsing message header - impossible to communicate the error to a listener, the message is corrupted or invalid");
506 continue;
507 }
508 }
509 }
510 Err(error) => {
511 error!("> ClientMessagesHandler > process > Error on receive {error:?}");
512 if matches!(error, TransportError::Closed) {
513 info!("> ClientMessagesHandler > process > closing...");
514 break;
515 }
516 }
517 }
518 }
519 }
520
521 pub fn await_server_ack_open_and_send_streams<M: Message + 'static>(
528 self: Arc<Self>,
529 open_promise: OneShotReceiver<Vec<u8>>,
530 client_stream: Generator<M>,
531 port_id: u32,
532 client_message_id: u32,
533 ) {
534 let transport = self.transport.clone();
535 tokio::spawn(async move {
536 match open_promise.await {
537 Ok(encoded_ack_response) => {
538 if let Ok(stream_message) =
539 StreamMessage::decode(encoded_ack_response.as_slice())
540 {
541 if stream_message.closed {
542 return;
543 }
544
545 let new_generator = Generator::from_generator(client_stream, |item| {
546 Some(item.encode_to_vec())
547 });
548
549 if let Err(error) = self
550 .streams_handler
551 .send_streams_through_transport(
552 transport,
553 new_generator,
554 port_id,
555 client_message_id,
556 )
557 .await
558 {
559 error!("> ClientMessagesHandler > await_server_ack_open_and_send_streams > Error while executing StreamsHandler::send_streams_through_transport - Error: {error:?}");
560 }
561 } else {
562 error!("> ClientMessagesHandler > await_server_ack_open_and_send_streams > Error while decoding bytes into StreamMessage")
563 }
564 }
565 Err(_) => {
566 error!("> ClientMessagesHandler > await_server_ack_open_and_send_streams > Error while awaiting the server to send the ACK for Open stream message, sender half seems to be dropped");
567 }
568 }
569 });
570 }
571
572 pub async fn register_one_time_listener(
574 &self,
575 message_number: u32,
576 callback: OneShotSender<Vec<u8>>,
577 ) {
578 let mut lock = self.one_time_listeners.lock().await;
579 lock.insert(message_number, callback);
580 }
581
582 pub async fn register_listener(
584 &self,
585 message_number: u32,
586 callback: AsyncChannelSender<(RpcMessageTypes, u32, StreamMessage)>,
587 ) {
588 let mut lock = self.listeners.lock().await;
589 lock.insert(message_number, callback);
590 }
591
592 pub async fn unregister_listener(&self, message_number: u32) {
594 let mut lock = self.listeners.lock().await;
595 lock.remove(&message_number);
596 }
597}
598
599#[derive(Default)]
601pub struct StreamsHandler {
602 ack_listeners: Mutex<HashMap<String, OneShotSender<Vec<u8>>>>,
603}
604
605impl StreamsHandler {
606 pub fn new() -> Self {
607 Self {
608 ack_listeners: Mutex::new(HashMap::new()),
609 }
610 }
611
612 async fn close_stream<T: Transport + ?Sized>(
614 &self,
615 transport: Arc<T>,
616 sequence_id: u32,
617 message_number: u32,
618 port_id: u32,
619 ) -> Result<(), CommonError> {
620 let close_message = StreamMessage {
621 closed: true,
622 ack: false,
623 sequence_id,
624 message_identifier: build_message_identifier(
625 RpcMessageTypes::StreamMessage as u32,
626 message_number,
627 ),
628 port_id,
629 payload: vec![],
630 };
631
632 transport
633 .send(close_message.encode_to_vec())
634 .await
635 .map_err(|_| CommonError::TransportError)?;
636
637 Ok(())
638 }
639
640 pub async fn send_streams_through_transport<T: Transport + ?Sized>(
647 &self,
648 transport: Arc<T>,
649 mut stream_generator: Generator<Vec<u8>>,
650 port_id: u32,
651 message_number: u32,
652 ) -> Result<(), CommonError> {
653 let mut sequence_number = 0;
654 let mut was_closed_by_peer = false;
655 while let Some(message) = stream_generator.next().await {
656 sequence_number += 1;
657 let current_message = StreamMessage {
658 closed: false,
659 ack: false,
660 sequence_id: sequence_number,
661 message_identifier: build_message_identifier(
662 RpcMessageTypes::StreamMessage as u32,
663 message_number,
664 ),
665 port_id,
666 payload: message,
667 };
668 let transport_cloned = transport.clone();
669
670 match self.send_stream(transport_cloned, current_message).await {
671 Ok(listener) => {
672 let ack_message = match listener.await {
673 Ok(msg) => match StreamMessage::decode(msg.as_slice()) {
674 Ok(msg) => msg,
675 Err(_) => {
676 error!("> StreamsHandler > send_streams_through_transport > Error while decoding bytes into a StreamMessage");
677 return Err(CommonError::ProtocolError);
678 }
679 },
680 Err(e) => {
681 error!("> StreamsHandler > send_streams_through_transport > Error while waiting for an ACK Message, the sender half seems to be dropped: {e:?}");
682 return Err(CommonError::UnexpectedError(
683 "The sender half of a listener seems to be droppped".to_string(),
684 ));
685 }
686 };
687 if ack_message.ack {
688 debug!("> StreamsHandler > send_streams_through_transport > Listener received the ack for a message, continuing with the next stream");
689 continue;
690 } else if ack_message.closed {
691 debug!("> StreamsHandler > send_streams_through_transport > stream was closed by the other peer");
692 was_closed_by_peer = true;
693 stream_generator.close();
694 break;
695 }
696 }
697 Err(err) => {
698 error!("> StreamsHandler > send_streams_through_transport > Error while streaming a server stream {err:?}");
699 return Err(err);
700 }
701 }
702 }
703
704 if !was_closed_by_peer {
705 self.close_stream(transport, sequence_number, message_number, port_id)
706 .await?;
707 }
708
709 Ok(())
710 }
711
712 async fn send_stream<T: Transport + ?Sized>(
717 &self,
718 transport: Arc<T>,
719 message: StreamMessage,
720 ) -> Result<OneShotReceiver<Vec<u8>>, CommonError> {
721 let (_, message_number) = parse_message_identifier(message.message_identifier);
722 let (tx, rx) = oneshot_channel();
723 {
724 let mut lock = self.ack_listeners.lock().await;
725 if lock
726 .insert(
727 format!(
728 "{}-{}-{}",
729 message.port_id, message_number, message.sequence_id
730 ),
731 tx,
732 )
733 .is_some()
734 {
735 error!("> StreamsHandler > send_stram > Overriding TX for message_number: {message_number} and sequence_id: {}", message.sequence_id);
736 }
737 }
738
739 if let Err(error) = transport.send(message.encode_to_vec()).await {
740 error!("> StreamsHandler > send_stream > Error while sending through transport - message: {message:?} - Error: {error:?}");
741 {
742 let mut lock = self.ack_listeners.lock().await;
744 lock.remove(&format!(
745 "{}-{}-{}",
746 message.port_id, message_number, message.sequence_id
747 ));
748 }
749 if !matches!(error, TransportError::Closed) {
750 return Err(CommonError::TransportError);
751 } else {
752 return Err(CommonError::TransportWasClosed);
753 }
754 }
755
756 Ok(rx)
757 }
758
759 pub fn message_acknowledged_by_peer(self: Arc<Self>, message_number: u32, payload: Vec<u8>) {
761 tokio::spawn(async move {
762 match parse_protocol_message::<StreamMessage>(&payload) {
763 Ok((_, _, stream_message)) => {
764 let listener = {
765 let mut lock = self.ack_listeners.lock().await;
766 lock.remove(&format!(
768 "{}-{}-{}",
769 stream_message.port_id, message_number, stream_message.sequence_id
770 ))
771 };
772 match listener {
773 Some(sender) => {
774 if sender.send(payload).is_err() {
775 error!("> Streams Handler > message_acknowledged_by_peer > Error while sending through the ack listener, seems to be dropped")
776 }
777 }
778 None => {
779 debug!("> Streams Handler > message_acknowledged_by_peer > ack listener not found")
780 }
781 }
782 }
783 Err(ParseErrors::IsARemoteError((_, remote_error))) => {
784 error!("> Streams Handler > message_acknowledged_by_peer > Remote Error: {remote_error:?}")
785 }
786 Err(err) => {
787 error!("> Streams Handler > message_acknowledged_by_peer > Error on parsing: {err:?}");
788 }
789 }
790 });
791 }
792}
793
794async fn send_remote_error<T: Transport + ?Sized>(
796 transport: Arc<T>,
797 message_number: u32,
798 mut remote_error: RemoteError,
799) {
800 fill_remote_error(&mut remote_error, message_number);
803 if let Err(err) = transport.send(remote_error.encode_to_vec()).await {
804 error!("> send_remote_error > Error while sending the remote error through a transport > RemoteError: {remote_error:?} - Error: {err:?}")
805 }
806}