1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use opentelemetry::propagation::{Extractor, Injector};
8use opentelemetry::trace::TraceContextExt;
9use slim_config::grpc::client::ClientConfig;
10use slim_tracing::utils::INSTANCE_ID;
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::{Stream, StreamExt};
14use tokio_util::sync::CancellationToken;
15use tonic::codegen::{Body, StdError};
16use tonic::{Request, Response, Status};
17use tracing::{Span, debug, error, info};
18use tracing_opentelemetry::OpenTelemetrySpanExt;
19
20use crate::api::proto::pubsub::v1::message::MessageType::Publish as PublishType;
21use crate::api::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
22use crate::api::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
23use crate::api::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
24use crate::api::proto::pubsub::v1::{Message, pub_sub_service_server::PubSubService};
25use crate::connection::{Channel, Connection, Type as ConnectionType};
26use crate::errors::DataPathError;
27use crate::forwarder::Forwarder;
28use crate::messages::AgentType;
29use crate::tables::connection_table::ConnectionTable;
30use crate::tables::subscription_table::SubscriptionTableImpl;
31
32struct MetadataExtractor<'a>(&'a std::collections::HashMap<String, String>);
34
35impl Extractor for MetadataExtractor<'_> {
36 fn get(&self, key: &str) -> Option<&str> {
37 self.0.get(key).map(|s| s.as_str())
38 }
39
40 fn keys(&self) -> Vec<&str> {
41 self.0.keys().map(|s| s.as_str()).collect()
42 }
43}
44
45struct MetadataInjector<'a>(&'a mut std::collections::HashMap<String, String>);
46
47impl Injector for MetadataInjector<'_> {
48 fn set(&mut self, key: &str, value: String) {
49 self.0.insert(key.to_string(), value);
50 }
51}
52
53fn extract_parent_context(msg: &Message) -> Option<opentelemetry::Context> {
55 let extractor = MetadataExtractor(&msg.metadata);
56 let parent_context =
57 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
58
59 if parent_context.span().span_context().is_valid() {
60 Some(parent_context)
61 } else {
62 None
63 }
64}
65
66fn inject_current_context(msg: &mut Message) {
68 let cx = tracing::Span::current().context();
69 let mut injector = MetadataInjector(&mut msg.metadata);
70 opentelemetry::global::get_text_map_propagator(|propagator| {
71 propagator.inject_context(&cx, &mut injector)
72 });
73}
74
75fn create_span(function: &str, out_conn: u64, msg: &Message) -> Span {
77 let span = tracing::span!(
78 tracing::Level::INFO,
79 "slim_process_message",
80 function = function,
81 source = format!("{}", msg.get_source()),
82 destination = format!("{}",msg.get_name_as_agent()),
83 instance_id = %INSTANCE_ID.as_str(),
84 connection_id = out_conn,
85 message_type = msg.get_type().to_string(),
86 telemetry = true
87 );
88
89 match msg.get_type() {
90 PublishType(_) => {
91 span.set_attribute("session_type", msg.get_header_type().as_str_name());
92 span.set_attribute(
93 "session_id",
94 msg.get_session_header().get_session_id().to_string(),
95 );
96 span.set_attribute(
97 "message_id",
98 msg.get_session_header().get_message_id().to_string(),
99 );
100 }
101 _ => {} }
103
104 span
105}
106
107#[derive(Debug)]
108struct MessageProcessorInternal {
109 forwarder: Forwarder<Connection>,
110 drain_channel: drain::Watch,
111}
112
113#[derive(Debug, Clone)]
114pub struct MessageProcessor {
115 internal: Arc<MessageProcessorInternal>,
116}
117
118impl MessageProcessor {
119 pub fn new() -> (Self, drain::Signal) {
120 let (signal, watch) = drain::channel();
121 let forwarder = Forwarder::new();
122 let forwarder = MessageProcessorInternal {
123 forwarder,
124 drain_channel: watch,
125 };
126
127 (
128 Self {
129 internal: Arc::new(forwarder),
130 },
131 signal,
132 )
133 }
134
135 pub fn with_drain_channel(watch: drain::Watch) -> Self {
136 let forwarder = Forwarder::new();
137 let forwarder = MessageProcessorInternal {
138 forwarder,
139 drain_channel: watch,
140 };
141 Self {
142 internal: Arc::new(forwarder),
143 }
144 }
145
146 fn forwarder(&self) -> &Forwarder<Connection> {
147 &self.internal.forwarder
148 }
149
150 fn get_drain_watch(&self) -> drain::Watch {
151 self.internal.drain_channel.clone()
152 }
153
154 async fn try_to_connect<C>(
155 &self,
156 channel: C,
157 client_config: Option<ClientConfig>,
158 local: Option<SocketAddr>,
159 remote: Option<SocketAddr>,
160 existing_conn_index: Option<u64>,
161 max_retry: u32,
162 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
163 where
164 C: tonic::client::GrpcService<tonic::body::Body>,
165 C::Error: Into<StdError>,
166 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
167 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
168 {
169 let mut client: PubSubServiceClient<C> = PubSubServiceClient::new(channel);
170 let mut i = 0;
171 while i < max_retry {
172 let (tx, rx) = mpsc::channel(128);
173 match client
174 .open_channel(Request::new(ReceiverStream::new(rx)))
175 .await
176 {
177 Ok(stream) => {
178 let cancellation_token = CancellationToken::new();
179 let connection = Connection::new(ConnectionType::Remote)
180 .with_local_addr(local)
181 .with_remote_addr(remote)
182 .with_channel(Channel::Client(tx))
183 .with_cancellation_token(Some(cancellation_token.clone()));
184
185 debug!(
186 "new connection initiated locally: (remote: {:?} - local: {:?})",
187 connection.remote_addr(),
188 connection.local_addr()
189 );
190
191 let opt = self
193 .forwarder()
194 .on_connection_established(connection, existing_conn_index);
195 if opt.is_none() {
196 error!("error adding connection to the connection table");
197 return Err(DataPathError::ConnectionError(
198 "error adding connection to the connection tables".to_string(),
199 ));
200 }
201
202 let conn_index = opt.unwrap();
203 debug!(
204 "new connection index = {:?}, is local {:?}",
205 conn_index, false
206 );
207
208 let ret = self.process_stream(
210 stream.into_inner(),
211 conn_index,
212 client_config,
213 cancellation_token,
214 false,
215 );
216 return Ok((ret, conn_index));
217 }
218 Err(e) => {
219 error!("connection error: {:?}.", e.to_string());
220 }
221 }
222 i += 1;
223
224 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
226 }
227
228 error!("unable to connect to the endpoint");
229 Err(DataPathError::ConnectionError(
230 "reached max connection retries".to_string(),
231 ))
232 }
233
234 pub async fn connect<C>(
235 &self,
236 channel: C,
237 client_config: Option<ClientConfig>,
238 local: Option<SocketAddr>,
239 remote: Option<SocketAddr>,
240 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
241 where
242 C: tonic::client::GrpcService<tonic::body::Body>,
243 C::Error: Into<StdError>,
244 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
245 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
246 {
247 self.try_to_connect(channel, client_config, local, remote, None, 10)
248 .await
249 }
250
251 pub fn disconnect(&self, conn: u64) -> Result<(), DataPathError> {
252 match self.forwarder().get_connection(conn) {
253 None => {
254 error!("error handling disconnect: connection unknown");
255 return Err(DataPathError::DisconnectionError(
256 "connection not found".to_string(),
257 ));
258 }
259 Some(c) => {
260 match c.cancellation_token() {
261 None => {
262 error!("error handling disconnect: missing cancellation token");
263 }
264 Some(t) => {
265 t.cancel();
269 }
270 }
271 }
272 }
273
274 Ok(())
275 }
276
277 pub fn register_local_connection(
278 &self,
279 ) -> (
280 u64,
281 tokio::sync::mpsc::Sender<Result<Message, Status>>,
282 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
283 ) {
284 let (tx1, rx1) = mpsc::channel(128);
286
287 debug!("establishing new local app connection");
288
289 let (tx2, rx2) = mpsc::channel(128);
291
292 let cancellation_token = CancellationToken::new();
294 let connection = Connection::new(ConnectionType::Local)
295 .with_channel(Channel::Server(tx2))
296 .with_cancellation_token(Some(cancellation_token.clone()));
297
298 let conn_id = self
300 .forwarder()
301 .on_connection_established(connection, None)
302 .unwrap();
303
304 debug!("local connection established with id: {:?}", conn_id);
305 info!(telemetry = true, counter.num_active_connections = 1);
306
307 self.process_stream(
309 ReceiverStream::new(rx1),
310 conn_id,
311 None,
312 cancellation_token,
313 true,
314 );
315
316 (conn_id, tx1, rx2)
318 }
319
320 pub async fn send_msg(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
321 let connection = self.forwarder().get_connection(out_conn);
322 match connection {
323 Some(conn) => {
324 msg.clear_slim_header();
326
327 let parent_context = extract_parent_context(&msg);
329 let span = create_span("send_message", out_conn, &msg);
330
331 if let Some(ctx) = parent_context {
332 span.set_parent(ctx);
333 }
334 let _guard = span.enter();
335 inject_current_context(&mut msg);
336 match conn.channel() {
339 Channel::Server(s) => s
340 .send(Ok(msg))
341 .await
342 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
343 Channel::Client(s) => s
344 .send(msg)
345 .await
346 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
347 _ => Err(DataPathError::MessageSendError(
348 "connection not found".to_string(),
349 )),
350 }
351 }
352 None => Err(DataPathError::MessageSendError(format!(
353 "connection {:?} not found",
354 out_conn
355 ))),
356 }
357 }
358
359 async fn match_and_forward_msg(
360 &self,
361 msg: Message,
362 agent_type: AgentType,
363 in_connection: u64,
364 fanout: u32,
365 agent_id: Option<u64>,
366 ) -> Result<(), DataPathError> {
367 debug!(
368 "match and forward message: type: {} - agent_id: {:?} - fanout: {}",
369 agent_type, agent_id, fanout,
370 );
371
372 if let Some(val) = msg.get_forward_to() {
375 debug!("forwarding message to connection {}", val);
376 return self
377 .send_msg(msg, val)
378 .await
379 .map_err(|e| DataPathError::PublicationError(e.to_string()));
380 }
381
382 match self
383 .forwarder()
384 .on_publish_msg_match(agent_type, agent_id, in_connection, fanout)
385 {
386 Ok(out_vec) => {
387 let mut i = 0;
390 while i < out_vec.len() - 1 {
391 self.send_msg(msg.clone(), out_vec[i])
392 .await
393 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
394 i += 1;
395 }
396 self.send_msg(msg, out_vec[i])
397 .await
398 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
399 Ok(())
400 }
401 Err(e) => Err(DataPathError::PublicationError(e.to_string())),
402 }
403 }
404
405 async fn process_publish(&self, msg: Message, in_connection: u64) -> Result<(), DataPathError> {
406 debug!(
407 "received publication from connection {}: {:?}",
408 in_connection, msg
409 );
410
411 info!(
413 telemetry = true,
414 monotonic_counter.num_messages_by_type = 1,
415 method = "publish"
416 );
417 let header = msg.get_slim_header();
421
422 let (agent_type, agent_id) = header.get_dst();
423
424 let fanout = msg.get_fanout();
427
428 return self
430 .match_and_forward_msg(msg, agent_type, in_connection, fanout, agent_id)
431 .await;
432 }
433
434 async fn process_subscription(
438 &self,
439 msg: Message,
440 in_connection: u64,
441 add: bool,
442 ) -> Result<(), DataPathError> {
443 debug!(
444 "received subscription (add = {}) from connection {}: {:?}",
445 add, in_connection, msg
446 );
447
448 info!(
450 telemetry = true,
451 monotonic_counter.num_messages_by_type = 1,
452 message_type = { if add { "subscribe" } else { "unsubscribe" } }
453 );
454 let (agent_type, agent_id) = msg.get_name();
457
458 let header = msg.get_slim_header();
460
461 let (conn, forward) = header.get_in_out_connections();
463
464 let connection = self
467 .forwarder()
468 .get_connection(conn)
469 .ok_or_else(|| DataPathError::SubscriptionError("connection not found".to_string()))?;
470
471 debug!(
472 "subscription update (add = {}) for agent type: {} (agent id: {:?}) - connection: {}",
473 add, agent_type, agent_id, conn
474 );
475
476 if let Err(e) = self.forwarder().on_subscription_msg(
477 agent_type.clone(),
478 agent_id,
479 conn,
480 connection.is_local_connection(),
481 add,
482 ) {
483 return Err(DataPathError::SubscriptionError(e.to_string()));
484 }
485
486 match forward {
487 None => {
488 Ok(())
490 }
491 Some(out_conn) => {
492 debug!("forward subscription (add = {}) to {}", add, out_conn);
493
494 let source_agent = msg.get_source();
496
497 match self.send_msg(msg, out_conn).await {
499 Ok(_) => {
500 self.forwarder().on_forwarded_subscription(
501 source_agent,
502 agent_type,
503 agent_id,
504 out_conn,
505 add,
506 );
507 Ok(())
508 }
509 Err(e) => Err(DataPathError::UnsubscriptionError(e.to_string())),
510 }
511 }
512 }
513 }
514
515 pub async fn process_message(
516 &self,
517 msg: Message,
518 in_connection: u64,
519 ) -> Result<(), DataPathError> {
520 match msg.get_type() {
522 SubscribeType(_) => self.process_subscription(msg, in_connection, true).await,
523 UnsubscribeType(_) => self.process_subscription(msg, in_connection, false).await,
524 PublishType(_) => self.process_publish(msg, in_connection).await,
525 }
526 }
527
528 async fn handle_new_message(
529 &self,
530 conn_index: u64,
531 is_local: bool,
532 mut msg: Message,
533 ) -> Result<(), DataPathError> {
534 debug!(%conn_index, "received message from connection");
535 info!(
536 telemetry = true,
537 monotonic_counter.num_processed_messages = 1
538 );
539
540 if let Err(err) = msg.validate() {
542 info!(
543 telemetry = true,
544 monotonic_counter.num_messages_by_type = 1,
545 message_type = "none"
546 );
547
548 return Err(DataPathError::InvalidMessage(err.to_string()));
549 }
550
551 msg.set_incoming_conn(Some(conn_index));
553
554 if is_local {
558 let span = create_span("process_local", conn_index, &msg);
560
561 let _guard = span.enter();
562
563 inject_current_context(&mut msg);
564 } else {
565 let parent_context = extract_parent_context(&msg);
567
568 let span = create_span("process_local", conn_index, &msg);
569
570 if let Some(ctx) = parent_context {
571 span.set_parent(ctx);
572 }
573 let _guard = span.enter();
574
575 inject_current_context(&mut msg);
576 }
577 match self.process_message(msg, conn_index).await {
580 Ok(_) => Ok(()),
581 Err(e) => {
582 info!(
584 telemetry = true,
585 monotonic_counter.num_message_process_errors = 1
586 );
587 Err(DataPathError::ProcessingError(e.to_string()))
591 }
592 }
593 }
594
595 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
596 let connection = self.forwarder().get_connection(conn_index);
597 match connection {
598 Some(conn) => {
599 debug!("try to notify the error to the local application");
600 if let Channel::Server(tx) = conn.channel() {
601 let status = Status::new(
603 tonic::Code::Internal,
604 format!("error processing message: {:?}", err),
605 );
606
607 if tx.send(Err(status)).await.is_err() {
608 debug!("unable to notify the error to the local app: {:?}", err);
609 }
610 }
611 }
612 None => {
613 error!(
614 "error sending error to local app: connection {:?} not found",
615 conn_index
616 );
617 }
618 }
619 }
620
621 async fn reconnect(
622 &self,
623 client_conf: Option<ClientConfig>,
624 conn_index: u64,
625 cancellation_token: &CancellationToken,
626 ) -> bool {
627 let config = client_conf.unwrap();
628 match config.to_channel() {
629 Err(e) => {
630 error!(
631 "cannot parse connection config, unable to reconnect {:?}",
632 e.to_string()
633 );
634 false
635 }
636 Ok(channel) => {
637 info!("connection lost with remote endpoint, try to reconnect");
638 let remote_subscriptions = self
643 .forwarder()
644 .get_subscriptions_forwarded_on_connection(conn_index);
645
646 tokio::select! {
647 _ = cancellation_token.cancelled() => {
648 debug!("cancellation token signaled, stopping reconnection process");
649 false
650 }
651 _ = self.get_drain_watch().signaled() => {
652 debug!("drain watch signaled, stopping reconnection process");
653 false
654 }
655 res = self.try_to_connect(channel, Some(config), None, None, Some(conn_index), 120) => {
656 match res {
657 Ok(_) => {
658 info!("connection re-established");
659 for r in remote_subscriptions.iter() {
661 let sub_msg = Message::new_subscribe(
662 r.source(),
663 r.name().agent_type(),
664 r.name().agent_id_option(),
665 None,
666 );
667 if self.send_msg(sub_msg, conn_index).await.is_err() {
668 error!("error restoring subscription on remote node");
669 }
670 }
671 true
672 }
673 Err(e) => {
674 error!("unable to connect to remote node {:?}", e.to_string());
676 false
677 }
678 }
679 }
680 }
681 }
682 }
683 }
684
685 fn process_stream(
686 &self,
687 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
688 conn_index: u64,
689 client_config: Option<ClientConfig>,
690 cancellation_token: CancellationToken,
691 is_local: bool,
692 ) -> tokio::task::JoinHandle<()> {
693 let self_clone = self.clone();
695 let token_clone = cancellation_token.clone();
696 let client_conf_clone = client_config.clone();
697 let handle = tokio::spawn(async move {
698 let mut try_to_reconnect = true;
699 loop {
700 tokio::select! {
701 next = stream.next() => {
702 match next {
703 Some(result) => {
704 match result {
705 Ok(msg) => {
706 if let Err(e) = self_clone.handle_new_message(conn_index, is_local, msg).await {
707 error!(%conn_index, %e, "error processing incoming message");
708 if is_local {
710 self_clone.send_error_to_local_app(conn_index, e).await;
712 }
713 }
714 }
715 Err(e) => {
716 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
717 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
718 info!("connection {:?} closed by peer", conn_index);
719 }
720 } else {
721 error!("error receiving messages {:?}", e);
722 }
723 break;
724 }
725 }
726 }
727 None => {
728 debug!(%conn_index, "end of stream");
729 break;
730 }
731 }
732 }
733 _ = self_clone.get_drain_watch().signaled() => {
734 debug!("shutting down stream on drain: {}", conn_index);
735 try_to_reconnect = false;
736 break;
737 }
738 _ = token_clone.cancelled() => {
739 debug!("shutting down stream on cancellation token: {}", conn_index);
740 try_to_reconnect = false;
741 break;
742 }
743 }
744 }
745
746 drop(stream);
750
751 let mut connected = false;
752
753 if try_to_reconnect && client_conf_clone.is_some() {
754 connected = self_clone
755 .reconnect(client_conf_clone, conn_index, &token_clone)
756 .await;
757 } else {
758 debug!("close connection {}", conn_index)
759 }
760
761 if !connected {
762 self_clone
764 .forwarder()
765 .on_connection_drop(conn_index, is_local);
766
767 info!(telemetry = true, counter.num_active_connections = -1);
768 }
769 });
770
771 handle
772 }
773
774 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
775 let mut err: &(dyn std::error::Error + 'static) = err_status;
776
777 loop {
778 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
779 return Some(io_err);
780 }
781
782 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
785 if let Some(io_err) = h2_err.get_io() {
786 return Some(io_err);
787 }
788 }
789
790 err = err.source()?;
791 }
792 }
793
794 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
795 &self.internal.forwarder.subscription_table
796 }
797
798 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
799 &self.internal.forwarder.connection_table
800 }
801}
802
803#[tonic::async_trait]
804impl PubSubService for MessageProcessor {
805 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
806
807 async fn open_channel(
808 &self,
809 request: Request<tonic::Streaming<Message>>,
810 ) -> Result<Response<Self::OpenChannelStream>, Status> {
811 let remote_addr = request.remote_addr();
812 let local_addr = request.local_addr();
813
814 let stream = request.into_inner();
815 let (tx, rx) = mpsc::channel(128);
816
817 let connection = Connection::new(ConnectionType::Remote)
818 .with_remote_addr(remote_addr)
819 .with_local_addr(local_addr)
820 .with_channel(Channel::Server(tx));
821
822 debug!(
823 "new connection received from remote: (remote: {:?} - local: {:?})",
824 connection.remote_addr(),
825 connection.local_addr()
826 );
827 info!(telemetry = true, counter.num_active_connections = 1);
828
829 let conn_index = self
831 .forwarder()
832 .on_connection_established(connection, None)
833 .unwrap();
834
835 self.process_stream(stream, conn_index, None, CancellationToken::new(), false);
836
837 let out_stream = ReceiverStream::new(rx);
838 Ok(Response::new(
839 Box::pin(out_stream) as Self::OpenChannelStream
840 ))
841 }
842}