1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use opentelemetry::propagation::{Extractor, Injector};
8use opentelemetry::trace::TraceContextExt;
9use parking_lot::RwLock;
10use slim_config::grpc::client::ClientConfig;
11use slim_tracing::utils::INSTANCE_ID;
12use tokio::sync::mpsc::{self, Sender};
13use tokio_stream::wrappers::ReceiverStream;
14use tokio_stream::{Stream, StreamExt};
15use tokio_util::sync::CancellationToken;
16use tonic::codegen::{Body, StdError};
17use tonic::{Request, Response, Status};
18use tracing::{Span, debug, error, info};
19use tracing_opentelemetry::OpenTelemetrySpanExt;
20
21use crate::api::ProtoPublishType as PublishType;
22use crate::api::ProtoSubscribeType as SubscribeType;
23use crate::api::ProtoUnsubscribeType as UnsubscribeType;
24use crate::api::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
25use crate::api::proto::pubsub::v1::{Message, pub_sub_service_server::PubSubService};
26use crate::connection::{Channel, Connection, Type as ConnectionType};
27use crate::errors::DataPathError;
28use crate::forwarder::Forwarder;
29use crate::messages::Name;
30use crate::tables::connection_table::ConnectionTable;
31use crate::tables::subscription_table::SubscriptionTableImpl;
32
33struct MetadataExtractor<'a>(&'a std::collections::HashMap<String, String>);
35
36impl Extractor for MetadataExtractor<'_> {
37 fn get(&self, key: &str) -> Option<&str> {
38 self.0.get(key).map(|s| s.as_str())
39 }
40
41 fn keys(&self) -> Vec<&str> {
42 self.0.keys().map(|s| s.as_str()).collect()
43 }
44}
45
46struct MetadataInjector<'a>(&'a mut std::collections::HashMap<String, String>);
47
48impl Injector for MetadataInjector<'_> {
49 fn set(&mut self, key: &str, value: String) {
50 self.0.insert(key.to_string(), value);
51 }
52}
53
54fn extract_parent_context(msg: &Message) -> Option<opentelemetry::Context> {
56 let extractor = MetadataExtractor(&msg.metadata);
57 let parent_context =
58 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
59
60 if parent_context.span().span_context().is_valid() {
61 Some(parent_context)
62 } else {
63 None
64 }
65}
66
67fn inject_current_context(msg: &mut Message) {
69 let cx = tracing::Span::current().context();
70 let mut injector = MetadataInjector(&mut msg.metadata);
71 opentelemetry::global::get_text_map_propagator(|propagator| {
72 propagator.inject_context(&cx, &mut injector)
73 });
74}
75
76fn create_span(function: &str, out_conn: u64, msg: &Message) -> Span {
78 let span = tracing::span!(
79 tracing::Level::INFO,
80 "slim_process_message",
81 function = function,
82 source = format!("{}", msg.get_source()),
83 destination = format!("{}", msg.get_dst()),
84 instance_id = %INSTANCE_ID.as_str(),
85 connection_id = out_conn,
86 message_type = msg.get_type().to_string(),
87 telemetry = true
88 );
89
90 if let PublishType(_) = msg.get_type() {
91 span.set_attribute("session_type", msg.get_session_message_type().as_str_name());
92 span.set_attribute(
93 "session_id",
94 msg.get_session_header().get_session_id().to_string(),
95 );
96 span.set_attribute(
97 "message_id",
98 msg.get_session_header().get_message_id().to_string(),
99 );
100 }
101
102 span
103}
104
105#[derive(Debug)]
106struct MessageProcessorInternal {
107 forwarder: Forwarder<Connection>,
108 drain_channel: drain::Watch,
109 tx_control_plane: RwLock<Option<Sender<Result<Message, Status>>>>,
110}
111
112#[derive(Debug, Clone)]
113pub struct MessageProcessor {
114 internal: Arc<MessageProcessorInternal>,
115}
116
117impl MessageProcessor {
118 pub fn new() -> (Self, drain::Signal) {
119 let (signal, watch) = drain::channel();
120 let forwarder = Forwarder::new();
121 let forwarder = MessageProcessorInternal {
122 forwarder,
123 drain_channel: watch,
124 tx_control_plane: RwLock::new(None),
125 };
126
127 (
128 Self {
129 internal: Arc::new(forwarder),
130 },
131 signal,
132 )
133 }
134
135 pub fn with_drain_channel(watch: drain::Watch) -> Self {
136 let forwarder = Forwarder::new();
137 let forwarder = MessageProcessorInternal {
138 forwarder,
139 drain_channel: watch,
140 tx_control_plane: RwLock::new(None),
141 };
142 Self {
143 internal: Arc::new(forwarder),
144 }
145 }
146
147 fn set_tx_control_plane(&self, tx: Sender<Result<Message, Status>>) {
148 let mut tx_guard = self.internal.tx_control_plane.write();
149 *tx_guard = Some(tx);
150 }
151
152 fn get_tx_control_plane(&self) -> Option<Sender<Result<Message, Status>>> {
153 let tx_guard = self.internal.tx_control_plane.read();
154 tx_guard.clone()
155 }
156
157 fn forwarder(&self) -> &Forwarder<Connection> {
158 &self.internal.forwarder
159 }
160
161 fn get_drain_watch(&self) -> drain::Watch {
162 self.internal.drain_channel.clone()
163 }
164
165 async fn try_to_connect<C>(
166 &self,
167 channel: C,
168 client_config: Option<ClientConfig>,
169 local: Option<SocketAddr>,
170 remote: Option<SocketAddr>,
171 existing_conn_index: Option<u64>,
172 max_retry: u32,
173 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
174 where
175 C: tonic::client::GrpcService<tonic::body::Body>,
176 C::Error: Into<StdError>,
177 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
178 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
179 {
180 let mut client: PubSubServiceClient<C> = PubSubServiceClient::new(channel);
181 let mut i = 0;
182 while i < max_retry {
183 let (tx, rx) = mpsc::channel(128);
184 match client
185 .open_channel(Request::new(ReceiverStream::new(rx)))
186 .await
187 {
188 Ok(stream) => {
189 let cancellation_token = CancellationToken::new();
190 let connection = Connection::new(ConnectionType::Remote)
191 .with_local_addr(local)
192 .with_remote_addr(remote)
193 .with_channel(Channel::Client(tx))
194 .with_config_data(client_config.clone())
195 .with_cancellation_token(Some(cancellation_token.clone()));
196
197 debug!(
198 "new connection initiated locally: (remote: {:?} - local: {:?})",
199 connection.remote_addr(),
200 connection.local_addr()
201 );
202
203 let opt = self
205 .forwarder()
206 .on_connection_established(connection, existing_conn_index);
207 if opt.is_none() {
208 error!("error adding connection to the connection table");
209 return Err(DataPathError::ConnectionError(
210 "error adding connection to the connection tables".to_string(),
211 ));
212 }
213
214 let conn_index = opt.unwrap();
215 debug!(
216 "new connection index = {:?}, is local {:?}",
217 conn_index, false
218 );
219
220 let ret = self.process_stream(
222 stream.into_inner(),
223 conn_index,
224 client_config,
225 cancellation_token,
226 false,
227 false,
228 );
229 return Ok((ret, conn_index));
230 }
231 Err(e) => {
232 error!("connection error: {:?}.", e.to_string());
233 }
234 }
235 i += 1;
236
237 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
239 }
240
241 error!("unable to connect to the endpoint");
242 Err(DataPathError::ConnectionError(
243 "reached max connection retries".to_string(),
244 ))
245 }
246
247 pub async fn connect<C>(
248 &self,
249 channel: C,
250 client_config: Option<ClientConfig>,
251 local: Option<SocketAddr>,
252 remote: Option<SocketAddr>,
253 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
254 where
255 C: tonic::client::GrpcService<tonic::body::Body>,
256 C::Error: Into<StdError>,
257 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
258 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
259 {
260 self.try_to_connect(channel, client_config, local, remote, None, 10)
261 .await
262 }
263
264 pub fn disconnect(&self, conn: u64) -> Result<(), DataPathError> {
265 match self.forwarder().get_connection(conn) {
266 None => {
267 error!("error handling disconnect: connection unknown");
268 return Err(DataPathError::DisconnectionError(
269 "connection not found".to_string(),
270 ));
271 }
272 Some(c) => {
273 match c.cancellation_token() {
274 None => {
275 error!("error handling disconnect: missing cancellation token");
276 }
277 Some(t) => {
278 t.cancel();
282 }
283 }
284 }
285 }
286
287 Ok(())
288 }
289
290 pub fn register_local_connection(
291 &self,
292 from_control_plane: bool,
293 ) -> (
294 u64,
295 tokio::sync::mpsc::Sender<Result<Message, Status>>,
296 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
297 ) {
298 let (tx1, rx1) = mpsc::channel(128);
300
301 debug!("establishing new local app connection");
302
303 let (tx2, rx2) = mpsc::channel(128);
305
306 if from_control_plane && self.get_tx_control_plane().is_none() {
309 self.set_tx_control_plane(tx2.clone());
310 }
311
312 let cancellation_token = CancellationToken::new();
314 let connection = Connection::new(ConnectionType::Local)
315 .with_channel(Channel::Server(tx2))
316 .with_cancellation_token(Some(cancellation_token.clone()));
317
318 let conn_id = self
320 .forwarder()
321 .on_connection_established(connection, None)
322 .unwrap();
323
324 debug!("local connection established with id: {:?}", conn_id);
325 info!(telemetry = true, counter.num_active_connections = 1);
326
327 self.process_stream(
329 ReceiverStream::new(rx1),
330 conn_id,
331 None,
332 cancellation_token,
333 true,
334 from_control_plane,
335 );
336
337 (conn_id, tx1, rx2)
339 }
340
341 pub async fn send_msg(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
342 let connection = self.forwarder().get_connection(out_conn);
343 match connection {
344 Some(conn) => {
345 msg.clear_slim_header();
347
348 let parent_context = extract_parent_context(&msg);
350 let span = create_span("send_message", out_conn, &msg);
351
352 if let Some(ctx) = parent_context {
353 span.set_parent(ctx);
354 }
355 let _guard = span.enter();
356 inject_current_context(&mut msg);
357 match conn.channel() {
360 Channel::Server(s) => s
361 .send(Ok(msg))
362 .await
363 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
364 Channel::Client(s) => s
365 .send(msg)
366 .await
367 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
368 _ => Err(DataPathError::MessageSendError(
369 "connection not found".to_string(),
370 )),
371 }
372 }
373 None => Err(DataPathError::MessageSendError(format!(
374 "connection {:?} not found",
375 out_conn
376 ))),
377 }
378 }
379
380 async fn match_and_forward_msg(
381 &self,
382 msg: Message,
383 name: Name,
384 in_connection: u64,
385 fanout: u32,
386 ) -> Result<(), DataPathError> {
387 debug!(
388 "match and forward message: name: {} - fanout: {}",
389 name, fanout,
390 );
391
392 if let Some(val) = msg.get_forward_to() {
395 debug!("forwarding message to connection {}", val);
396 return self
397 .send_msg(msg, val)
398 .await
399 .map_err(|e| DataPathError::PublicationError(e.to_string()));
400 }
401
402 match self
403 .forwarder()
404 .on_publish_msg_match(name, in_connection, fanout)
405 {
406 Ok(out_vec) => {
407 let mut i = 0;
410 while i < out_vec.len() - 1 {
411 self.send_msg(msg.clone(), out_vec[i])
412 .await
413 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
414 i += 1;
415 }
416 self.send_msg(msg, out_vec[i])
417 .await
418 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
419 Ok(())
420 }
421 Err(e) => Err(DataPathError::PublicationError(e.to_string())),
422 }
423 }
424
425 async fn process_publish(&self, msg: Message, in_connection: u64) -> Result<(), DataPathError> {
426 debug!(
427 "received publication from connection {}: {:?}",
428 in_connection, msg
429 );
430
431 info!(
433 telemetry = true,
434 monotonic_counter.num_messages_by_type = 1,
435 method = "publish"
436 );
437 let header = msg.get_slim_header();
441
442 let dst = header.get_dst();
443
444 let fanout = msg.get_fanout();
447
448 return self
450 .match_and_forward_msg(msg, dst, in_connection, fanout)
451 .await;
452 }
453
454 async fn process_subscription(
458 &self,
459 msg: Message,
460 in_connection: u64,
461 add: bool,
462 ) -> Result<(), DataPathError> {
463 debug!(
464 "received subscription (add = {}) from connection {}: {:?}",
465 add, in_connection, msg
466 );
467
468 info!(
470 telemetry = true,
471 monotonic_counter.num_messages_by_type = 1,
472 message_type = { if add { "subscribe" } else { "unsubscribe" } }
473 );
474 let dst = msg.get_dst();
477
478 let header = msg.get_slim_header();
480
481 let (conn, forward) = header.get_in_out_connections();
483
484 let connection = self
487 .forwarder()
488 .get_connection(conn)
489 .ok_or_else(|| DataPathError::SubscriptionError("connection not found".to_string()))?;
490
491 debug!(
492 "subscription update (add = {}) for name: {} - connection: {}",
493 add, dst, conn
494 );
495
496 if let Err(e) = self.forwarder().on_subscription_msg(
497 dst.clone(),
498 conn,
499 connection.is_local_connection(),
500 add,
501 ) {
502 return Err(DataPathError::SubscriptionError(e.to_string()));
503 }
504
505 match forward {
506 None => {
507 Ok(())
509 }
510 Some(out_conn) => {
511 debug!("forward subscription (add = {}) to {}", add, out_conn);
512
513 let source = msg.get_source();
515
516 match self.send_msg(msg, out_conn).await {
518 Ok(_) => {
519 self.forwarder()
520 .on_forwarded_subscription(source, dst, out_conn, add);
521 Ok(())
522 }
523 Err(e) => Err(DataPathError::UnsubscriptionError(e.to_string())),
524 }
525 }
526 }
527 }
528
529 pub async fn process_message(
530 &self,
531 msg: Message,
532 in_connection: u64,
533 ) -> Result<(), DataPathError> {
534 match msg.get_type() {
536 SubscribeType(_) => self.process_subscription(msg, in_connection, true).await,
537 UnsubscribeType(_) => self.process_subscription(msg, in_connection, false).await,
538 PublishType(_) => self.process_publish(msg, in_connection).await,
539 }
540 }
541
542 async fn handle_new_message(
543 &self,
544 conn_index: u64,
545 is_local: bool,
546 mut msg: Message,
547 ) -> Result<(), DataPathError> {
548 debug!(%conn_index, "received message from connection");
549 info!(
550 telemetry = true,
551 monotonic_counter.num_processed_messages = 1
552 );
553
554 if let Err(err) = msg.validate() {
556 info!(
557 telemetry = true,
558 monotonic_counter.num_messages_by_type = 1,
559 message_type = "none"
560 );
561
562 return Err(DataPathError::InvalidMessage(err.to_string()));
563 }
564
565 msg.set_incoming_conn(Some(conn_index));
567
568 if is_local {
572 let span = create_span("process_local", conn_index, &msg);
574
575 let _guard = span.enter();
576
577 inject_current_context(&mut msg);
578 } else {
579 let parent_context = extract_parent_context(&msg);
581
582 let span = create_span("process_local", conn_index, &msg);
583
584 if let Some(ctx) = parent_context {
585 span.set_parent(ctx);
586 }
587 let _guard = span.enter();
588
589 inject_current_context(&mut msg);
590 }
591 match self.process_message(msg, conn_index).await {
594 Ok(_) => Ok(()),
595 Err(e) => {
596 info!(
598 telemetry = true,
599 monotonic_counter.num_message_process_errors = 1
600 );
601 Err(DataPathError::ProcessingError(e.to_string()))
605 }
606 }
607 }
608
609 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
610 let connection = self.forwarder().get_connection(conn_index);
611 match connection {
612 Some(conn) => {
613 debug!("try to notify the error to the local application");
614 if let Channel::Server(tx) = conn.channel() {
615 let status = Status::new(
617 tonic::Code::Internal,
618 format!("error processing message: {:?}", err),
619 );
620
621 if tx.send(Err(status)).await.is_err() {
622 debug!("unable to notify the error to the local app: {:?}", err);
623 }
624 }
625 }
626 None => {
627 error!(
628 "error sending error to local app: connection {:?} not found",
629 conn_index
630 );
631 }
632 }
633 }
634
635 async fn reconnect(
636 &self,
637 client_conf: Option<ClientConfig>,
638 conn_index: u64,
639 cancellation_token: &CancellationToken,
640 ) -> bool {
641 let config = client_conf.unwrap();
642 match config.to_channel() {
643 Err(e) => {
644 error!(
645 "cannot parse connection config, unable to reconnect {:?}",
646 e.to_string()
647 );
648 false
649 }
650 Ok(channel) => {
651 info!("connection lost with remote endpoint, try to reconnect");
652 let remote_subscriptions = self
657 .forwarder()
658 .get_subscriptions_forwarded_on_connection(conn_index);
659
660 tokio::select! {
661 _ = cancellation_token.cancelled() => {
662 debug!("cancellation token signaled, stopping reconnection process");
663 false
664 }
665 _ = self.get_drain_watch().signaled() => {
666 debug!("drain watch signaled, stopping reconnection process");
667 false
668 }
669 res = self.try_to_connect(channel, Some(config), None, None, Some(conn_index), 120) => {
670 match res {
671 Ok(_) => {
672 info!("connection re-established");
673 for r in remote_subscriptions.iter() {
675 let sub_msg = Message::new_subscribe(
676 r.source(),
677 r.name(),
678 None,
679 );
680 if self.send_msg(sub_msg, conn_index).await.is_err() {
681 error!("error restoring subscription on remote node");
682 }
683 }
684 true
685 }
686 Err(e) => {
687 error!("unable to connect to remote node {:?}", e.to_string());
689 false
690 }
691 }
692 }
693 }
694 }
695 }
696 }
697
698 fn process_stream(
699 &self,
700 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
701 conn_index: u64,
702 client_config: Option<ClientConfig>,
703 cancellation_token: CancellationToken,
704 is_local: bool,
705 from_control_plane: bool,
706 ) -> tokio::task::JoinHandle<()> {
707 let self_clone = self.clone();
709 let token_clone = cancellation_token.clone();
710 let client_conf_clone = client_config.clone();
711 let tx_cp: Option<Sender<Result<Message, Status>>> = self.get_tx_control_plane();
712
713 tokio::spawn(async move {
714 let mut try_to_reconnect = true;
715 loop {
716 tokio::select! {
717 next = stream.next() => {
718 match next {
719 Some(result) => {
720 match result {
721 Ok(msg) => {
722 if !is_local && !from_control_plane && tx_cp.is_some(){
728 match msg.get_type() {
729 PublishType(_) => {}
730 _ => {
731 let _ = tx_cp.as_ref().unwrap().send(Ok(msg.clone())).await;
734 }
735 }
736 }
737
738 if let Err(e) = self_clone.handle_new_message(conn_index, is_local, msg).await {
739 error!(%conn_index, %e, "error processing incoming message");
740 if is_local {
742 self_clone.send_error_to_local_app(conn_index, e).await;
744 }
745 }
746 }
747 Err(e) => {
748 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
749 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
750 info!(%conn_index, "connection closed by peer");
751 }
752 } else {
753 error!("error receiving messages {:?}", e);
754 }
755 break;
756 }
757 }
758 }
759 None => {
760 debug!(%conn_index, "end of stream");
761 break;
762 }
763 }
764 }
765 _ = self_clone.get_drain_watch().signaled() => {
766 debug!("shutting down stream on drain: {}", conn_index);
767 try_to_reconnect = false;
768 break;
769 }
770 _ = token_clone.cancelled() => {
771 debug!("shutting down stream on cancellation token: {}", conn_index);
772 try_to_reconnect = false;
773 break;
774 }
775 }
776 }
777
778 drop(stream);
782
783 let mut connected = false;
784
785 if try_to_reconnect && client_conf_clone.is_some() {
786 connected = self_clone
787 .reconnect(client_conf_clone, conn_index, &token_clone)
788 .await;
789 } else {
790 debug!("close connection {}", conn_index)
791 }
792
793 if !connected {
794 self_clone
796 .forwarder()
797 .on_connection_drop(conn_index, is_local);
798
799 info!(telemetry = true, counter.num_active_connections = -1);
800 }
801 })
802 }
803
804 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
805 let mut err: &(dyn std::error::Error + 'static) = err_status;
806
807 loop {
808 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
809 return Some(io_err);
810 }
811
812 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
815 if let Some(io_err) = h2_err.get_io() {
816 return Some(io_err);
817 }
818 }
819
820 err = err.source()?;
821 }
822 }
823
824 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
825 &self.internal.forwarder.subscription_table
826 }
827
828 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
829 &self.internal.forwarder.connection_table
830 }
831}
832
833#[tonic::async_trait]
834impl PubSubService for MessageProcessor {
835 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
836
837 async fn open_channel(
838 &self,
839 request: Request<tonic::Streaming<Message>>,
840 ) -> Result<Response<Self::OpenChannelStream>, Status> {
841 let remote_addr = request.remote_addr();
842 let local_addr = request.local_addr();
843
844 let stream = request.into_inner();
845 let (tx, rx) = mpsc::channel(128);
846
847 let connection = Connection::new(ConnectionType::Remote)
848 .with_remote_addr(remote_addr)
849 .with_local_addr(local_addr)
850 .with_channel(Channel::Server(tx));
851
852 debug!(
853 "new connection received from remote: (remote: {:?} - local: {:?})",
854 connection.remote_addr(),
855 connection.local_addr()
856 );
857 info!(telemetry = true, counter.num_active_connections = 1);
858
859 let conn_index = self
861 .forwarder()
862 .on_connection_established(connection, None)
863 .unwrap();
864
865 self.process_stream(
866 stream,
867 conn_index,
868 None,
869 CancellationToken::new(),
870 false,
871 false,
872 );
873
874 let out_stream = ReceiverStream::new(rx);
875 Ok(Response::new(
876 Box::pin(out_stream) as Self::OpenChannelStream
877 ))
878 }
879}