agp_datapath/
message_processing.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 Cisco and/or its affiliates.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11use tonic::codegen::{Body, StdError};
12use tonic::{Request, Response, Status};
13use tracing::{debug, error, info, trace};
14
15use crate::connection::{Channel, Connection, Type as ConnectionType};
16use crate::errors::DataPathError;
17use crate::forwarder::Forwarder;
18use crate::messages::utils::{
19    add_incoming_connection, get_agent_id, get_fanout, process_name, CommandType,
20};
21use crate::messages::AgentClass;
22use crate::pubsub::proto::pubsub::v1::message::MessageType::Publish as PublishType;
23use crate::pubsub::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
24use crate::pubsub::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
25use crate::pubsub::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
26use crate::pubsub::proto::pubsub::v1::{pub_sub_service_server::PubSubService, Message};
27
28#[derive(Debug)]
29struct MessageProcessorInternal {
30    forwarder: Forwarder<Connection>,
31    drain_channel: drain::Watch,
32}
33
34#[derive(Debug, Clone)]
35pub struct MessageProcessor {
36    internal: Arc<MessageProcessorInternal>,
37}
38
39impl MessageProcessor {
40    pub fn new() -> (Self, drain::Signal) {
41        let (signal, watch) = drain::channel();
42        let forwarder = Forwarder::new();
43        let forwarder = MessageProcessorInternal {
44            forwarder,
45            drain_channel: watch,
46        };
47
48        (
49            Self {
50                internal: Arc::new(forwarder),
51            },
52            signal,
53        )
54    }
55
56    pub fn with_drain_channel(watch: drain::Watch) -> Self {
57        let forwarder = Forwarder::new();
58        let forwarder = MessageProcessorInternal {
59            forwarder,
60            drain_channel: watch,
61        };
62        Self {
63            internal: Arc::new(forwarder),
64        }
65    }
66
67    fn forwarder(&self) -> &Forwarder<Connection> {
68        &self.internal.forwarder
69    }
70
71    fn get_drain_watch(&self) -> drain::Watch {
72        self.internal.drain_channel.clone()
73    }
74
75    pub async fn connect<C>(
76        &self,
77        channel: C,
78        local: Option<SocketAddr>,
79        remote: Option<SocketAddr>,
80    ) -> Result<(tokio::task::JoinHandle<()>, CancellationToken, u64), DataPathError>
81    where
82        C: tonic::client::GrpcService<tonic::body::BoxBody>,
83        C::Error: Into<StdError>,
84        C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
85        <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
86    {
87        let mut client = PubSubServiceClient::new(channel);
88        let (tx, rx) = mpsc::channel(128);
89        let stream = client
90            .open_channel(Request::new(ReceiverStream::new(rx)))
91            .await
92            .map_err(|e| DataPathError::ConnectionError(e.to_string()))?
93            .into_inner();
94
95        let connection = Connection::new(ConnectionType::Remote)
96            .with_local_addr(local)
97            .with_remote_addr(remote)
98            .with_channel(Channel::Client(tx));
99
100        info!(
101            "new connection initiated locally: (remote: {:?} - local: {:?})",
102            connection.remote_addr(),
103            connection.local_addr()
104        );
105
106        // insert connection into connection table
107        let conn_index = self.forwarder().on_connection_established(connection);
108
109        // Start loop to process messages
110        let ret = self.process_stream(stream, conn_index, false);
111        Ok((ret.0, ret.1, conn_index))
112    }
113
114    pub fn register_local_connection(
115        &self,
116    ) -> (
117        tokio::sync::mpsc::Sender<Result<Message, Status>>,
118        tokio::sync::mpsc::Receiver<Result<Message, Status>>,
119    ) {
120        // create a pair tx, rx to be able to send messages with the standard processing loop
121        let (tx1, rx1) = mpsc::channel(128);
122
123        info!("establishing new local app connection");
124
125        // create a pair tx, rx to be able to receive messages and insert it into the connection table
126        let (tx2, rx2) = mpsc::channel(128);
127
128        // create a connection
129        let connection = Connection::new(ConnectionType::Local).with_channel(Channel::Server(tx2));
130
131        // add it to the connection table
132        let conn_id = self.forwarder().on_connection_established(connection);
133
134        debug!("local connection established with id: {:?}", conn_id);
135
136        // this loop will process messages from the local app
137        self.process_stream(ReceiverStream::new(rx1), conn_id, true);
138
139        // return the handles to be used to send and receive messages
140        (tx1, rx2)
141    }
142
143    pub async fn send_msg(
144        &self,
145        msg: Message,
146        out_conn: u64,
147    ) -> Result<(), Box<dyn std::error::Error>> {
148        let connection = self.forwarder().get_connection(out_conn);
149        match connection {
150            Some(conn) => match conn.channel() {
151                Channel::Server(s) => s.send(Ok(msg)).await?,
152                Channel::Client(s) => s.send(msg).await?,
153                _ => error!("error reading channel"),
154            },
155            None => error!("connection {:?} not found", out_conn),
156        }
157        Ok(())
158    }
159
160    async fn match_and_forward_msg(
161        &self,
162        msg: Message,
163        class: AgentClass,
164        in_connection: u64,
165        fanout: u32,
166        agent_id: Option<u64>,
167    ) -> Result<(), DataPathError> {
168        debug!(
169            "match and forward message: class: {:?} - agent_id: {:?} - fanout: {:?}",
170            class, agent_id, fanout,
171        );
172
173        if fanout == 1 {
174            match self
175                .forwarder()
176                .on_publish_msg_match_one(class, agent_id, in_connection)
177            {
178                Ok(out) => match self.send_msg(msg, out).await {
179                    Ok(_) => Ok(()),
180                    Err(e) => {
181                        error!("error sending a message {:?}", e);
182                        Err(DataPathError::PublicationError(e.to_string()))
183                    }
184                },
185                Err(e) => {
186                    error!("error matching a message {:?}", e);
187                    Err(DataPathError::PublicationError(e.to_string()))
188                }
189            }
190        } else {
191            match self
192                .forwarder()
193                .on_publish_msg_match_all(class, agent_id, in_connection)
194            {
195                Ok(out_set) => {
196                    for out in out_set {
197                        match self.send_msg(msg.clone(), out).await {
198                            Ok(_) => {}
199                            Err(e) => {
200                                error!("error sending a message {:?}", e);
201                                return Err(DataPathError::PublicationError(e.to_string()));
202                            }
203                        }
204                    }
205                    Ok(())
206                }
207                Err(e) => {
208                    error!("error sending a message {:?}", e);
209                    Err(DataPathError::PublicationError(e.to_string()))
210                }
211            }
212        }
213    }
214
215    async fn process_publish(
216        &self,
217        mut msg: Message,
218        in_connection: u64,
219    ) -> Result<(), DataPathError> {
220        let pubmsg = match &msg.message_type {
221            Some(PublishType(p)) => p,
222            // this should never happen
223            _ => panic!("wrong message type"),
224        };
225
226        match process_name(&pubmsg.name) {
227            Ok(class) => {
228                let fanout = get_fanout(pubmsg);
229                let agent_id = get_agent_id(&pubmsg.name);
230
231                debug!(
232                    "received publication from connection {}: {:?}",
233                    in_connection, pubmsg
234                );
235
236                // add incoming connection to the metadata
237                add_incoming_connection(&mut msg, in_connection);
238
239                // if we get valid class also the name is valid so we can safely unwrap
240                return self
241                    .match_and_forward_msg(msg, class, in_connection, fanout, agent_id)
242                    .await;
243            }
244            Err(e) => {
245                error!("error processing publication message {:?}", e);
246                Err(DataPathError::PublicationError(e.to_string()))
247            }
248        }
249    }
250
251    fn process_command(&self, msg: &Message) -> Result<(CommandType, u64), DataPathError> {
252        if !msg.metadata.is_empty() {
253            match msg.metadata.get(&CommandType::ReceivedFrom.to_string()) {
254                None => {}
255                Some(out_str) => match out_str.parse::<u64>() {
256                    Err(e) => {
257                        error! {"error parsing the connection in command type ReceivedFrom: {:?}", e};
258                        return Err(DataPathError::CommandError(e.to_string()));
259                    }
260                    Ok(out) => {
261                        debug!(%out, "received subscription_from command, register subscription");
262                        return Ok((CommandType::ReceivedFrom, out));
263                    }
264                },
265            }
266            match msg.metadata.get(&CommandType::ForwardTo.to_string()) {
267                None => {}
268                Some(out_str) => match out_str.parse::<u64>() {
269                    Err(e) => {
270                        error! {"error parsing the connection in command type ForwardTo: {:?}", e};
271                        return Err(DataPathError::CommandError(e.to_string()));
272                    }
273                    Ok(out) => {
274                        debug!(%out, "received forward_to command, register subscription and forward");
275                        return Ok((CommandType::ForwardTo, out));
276                    }
277                },
278            }
279        }
280        Ok((CommandType::Unknown, 0))
281    }
282
283    async fn process_unsubscription(
284        &self,
285        mut msg: Message,
286        in_connection: u64,
287    ) -> Result<(), DataPathError> {
288        let unsubmsg = match &msg.message_type {
289            Some(UnsubscribeType(s)) => s,
290            // this should never happen
291            _ => panic!("wrong message type"),
292        };
293
294        match process_name(&unsubmsg.name) {
295            Ok(class) => {
296                // process command
297                let command = self.process_command(&msg);
298                let mut conn = in_connection;
299                let mut forward = false;
300                // only used if the subscription needs to be forwarded
301                let mut out_conn = in_connection;
302                match command {
303                    Err(e) => {
304                        return Err(e);
305                    }
306                    Ok(tuple) => match tuple.0 {
307                        CommandType::ReceivedFrom => {
308                            conn = tuple.1;
309                        }
310                        CommandType::ForwardTo => {
311                            forward = true;
312                            out_conn = tuple.1;
313                        }
314                        _ => {}
315                    },
316                }
317                let connection = self.forwarder().get_connection(in_connection);
318                if connection.is_none() {
319                    // this should never happen
320                    error!("incoming connection does not exists");
321                    return Err(DataPathError::SubscriptionError(
322                        "incoming connection does not exists".to_string(),
323                    ));
324                }
325                match self.forwarder().on_unsubscription_msg(
326                    class,
327                    get_agent_id(&unsubmsg.name),
328                    conn,
329                    connection.unwrap().is_local_connection(),
330                ) {
331                    Ok(_) => {}
332                    Err(e) => {
333                        return Err(DataPathError::UnsubscriptionError(e.to_string()));
334                    }
335                }
336                if forward {
337                    debug!("forward subscription to {:?}", out_conn);
338                    msg.metadata.clear();
339                    match self.send_msg(msg, out_conn).await {
340                        Ok(_) => {}
341                        Err(e) => {
342                            error!("error sending a message {:?}", e);
343                            return Err(DataPathError::SubscriptionError(e.to_string()));
344                        }
345                    };
346                }
347                Ok(())
348            }
349            Err(e) => {
350                error!("error processing unsubscription message {:?}", e);
351                Err(DataPathError::UnsubscriptionError(e.to_string()))
352            }
353        }
354    }
355
356    async fn process_subscription(
357        &self,
358        mut msg: Message,
359        in_connection: u64,
360    ) -> Result<(), DataPathError> {
361        let submsg = match &msg.message_type {
362            Some(SubscribeType(s)) => s,
363            // this should never happen
364            _ => panic!("wrong message type"),
365        };
366
367        debug!(
368            "received subscription from connection {}: {:?}",
369            in_connection, submsg
370        );
371
372        match process_name(&submsg.name) {
373            Ok(class) => {
374                // process command
375                trace!("process command");
376                let command = self.process_command(&msg);
377                let mut conn = in_connection;
378                let mut forward = false;
379
380                // only used if the subscription needs to be forwarded
381                let mut out_conn = in_connection;
382                match command {
383                    Err(e) => {
384                        return Err(e);
385                    }
386                    Ok(tuple) => match tuple.0 {
387                        CommandType::ReceivedFrom => {
388                            conn = tuple.1;
389                            trace!("received subscription_from command, register subscription with conn id {:?}", tuple.1);
390                        }
391                        CommandType::ForwardTo => {
392                            forward = true;
393                            out_conn = tuple.1;
394                            trace!("received forward_to command, register subscription and forward to conn id {:?}", out_conn);
395                        }
396                        _ => {}
397                    },
398                }
399                let connection = self.forwarder().get_connection(in_connection);
400                if connection.is_none() {
401                    // this should never happen
402                    error!("incoming connection does not exists");
403                    return Err(DataPathError::SubscriptionError(
404                        "incoming connection does not exists".to_string(),
405                    ));
406                }
407                match self.forwarder().on_subscription_msg(
408                    class,
409                    get_agent_id(&submsg.name),
410                    conn,
411                    connection.unwrap().is_local_connection(),
412                ) {
413                    Ok(_) => {}
414                    Err(e) => {
415                        return Err(DataPathError::SubscriptionError(e.to_string()));
416                    }
417                }
418
419                if forward {
420                    debug!("forward subscription {:?} to {:?}", msg, out_conn);
421                    msg.metadata.clear();
422                    match self.send_msg(msg, out_conn).await {
423                        Ok(_) => {}
424                        Err(e) => {
425                            error!("error sending a message {:?}", e);
426                            return Err(DataPathError::SubscriptionError(e.to_string()));
427                        }
428                    };
429                }
430                Ok(())
431            }
432            Err(e) => {
433                error!("error processing subscription message {:?}", e);
434                Err(DataPathError::SubscriptionError(e.to_string()))
435            }
436        }
437    }
438
439    pub async fn process_message(
440        &self,
441        msg: Message,
442        in_connection: u64,
443    ) -> Result<(), DataPathError> {
444        match &msg.message_type {
445            None => {
446                error!(
447                    "received message without message type from connection {}: {:?}",
448                    in_connection, msg
449                );
450                Err(DataPathError::UnknownMsgType("".to_string()))
451            }
452            Some(msg_type) => match msg_type {
453                SubscribeType(s) => {
454                    debug!(
455                        "received subscription from connection {}: {:?}",
456                        in_connection, s
457                    );
458                    match self.process_subscription(msg, in_connection).await {
459                        Err(e) => {
460                            error! {"error processing subscription {:?}", e}
461                            Err(e)
462                        }
463                        Ok(_) => Ok(()),
464                    }
465                }
466                UnsubscribeType(u) => {
467                    debug!(
468                        "Received ubsubscription from client {}: {:?}",
469                        in_connection, u
470                    );
471                    match self.process_unsubscription(msg, in_connection).await {
472                        Err(e) => {
473                            error! {"error processing unsubscription {:?}", e}
474                            Err(e)
475                        }
476                        Ok(_) => Ok(()),
477                    }
478                }
479                PublishType(p) => {
480                    debug!("Received publish from client {}: {:?}", in_connection, p);
481                    match self.process_publish(msg, in_connection).await {
482                        Err(e) => {
483                            error! {"error processing publication {:?}", e}
484                            Err(e)
485                        }
486                        Ok(_) => Ok(()),
487                    }
488                }
489            },
490        }
491    }
492
493    async fn handle_new_message(
494        &self,
495        conn_index: u64,
496        result: Result<Message, Status>,
497    ) -> Result<(), DataPathError> {
498        debug!(%conn_index, "Received message from connection");
499
500        match result {
501            Ok(msg) => {
502                match self.process_message(msg, conn_index).await {
503                    Ok(_) => Ok(()),
504                    Err(e) => {
505                        // drop message and log
506                        error!(
507                            "error processing message from connection {:?}: {:?}",
508                            conn_index, e
509                        );
510                        Ok(())
511                    }
512                }
513            }
514            Err(e) => {
515                if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
516                    if io_err.kind() == std::io::ErrorKind::BrokenPipe {
517                        info!("Connection {:?} closed by peer", conn_index);
518                        return Err(DataPathError::StreamError(e.to_string()));
519                    }
520                }
521                error!("error receiving messages {:?}", e);
522                let connection = self.forwarder().get_connection(conn_index);
523                match connection {
524                    Some(conn) => {
525                        match conn.channel() {
526                            Channel::Server(tx) => tx
527                                .send(Err(e))
528                                .await
529                                .map_err(|e| DataPathError::MessageSendError(e.to_string())),
530                            _ => Err(DataPathError::WrongChannelType), // error
531                        }
532                    }
533                    None => {
534                        error!("connection {:?} not found", conn_index);
535                        Err(DataPathError::ConnectionNotFound(conn_index.to_string()))
536                    }
537                }
538            }
539        }
540    }
541
542    fn process_stream(
543        &self,
544        mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
545        conn_index: u64,
546        is_local: bool,
547    ) -> (tokio::task::JoinHandle<()>, CancellationToken) {
548        // Clone self to be able to move it into the spawned task
549        let self_clone = self.clone();
550        let token = CancellationToken::new();
551        let token_clone = token.clone();
552        let handle = tokio::spawn(async move {
553            loop {
554                tokio::select! {
555                    res = stream.next() => {
556                        match res {
557                            Some(msg) => {
558                                if let Err(e) = self_clone.handle_new_message(conn_index, msg).await {
559                                    error!("error handling stream {:?}", e);
560                                    break;
561                                }
562                            }
563                            None => {
564                                info!(%conn_index, "end of stream");
565                                break;
566                            }
567                        }
568                    }
569                    _ = self_clone.get_drain_watch().signaled() => {
570                        info!("shutting down stream on drain: {}", conn_index);
571                        break;
572                    }
573                    _ = token_clone.cancelled() => {
574                        info!("shutting down stream cancellation token: {}", conn_index);
575                        break;
576                    }
577                }
578            }
579
580            self_clone
581                .forwarder()
582                .on_connection_drop(conn_index, is_local);
583        });
584
585        (handle, token)
586    }
587
588    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
589        let mut err: &(dyn std::error::Error + 'static) = err_status;
590
591        loop {
592            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
593                return Some(io_err);
594            }
595
596            // h2::Error do not expose std::io::Error with `source()`
597            // https://github.com/hyperium/h2/pull/462
598            if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
599                if let Some(io_err) = h2_err.get_io() {
600                    return Some(io_err);
601                }
602            }
603
604            err = err.source()?;
605        }
606    }
607}
608
609#[tonic::async_trait]
610impl PubSubService for MessageProcessor {
611    type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
612
613    async fn open_channel(
614        &self,
615        request: Request<tonic::Streaming<Message>>,
616    ) -> Result<Response<Self::OpenChannelStream>, Status> {
617        let remote_addr = request.remote_addr();
618        let local_addr = request.local_addr();
619
620        let stream = request.into_inner();
621        let (tx, rx) = mpsc::channel(128);
622
623        let connection = Connection::new(ConnectionType::Remote)
624            .with_remote_addr(remote_addr)
625            .with_local_addr(local_addr)
626            .with_channel(Channel::Server(tx));
627
628        info!(
629            "new connection received from remote: (remote: {:?} - local: {:?})",
630            connection.remote_addr(),
631            connection.local_addr()
632        );
633
634        // insert connection into connection table
635        let conn_index = self.forwarder().on_connection_established(connection);
636
637        self.process_stream(stream, conn_index, false);
638
639        let out_stream = ReceiverStream::new(rx);
640        Ok(Response::new(
641            Box::pin(out_stream) as Self::OpenChannelStream
642        ))
643    }
644}