1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use opentelemetry::propagation::{Extractor, Injector};
8use opentelemetry::trace::TraceContextExt;
9use parking_lot::RwLock;
10use slim_config::grpc::client::ClientConfig;
11use slim_tracing::utils::INSTANCE_ID;
12use tokio::sync::mpsc::{self, Sender};
13use tokio_stream::wrappers::ReceiverStream;
14use tokio_stream::{Stream, StreamExt};
15use tokio_util::sync::CancellationToken;
16use tonic::codegen::{Body, StdError};
17use tonic::{Request, Response, Status};
18use tracing::{Span, debug, error, info};
19use tracing_opentelemetry::OpenTelemetrySpanExt;
20
21use crate::api::ProtoPublishType as PublishType;
22use crate::api::ProtoSubscribeType as SubscribeType;
23use crate::api::ProtoUnsubscribeType as UnsubscribeType;
24use crate::api::proto::dataplane::v1::Message;
25use crate::api::proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient;
26use crate::api::proto::dataplane::v1::data_plane_service_server::DataPlaneService;
27use crate::connection::{Channel, Connection, Type as ConnectionType};
28use crate::errors::DataPathError;
29use crate::forwarder::Forwarder;
30use crate::messages::Name;
31use crate::tables::connection_table::ConnectionTable;
32use crate::tables::subscription_table::SubscriptionTableImpl;
33
34struct MetadataExtractor<'a>(&'a std::collections::HashMap<String, String>);
36
37impl Extractor for MetadataExtractor<'_> {
38 fn get(&self, key: &str) -> Option<&str> {
39 self.0.get(key).map(|s| s.as_str())
40 }
41
42 fn keys(&self) -> Vec<&str> {
43 self.0.keys().map(|s| s.as_str()).collect()
44 }
45}
46
47struct MetadataInjector<'a>(&'a mut std::collections::HashMap<String, String>);
48
49impl Injector for MetadataInjector<'_> {
50 fn set(&mut self, key: &str, value: String) {
51 self.0.insert(key.to_string(), value);
52 }
53}
54
55fn extract_parent_context(msg: &Message) -> Option<opentelemetry::Context> {
57 let extractor = MetadataExtractor(&msg.metadata);
58 let parent_context =
59 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
60
61 if parent_context.span().span_context().is_valid() {
62 Some(parent_context)
63 } else {
64 None
65 }
66}
67
68fn inject_current_context(msg: &mut Message) {
70 let cx = tracing::Span::current().context();
71 let mut injector = MetadataInjector(&mut msg.metadata);
72 opentelemetry::global::get_text_map_propagator(|propagator| {
73 propagator.inject_context(&cx, &mut injector)
74 });
75}
76
77fn create_span(function: &str, out_conn: u64, msg: &Message) -> Span {
79 let span = tracing::span!(
80 tracing::Level::INFO,
81 "slim_process_message",
82 function = function,
83 source = format!("{}", msg.get_source()),
84 destination = format!("{}", msg.get_dst()),
85 instance_id = %INSTANCE_ID.as_str(),
86 connection_id = out_conn,
87 message_type = msg.get_type().to_string(),
88 telemetry = true
89 );
90
91 if let PublishType(_) = msg.get_type() {
92 span.set_attribute("session_type", msg.get_session_message_type().as_str_name());
93 span.set_attribute(
94 "session_id",
95 msg.get_session_header().get_session_id().to_string(),
96 );
97 span.set_attribute(
98 "message_id",
99 msg.get_session_header().get_message_id().to_string(),
100 );
101 }
102
103 span
104}
105
106#[derive(Debug)]
107struct MessageProcessorInternal {
108 forwarder: Forwarder<Connection>,
109 drain_channel: drain::Watch,
110 tx_control_plane: RwLock<Option<Sender<Result<Message, Status>>>>,
111}
112
113#[derive(Debug, Clone)]
114pub struct MessageProcessor {
115 internal: Arc<MessageProcessorInternal>,
116}
117
118impl MessageProcessor {
119 pub fn new() -> (Self, drain::Signal) {
120 let (signal, watch) = drain::channel();
121 let forwarder = Forwarder::new();
122 let forwarder = MessageProcessorInternal {
123 forwarder,
124 drain_channel: watch,
125 tx_control_plane: RwLock::new(None),
126 };
127
128 (
129 Self {
130 internal: Arc::new(forwarder),
131 },
132 signal,
133 )
134 }
135
136 pub fn with_drain_channel(watch: drain::Watch) -> Self {
137 let forwarder = Forwarder::new();
138 let forwarder = MessageProcessorInternal {
139 forwarder,
140 drain_channel: watch,
141 tx_control_plane: RwLock::new(None),
142 };
143 Self {
144 internal: Arc::new(forwarder),
145 }
146 }
147
148 fn set_tx_control_plane(&self, tx: Sender<Result<Message, Status>>) {
149 let mut tx_guard = self.internal.tx_control_plane.write();
150 *tx_guard = Some(tx);
151 }
152
153 fn get_tx_control_plane(&self) -> Option<Sender<Result<Message, Status>>> {
154 let tx_guard = self.internal.tx_control_plane.read();
155 tx_guard.clone()
156 }
157
158 fn forwarder(&self) -> &Forwarder<Connection> {
159 &self.internal.forwarder
160 }
161
162 fn get_drain_watch(&self) -> drain::Watch {
163 self.internal.drain_channel.clone()
164 }
165
166 async fn try_to_connect<C>(
167 &self,
168 channel: C,
169 client_config: Option<ClientConfig>,
170 local: Option<SocketAddr>,
171 remote: Option<SocketAddr>,
172 existing_conn_index: Option<u64>,
173 max_retry: u32,
174 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
175 where
176 C: tonic::client::GrpcService<tonic::body::Body>,
177 C::Error: Into<StdError>,
178 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
179 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
180 {
181 let mut client: DataPlaneServiceClient<C> = DataPlaneServiceClient::new(channel);
182 let mut i = 0;
183 while i < max_retry {
184 let (tx, rx) = mpsc::channel(128);
185 match client
186 .open_channel(Request::new(ReceiverStream::new(rx)))
187 .await
188 {
189 Ok(stream) => {
190 let cancellation_token = CancellationToken::new();
191 let connection = Connection::new(ConnectionType::Remote)
192 .with_local_addr(local)
193 .with_remote_addr(remote)
194 .with_channel(Channel::Client(tx))
195 .with_config_data(client_config.clone())
196 .with_cancellation_token(Some(cancellation_token.clone()));
197
198 debug!(
199 "new connection initiated locally: (remote: {:?} - local: {:?})",
200 connection.remote_addr(),
201 connection.local_addr()
202 );
203
204 let opt = self
206 .forwarder()
207 .on_connection_established(connection, existing_conn_index);
208 if opt.is_none() {
209 error!("error adding connection to the connection table");
210 return Err(DataPathError::ConnectionError(
211 "error adding connection to the connection tables".to_string(),
212 ));
213 }
214
215 let conn_index = opt.unwrap();
216 debug!(
217 "new connection index = {:?}, is local {:?}",
218 conn_index, false
219 );
220
221 let ret = self.process_stream(
223 stream.into_inner(),
224 conn_index,
225 client_config,
226 cancellation_token,
227 false,
228 false,
229 );
230 return Ok((ret, conn_index));
231 }
232 Err(e) => {
233 error!("connection error: {:?}.", e.to_string());
234 }
235 }
236 i += 1;
237
238 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
240 }
241
242 error!("unable to connect to the endpoint");
243 Err(DataPathError::ConnectionError(
244 "reached max connection retries".to_string(),
245 ))
246 }
247
248 pub async fn connect<C>(
249 &self,
250 channel: C,
251 client_config: Option<ClientConfig>,
252 local: Option<SocketAddr>,
253 remote: Option<SocketAddr>,
254 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
255 where
256 C: tonic::client::GrpcService<tonic::body::Body>,
257 C::Error: Into<StdError>,
258 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
259 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
260 {
261 self.try_to_connect(channel, client_config, local, remote, None, 10)
262 .await
263 }
264
265 pub fn disconnect(&self, conn: u64) -> Result<(), DataPathError> {
266 match self.forwarder().get_connection(conn) {
267 None => {
268 error!("error handling disconnect: connection unknown");
269 return Err(DataPathError::DisconnectionError(
270 "connection not found".to_string(),
271 ));
272 }
273 Some(c) => {
274 match c.cancellation_token() {
275 None => {
276 error!("error handling disconnect: missing cancellation token");
277 }
278 Some(t) => {
279 t.cancel();
283 }
284 }
285 }
286 }
287
288 Ok(())
289 }
290
291 pub fn register_local_connection(
292 &self,
293 from_control_plane: bool,
294 ) -> (
295 u64,
296 tokio::sync::mpsc::Sender<Result<Message, Status>>,
297 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
298 ) {
299 let (tx1, rx1) = mpsc::channel(128);
301
302 debug!("establishing new local app connection");
303
304 let (tx2, rx2) = mpsc::channel(128);
306
307 if from_control_plane && self.get_tx_control_plane().is_none() {
310 self.set_tx_control_plane(tx2.clone());
311 }
312
313 let cancellation_token = CancellationToken::new();
315 let connection = Connection::new(ConnectionType::Local)
316 .with_channel(Channel::Server(tx2))
317 .with_cancellation_token(Some(cancellation_token.clone()));
318
319 let conn_id = self
321 .forwarder()
322 .on_connection_established(connection, None)
323 .unwrap();
324
325 debug!("local connection established with id: {:?}", conn_id);
326 info!(telemetry = true, counter.num_active_connections = 1);
327
328 self.process_stream(
330 ReceiverStream::new(rx1),
331 conn_id,
332 None,
333 cancellation_token,
334 true,
335 from_control_plane,
336 );
337
338 (conn_id, tx1, rx2)
340 }
341
342 pub async fn send_msg(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
343 let connection = self.forwarder().get_connection(out_conn);
344 match connection {
345 Some(conn) => {
346 msg.clear_slim_header();
348
349 let parent_context = extract_parent_context(&msg);
351 let span = create_span("send_message", out_conn, &msg);
352
353 if let Some(ctx) = parent_context {
354 span.set_parent(ctx);
355 }
356 let _guard = span.enter();
357 inject_current_context(&mut msg);
358 match conn.channel() {
361 Channel::Server(s) => s
362 .send(Ok(msg))
363 .await
364 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
365 Channel::Client(s) => s
366 .send(msg)
367 .await
368 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
369 _ => Err(DataPathError::MessageSendError(
370 "connection not found".to_string(),
371 )),
372 }
373 }
374 None => Err(DataPathError::MessageSendError(format!(
375 "connection {:?} not found",
376 out_conn
377 ))),
378 }
379 }
380
381 async fn match_and_forward_msg(
382 &self,
383 msg: Message,
384 name: Name,
385 in_connection: u64,
386 fanout: u32,
387 ) -> Result<(), DataPathError> {
388 debug!(
389 "match and forward message: name: {} - fanout: {}",
390 name, fanout,
391 );
392
393 if let Some(val) = msg.get_forward_to() {
396 debug!("forwarding message to connection {}", val);
397 return self
398 .send_msg(msg, val)
399 .await
400 .map_err(|e| DataPathError::PublicationError(e.to_string()));
401 }
402
403 match self
404 .forwarder()
405 .on_publish_msg_match(name, in_connection, fanout)
406 {
407 Ok(out_vec) => {
408 let mut i = 0;
411 while i < out_vec.len() - 1 {
412 self.send_msg(msg.clone(), out_vec[i])
413 .await
414 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
415 i += 1;
416 }
417 self.send_msg(msg, out_vec[i])
418 .await
419 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
420 Ok(())
421 }
422 Err(e) => Err(DataPathError::PublicationError(e.to_string())),
423 }
424 }
425
426 async fn process_publish(&self, msg: Message, in_connection: u64) -> Result<(), DataPathError> {
427 debug!(
428 "received publication from connection {}: {:?}",
429 in_connection, msg
430 );
431
432 info!(
434 telemetry = true,
435 monotonic_counter.num_messages_by_type = 1,
436 method = "publish"
437 );
438 let header = msg.get_slim_header();
442
443 let dst = header.get_dst();
444
445 let fanout = msg.get_fanout();
448
449 return self
451 .match_and_forward_msg(msg, dst, in_connection, fanout)
452 .await;
453 }
454
455 async fn process_subscription(
459 &self,
460 msg: Message,
461 in_connection: u64,
462 add: bool,
463 ) -> Result<(), DataPathError> {
464 debug!(
465 "received subscription (add = {}) from connection {}: {:?}",
466 add, in_connection, msg
467 );
468
469 info!(
471 telemetry = true,
472 monotonic_counter.num_messages_by_type = 1,
473 message_type = { if add { "subscribe" } else { "unsubscribe" } }
474 );
475 let dst = msg.get_dst();
478
479 let header = msg.get_slim_header();
481
482 let (conn, forward) = header.get_in_out_connections();
484
485 let connection = self
488 .forwarder()
489 .get_connection(conn)
490 .ok_or_else(|| DataPathError::SubscriptionError("connection not found".to_string()))?;
491
492 debug!(
493 "subscription update (add = {}) for name: {} - connection: {}",
494 add, dst, conn
495 );
496
497 if let Err(e) = self.forwarder().on_subscription_msg(
498 dst.clone(),
499 conn,
500 connection.is_local_connection(),
501 add,
502 ) {
503 return Err(DataPathError::SubscriptionError(e.to_string()));
504 }
505
506 match forward {
507 None => {
508 Ok(())
510 }
511 Some(out_conn) => {
512 debug!("forward subscription (add = {}) to {}", add, out_conn);
513
514 let source = msg.get_source();
516
517 match self.send_msg(msg, out_conn).await {
519 Ok(_) => {
520 self.forwarder()
521 .on_forwarded_subscription(source, dst, out_conn, add);
522 Ok(())
523 }
524 Err(e) => Err(DataPathError::UnsubscriptionError(e.to_string())),
525 }
526 }
527 }
528 }
529
530 pub async fn process_message(
531 &self,
532 msg: Message,
533 in_connection: u64,
534 ) -> Result<(), DataPathError> {
535 match msg.get_type() {
537 SubscribeType(_) => self.process_subscription(msg, in_connection, true).await,
538 UnsubscribeType(_) => self.process_subscription(msg, in_connection, false).await,
539 PublishType(_) => self.process_publish(msg, in_connection).await,
540 }
541 }
542
543 async fn handle_new_message(
544 &self,
545 conn_index: u64,
546 is_local: bool,
547 mut msg: Message,
548 ) -> Result<(), DataPathError> {
549 debug!(%conn_index, "received message from connection");
550 info!(
551 telemetry = true,
552 monotonic_counter.num_processed_messages = 1
553 );
554
555 if let Err(err) = msg.validate() {
557 info!(
558 telemetry = true,
559 monotonic_counter.num_messages_by_type = 1,
560 message_type = "none"
561 );
562
563 return Err(DataPathError::InvalidMessage(err.to_string()));
564 }
565
566 msg.set_incoming_conn(Some(conn_index));
568
569 if is_local {
573 let span = create_span("process_local", conn_index, &msg);
575
576 let _guard = span.enter();
577
578 inject_current_context(&mut msg);
579 } else {
580 let parent_context = extract_parent_context(&msg);
582
583 let span = create_span("process_local", conn_index, &msg);
584
585 if let Some(ctx) = parent_context {
586 span.set_parent(ctx);
587 }
588 let _guard = span.enter();
589
590 inject_current_context(&mut msg);
591 }
592 match self.process_message(msg, conn_index).await {
595 Ok(_) => Ok(()),
596 Err(e) => {
597 info!(
599 telemetry = true,
600 monotonic_counter.num_message_process_errors = 1
601 );
602 Err(DataPathError::ProcessingError(e.to_string()))
606 }
607 }
608 }
609
610 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
611 let connection = self.forwarder().get_connection(conn_index);
612 match connection {
613 Some(conn) => {
614 debug!("try to notify the error to the local application");
615 if let Channel::Server(tx) = conn.channel() {
616 let status = Status::new(
618 tonic::Code::Internal,
619 format!("error processing message: {:?}", err),
620 );
621
622 if tx.send(Err(status)).await.is_err() {
623 debug!("unable to notify the error to the local app: {:?}", err);
624 }
625 }
626 }
627 None => {
628 error!(
629 "error sending error to local app: connection {:?} not found",
630 conn_index
631 );
632 }
633 }
634 }
635
636 async fn reconnect(
637 &self,
638 client_conf: Option<ClientConfig>,
639 conn_index: u64,
640 cancellation_token: &CancellationToken,
641 ) -> bool {
642 let config = client_conf.unwrap();
643 match config.to_channel() {
644 Err(e) => {
645 error!(
646 "cannot parse connection config, unable to reconnect {:?}",
647 e.to_string()
648 );
649 false
650 }
651 Ok(channel) => {
652 info!("connection lost with remote endpoint, try to reconnect");
653 let remote_subscriptions = self
658 .forwarder()
659 .get_subscriptions_forwarded_on_connection(conn_index);
660
661 tokio::select! {
662 _ = cancellation_token.cancelled() => {
663 debug!("cancellation token signaled, stopping reconnection process");
664 false
665 }
666 _ = self.get_drain_watch().signaled() => {
667 debug!("drain watch signaled, stopping reconnection process");
668 false
669 }
670 res = self.try_to_connect(channel, Some(config), None, None, Some(conn_index), 120) => {
671 match res {
672 Ok(_) => {
673 info!("connection re-established");
674 for r in remote_subscriptions.iter() {
676 let sub_msg = Message::new_subscribe(
677 r.source(),
678 r.name(),
679 None,
680 );
681 if self.send_msg(sub_msg, conn_index).await.is_err() {
682 error!("error restoring subscription on remote node");
683 }
684 }
685 true
686 }
687 Err(e) => {
688 error!("unable to connect to remote node {:?}", e.to_string());
690 false
691 }
692 }
693 }
694 }
695 }
696 }
697 }
698
699 fn process_stream(
700 &self,
701 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
702 conn_index: u64,
703 client_config: Option<ClientConfig>,
704 cancellation_token: CancellationToken,
705 is_local: bool,
706 from_control_plane: bool,
707 ) -> tokio::task::JoinHandle<()> {
708 let self_clone = self.clone();
710 let token_clone = cancellation_token.clone();
711 let client_conf_clone = client_config.clone();
712 let tx_cp: Option<Sender<Result<Message, Status>>> = self.get_tx_control_plane();
713
714 tokio::spawn(async move {
715 let mut try_to_reconnect = true;
716 loop {
717 tokio::select! {
718 next = stream.next() => {
719 match next {
720 Some(result) => {
721 match result {
722 Ok(msg) => {
723 if !is_local && !from_control_plane && tx_cp.is_some(){
729 match msg.get_type() {
730 PublishType(_) => {}
731 _ => {
732 let _ = tx_cp.as_ref().unwrap().send(Ok(msg.clone())).await;
735 }
736 }
737 }
738
739 if let Err(e) = self_clone.handle_new_message(conn_index, is_local, msg).await {
740 error!(%conn_index, %e, "error processing incoming message");
741 if is_local {
743 self_clone.send_error_to_local_app(conn_index, e).await;
745 }
746 }
747 }
748 Err(e) => {
749 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
750 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
751 info!(%conn_index, "connection closed by peer");
752 }
753 } else {
754 error!("error receiving messages {:?}", e);
755 }
756 break;
757 }
758 }
759 }
760 None => {
761 debug!(%conn_index, "end of stream");
762 break;
763 }
764 }
765 }
766 _ = self_clone.get_drain_watch().signaled() => {
767 debug!("shutting down stream on drain: {}", conn_index);
768 try_to_reconnect = false;
769 break;
770 }
771 _ = token_clone.cancelled() => {
772 debug!("shutting down stream on cancellation token: {}", conn_index);
773 try_to_reconnect = false;
774 break;
775 }
776 }
777 }
778
779 drop(stream);
783
784 let mut connected = false;
785
786 if try_to_reconnect && client_conf_clone.is_some() {
787 connected = self_clone
788 .reconnect(client_conf_clone, conn_index, &token_clone)
789 .await;
790 } else {
791 debug!("close connection {}", conn_index)
792 }
793
794 if !connected {
795 self_clone
797 .forwarder()
798 .on_connection_drop(conn_index, is_local);
799
800 info!(telemetry = true, counter.num_active_connections = -1);
801 }
802 })
803 }
804
805 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
806 let mut err: &(dyn std::error::Error + 'static) = err_status;
807
808 loop {
809 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
810 return Some(io_err);
811 }
812
813 if let Some(h2_err) = err.downcast_ref::<h2::Error>()
816 && let Some(io_err) = h2_err.get_io()
817 {
818 return Some(io_err);
819 }
820
821 err = err.source()?;
822 }
823 }
824
825 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
826 &self.internal.forwarder.subscription_table
827 }
828
829 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
830 &self.internal.forwarder.connection_table
831 }
832}
833
834#[tonic::async_trait]
835impl DataPlaneService for MessageProcessor {
836 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
837
838 async fn open_channel(
839 &self,
840 request: Request<tonic::Streaming<Message>>,
841 ) -> Result<Response<Self::OpenChannelStream>, Status> {
842 let remote_addr = request.remote_addr();
843 let local_addr = request.local_addr();
844
845 let stream = request.into_inner();
846 let (tx, rx) = mpsc::channel(128);
847
848 let connection = Connection::new(ConnectionType::Remote)
849 .with_remote_addr(remote_addr)
850 .with_local_addr(local_addr)
851 .with_channel(Channel::Server(tx));
852
853 debug!(
854 "new connection received from remote: (remote: {:?} - local: {:?})",
855 connection.remote_addr(),
856 connection.local_addr()
857 );
858 info!(telemetry = true, counter.num_active_connections = 1);
859
860 let conn_index = self
862 .forwarder()
863 .on_connection_established(connection, None)
864 .unwrap();
865
866 self.process_stream(
867 stream,
868 conn_index,
869 None,
870 CancellationToken::new(),
871 false,
872 false,
873 );
874
875 let out_stream = ReceiverStream::new(rx);
876 Ok(Response::new(
877 Box::pin(out_stream) as Self::OpenChannelStream
878 ))
879 }
880}