1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use agp_config::grpc::client::ClientConfig;
8use agp_tracing::utils::INSTANCE_ID;
9use opentelemetry::propagation::{Extractor, Injector};
10use opentelemetry::trace::TraceContextExt;
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::{Stream, StreamExt};
14use tokio_util::sync::CancellationToken;
15use tonic::codegen::{Body, StdError};
16use tonic::{Request, Response, Status};
17use tracing::{Span, debug, error, info};
18use tracing_opentelemetry::OpenTelemetrySpanExt;
19
20use crate::connection::{Channel, Connection, Type as ConnectionType};
21use crate::errors::DataPathError;
22use crate::forwarder::Forwarder;
23use crate::messages::AgentType;
24use crate::pubsub::proto::pubsub::v1::message::MessageType::Publish as PublishType;
25use crate::pubsub::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
26use crate::pubsub::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
27use crate::pubsub::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
28use crate::pubsub::proto::pubsub::v1::{Message, pub_sub_service_server::PubSubService};
29
30struct MetadataExtractor<'a>(&'a std::collections::HashMap<String, String>);
32
33impl Extractor for MetadataExtractor<'_> {
34 fn get(&self, key: &str) -> Option<&str> {
35 self.0.get(key).map(|s| s.as_str())
36 }
37
38 fn keys(&self) -> Vec<&str> {
39 self.0.keys().map(|s| s.as_str()).collect()
40 }
41}
42
43struct MetadataInjector<'a>(&'a mut std::collections::HashMap<String, String>);
44
45impl Injector for MetadataInjector<'_> {
46 fn set(&mut self, key: &str, value: String) {
47 self.0.insert(key.to_string(), value);
48 }
49}
50
51fn extract_parent_context(msg: &Message) -> Option<opentelemetry::Context> {
53 let extractor = MetadataExtractor(&msg.metadata);
54 let parent_context =
55 opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor));
56
57 if parent_context.span().span_context().is_valid() {
58 Some(parent_context)
59 } else {
60 None
61 }
62}
63
64fn inject_current_context(msg: &mut Message) {
66 let cx = tracing::Span::current().context();
67 let mut injector = MetadataInjector(&mut msg.metadata);
68 opentelemetry::global::get_text_map_propagator(|propagator| {
69 propagator.inject_context(&cx, &mut injector)
70 });
71}
72
73fn create_span(function: &str, out_conn: u64, msg: &Message) -> Span {
75 let span = tracing::span!(
76 tracing::Level::INFO,
77 "agp_process_message",
78 function = function,
79 source = format!("{}", msg.get_source()),
80 destination = format!("{}",msg.get_name_as_agent()),
81 instance_id = %INSTANCE_ID.as_str(),
82 connection_id = out_conn,
83 message_type = msg.get_type().to_string(),
84 telemetry = true
85 );
86
87 match msg.get_type() {
88 PublishType(_) => {
89 span.set_attribute("session_type", msg.get_header_type().as_str_name());
90 span.set_attribute(
91 "session_id",
92 msg.get_session_header().get_session_id().to_string(),
93 );
94 span.set_attribute(
95 "message_id",
96 msg.get_session_header().get_message_id().to_string(),
97 );
98 }
99 _ => {} }
101
102 span
103}
104
105#[derive(Debug)]
106struct MessageProcessorInternal {
107 forwarder: Forwarder<Connection>,
108 drain_channel: drain::Watch,
109}
110
111#[derive(Debug, Clone)]
112pub struct MessageProcessor {
113 internal: Arc<MessageProcessorInternal>,
114}
115
116impl MessageProcessor {
117 pub fn new() -> (Self, drain::Signal) {
118 let (signal, watch) = drain::channel();
119 let forwarder = Forwarder::new();
120 let forwarder = MessageProcessorInternal {
121 forwarder,
122 drain_channel: watch,
123 };
124
125 (
126 Self {
127 internal: Arc::new(forwarder),
128 },
129 signal,
130 )
131 }
132
133 pub fn with_drain_channel(watch: drain::Watch) -> Self {
134 let forwarder = Forwarder::new();
135 let forwarder = MessageProcessorInternal {
136 forwarder,
137 drain_channel: watch,
138 };
139 Self {
140 internal: Arc::new(forwarder),
141 }
142 }
143
144 fn forwarder(&self) -> &Forwarder<Connection> {
145 &self.internal.forwarder
146 }
147
148 fn get_drain_watch(&self) -> drain::Watch {
149 self.internal.drain_channel.clone()
150 }
151
152 async fn try_to_connect<C>(
153 &self,
154 channel: C,
155 client_config: Option<ClientConfig>,
156 local: Option<SocketAddr>,
157 remote: Option<SocketAddr>,
158 existing_conn_index: Option<u64>,
159 max_retry: u32,
160 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
161 where
162 C: tonic::client::GrpcService<tonic::body::Body>,
163 C::Error: Into<StdError>,
164 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
165 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
166 {
167 let mut client: PubSubServiceClient<C> = PubSubServiceClient::new(channel);
168 let mut i = 0;
169 while i < max_retry {
170 let (tx, rx) = mpsc::channel(128);
171 match client
172 .open_channel(Request::new(ReceiverStream::new(rx)))
173 .await
174 {
175 Ok(stream) => {
176 let cancellation_token = CancellationToken::new();
177 let connection = Connection::new(ConnectionType::Remote)
178 .with_local_addr(local)
179 .with_remote_addr(remote)
180 .with_channel(Channel::Client(tx))
181 .with_cancellation_token(Some(cancellation_token.clone()));
182
183 debug!(
184 "new connection initiated locally: (remote: {:?} - local: {:?})",
185 connection.remote_addr(),
186 connection.local_addr()
187 );
188
189 let opt = self
191 .forwarder()
192 .on_connection_established(connection, existing_conn_index);
193 if opt.is_none() {
194 error!("error adding connection to the connection table");
195 return Err(DataPathError::ConnectionError(
196 "error adding connection to the connection tables".to_string(),
197 ));
198 }
199
200 let conn_index = opt.unwrap();
201 debug!(
202 "new connection index = {:?}, is local {:?}",
203 conn_index, false
204 );
205
206 let ret = self.process_stream(
208 stream.into_inner(),
209 conn_index,
210 client_config,
211 cancellation_token,
212 false,
213 );
214 return Ok((ret, conn_index));
215 }
216 Err(e) => {
217 error!("connection error: {:?}.", e.to_string());
218 }
219 }
220 i += 1;
221
222 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
224 }
225
226 error!("unable to connect to the endpoint");
227 Err(DataPathError::ConnectionError(
228 "reached max connection retries".to_string(),
229 ))
230 }
231
232 pub async fn connect<C>(
233 &self,
234 channel: C,
235 client_config: Option<ClientConfig>,
236 local: Option<SocketAddr>,
237 remote: Option<SocketAddr>,
238 ) -> Result<(tokio::task::JoinHandle<()>, u64), DataPathError>
239 where
240 C: tonic::client::GrpcService<tonic::body::Body>,
241 C::Error: Into<StdError>,
242 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
243 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
244 {
245 self.try_to_connect(channel, client_config, local, remote, None, 10)
246 .await
247 }
248
249 pub fn disconnect(&self, conn: u64) -> Result<(), DataPathError> {
250 match self.forwarder().get_connection(conn) {
251 None => {
252 error!("error handling disconnect: connection unknown");
253 return Err(DataPathError::DisconnectionError(
254 "connection not found".to_string(),
255 ));
256 }
257 Some(c) => {
258 match c.cancellation_token() {
259 None => {
260 error!("error handling disconnect: missing cancellation token");
261 }
262 Some(t) => {
263 t.cancel();
267 }
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 pub fn register_local_connection(
276 &self,
277 ) -> (
278 u64,
279 tokio::sync::mpsc::Sender<Result<Message, Status>>,
280 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
281 ) {
282 let (tx1, rx1) = mpsc::channel(128);
284
285 debug!("establishing new local app connection");
286
287 let (tx2, rx2) = mpsc::channel(128);
289
290 let cancellation_token = CancellationToken::new();
292 let connection = Connection::new(ConnectionType::Local)
293 .with_channel(Channel::Server(tx2))
294 .with_cancellation_token(Some(cancellation_token.clone()));
295
296 let conn_id = self
298 .forwarder()
299 .on_connection_established(connection, None)
300 .unwrap();
301
302 debug!("local connection established with id: {:?}", conn_id);
303 info!(telemetry = true, counter.num_active_connections = 1);
304
305 self.process_stream(
307 ReceiverStream::new(rx1),
308 conn_id,
309 None,
310 cancellation_token,
311 true,
312 );
313
314 (conn_id, tx1, rx2)
316 }
317
318 pub async fn send_msg(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
319 let connection = self.forwarder().get_connection(out_conn);
320 match connection {
321 Some(conn) => {
322 msg.clear_agp_header();
324
325 let parent_context = extract_parent_context(&msg);
327 let span = create_span("send_message", out_conn, &msg);
328
329 if let Some(ctx) = parent_context {
330 span.set_parent(ctx);
331 }
332 let _guard = span.enter();
333 inject_current_context(&mut msg);
334 match conn.channel() {
337 Channel::Server(s) => s
338 .send(Ok(msg))
339 .await
340 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
341 Channel::Client(s) => s
342 .send(msg)
343 .await
344 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
345 _ => Err(DataPathError::MessageSendError(
346 "connection not found".to_string(),
347 )),
348 }
349 }
350 None => Err(DataPathError::MessageSendError(format!(
351 "connection {:?} not found",
352 out_conn
353 ))),
354 }
355 }
356
357 async fn match_and_forward_msg(
358 &self,
359 msg: Message,
360 agent_type: AgentType,
361 in_connection: u64,
362 fanout: u32,
363 agent_id: Option<u64>,
364 ) -> Result<(), DataPathError> {
365 debug!(
366 "match and forward message: type: {} - agent_id: {:?} - fanout: {}",
367 agent_type, agent_id, fanout,
368 );
369
370 if let Some(val) = msg.get_forward_to() {
373 debug!("forwarding message to connection {}", val);
374 return self
375 .send_msg(msg, val)
376 .await
377 .map_err(|e| DataPathError::PublicationError(e.to_string()));
378 }
379
380 match self
381 .forwarder()
382 .on_publish_msg_match(agent_type, agent_id, in_connection, fanout)
383 {
384 Ok(out_vec) => {
385 let mut i = 0;
388 while i < out_vec.len() - 1 {
389 self.send_msg(msg.clone(), out_vec[i])
390 .await
391 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
392 i += 1;
393 }
394 self.send_msg(msg, out_vec[i])
395 .await
396 .map_err(|e| DataPathError::PublicationError(e.to_string()))?;
397 Ok(())
398 }
399 Err(e) => Err(DataPathError::PublicationError(e.to_string())),
400 }
401 }
402
403 async fn process_publish(&self, msg: Message, in_connection: u64) -> Result<(), DataPathError> {
404 debug!(
405 "received publication from connection {}: {:?}",
406 in_connection, msg
407 );
408
409 info!(
411 telemetry = true,
412 monotonic_counter.num_messages_by_type = 1,
413 method = "publish"
414 );
415 let header = msg.get_agp_header();
419
420 let (agent_type, agent_id) = header.get_dst();
421
422 let fanout = msg.get_fanout();
425
426 return self
428 .match_and_forward_msg(msg, agent_type, in_connection, fanout, agent_id)
429 .await;
430 }
431
432 async fn process_subscription(
436 &self,
437 msg: Message,
438 in_connection: u64,
439 add: bool,
440 ) -> Result<(), DataPathError> {
441 debug!(
442 "received subscription (add = {}) from connection {}: {:?}",
443 add, in_connection, msg
444 );
445
446 info!(
448 telemetry = true,
449 monotonic_counter.num_messages_by_type = 1,
450 message_type = { if add { "subscribe" } else { "unsubscribe" } }
451 );
452 let (agent_type, agent_id) = msg.get_name();
455
456 let header = msg.get_agp_header();
458
459 let (conn, forward) = header.get_in_out_connections();
461
462 let connection = self
465 .forwarder()
466 .get_connection(conn)
467 .ok_or_else(|| DataPathError::SubscriptionError("connection not found".to_string()))?;
468
469 debug!(
470 "subscription update (add = {}) for agent type: {} (agent id: {:?}) - connection: {}",
471 add, agent_type, agent_id, conn
472 );
473
474 if let Err(e) = self.forwarder().on_subscription_msg(
475 agent_type.clone(),
476 agent_id,
477 conn,
478 connection.is_local_connection(),
479 add,
480 ) {
481 return Err(DataPathError::SubscriptionError(e.to_string()));
482 }
483
484 match forward {
485 None => {
486 Ok(())
488 }
489 Some(out_conn) => {
490 debug!("forward subscription (add = {}) to {}", add, out_conn);
491
492 let source_agent = msg.get_source();
494
495 match self.send_msg(msg, out_conn).await {
497 Ok(_) => {
498 self.forwarder().on_forwarded_subscription(
499 source_agent,
500 agent_type,
501 agent_id,
502 out_conn,
503 add,
504 );
505 Ok(())
506 }
507 Err(e) => Err(DataPathError::UnsubscriptionError(e.to_string())),
508 }
509 }
510 }
511 }
512
513 pub async fn process_message(
514 &self,
515 msg: Message,
516 in_connection: u64,
517 ) -> Result<(), DataPathError> {
518 match msg.get_type() {
520 SubscribeType(_) => self.process_subscription(msg, in_connection, true).await,
521 UnsubscribeType(_) => self.process_subscription(msg, in_connection, false).await,
522 PublishType(_) => self.process_publish(msg, in_connection).await,
523 }
524 }
525
526 async fn handle_new_message(
527 &self,
528 conn_index: u64,
529 is_local: bool,
530 mut msg: Message,
531 ) -> Result<(), DataPathError> {
532 debug!(%conn_index, "received message from connection");
533 info!(
534 telemetry = true,
535 monotonic_counter.num_processed_messages = 1
536 );
537
538 if let Err(err) = msg.validate() {
540 info!(
541 telemetry = true,
542 monotonic_counter.num_messages_by_type = 1,
543 message_type = "none"
544 );
545
546 return Err(DataPathError::InvalidMessage(err.to_string()));
547 }
548
549 msg.set_incoming_conn(Some(conn_index));
551
552 if is_local {
556 let span = create_span("process_local", conn_index, &msg);
558
559 let _guard = span.enter();
560
561 inject_current_context(&mut msg);
562 } else {
563 let parent_context = extract_parent_context(&msg);
565
566 let span = create_span("process_local", conn_index, &msg);
567
568 if let Some(ctx) = parent_context {
569 span.set_parent(ctx);
570 }
571 let _guard = span.enter();
572
573 inject_current_context(&mut msg);
574 }
575 match self.process_message(msg, conn_index).await {
578 Ok(_) => Ok(()),
579 Err(e) => {
580 info!(
582 telemetry = true,
583 monotonic_counter.num_message_process_errors = 1
584 );
585 Err(DataPathError::ProcessingError(e.to_string()))
589 }
590 }
591 }
592
593 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
594 let connection = self.forwarder().get_connection(conn_index);
595 match connection {
596 Some(conn) => {
597 debug!("try to notify the error to the local application");
598 if let Channel::Server(tx) = conn.channel() {
599 let status = Status::new(
601 tonic::Code::Internal,
602 format!("error processing message: {:?}", err),
603 );
604
605 if tx.send(Err(status)).await.is_err() {
606 debug!("unable to notify the error to the local app: {:?}", err);
607 }
608 }
609 }
610 None => {
611 error!(
612 "error sending error to local app: connection {:?} not found",
613 conn_index
614 );
615 }
616 }
617 }
618
619 async fn reconnect(&self, client_conf: Option<ClientConfig>, conn_index: u64) -> bool {
620 let config = client_conf.unwrap();
621 match config.to_channel() {
622 Err(e) => {
623 error!(
624 "cannot parse connection config, unable to reconnect {:?}",
625 e.to_string()
626 );
627 false
628 }
629 Ok(channel) => {
630 info!("connection lost with remote endpoint, try to reconnect");
631 let remote_subscriptions = self
636 .forwarder()
637 .get_subscriptions_forwarded_on_connection(conn_index);
638
639 match self
640 .try_to_connect(channel, Some(config), None, None, Some(conn_index), 120)
641 .await
642 {
643 Ok(_) => {
644 info!("connection re-established");
645 for r in remote_subscriptions.iter() {
647 let sub_msg = Message::new_subscribe(
648 r.source(),
649 r.name().agent_type(),
650 r.name().agent_id_option(),
651 None,
652 );
653 if self.send_msg(sub_msg, conn_index).await.is_err() {
654 error!("error restoring subscription on remote node");
655 }
656 }
657 true
658 }
659 Err(e) => {
660 error!("unable to connect to remote node {:?}", e.to_string());
662 false
663 }
664 }
665 }
666 }
667 }
668
669 fn process_stream(
670 &self,
671 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
672 conn_index: u64,
673 client_config: Option<ClientConfig>,
674 cancellation_token: CancellationToken,
675 is_local: bool,
676 ) -> tokio::task::JoinHandle<()> {
677 let self_clone = self.clone();
679 let token_clone = cancellation_token.clone();
680 let client_conf_clone = client_config.clone();
681 let handle = tokio::spawn(async move {
682 let mut try_to_reconnect = true;
683 loop {
684 tokio::select! {
685 next = stream.next() => {
686 match next {
687 Some(result) => {
688 match result {
689 Ok(msg) => {
690 if let Err(e) = self_clone.handle_new_message(conn_index, is_local, msg).await {
691 error!(%conn_index, %e, "error processing incoming message");
692 if is_local {
694 self_clone.send_error_to_local_app(conn_index, e).await;
696 }
697 }
698 }
699 Err(e) => {
700 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
701 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
702 info!("connection {:?} closed by peer", conn_index);
703 }
704 } else {
705 error!("error receiving messages {:?}", e);
706 }
707 break;
708 }
709 }
710 }
711 None => {
712 debug!(%conn_index, "end of stream");
713 break;
714 }
715 }
716 }
717 _ = self_clone.get_drain_watch().signaled() => {
718 debug!("shutting down stream on drain: {}", conn_index);
719 try_to_reconnect = false;
720 break;
721 }
722 _ = token_clone.cancelled() => {
723 debug!("shutting down stream on cancellation token: {}", conn_index);
724 try_to_reconnect = false;
725 break;
726 }
727 }
728 }
729
730 drop(stream);
734
735 let mut connected = false;
736
737 if try_to_reconnect && client_conf_clone.is_some() {
738 connected = self_clone.reconnect(client_conf_clone, conn_index).await;
739 } else {
740 debug!("close connection {}", conn_index)
741 }
742
743 if !connected {
744 self_clone
746 .forwarder()
747 .on_connection_drop(conn_index, is_local);
748
749 info!(telemetry = true, counter.num_active_connections = -1);
750 }
751 });
752
753 handle
754 }
755
756 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
757 let mut err: &(dyn std::error::Error + 'static) = err_status;
758
759 loop {
760 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
761 return Some(io_err);
762 }
763
764 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
767 if let Some(io_err) = h2_err.get_io() {
768 return Some(io_err);
769 }
770 }
771
772 err = err.source()?;
773 }
774 }
775}
776
777#[tonic::async_trait]
778impl PubSubService for MessageProcessor {
779 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
780
781 async fn open_channel(
782 &self,
783 request: Request<tonic::Streaming<Message>>,
784 ) -> Result<Response<Self::OpenChannelStream>, Status> {
785 let remote_addr = request.remote_addr();
786 let local_addr = request.local_addr();
787
788 let stream = request.into_inner();
789 let (tx, rx) = mpsc::channel(128);
790
791 let connection = Connection::new(ConnectionType::Remote)
792 .with_remote_addr(remote_addr)
793 .with_local_addr(local_addr)
794 .with_channel(Channel::Server(tx));
795
796 debug!(
797 "new connection received from remote: (remote: {:?} - local: {:?})",
798 connection.remote_addr(),
799 connection.local_addr()
800 );
801 info!(telemetry = true, counter.num_active_connections = 1);
802
803 let conn_index = self
805 .forwarder()
806 .on_connection_established(connection, None)
807 .unwrap();
808
809 self.process_stream(stream, conn_index, None, CancellationToken::new(), false);
810
811 let out_stream = ReceiverStream::new(rx);
812 Ok(Response::new(
813 Box::pin(out_stream) as Self::OpenChannelStream
814 ))
815 }
816}