1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11use tonic::codegen::{Body, StdError};
12use tonic::{Request, Response, Status};
13use tracing::{debug, error, info, trace};
14
15use crate::connection::{Channel, Connection, Type as ConnectionType};
16use crate::errors::DataPathError;
17use crate::forwarder::Forwarder;
18use crate::messages::utils::{
19 add_incoming_connection, get_agent_id, get_fanout, process_name, CommandType,
20};
21use crate::messages::AgentClass;
22use crate::pubsub::proto::pubsub::v1::message::MessageType::Publish as PublishType;
23use crate::pubsub::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
24use crate::pubsub::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
25use crate::pubsub::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
26use crate::pubsub::proto::pubsub::v1::{pub_sub_service_server::PubSubService, Message};
27
28#[derive(Debug)]
29struct MessageProcessorInternal {
30 forwarder: Forwarder<Connection>,
31 drain_channel: drain::Watch,
32}
33
34#[derive(Debug, Clone)]
35pub struct MessageProcessor {
36 internal: Arc<MessageProcessorInternal>,
37}
38
39impl MessageProcessor {
40 pub fn new() -> (Self, drain::Signal) {
41 let (signal, watch) = drain::channel();
42 let forwarder = Forwarder::new();
43 let forwarder = MessageProcessorInternal {
44 forwarder,
45 drain_channel: watch,
46 };
47
48 (
49 Self {
50 internal: Arc::new(forwarder),
51 },
52 signal,
53 )
54 }
55
56 pub fn with_drain_channel(watch: drain::Watch) -> Self {
57 let forwarder = Forwarder::new();
58 let forwarder = MessageProcessorInternal {
59 forwarder,
60 drain_channel: watch,
61 };
62 Self {
63 internal: Arc::new(forwarder),
64 }
65 }
66
67 fn forwarder(&self) -> &Forwarder<Connection> {
68 &self.internal.forwarder
69 }
70
71 fn get_drain_watch(&self) -> drain::Watch {
72 self.internal.drain_channel.clone()
73 }
74
75 pub async fn connect<C>(
76 &self,
77 channel: C,
78 local: Option<SocketAddr>,
79 remote: Option<SocketAddr>,
80 ) -> Result<(tokio::task::JoinHandle<()>, CancellationToken, u64), DataPathError>
81 where
82 C: tonic::client::GrpcService<tonic::body::BoxBody>,
83 C::Error: Into<StdError>,
84 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
85 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
86 {
87 let mut client = PubSubServiceClient::new(channel);
88 let (tx, rx) = mpsc::channel(128);
89 let stream = client
90 .open_channel(Request::new(ReceiverStream::new(rx)))
91 .await
92 .map_err(|e| DataPathError::ConnectionError(e.to_string()))?
93 .into_inner();
94
95 let connection = Connection::new(ConnectionType::Remote)
96 .with_local_addr(local)
97 .with_remote_addr(remote)
98 .with_channel(Channel::Client(tx));
99
100 info!(
101 "new connection initiated locally: (remote: {:?} - local: {:?})",
102 connection.remote_addr(),
103 connection.local_addr()
104 );
105
106 let conn_index = self.forwarder().on_connection_established(connection);
108
109 let ret = self.process_stream(stream, conn_index, false);
111 Ok((ret.0, ret.1, conn_index))
112 }
113
114 pub fn register_local_connection(
115 &self,
116 ) -> (
117 tokio::sync::mpsc::Sender<Result<Message, Status>>,
118 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
119 ) {
120 let (tx1, rx1) = mpsc::channel(128);
122
123 info!("establishing new local app connection");
124
125 let (tx2, rx2) = mpsc::channel(128);
127
128 let connection = Connection::new(ConnectionType::Local).with_channel(Channel::Server(tx2));
130
131 let conn_id = self.forwarder().on_connection_established(connection);
133
134 debug!("local connection established with id: {:?}", conn_id);
135 info!(telemetry = true, counter.num_active_connections = 1);
136
137 self.process_stream(ReceiverStream::new(rx1), conn_id, true);
139
140 (tx1, rx2)
142 }
143
144 pub async fn send_msg(
145 &self,
146 msg: Message,
147 out_conn: u64,
148 ) -> Result<(), Box<dyn std::error::Error>> {
149 let connection = self.forwarder().get_connection(out_conn);
150 match connection {
151 Some(conn) => match conn.channel() {
152 Channel::Server(s) => s.send(Ok(msg)).await?,
153 Channel::Client(s) => s.send(msg).await?,
154 _ => error!("error reading channel"),
155 },
156 None => error!("connection {:?} not found", out_conn),
157 }
158 Ok(())
159 }
160
161 async fn match_and_forward_msg(
162 &self,
163 msg: Message,
164 class: AgentClass,
165 in_connection: u64,
166 fanout: u32,
167 agent_id: Option<u64>,
168 ) -> Result<(), DataPathError> {
169 debug!(
170 "match and forward message: class: {:?} - agent_id: {:?} - fanout: {:?}",
171 class, agent_id, fanout,
172 );
173
174 if fanout == 1 {
175 match self
176 .forwarder()
177 .on_publish_msg_match_one(class, agent_id, in_connection)
178 {
179 Ok(out) => match self.send_msg(msg, out).await {
180 Ok(_) => Ok(()),
181 Err(e) => {
182 error!("error sending a message {:?}", e);
183 Err(DataPathError::PublicationError(e.to_string()))
184 }
185 },
186 Err(e) => {
187 error!("error matching a message {:?}", e);
188 Err(DataPathError::PublicationError(e.to_string()))
189 }
190 }
191 } else {
192 match self
193 .forwarder()
194 .on_publish_msg_match_all(class, agent_id, in_connection)
195 {
196 Ok(out_set) => {
197 for out in out_set {
198 match self.send_msg(msg.clone(), out).await {
199 Ok(_) => {}
200 Err(e) => {
201 error!("error sending a message {:?}", e);
202 return Err(DataPathError::PublicationError(e.to_string()));
203 }
204 }
205 }
206 Ok(())
207 }
208 Err(e) => {
209 error!("error sending a message {:?}", e);
210 Err(DataPathError::PublicationError(e.to_string()))
211 }
212 }
213 }
214 }
215
216 async fn process_publish(
217 &self,
218 mut msg: Message,
219 in_connection: u64,
220 ) -> Result<(), DataPathError> {
221 let pubmsg = match &msg.message_type {
222 Some(PublishType(p)) => p,
223 _ => panic!("wrong message type"),
225 };
226
227 match process_name(&pubmsg.name) {
228 Ok(class) => {
229 let fanout = get_fanout(pubmsg);
230 let agent_id = get_agent_id(&pubmsg.name);
231
232 debug!(
233 "received publication from connection {}: {:?}",
234 in_connection, pubmsg
235 );
236
237 add_incoming_connection(&mut msg, in_connection);
239
240 return self
242 .match_and_forward_msg(msg, class, in_connection, fanout, agent_id)
243 .await;
244 }
245 Err(e) => {
246 error!("error processing publication message {:?}", e);
247 Err(DataPathError::PublicationError(e.to_string()))
248 }
249 }
250 }
251
252 fn process_command(&self, msg: &Message) -> Result<(CommandType, u64), DataPathError> {
253 if !msg.metadata.is_empty() {
254 match msg.metadata.get(&CommandType::ReceivedFrom.to_string()) {
255 None => {}
256 Some(out_str) => match out_str.parse::<u64>() {
257 Err(e) => {
258 error! {"error parsing the connection in command type ReceivedFrom: {:?}", e};
259 return Err(DataPathError::CommandError(e.to_string()));
260 }
261 Ok(out) => {
262 debug!(%out, "received subscription_from command, register subscription");
263 return Ok((CommandType::ReceivedFrom, out));
264 }
265 },
266 }
267 match msg.metadata.get(&CommandType::ForwardTo.to_string()) {
268 None => {}
269 Some(out_str) => match out_str.parse::<u64>() {
270 Err(e) => {
271 error! {"error parsing the connection in command type ForwardTo: {:?}", e};
272 return Err(DataPathError::CommandError(e.to_string()));
273 }
274 Ok(out) => {
275 debug!(%out, "received forward_to command, register subscription and forward");
276 return Ok((CommandType::ForwardTo, out));
277 }
278 },
279 }
280 }
281 Ok((CommandType::Unknown, 0))
282 }
283
284 async fn process_unsubscription(
285 &self,
286 mut msg: Message,
287 in_connection: u64,
288 ) -> Result<(), DataPathError> {
289 let unsubmsg = match &msg.message_type {
290 Some(UnsubscribeType(s)) => s,
291 _ => panic!("wrong message type"),
293 };
294
295 match process_name(&unsubmsg.name) {
296 Ok(class) => {
297 let command = self.process_command(&msg);
299 let mut conn = in_connection;
300 let mut forward = false;
301 let mut out_conn = in_connection;
303 match command {
304 Err(e) => {
305 return Err(e);
306 }
307 Ok(tuple) => match tuple.0 {
308 CommandType::ReceivedFrom => {
309 conn = tuple.1;
310 }
311 CommandType::ForwardTo => {
312 forward = true;
313 out_conn = tuple.1;
314 }
315 _ => {}
316 },
317 }
318 let connection = self.forwarder().get_connection(in_connection);
319 if connection.is_none() {
320 error!("incoming connection does not exists");
322 return Err(DataPathError::SubscriptionError(
323 "incoming connection does not exists".to_string(),
324 ));
325 }
326 match self.forwarder().on_unsubscription_msg(
327 class,
328 get_agent_id(&unsubmsg.name),
329 conn,
330 connection.unwrap().is_local_connection(),
331 ) {
332 Ok(_) => {}
333 Err(e) => {
334 return Err(DataPathError::UnsubscriptionError(e.to_string()));
335 }
336 }
337 if forward {
338 debug!("forward subscription to {:?}", out_conn);
339 msg.metadata.clear();
340 match self.send_msg(msg, out_conn).await {
341 Ok(_) => {}
342 Err(e) => {
343 error!("error sending a message {:?}", e);
344 return Err(DataPathError::SubscriptionError(e.to_string()));
345 }
346 };
347 }
348 Ok(())
349 }
350 Err(e) => {
351 error!("error processing unsubscription message {:?}", e);
352 Err(DataPathError::UnsubscriptionError(e.to_string()))
353 }
354 }
355 }
356
357 async fn process_subscription(
358 &self,
359 mut msg: Message,
360 in_connection: u64,
361 ) -> Result<(), DataPathError> {
362 let submsg = match &msg.message_type {
363 Some(SubscribeType(s)) => s,
364 _ => panic!("wrong message type"),
366 };
367
368 debug!(
369 "received subscription from connection {}: {:?}",
370 in_connection, submsg
371 );
372
373 match process_name(&submsg.name) {
374 Ok(class) => {
375 trace!("process command");
377 let command = self.process_command(&msg);
378 let mut conn = in_connection;
379 let mut forward = false;
380
381 let mut out_conn = in_connection;
383 match command {
384 Err(e) => {
385 return Err(e);
386 }
387 Ok(tuple) => match tuple.0 {
388 CommandType::ReceivedFrom => {
389 conn = tuple.1;
390 trace!("received subscription_from command, register subscription with conn id {:?}", tuple.1);
391 }
392 CommandType::ForwardTo => {
393 forward = true;
394 out_conn = tuple.1;
395 trace!("received forward_to command, register subscription and forward to conn id {:?}", out_conn);
396 }
397 _ => {}
398 },
399 }
400 let connection = self.forwarder().get_connection(in_connection);
401 if connection.is_none() {
402 error!("incoming connection does not exists");
404 return Err(DataPathError::SubscriptionError(
405 "incoming connection does not exists".to_string(),
406 ));
407 }
408 match self.forwarder().on_subscription_msg(
409 class,
410 get_agent_id(&submsg.name),
411 conn,
412 connection.unwrap().is_local_connection(),
413 ) {
414 Ok(_) => {}
415 Err(e) => {
416 return Err(DataPathError::SubscriptionError(e.to_string()));
417 }
418 }
419
420 if forward {
421 debug!("forward subscription {:?} to {:?}", msg, out_conn);
422 msg.metadata.clear();
423 match self.send_msg(msg, out_conn).await {
424 Ok(_) => {}
425 Err(e) => {
426 error!("error sending a message {:?}", e);
427 return Err(DataPathError::SubscriptionError(e.to_string()));
428 }
429 };
430 }
431 Ok(())
432 }
433 Err(e) => {
434 error!("error processing subscription message {:?}", e);
435 Err(DataPathError::SubscriptionError(e.to_string()))
436 }
437 }
438 }
439
440 pub async fn process_message(
441 &self,
442 msg: Message,
443 in_connection: u64,
444 ) -> Result<(), DataPathError> {
445 match &msg.message_type {
446 None => {
447 error!(
448 "received message without message type from connection {}: {:?}",
449 in_connection, msg
450 );
451 info!(
452 telemetry = true,
453 monotonic_counter.num_messages_by_type = 1,
454 message_type = "none"
455 );
456 Err(DataPathError::UnknownMsgType("".to_string()))
457 }
458 Some(msg_type) => match msg_type {
459 SubscribeType(s) => {
460 debug!(
461 "received subscription from connection {}: {:?}",
462 in_connection, s
463 );
464 info!(
465 telemetry = true,
466 monotonic_counter.num_messages_by_type = 1,
467 message_type = "subscribe"
468 );
469 match self.process_subscription(msg, in_connection).await {
470 Err(e) => {
471 error! {"error processing subscription {:?}", e}
472 Err(e)
473 }
474 Ok(_) => Ok(()),
475 }
476 }
477 UnsubscribeType(u) => {
478 debug!(
479 "Received ubsubscription from client {}: {:?}",
480 in_connection, u
481 );
482 info!(
483 telemetry = true,
484 monotonic_counter.num_messages_by_type = 1,
485 message_type = "unsubscribe"
486 );
487 match self.process_unsubscription(msg, in_connection).await {
488 Err(e) => {
489 error! {"error processing unsubscription {:?}", e}
490 Err(e)
491 }
492 Ok(_) => Ok(()),
493 }
494 }
495 PublishType(p) => {
496 debug!("Received publish from client {}: {:?}", in_connection, p);
497 info!(
498 telemetry = true,
499 monotonic_counter.num_messages_by_type = 1,
500 method = "publish"
501 );
502 match self.process_publish(msg, in_connection).await {
503 Err(e) => {
504 error! {"error processing publication {:?}", e}
505 Err(e)
506 }
507 Ok(_) => Ok(()),
508 }
509 }
510 },
511 }
512 }
513
514 async fn handle_new_message(
515 &self,
516 conn_index: u64,
517 result: Result<Message, Status>,
518 ) -> Result<(), DataPathError> {
519 debug!(%conn_index, "Received message from connection");
520 info!(
521 telemetry = true,
522 monotonic_counter.num_processed_messages = 1
523 );
524
525 match result {
526 Ok(msg) => {
527 match self.process_message(msg, conn_index).await {
528 Ok(_) => Ok(()),
529 Err(e) => {
530 error!(
532 "error processing message from connection {:?}: {:?}",
533 conn_index, e
534 );
535 info!(
536 telemetry = true,
537 monotonic_counter.num_message_process_errors = 1
538 );
539 Ok(())
540 }
541 }
542 }
543 Err(e) => {
544 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
545 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
546 info!("Connection {:?} closed by peer", conn_index);
547 return Err(DataPathError::StreamError(e.to_string()));
548 }
549 }
550 error!("error receiving messages {:?}", e);
551 let connection = self.forwarder().get_connection(conn_index);
552 match connection {
553 Some(conn) => {
554 match conn.channel() {
555 Channel::Server(tx) => tx
556 .send(Err(e))
557 .await
558 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
559 _ => Err(DataPathError::WrongChannelType), }
561 }
562 None => {
563 error!("connection {:?} not found", conn_index);
564 Err(DataPathError::ConnectionNotFound(conn_index.to_string()))
565 }
566 }
567 }
568 }
569 }
570
571 #[tracing::instrument(fields(telemetry = true), skip(stream))]
572 fn process_stream(
573 &self,
574 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
575 conn_index: u64,
576 is_local: bool,
577 ) -> (tokio::task::JoinHandle<()>, CancellationToken) {
578 let self_clone = self.clone();
580 let token = CancellationToken::new();
581 let token_clone = token.clone();
582 let handle = tokio::spawn(async move {
583 loop {
584 tokio::select! {
585 res = stream.next() => {
586 match res {
587 Some(msg) => {
588 if let Err(e) = self_clone.handle_new_message(conn_index, msg).await {
589 error!("error handling stream {:?}", e);
590 break;
591 }
592 }
593 None => {
594 info!(%conn_index, "end of stream");
595 break;
596 }
597 }
598 }
599 _ = self_clone.get_drain_watch().signaled() => {
600 info!("shutting down stream on drain: {}", conn_index);
601 break;
602 }
603 _ = token_clone.cancelled() => {
604 info!("shutting down stream cancellation token: {}", conn_index);
605 break;
606 }
607 }
608 }
609
610 info!(telemetry = true, counter.num_active_connections = -1);
611
612 self_clone
613 .forwarder()
614 .on_connection_drop(conn_index, is_local);
615 });
616
617 (handle, token)
618 }
619
620 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
621 let mut err: &(dyn std::error::Error + 'static) = err_status;
622
623 loop {
624 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
625 return Some(io_err);
626 }
627
628 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
631 if let Some(io_err) = h2_err.get_io() {
632 return Some(io_err);
633 }
634 }
635
636 err = err.source()?;
637 }
638 }
639}
640
641#[tonic::async_trait]
642impl PubSubService for MessageProcessor {
643 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
644
645 #[tracing::instrument(fields(telemetry = true))]
646 async fn open_channel(
647 &self,
648 request: Request<tonic::Streaming<Message>>,
649 ) -> Result<Response<Self::OpenChannelStream>, Status> {
650 let remote_addr = request.remote_addr();
651 let local_addr = request.local_addr();
652
653 let stream = request.into_inner();
654 let (tx, rx) = mpsc::channel(128);
655
656 let connection = Connection::new(ConnectionType::Remote)
657 .with_remote_addr(remote_addr)
658 .with_local_addr(local_addr)
659 .with_channel(Channel::Server(tx));
660
661 info!(
662 "new connection received from remote: (remote: {:?} - local: {:?})",
663 connection.remote_addr(),
664 connection.local_addr()
665 );
666 info!(telemetry = true, counter.num_active_connections = 1);
667
668 let conn_index = self.forwarder().on_connection_established(connection);
670
671 self.process_stream(stream, conn_index, false);
672
673 let out_stream = ReceiverStream::new(rx);
674 Ok(Response::new(
675 Box::pin(out_stream) as Self::OpenChannelStream
676 ))
677 }
678}