agp_datapath/
message_processing.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11use tonic::codegen::{Body, StdError};
12use tonic::{Request, Response, Status};
13use tracing::{debug, error, info, trace};
14
15use crate::connection::{Channel, Connection, Type as ConnectionType};
16use crate::errors::DataPathError;
17use crate::forwarder::Forwarder;
18use crate::messages::utils::{
19    add_incoming_connection, get_agent_id, get_fanout, process_name, CommandType,
20};
21use crate::messages::AgentClass;
22use crate::pubsub::proto::pubsub::v1::message::MessageType::Publish as PublishType;
23use crate::pubsub::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
24use crate::pubsub::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
25use crate::pubsub::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
26use crate::pubsub::proto::pubsub::v1::{pub_sub_service_server::PubSubService, Message};
27
28#[derive(Debug)]
29struct MessageProcessorInternal {
30    forwarder: Forwarder<Connection>,
31    drain_channel: drain::Watch,
32}
33
34#[derive(Debug, Clone)]
35pub struct MessageProcessor {
36    internal: Arc<MessageProcessorInternal>,
37}
38
39impl MessageProcessor {
40    pub fn new() -> (Self, drain::Signal) {
41        let (signal, watch) = drain::channel();
42        let forwarder = Forwarder::new();
43        let forwarder = MessageProcessorInternal {
44            forwarder,
45            drain_channel: watch,
46        };
47
48        (
49            Self {
50                internal: Arc::new(forwarder),
51            },
52            signal,
53        )
54    }
55
56    pub fn with_drain_channel(watch: drain::Watch) -> Self {
57        let forwarder = Forwarder::new();
58        let forwarder = MessageProcessorInternal {
59            forwarder,
60            drain_channel: watch,
61        };
62        Self {
63            internal: Arc::new(forwarder),
64        }
65    }
66
67    fn forwarder(&self) -> &Forwarder<Connection> {
68        &self.internal.forwarder
69    }
70
71    fn get_drain_watch(&self) -> drain::Watch {
72        self.internal.drain_channel.clone()
73    }
74
75    pub async fn connect<C>(
76        &self,
77        channel: C,
78        local: Option<SocketAddr>,
79        remote: Option<SocketAddr>,
80    ) -> Result<(tokio::task::JoinHandle<()>, CancellationToken, u64), DataPathError>
81    where
82        C: tonic::client::GrpcService<tonic::body::BoxBody>,
83        C::Error: Into<StdError>,
84        C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
85        <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
86    {
87        let mut client = PubSubServiceClient::new(channel);
88        let (tx, rx) = mpsc::channel(128);
89        let stream = client
90            .open_channel(Request::new(ReceiverStream::new(rx)))
91            .await
92            .map_err(|e| DataPathError::ConnectionError(e.to_string()))?
93            .into_inner();
94
95        let connection = Connection::new(ConnectionType::Remote)
96            .with_local_addr(local)
97            .with_remote_addr(remote)
98            .with_channel(Channel::Client(tx));
99
100        info!(
101            "new connection initiated locally: (remote: {:?} - local: {:?})",
102            connection.remote_addr(),
103            connection.local_addr()
104        );
105
106        // insert connection into connection table
107        let conn_index = self.forwarder().on_connection_established(connection);
108
109        // Start loop to process messages
110        let ret = self.process_stream(stream, conn_index, false);
111        Ok((ret.0, ret.1, conn_index))
112    }
113
114    pub fn register_local_connection(
115        &self,
116    ) -> (
117        tokio::sync::mpsc::Sender<Result<Message, Status>>,
118        tokio::sync::mpsc::Receiver<Result<Message, Status>>,
119    ) {
120        // create a pair tx, rx to be able to send messages with the standard processing loop
121        let (tx1, rx1) = mpsc::channel(128);
122
123        info!("establishing new local app connection");
124
125        // create a pair tx, rx to be able to receive messages and insert it into the connection table
126        let (tx2, rx2) = mpsc::channel(128);
127
128        // create a connection
129        let connection = Connection::new(ConnectionType::Local).with_channel(Channel::Server(tx2));
130
131        // add it to the connection table
132        let conn_id = self.forwarder().on_connection_established(connection);
133
134        debug!("local connection established with id: {:?}", conn_id);
135        info!(telemetry = true, counter.num_active_connections = 1);
136
137        // this loop will process messages from the local app
138        self.process_stream(ReceiverStream::new(rx1), conn_id, true);
139
140        // return the handles to be used to send and receive messages
141        (tx1, rx2)
142    }
143
144    pub async fn send_msg(
145        &self,
146        msg: Message,
147        out_conn: u64,
148    ) -> Result<(), Box<dyn std::error::Error>> {
149        let connection = self.forwarder().get_connection(out_conn);
150        match connection {
151            Some(conn) => match conn.channel() {
152                Channel::Server(s) => s.send(Ok(msg)).await?,
153                Channel::Client(s) => s.send(msg).await?,
154                _ => error!("error reading channel"),
155            },
156            None => error!("connection {:?} not found", out_conn),
157        }
158        Ok(())
159    }
160
161    async fn match_and_forward_msg(
162        &self,
163        msg: Message,
164        class: AgentClass,
165        in_connection: u64,
166        fanout: u32,
167        agent_id: Option<u64>,
168    ) -> Result<(), DataPathError> {
169        debug!(
170            "match and forward message: class: {:?} - agent_id: {:?} - fanout: {:?}",
171            class, agent_id, fanout,
172        );
173
174        if fanout == 1 {
175            match self
176                .forwarder()
177                .on_publish_msg_match_one(class, agent_id, in_connection)
178            {
179                Ok(out) => match self.send_msg(msg, out).await {
180                    Ok(_) => Ok(()),
181                    Err(e) => {
182                        error!("error sending a message {:?}", e);
183                        Err(DataPathError::PublicationError(e.to_string()))
184                    }
185                },
186                Err(e) => {
187                    error!("error matching a message {:?}", e);
188                    Err(DataPathError::PublicationError(e.to_string()))
189                }
190            }
191        } else {
192            match self
193                .forwarder()
194                .on_publish_msg_match_all(class, agent_id, in_connection)
195            {
196                Ok(out_set) => {
197                    for out in out_set {
198                        match self.send_msg(msg.clone(), out).await {
199                            Ok(_) => {}
200                            Err(e) => {
201                                error!("error sending a message {:?}", e);
202                                return Err(DataPathError::PublicationError(e.to_string()));
203                            }
204                        }
205                    }
206                    Ok(())
207                }
208                Err(e) => {
209                    error!("error sending a message {:?}", e);
210                    Err(DataPathError::PublicationError(e.to_string()))
211                }
212            }
213        }
214    }
215
216    async fn process_publish(
217        &self,
218        mut msg: Message,
219        in_connection: u64,
220    ) -> Result<(), DataPathError> {
221        let pubmsg = match &msg.message_type {
222            Some(PublishType(p)) => p,
223            // this should never happen
224            _ => panic!("wrong message type"),
225        };
226
227        match process_name(&pubmsg.name) {
228            Ok(class) => {
229                let fanout = get_fanout(pubmsg);
230                let agent_id = get_agent_id(&pubmsg.name);
231
232                debug!(
233                    "received publication from connection {}: {:?}",
234                    in_connection, pubmsg
235                );
236
237                // add incoming connection to the metadata
238                add_incoming_connection(&mut msg, in_connection);
239
240                // if we get valid class also the name is valid so we can safely unwrap
241                return self
242                    .match_and_forward_msg(msg, class, in_connection, fanout, agent_id)
243                    .await;
244            }
245            Err(e) => {
246                error!("error processing publication message {:?}", e);
247                Err(DataPathError::PublicationError(e.to_string()))
248            }
249        }
250    }
251
252    fn process_command(&self, msg: &Message) -> Result<(CommandType, u64), DataPathError> {
253        if !msg.metadata.is_empty() {
254            match msg.metadata.get(&CommandType::ReceivedFrom.to_string()) {
255                None => {}
256                Some(out_str) => match out_str.parse::<u64>() {
257                    Err(e) => {
258                        error! {"error parsing the connection in command type ReceivedFrom: {:?}", e};
259                        return Err(DataPathError::CommandError(e.to_string()));
260                    }
261                    Ok(out) => {
262                        debug!(%out, "received subscription_from command, register subscription");
263                        return Ok((CommandType::ReceivedFrom, out));
264                    }
265                },
266            }
267            match msg.metadata.get(&CommandType::ForwardTo.to_string()) {
268                None => {}
269                Some(out_str) => match out_str.parse::<u64>() {
270                    Err(e) => {
271                        error! {"error parsing the connection in command type ForwardTo: {:?}", e};
272                        return Err(DataPathError::CommandError(e.to_string()));
273                    }
274                    Ok(out) => {
275                        debug!(%out, "received forward_to command, register subscription and forward");
276                        return Ok((CommandType::ForwardTo, out));
277                    }
278                },
279            }
280        }
281        Ok((CommandType::Unknown, 0))
282    }
283
284    async fn process_unsubscription(
285        &self,
286        mut msg: Message,
287        in_connection: u64,
288    ) -> Result<(), DataPathError> {
289        let unsubmsg = match &msg.message_type {
290            Some(UnsubscribeType(s)) => s,
291            // this should never happen
292            _ => panic!("wrong message type"),
293        };
294
295        match process_name(&unsubmsg.name) {
296            Ok(class) => {
297                // process command
298                let command = self.process_command(&msg);
299                let mut conn = in_connection;
300                let mut forward = false;
301                // only used if the subscription needs to be forwarded
302                let mut out_conn = in_connection;
303                match command {
304                    Err(e) => {
305                        return Err(e);
306                    }
307                    Ok(tuple) => match tuple.0 {
308                        CommandType::ReceivedFrom => {
309                            conn = tuple.1;
310                        }
311                        CommandType::ForwardTo => {
312                            forward = true;
313                            out_conn = tuple.1;
314                        }
315                        _ => {}
316                    },
317                }
318                let connection = self.forwarder().get_connection(in_connection);
319                if connection.is_none() {
320                    // this should never happen
321                    error!("incoming connection does not exists");
322                    return Err(DataPathError::SubscriptionError(
323                        "incoming connection does not exists".to_string(),
324                    ));
325                }
326                match self.forwarder().on_unsubscription_msg(
327                    class,
328                    get_agent_id(&unsubmsg.name),
329                    conn,
330                    connection.unwrap().is_local_connection(),
331                ) {
332                    Ok(_) => {}
333                    Err(e) => {
334                        return Err(DataPathError::UnsubscriptionError(e.to_string()));
335                    }
336                }
337                if forward {
338                    debug!("forward subscription to {:?}", out_conn);
339                    msg.metadata.clear();
340                    match self.send_msg(msg, out_conn).await {
341                        Ok(_) => {}
342                        Err(e) => {
343                            error!("error sending a message {:?}", e);
344                            return Err(DataPathError::SubscriptionError(e.to_string()));
345                        }
346                    };
347                }
348                Ok(())
349            }
350            Err(e) => {
351                error!("error processing unsubscription message {:?}", e);
352                Err(DataPathError::UnsubscriptionError(e.to_string()))
353            }
354        }
355    }
356
357    async fn process_subscription(
358        &self,
359        mut msg: Message,
360        in_connection: u64,
361    ) -> Result<(), DataPathError> {
362        let submsg = match &msg.message_type {
363            Some(SubscribeType(s)) => s,
364            // this should never happen
365            _ => panic!("wrong message type"),
366        };
367
368        debug!(
369            "received subscription from connection {}: {:?}",
370            in_connection, submsg
371        );
372
373        match process_name(&submsg.name) {
374            Ok(class) => {
375                // process command
376                trace!("process command");
377                let command = self.process_command(&msg);
378                let mut conn = in_connection;
379                let mut forward = false;
380
381                // only used if the subscription needs to be forwarded
382                let mut out_conn = in_connection;
383                match command {
384                    Err(e) => {
385                        return Err(e);
386                    }
387                    Ok(tuple) => match tuple.0 {
388                        CommandType::ReceivedFrom => {
389                            conn = tuple.1;
390                            trace!("received subscription_from command, register subscription with conn id {:?}", tuple.1);
391                        }
392                        CommandType::ForwardTo => {
393                            forward = true;
394                            out_conn = tuple.1;
395                            trace!("received forward_to command, register subscription and forward to conn id {:?}", out_conn);
396                        }
397                        _ => {}
398                    },
399                }
400                let connection = self.forwarder().get_connection(in_connection);
401                if connection.is_none() {
402                    // this should never happen
403                    error!("incoming connection does not exists");
404                    return Err(DataPathError::SubscriptionError(
405                        "incoming connection does not exists".to_string(),
406                    ));
407                }
408                match self.forwarder().on_subscription_msg(
409                    class,
410                    get_agent_id(&submsg.name),
411                    conn,
412                    connection.unwrap().is_local_connection(),
413                ) {
414                    Ok(_) => {}
415                    Err(e) => {
416                        return Err(DataPathError::SubscriptionError(e.to_string()));
417                    }
418                }
419
420                if forward {
421                    debug!("forward subscription {:?} to {:?}", msg, out_conn);
422                    msg.metadata.clear();
423                    match self.send_msg(msg, out_conn).await {
424                        Ok(_) => {}
425                        Err(e) => {
426                            error!("error sending a message {:?}", e);
427                            return Err(DataPathError::SubscriptionError(e.to_string()));
428                        }
429                    };
430                }
431                Ok(())
432            }
433            Err(e) => {
434                error!("error processing subscription message {:?}", e);
435                Err(DataPathError::SubscriptionError(e.to_string()))
436            }
437        }
438    }
439
440    pub async fn process_message(
441        &self,
442        msg: Message,
443        in_connection: u64,
444    ) -> Result<(), DataPathError> {
445        match &msg.message_type {
446            None => {
447                error!(
448                    "received message without message type from connection {}: {:?}",
449                    in_connection, msg
450                );
451                info!(
452                    telemetry = true,
453                    monotonic_counter.num_messages_by_type = 1,
454                    message_type = "none"
455                );
456                Err(DataPathError::UnknownMsgType("".to_string()))
457            }
458            Some(msg_type) => match msg_type {
459                SubscribeType(s) => {
460                    debug!(
461                        "received subscription from connection {}: {:?}",
462                        in_connection, s
463                    );
464                    info!(
465                        telemetry = true,
466                        monotonic_counter.num_messages_by_type = 1,
467                        message_type = "subscribe"
468                    );
469                    match self.process_subscription(msg, in_connection).await {
470                        Err(e) => {
471                            error! {"error processing subscription {:?}", e}
472                            Err(e)
473                        }
474                        Ok(_) => Ok(()),
475                    }
476                }
477                UnsubscribeType(u) => {
478                    debug!(
479                        "Received ubsubscription from client {}: {:?}",
480                        in_connection, u
481                    );
482                    info!(
483                        telemetry = true,
484                        monotonic_counter.num_messages_by_type = 1,
485                        message_type = "unsubscribe"
486                    );
487                    match self.process_unsubscription(msg, in_connection).await {
488                        Err(e) => {
489                            error! {"error processing unsubscription {:?}", e}
490                            Err(e)
491                        }
492                        Ok(_) => Ok(()),
493                    }
494                }
495                PublishType(p) => {
496                    debug!("Received publish from client {}: {:?}", in_connection, p);
497                    info!(
498                        telemetry = true,
499                        monotonic_counter.num_messages_by_type = 1,
500                        method = "publish"
501                    );
502                    match self.process_publish(msg, in_connection).await {
503                        Err(e) => {
504                            error! {"error processing publication {:?}", e}
505                            Err(e)
506                        }
507                        Ok(_) => Ok(()),
508                    }
509                }
510            },
511        }
512    }
513
514    async fn handle_new_message(
515        &self,
516        conn_index: u64,
517        result: Result<Message, Status>,
518    ) -> Result<(), DataPathError> {
519        debug!(%conn_index, "Received message from connection");
520        info!(
521            telemetry = true,
522            monotonic_counter.num_processed_messages = 1
523        );
524
525        match result {
526            Ok(msg) => {
527                match self.process_message(msg, conn_index).await {
528                    Ok(_) => Ok(()),
529                    Err(e) => {
530                        // drop message and log
531                        error!(
532                            "error processing message from connection {:?}: {:?}",
533                            conn_index, e
534                        );
535                        info!(
536                            telemetry = true,
537                            monotonic_counter.num_message_process_errors = 1
538                        );
539                        Ok(())
540                    }
541                }
542            }
543            Err(e) => {
544                if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
545                    if io_err.kind() == std::io::ErrorKind::BrokenPipe {
546                        info!("Connection {:?} closed by peer", conn_index);
547                        return Err(DataPathError::StreamError(e.to_string()));
548                    }
549                }
550                error!("error receiving messages {:?}", e);
551                let connection = self.forwarder().get_connection(conn_index);
552                match connection {
553                    Some(conn) => {
554                        match conn.channel() {
555                            Channel::Server(tx) => tx
556                                .send(Err(e))
557                                .await
558                                .map_err(|e| DataPathError::MessageSendError(e.to_string())),
559                            _ => Err(DataPathError::WrongChannelType), // error
560                        }
561                    }
562                    None => {
563                        error!("connection {:?} not found", conn_index);
564                        Err(DataPathError::ConnectionNotFound(conn_index.to_string()))
565                    }
566                }
567            }
568        }
569    }
570
571    #[tracing::instrument(fields(telemetry = true), skip(stream))]
572    fn process_stream(
573        &self,
574        mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
575        conn_index: u64,
576        is_local: bool,
577    ) -> (tokio::task::JoinHandle<()>, CancellationToken) {
578        // Clone self to be able to move it into the spawned task
579        let self_clone = self.clone();
580        let token = CancellationToken::new();
581        let token_clone = token.clone();
582        let handle = tokio::spawn(async move {
583            loop {
584                tokio::select! {
585                    res = stream.next() => {
586                        match res {
587                            Some(msg) => {
588                                if let Err(e) = self_clone.handle_new_message(conn_index, msg).await {
589                                    error!("error handling stream {:?}", e);
590                                    break;
591                                }
592                            }
593                            None => {
594                                info!(%conn_index, "end of stream");
595                                break;
596                            }
597                        }
598                    }
599                    _ = self_clone.get_drain_watch().signaled() => {
600                        info!("shutting down stream on drain: {}", conn_index);
601                        break;
602                    }
603                    _ = token_clone.cancelled() => {
604                        info!("shutting down stream cancellation token: {}", conn_index);
605                        break;
606                    }
607                }
608            }
609
610            info!(telemetry = true, counter.num_active_connections = -1);
611
612            self_clone
613                .forwarder()
614                .on_connection_drop(conn_index, is_local);
615        });
616
617        (handle, token)
618    }
619
620    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
621        let mut err: &(dyn std::error::Error + 'static) = err_status;
622
623        loop {
624            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
625                return Some(io_err);
626            }
627
628            // h2::Error do not expose std::io::Error with `source()`
629            // https://github.com/hyperium/h2/pull/462
630            if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
631                if let Some(io_err) = h2_err.get_io() {
632                    return Some(io_err);
633                }
634            }
635
636            err = err.source()?;
637        }
638    }
639}
640
641#[tonic::async_trait]
642impl PubSubService for MessageProcessor {
643    type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
644
645    #[tracing::instrument(fields(telemetry = true))]
646    async fn open_channel(
647        &self,
648        request: Request<tonic::Streaming<Message>>,
649    ) -> Result<Response<Self::OpenChannelStream>, Status> {
650        let remote_addr = request.remote_addr();
651        let local_addr = request.local_addr();
652
653        let stream = request.into_inner();
654        let (tx, rx) = mpsc::channel(128);
655
656        let connection = Connection::new(ConnectionType::Remote)
657            .with_remote_addr(remote_addr)
658            .with_local_addr(local_addr)
659            .with_channel(Channel::Server(tx));
660
661        info!(
662            "new connection received from remote: (remote: {:?} - local: {:?})",
663            connection.remote_addr(),
664            connection.local_addr()
665        );
666        info!(telemetry = true, counter.num_active_connections = 1);
667
668        // insert connection into connection table
669        let conn_index = self.forwarder().on_connection_established(connection);
670
671        self.process_stream(stream, conn_index, false);
672
673        let out_stream = ReceiverStream::new(rx);
674        Ok(Response::new(
675            Box::pin(out_stream) as Self::OpenChannelStream
676        ))
677    }
678}