1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use opentelemetry::propagation::{Extractor, Injector};
8use opentelemetry::trace::TraceContextExt;
9use slim_config::grpc::client::ClientConfig;
10use slim_tracing::utils::INSTANCE_ID;
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::{Stream, StreamExt};
14use tokio_util::sync::CancellationToken;
15use tonic::codegen::{Body, StdError};
16use tonic::{Request, Response, Status};
17use tracing::{Span, debug, error, info};
18use tracing_opentelemetry::OpenTelemetrySpanExt;
19
20use crate::api::ProtoPublishType as PublishType;
21use crate::api::ProtoSubscribeType as SubscribeType;
22use crate::api::ProtoUnsubscribeType as UnsubscribeType;
23use crate::api::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
24use crate::api::proto::pubsub::v1::{Message, pub_sub_service_server::PubSubService};
25use crate::connection::{Channel, Connection, Type as ConnectionType};
26use crate::errors::DataPathError;
27use crate::forwarder::Forwarder;
28use crate::messages::Name;
29use crate::tables::connection_table::ConnectionTable;
30use crate::tables::subscription_table::SubscriptionTableImpl;
31
32struct MetadataExtractor<'a>(&'a std::collections::HashMap<String, String>);
34
35impl Extractor for MetadataExtractor<'_> {
36 fn get(&self, key: &str) -> Option<&str> {
37 self.0.get(key).map(|s| s.as_str())
38 }
39
40 fn keys(&self) -> Vec<&str> {
41 self.0.keys().map(|s| s.as_str()).collect()
42 }
43}
44
45struct MetadataInjector<'a>(&'a mut std::collections::HashMap<String, String>);
46
47impl Injector for MetadataInjector<'_> {
48 fn set(&mut self, key: &str, value: String) {
49 self.0.insert(key.to_string(), value);
50 }
51}
52
53fn extract_parent_context(msg: &Message) -> Option<opentelemetry::Context> {
55 let extractor = MetadataExtractor(&msg.metadata);
56 let parent_context =
57 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
58
59 if parent_context.span().span_context().is_valid() {
60 Some(parent_context)
61 } else {
62 None
63 }
64}
65
66fn inject_current_context(msg: &mut Message) {
68 let cx = tracing::Span::current().context();
69 let mut injector = MetadataInjector(&mut msg.metadata);
70 opentelemetry::global::get_text_map_propagator(|propagator| {
71 propagator.inject_context(&cx, &mut injector)
72 });
73}
74
75fn create_span(function: &str, out_conn: u64, msg: &Message) -> Span {
77 let span = tracing::span!(
78 tracing::Level::INFO,
79 "slim_process_message",
80 function = function,
81 source = format!("{}", msg.get_source()),
82 destination = format!("{}", msg.get_dst()),
83 instance_id = %INSTANCE_ID.as_str(),
84 connection_id = out_conn,
85 message_type = msg.get_type().to_string(),
86 telemetry = true
87 );
88
89 if let PublishType(_) = msg.get_type() {
90 span.set_attribute("session_type", msg.get_session_message_type().as_str_name());
91 span.set_attribute(
92 "session_id",
93 msg.get_session_header().get_session_id().to_string(),
94 );
95 span.set_attribute(
96 "message_id",
97 msg.get_session_header().get_message_id().to_string(),
98 );
99 }
100
101 span
102}
103
104#[derive(Debug)]
105struct MessageProcessorInternal {
106 forwarder: Forwarder<Connection>,
107 drain_channel: drain::Watch,
108}
109
110#[derive(Debug, Clone)]
111pub struct MessageProcessor {
112 internal: Arc<MessageProcessorInternal>,
113}
114
115impl MessageProcessor {
116 pub fn new() -> (Self, drain::Signal) {
117 let (signal, watch) = drain::channel();
118 let forwarder = Forwarder::new();
119 let forwarder = MessageProcessorInternal {
120 forwarder,
121 drain_channel: watch,
122 };
123
124 (
125 Self {
126 internal: Arc::new(forwarder),
127 },
128 signal,
129 )
130 }
131
132 pub fn with_drain_channel(watch: drain::Watch) -> Self {
133 let forwarder = Forwarder::new();
134 let forwarder = MessageProcessorInternal {
135 forwarder,
136 drain_channel: watch,
137 };
138 Self {
139 internal: Arc::new(forwarder),
140 }
141 }
142
143 fn forwarder(&self) -> &Forwarder<Connection> {
144 &self.internal.forwarder
145 }
146
147 fn get_drain_watch(&self) -> drain::Watch {
148 self.internal.drain_channel.clone()
149 }
150
151 async fn try_to_connect<C>(
152 &self,
153 channel: C,
154 client_config: Option<ClientConfig>,
155 local: Option<SocketAddr>,
156 remote: Option<SocketAddr>,
157 existing_conn_index: Option<u64>,
158 max_retry: u32,
159 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
160 where
161 C: tonic::client::GrpcService<tonic::body::Body>,
162 C::Error: Into<StdError>,
163 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
164 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
165 {
166 let mut client: PubSubServiceClient<C> = PubSubServiceClient::new(channel);
167 let mut i = 0;
168 while i < max_retry {
169 let (tx, rx) = mpsc::channel(128);
170 match client
171 .open_channel(Request::new(ReceiverStream::new(rx)))
172 .await
173 {
174 Ok(stream) => {
175 let cancellation_token = CancellationToken::new();
176 let connection = Connection::new(ConnectionType::Remote)
177 .with_local_addr(local)
178 .with_remote_addr(remote)
179 .with_channel(Channel::Client(tx))
180 .with_config_data(client_config.clone())
181 .with_cancellation_token(Some(cancellation_token.clone()));
182
183 debug!(
184 "new connection initiated locally: (remote: {:?} - local: {:?})",
185 connection.remote_addr(),
186 connection.local_addr()
187 );
188
189 let opt = self
191 .forwarder()
192 .on_connection_established(connection, existing_conn_index);
193 if opt.is_none() {
194 error!("error adding connection to the connection table");
195 return Err(DataPathError::ConnectionError(
196 "error adding connection to the connection tables".to_string(),
197 ));
198 }
199
200 let conn_index = opt.unwrap();
201 debug!(
202 "new connection index = {:?}, is local {:?}",
203 conn_index, false
204 );
205
206 let ret = self.process_stream(
208 stream.into_inner(),
209 conn_index,
210 client_config,
211 cancellation_token,
212 false,
213 );
214 return Ok((ret, conn_index));
215 }
216 Err(e) => {
217 error!("connection error: {:?}.", e.to_string());
218 }
219 }
220 i += 1;
221
222 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
224 }
225
226 error!("unable to connect to the endpoint");
227 Err(DataPathError::ConnectionError(
228 "reached max connection retries".to_string(),
229 ))
230 }
231
232 pub async fn connect<C>(
233 &self,
234 channel: C,
235 client_config: Option<ClientConfig>,
236 local: Option<SocketAddr>,
237 remote: Option<SocketAddr>,
238 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
239 where
240 C: tonic::client::GrpcService<tonic::body::Body>,
241 C::Error: Into<StdError>,
242 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
243 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
244 {
245 self.try_to_connect(channel, client_config, local, remote, None, 10)
246 .await
247 }
248
249 pub fn disconnect(&self, conn: u64) -> Result<(), DataPathError> {
250 match self.forwarder().get_connection(conn) {
251 None => {
252 error!("error handling disconnect: connection unknown");
253 return Err(DataPathError::DisconnectionError(
254 "connection not found".to_string(),
255 ));
256 }
257 Some(c) => {
258 match c.cancellation_token() {
259 None => {
260 error!("error handling disconnect: missing cancellation token");
261 }
262 Some(t) => {
263 t.cancel();
267 }
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 pub fn register_local_connection(
276 &self,
277 ) -> (
278 u64,
279 tokio::sync::mpsc::Sender<Result<Message, Status>>,
280 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
281 ) {
282 let (tx1, rx1) = mpsc::channel(128);
284
285 debug!("establishing new local app connection");
286
287 let (tx2, rx2) = mpsc::channel(128);
289
290 let cancellation_token = CancellationToken::new();
292 let connection = Connection::new(ConnectionType::Local)
293 .with_channel(Channel::Server(tx2))
294 .with_cancellation_token(Some(cancellation_token.clone()));
295
296 let conn_id = self
298 .forwarder()
299 .on_connection_established(connection, None)
300 .unwrap();
301
302 debug!("local connection established with id: {:?}", conn_id);
303 info!(telemetry = true, counter.num_active_connections = 1);
304
305 self.process_stream(
307 ReceiverStream::new(rx1),
308 conn_id,
309 None,
310 cancellation_token,
311 true,
312 );
313
314 (conn_id, tx1, rx2)
316 }
317
318 pub async fn send_msg(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
319 let connection = self.forwarder().get_connection(out_conn);
320 match connection {
321 Some(conn) => {
322 msg.clear_slim_header();
324
325 let parent_context = extract_parent_context(&msg);
327 let span = create_span("send_message", out_conn, &msg);
328
329 if let Some(ctx) = parent_context {
330 span.set_parent(ctx);
331 }
332 let _guard = span.enter();
333 inject_current_context(&mut msg);
334 match conn.channel() {
337 Channel::Server(s) => s
338 .send(Ok(msg))
339 .await
340 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
341 Channel::Client(s) => s
342 .send(msg)
343 .await
344 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
345 _ => Err(DataPathError::MessageSendError(
346 "connection not found".to_string(),
347 )),
348 }
349 }
350 None => Err(DataPathError::MessageSendError(format!(
351 "connection {:?} not found",
352 out_conn
353 ))),
354 }
355 }
356
357 async fn match_and_forward_msg(
358 &self,
359 msg: Message,
360 name: Name,
361 in_connection: u64,
362 fanout: u32,
363 ) -> Result<(), DataPathError> {
364 debug!(
365 "match and forward message: name: {} - fanout: {}",
366 name, fanout,
367 );
368
369 if let Some(val) = msg.get_forward_to() {
372 debug!("forwarding message to connection {}", val);
373 return self
374 .send_msg(msg, val)
375 .await
376 .map_err(|e| DataPathError::PublicationError(e.to_string()));
377 }
378
379 match self
380 .forwarder()
381 .on_publish_msg_match(name, in_connection, fanout)
382 {
383 Ok(out_vec) => {
384 let mut i = 0;
387 while i < out_vec.len() - 1 {
388 self.send_msg(msg.clone(), out_vec[i])
389 .await
390 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
391 i += 1;
392 }
393 self.send_msg(msg, out_vec[i])
394 .await
395 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
396 Ok(())
397 }
398 Err(e) => Err(DataPathError::PublicationError(e.to_string())),
399 }
400 }
401
402 async fn process_publish(&self, msg: Message, in_connection: u64) -> Result<(), DataPathError> {
403 debug!(
404 "received publication from connection {}: {:?}",
405 in_connection, msg
406 );
407
408 info!(
410 telemetry = true,
411 monotonic_counter.num_messages_by_type = 1,
412 method = "publish"
413 );
414 let header = msg.get_slim_header();
418
419 let dst = header.get_dst();
420
421 let fanout = msg.get_fanout();
424
425 return self
427 .match_and_forward_msg(msg, dst, in_connection, fanout)
428 .await;
429 }
430
431 async fn process_subscription(
435 &self,
436 msg: Message,
437 in_connection: u64,
438 add: bool,
439 ) -> Result<(), DataPathError> {
440 debug!(
441 "received subscription (add = {}) from connection {}: {:?}",
442 add, in_connection, msg
443 );
444
445 info!(
447 telemetry = true,
448 monotonic_counter.num_messages_by_type = 1,
449 message_type = { if add { "subscribe" } else { "unsubscribe" } }
450 );
451 let dst = msg.get_dst();
454
455 let header = msg.get_slim_header();
457
458 let (conn, forward) = header.get_in_out_connections();
460
461 let connection = self
464 .forwarder()
465 .get_connection(conn)
466 .ok_or_else(|| DataPathError::SubscriptionError("connection not found".to_string()))?;
467
468 debug!(
469 "subscription update (add = {}) for name: {} - connection: {}",
470 add, dst, conn
471 );
472
473 if let Err(e) = self.forwarder().on_subscription_msg(
474 dst.clone(),
475 conn,
476 connection.is_local_connection(),
477 add,
478 ) {
479 return Err(DataPathError::SubscriptionError(e.to_string()));
480 }
481
482 match forward {
483 None => {
484 Ok(())
486 }
487 Some(out_conn) => {
488 debug!("forward subscription (add = {}) to {}", add, out_conn);
489
490 let source = msg.get_source();
492
493 match self.send_msg(msg, out_conn).await {
495 Ok(_) => {
496 self.forwarder()
497 .on_forwarded_subscription(source, dst, out_conn, add);
498 Ok(())
499 }
500 Err(e) => Err(DataPathError::UnsubscriptionError(e.to_string())),
501 }
502 }
503 }
504 }
505
506 pub async fn process_message(
507 &self,
508 msg: Message,
509 in_connection: u64,
510 ) -> Result<(), DataPathError> {
511 match msg.get_type() {
513 SubscribeType(_) => self.process_subscription(msg, in_connection, true).await,
514 UnsubscribeType(_) => self.process_subscription(msg, in_connection, false).await,
515 PublishType(_) => self.process_publish(msg, in_connection).await,
516 }
517 }
518
519 async fn handle_new_message(
520 &self,
521 conn_index: u64,
522 is_local: bool,
523 mut msg: Message,
524 ) -> Result<(), DataPathError> {
525 debug!(%conn_index, "received message from connection");
526 info!(
527 telemetry = true,
528 monotonic_counter.num_processed_messages = 1
529 );
530
531 if let Err(err) = msg.validate() {
533 info!(
534 telemetry = true,
535 monotonic_counter.num_messages_by_type = 1,
536 message_type = "none"
537 );
538
539 return Err(DataPathError::InvalidMessage(err.to_string()));
540 }
541
542 msg.set_incoming_conn(Some(conn_index));
544
545 if is_local {
549 let span = create_span("process_local", conn_index, &msg);
551
552 let _guard = span.enter();
553
554 inject_current_context(&mut msg);
555 } else {
556 let parent_context = extract_parent_context(&msg);
558
559 let span = create_span("process_local", conn_index, &msg);
560
561 if let Some(ctx) = parent_context {
562 span.set_parent(ctx);
563 }
564 let _guard = span.enter();
565
566 inject_current_context(&mut msg);
567 }
568 match self.process_message(msg, conn_index).await {
571 Ok(_) => Ok(()),
572 Err(e) => {
573 info!(
575 telemetry = true,
576 monotonic_counter.num_message_process_errors = 1
577 );
578 Err(DataPathError::ProcessingError(e.to_string()))
582 }
583 }
584 }
585
586 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
587 let connection = self.forwarder().get_connection(conn_index);
588 match connection {
589 Some(conn) => {
590 debug!("try to notify the error to the local application");
591 if let Channel::Server(tx) = conn.channel() {
592 let status = Status::new(
594 tonic::Code::Internal,
595 format!("error processing message: {:?}", err),
596 );
597
598 if tx.send(Err(status)).await.is_err() {
599 debug!("unable to notify the error to the local app: {:?}", err);
600 }
601 }
602 }
603 None => {
604 error!(
605 "error sending error to local app: connection {:?} not found",
606 conn_index
607 );
608 }
609 }
610 }
611
612 async fn reconnect(
613 &self,
614 client_conf: Option<ClientConfig>,
615 conn_index: u64,
616 cancellation_token: &CancellationToken,
617 ) -> bool {
618 let config = client_conf.unwrap();
619 match config.to_channel() {
620 Err(e) => {
621 error!(
622 "cannot parse connection config, unable to reconnect {:?}",
623 e.to_string()
624 );
625 false
626 }
627 Ok(channel) => {
628 info!("connection lost with remote endpoint, try to reconnect");
629 let remote_subscriptions = self
634 .forwarder()
635 .get_subscriptions_forwarded_on_connection(conn_index);
636
637 tokio::select! {
638 _ = cancellation_token.cancelled() => {
639 debug!("cancellation token signaled, stopping reconnection process");
640 false
641 }
642 _ = self.get_drain_watch().signaled() => {
643 debug!("drain watch signaled, stopping reconnection process");
644 false
645 }
646 res = self.try_to_connect(channel, Some(config), None, None, Some(conn_index), 120) => {
647 match res {
648 Ok(_) => {
649 info!("connection re-established");
650 for r in remote_subscriptions.iter() {
652 let sub_msg = Message::new_subscribe(
653 r.source(),
654 r.name(),
655 None,
656 );
657 if self.send_msg(sub_msg, conn_index).await.is_err() {
658 error!("error restoring subscription on remote node");
659 }
660 }
661 true
662 }
663 Err(e) => {
664 error!("unable to connect to remote node {:?}", e.to_string());
666 false
667 }
668 }
669 }
670 }
671 }
672 }
673 }
674
675 fn process_stream(
676 &self,
677 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
678 conn_index: u64,
679 client_config: Option<ClientConfig>,
680 cancellation_token: CancellationToken,
681 is_local: bool,
682 ) -> tokio::task::JoinHandle<()> {
683 let self_clone = self.clone();
685 let token_clone = cancellation_token.clone();
686 let client_conf_clone = client_config.clone();
687
688 tokio::spawn(async move {
689 let mut try_to_reconnect = true;
690 loop {
691 tokio::select! {
692 next = stream.next() => {
693 match next {
694 Some(result) => {
695 match result {
696 Ok(msg) => {
697 if let Err(e) = self_clone.handle_new_message(conn_index, is_local, msg).await {
698 error!(%conn_index, %e, "error processing incoming message");
699 if is_local {
701 self_clone.send_error_to_local_app(conn_index, e).await;
703 }
704 }
705 }
706 Err(e) => {
707 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
708 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
709 info!(%conn_index, "connection closed by peer");
710 }
711 } else {
712 error!("error receiving messages {:?}", e);
713 }
714 break;
715 }
716 }
717 }
718 None => {
719 debug!(%conn_index, "end of stream");
720 break;
721 }
722 }
723 }
724 _ = self_clone.get_drain_watch().signaled() => {
725 debug!("shutting down stream on drain: {}", conn_index);
726 try_to_reconnect = false;
727 break;
728 }
729 _ = token_clone.cancelled() => {
730 debug!("shutting down stream on cancellation token: {}", conn_index);
731 try_to_reconnect = false;
732 break;
733 }
734 }
735 }
736
737 drop(stream);
741
742 let mut connected = false;
743
744 if try_to_reconnect && client_conf_clone.is_some() {
745 connected = self_clone
746 .reconnect(client_conf_clone, conn_index, &token_clone)
747 .await;
748 } else {
749 debug!("close connection {}", conn_index)
750 }
751
752 if !connected {
753 self_clone
755 .forwarder()
756 .on_connection_drop(conn_index, is_local);
757
758 info!(telemetry = true, counter.num_active_connections = -1);
759 }
760 })
761 }
762
763 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
764 let mut err: &(dyn std::error::Error + 'static) = err_status;
765
766 loop {
767 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
768 return Some(io_err);
769 }
770
771 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
774 if let Some(io_err) = h2_err.get_io() {
775 return Some(io_err);
776 }
777 }
778
779 err = err.source()?;
780 }
781 }
782
783 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
784 &self.internal.forwarder.subscription_table
785 }
786
787 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
788 &self.internal.forwarder.connection_table
789 }
790}
791
792#[tonic::async_trait]
793impl PubSubService for MessageProcessor {
794 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
795
796 async fn open_channel(
797 &self,
798 request: Request<tonic::Streaming<Message>>,
799 ) -> Result<Response<Self::OpenChannelStream>, Status> {
800 let remote_addr = request.remote_addr();
801 let local_addr = request.local_addr();
802
803 let stream = request.into_inner();
804 let (tx, rx) = mpsc::channel(128);
805
806 let connection = Connection::new(ConnectionType::Remote)
807 .with_remote_addr(remote_addr)
808 .with_local_addr(local_addr)
809 .with_channel(Channel::Server(tx));
810
811 debug!(
812 "new connection received from remote: (remote: {:?} - local: {:?})",
813 connection.remote_addr(),
814 connection.local_addr()
815 );
816 info!(telemetry = true, counter.num_active_connections = 1);
817
818 let conn_index = self
820 .forwarder()
821 .on_connection_established(connection, None)
822 .unwrap();
823
824 self.process_stream(stream, conn_index, None, CancellationToken::new(), false);
825
826 let out_stream = ReceiverStream::new(rx);
827 Ok(Response::new(
828 Box::pin(out_stream) as Self::OpenChannelStream
829 ))
830 }
831}