1use std::net::SocketAddr;
5use std::{pin::Pin, sync::Arc};
6
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::{Stream, StreamExt};
10use tokio_util::sync::CancellationToken;
11use tonic::codegen::{Body, StdError};
12use tonic::{Request, Response, Status};
13use tracing::{debug, error, info, trace};
14
15use crate::connection::{Channel, Connection, Type as ConnectionType};
16use crate::errors::DataPathError;
17use crate::forwarder::Forwarder;
18use crate::messages::utils::{
19 add_incoming_connection, get_agent_id, get_fanout, process_name, CommandType,
20};
21use crate::messages::AgentClass;
22use crate::pubsub::proto::pubsub::v1::message::MessageType::Publish as PublishType;
23use crate::pubsub::proto::pubsub::v1::message::MessageType::Subscribe as SubscribeType;
24use crate::pubsub::proto::pubsub::v1::message::MessageType::Unsubscribe as UnsubscribeType;
25use crate::pubsub::proto::pubsub::v1::pub_sub_service_client::PubSubServiceClient;
26use crate::pubsub::proto::pubsub::v1::{pub_sub_service_server::PubSubService, Message};
27
28#[derive(Debug)]
29struct MessageProcessorInternal {
30 forwarder: Forwarder<Connection>,
31 drain_channel: drain::Watch,
32}
33
34#[derive(Debug, Clone)]
35pub struct MessageProcessor {
36 internal: Arc<MessageProcessorInternal>,
37}
38
39impl MessageProcessor {
40 pub fn new() -> (Self, drain::Signal) {
41 let (signal, watch) = drain::channel();
42 let forwarder = Forwarder::new();
43 let forwarder = MessageProcessorInternal {
44 forwarder,
45 drain_channel: watch,
46 };
47
48 (
49 Self {
50 internal: Arc::new(forwarder),
51 },
52 signal,
53 )
54 }
55
56 pub fn with_drain_channel(watch: drain::Watch) -> Self {
57 let forwarder = Forwarder::new();
58 let forwarder = MessageProcessorInternal {
59 forwarder,
60 drain_channel: watch,
61 };
62 Self {
63 internal: Arc::new(forwarder),
64 }
65 }
66
67 fn forwarder(&self) -> &Forwarder<Connection> {
68 &self.internal.forwarder
69 }
70
71 fn get_drain_watch(&self) -> drain::Watch {
72 self.internal.drain_channel.clone()
73 }
74
75 pub async fn connect<C>(
76 &self,
77 channel: C,
78 local: Option<SocketAddr>,
79 remote: Option<SocketAddr>,
80 ) -> Result<(tokio::task::JoinHandle<()>, CancellationToken, u64), DataPathError>
81 where
82 C: tonic::client::GrpcService<tonic::body::BoxBody>,
83 C::Error: Into<StdError>,
84 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
85 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
86 {
87 let mut client = PubSubServiceClient::new(channel);
88 let (tx, rx) = mpsc::channel(128);
89 let stream = client
90 .open_channel(Request::new(ReceiverStream::new(rx)))
91 .await
92 .map_err(|e| DataPathError::ConnectionError(e.to_string()))?
93 .into_inner();
94
95 let connection = Connection::new(ConnectionType::Remote)
96 .with_local_addr(local)
97 .with_remote_addr(remote)
98 .with_channel(Channel::Client(tx));
99
100 info!(
101 "new connection initiated locally: (remote: {:?} - local: {:?})",
102 connection.remote_addr(),
103 connection.local_addr()
104 );
105
106 let conn_index = self.forwarder().on_connection_established(connection);
108
109 let ret = self.process_stream(stream, conn_index, false);
111 Ok((ret.0, ret.1, conn_index))
112 }
113
114 pub fn register_local_connection(
115 &self,
116 ) -> (
117 tokio::sync::mpsc::Sender<Result<Message, Status>>,
118 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
119 ) {
120 let (tx1, rx1) = mpsc::channel(128);
122
123 info!("establishing new local app connection");
124
125 let (tx2, rx2) = mpsc::channel(128);
127
128 let connection = Connection::new(ConnectionType::Local).with_channel(Channel::Server(tx2));
130
131 let conn_id = self.forwarder().on_connection_established(connection);
133
134 debug!("local connection established with id: {:?}", conn_id);
135
136 self.process_stream(ReceiverStream::new(rx1), conn_id, true);
138
139 (tx1, rx2)
141 }
142
143 pub async fn send_msg(
144 &self,
145 msg: Message,
146 out_conn: u64,
147 ) -> Result<(), Box<dyn std::error::Error>> {
148 let connection = self.forwarder().get_connection(out_conn);
149 match connection {
150 Some(conn) => match conn.channel() {
151 Channel::Server(s) => s.send(Ok(msg)).await?,
152 Channel::Client(s) => s.send(msg).await?,
153 _ => error!("error reading channel"),
154 },
155 None => error!("connection {:?} not found", out_conn),
156 }
157 Ok(())
158 }
159
160 async fn match_and_forward_msg(
161 &self,
162 msg: Message,
163 class: AgentClass,
164 in_connection: u64,
165 fanout: u32,
166 agent_id: Option<u64>,
167 ) -> Result<(), DataPathError> {
168 debug!(
169 "match and forward message: class: {:?} - agent_id: {:?} - fanout: {:?}",
170 class, agent_id, fanout,
171 );
172
173 if fanout == 1 {
174 match self
175 .forwarder()
176 .on_publish_msg_match_one(class, agent_id, in_connection)
177 {
178 Ok(out) => match self.send_msg(msg, out).await {
179 Ok(_) => Ok(()),
180 Err(e) => {
181 error!("error sending a message {:?}", e);
182 Err(DataPathError::PublicationError(e.to_string()))
183 }
184 },
185 Err(e) => {
186 error!("error matching a message {:?}", e);
187 Err(DataPathError::PublicationError(e.to_string()))
188 }
189 }
190 } else {
191 match self
192 .forwarder()
193 .on_publish_msg_match_all(class, agent_id, in_connection)
194 {
195 Ok(out_set) => {
196 for out in out_set {
197 match self.send_msg(msg.clone(), out).await {
198 Ok(_) => {}
199 Err(e) => {
200 error!("error sending a message {:?}", e);
201 return Err(DataPathError::PublicationError(e.to_string()));
202 }
203 }
204 }
205 Ok(())
206 }
207 Err(e) => {
208 error!("error sending a message {:?}", e);
209 Err(DataPathError::PublicationError(e.to_string()))
210 }
211 }
212 }
213 }
214
215 async fn process_publish(
216 &self,
217 mut msg: Message,
218 in_connection: u64,
219 ) -> Result<(), DataPathError> {
220 let pubmsg = match &msg.message_type {
221 Some(PublishType(p)) => p,
222 _ => panic!("wrong message type"),
224 };
225
226 match process_name(&pubmsg.name) {
227 Ok(class) => {
228 let fanout = get_fanout(pubmsg);
229 let agent_id = get_agent_id(&pubmsg.name);
230
231 debug!(
232 "received publication from connection {}: {:?}",
233 in_connection, pubmsg
234 );
235
236 add_incoming_connection(&mut msg, in_connection);
238
239 return self
241 .match_and_forward_msg(msg, class, in_connection, fanout, agent_id)
242 .await;
243 }
244 Err(e) => {
245 error!("error processing publication message {:?}", e);
246 Err(DataPathError::PublicationError(e.to_string()))
247 }
248 }
249 }
250
251 fn process_command(&self, msg: &Message) -> Result<(CommandType, u64), DataPathError> {
252 if !msg.metadata.is_empty() {
253 match msg.metadata.get(&CommandType::ReceivedFrom.to_string()) {
254 None => {}
255 Some(out_str) => match out_str.parse::<u64>() {
256 Err(e) => {
257 error! {"error parsing the connection in command type ReceivedFrom: {:?}", e};
258 return Err(DataPathError::CommandError(e.to_string()));
259 }
260 Ok(out) => {
261 debug!(%out, "received subscription_from command, register subscription");
262 return Ok((CommandType::ReceivedFrom, out));
263 }
264 },
265 }
266 match msg.metadata.get(&CommandType::ForwardTo.to_string()) {
267 None => {}
268 Some(out_str) => match out_str.parse::<u64>() {
269 Err(e) => {
270 error! {"error parsing the connection in command type ForwardTo: {:?}", e};
271 return Err(DataPathError::CommandError(e.to_string()));
272 }
273 Ok(out) => {
274 debug!(%out, "received forward_to command, register subscription and forward");
275 return Ok((CommandType::ForwardTo, out));
276 }
277 },
278 }
279 }
280 Ok((CommandType::Unknown, 0))
281 }
282
283 async fn process_unsubscription(
284 &self,
285 mut msg: Message,
286 in_connection: u64,
287 ) -> Result<(), DataPathError> {
288 let unsubmsg = match &msg.message_type {
289 Some(UnsubscribeType(s)) => s,
290 _ => panic!("wrong message type"),
292 };
293
294 match process_name(&unsubmsg.name) {
295 Ok(class) => {
296 let command = self.process_command(&msg);
298 let mut conn = in_connection;
299 let mut forward = false;
300 let mut out_conn = in_connection;
302 match command {
303 Err(e) => {
304 return Err(e);
305 }
306 Ok(tuple) => match tuple.0 {
307 CommandType::ReceivedFrom => {
308 conn = tuple.1;
309 }
310 CommandType::ForwardTo => {
311 forward = true;
312 out_conn = tuple.1;
313 }
314 _ => {}
315 },
316 }
317 let connection = self.forwarder().get_connection(in_connection);
318 if connection.is_none() {
319 error!("incoming connection does not exists");
321 return Err(DataPathError::SubscriptionError(
322 "incoming connection does not exists".to_string(),
323 ));
324 }
325 match self.forwarder().on_unsubscription_msg(
326 class,
327 get_agent_id(&unsubmsg.name),
328 conn,
329 connection.unwrap().is_local_connection(),
330 ) {
331 Ok(_) => {}
332 Err(e) => {
333 return Err(DataPathError::UnsubscriptionError(e.to_string()));
334 }
335 }
336 if forward {
337 debug!("forward subscription to {:?}", out_conn);
338 msg.metadata.clear();
339 match self.send_msg(msg, out_conn).await {
340 Ok(_) => {}
341 Err(e) => {
342 error!("error sending a message {:?}", e);
343 return Err(DataPathError::SubscriptionError(e.to_string()));
344 }
345 };
346 }
347 Ok(())
348 }
349 Err(e) => {
350 error!("error processing unsubscription message {:?}", e);
351 Err(DataPathError::UnsubscriptionError(e.to_string()))
352 }
353 }
354 }
355
356 async fn process_subscription(
357 &self,
358 mut msg: Message,
359 in_connection: u64,
360 ) -> Result<(), DataPathError> {
361 let submsg = match &msg.message_type {
362 Some(SubscribeType(s)) => s,
363 _ => panic!("wrong message type"),
365 };
366
367 debug!(
368 "received subscription from connection {}: {:?}",
369 in_connection, submsg
370 );
371
372 match process_name(&submsg.name) {
373 Ok(class) => {
374 trace!("process command");
376 let command = self.process_command(&msg);
377 let mut conn = in_connection;
378 let mut forward = false;
379
380 let mut out_conn = in_connection;
382 match command {
383 Err(e) => {
384 return Err(e);
385 }
386 Ok(tuple) => match tuple.0 {
387 CommandType::ReceivedFrom => {
388 conn = tuple.1;
389 trace!("received subscription_from command, register subscription with conn id {:?}", tuple.1);
390 }
391 CommandType::ForwardTo => {
392 forward = true;
393 out_conn = tuple.1;
394 trace!("received forward_to command, register subscription and forward to conn id {:?}", out_conn);
395 }
396 _ => {}
397 },
398 }
399 let connection = self.forwarder().get_connection(in_connection);
400 if connection.is_none() {
401 error!("incoming connection does not exists");
403 return Err(DataPathError::SubscriptionError(
404 "incoming connection does not exists".to_string(),
405 ));
406 }
407 match self.forwarder().on_subscription_msg(
408 class,
409 get_agent_id(&submsg.name),
410 conn,
411 connection.unwrap().is_local_connection(),
412 ) {
413 Ok(_) => {}
414 Err(e) => {
415 return Err(DataPathError::SubscriptionError(e.to_string()));
416 }
417 }
418
419 if forward {
420 debug!("forward subscription {:?} to {:?}", msg, out_conn);
421 msg.metadata.clear();
422 match self.send_msg(msg, out_conn).await {
423 Ok(_) => {}
424 Err(e) => {
425 error!("error sending a message {:?}", e);
426 return Err(DataPathError::SubscriptionError(e.to_string()));
427 }
428 };
429 }
430 Ok(())
431 }
432 Err(e) => {
433 error!("error processing subscription message {:?}", e);
434 Err(DataPathError::SubscriptionError(e.to_string()))
435 }
436 }
437 }
438
439 pub async fn process_message(
440 &self,
441 msg: Message,
442 in_connection: u64,
443 ) -> Result<(), DataPathError> {
444 match &msg.message_type {
445 None => {
446 error!(
447 "received message without message type from connection {}: {:?}",
448 in_connection, msg
449 );
450 Err(DataPathError::UnknownMsgType("".to_string()))
451 }
452 Some(msg_type) => match msg_type {
453 SubscribeType(s) => {
454 debug!(
455 "received subscription from connection {}: {:?}",
456 in_connection, s
457 );
458 match self.process_subscription(msg, in_connection).await {
459 Err(e) => {
460 error! {"error processing subscription {:?}", e}
461 Err(e)
462 }
463 Ok(_) => Ok(()),
464 }
465 }
466 UnsubscribeType(u) => {
467 debug!(
468 "Received ubsubscription from client {}: {:?}",
469 in_connection, u
470 );
471 match self.process_unsubscription(msg, in_connection).await {
472 Err(e) => {
473 error! {"error processing unsubscription {:?}", e}
474 Err(e)
475 }
476 Ok(_) => Ok(()),
477 }
478 }
479 PublishType(p) => {
480 debug!("Received publish from client {}: {:?}", in_connection, p);
481 match self.process_publish(msg, in_connection).await {
482 Err(e) => {
483 error! {"error processing publication {:?}", e}
484 Err(e)
485 }
486 Ok(_) => Ok(()),
487 }
488 }
489 },
490 }
491 }
492
493 async fn handle_new_message(
494 &self,
495 conn_index: u64,
496 result: Result<Message, Status>,
497 ) -> Result<(), DataPathError> {
498 debug!(%conn_index, "Received message from connection");
499
500 match result {
501 Ok(msg) => {
502 match self.process_message(msg, conn_index).await {
503 Ok(_) => Ok(()),
504 Err(e) => {
505 error!(
507 "error processing message from connection {:?}: {:?}",
508 conn_index, e
509 );
510 Ok(())
511 }
512 }
513 }
514 Err(e) => {
515 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
516 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
517 info!("Connection {:?} closed by peer", conn_index);
518 return Err(DataPathError::StreamError(e.to_string()));
519 }
520 }
521 error!("error receiving messages {:?}", e);
522 let connection = self.forwarder().get_connection(conn_index);
523 match connection {
524 Some(conn) => {
525 match conn.channel() {
526 Channel::Server(tx) => tx
527 .send(Err(e))
528 .await
529 .map_err(|e| DataPathError::MessageSendError(e.to_string())),
530 _ => Err(DataPathError::WrongChannelType), }
532 }
533 None => {
534 error!("connection {:?} not found", conn_index);
535 Err(DataPathError::ConnectionNotFound(conn_index.to_string()))
536 }
537 }
538 }
539 }
540 }
541
542 fn process_stream(
543 &self,
544 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
545 conn_index: u64,
546 is_local: bool,
547 ) -> (tokio::task::JoinHandle<()>, CancellationToken) {
548 let self_clone = self.clone();
550 let token = CancellationToken::new();
551 let token_clone = token.clone();
552 let handle = tokio::spawn(async move {
553 loop {
554 tokio::select! {
555 res = stream.next() => {
556 match res {
557 Some(msg) => {
558 if let Err(e) = self_clone.handle_new_message(conn_index, msg).await {
559 error!("error handling stream {:?}", e);
560 break;
561 }
562 }
563 None => {
564 info!(%conn_index, "end of stream");
565 break;
566 }
567 }
568 }
569 _ = self_clone.get_drain_watch().signaled() => {
570 info!("shutting down stream on drain: {}", conn_index);
571 break;
572 }
573 _ = token_clone.cancelled() => {
574 info!("shutting down stream cancellation token: {}", conn_index);
575 break;
576 }
577 }
578 }
579
580 self_clone
581 .forwarder()
582 .on_connection_drop(conn_index, is_local);
583 });
584
585 (handle, token)
586 }
587
588 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
589 let mut err: &(dyn std::error::Error + 'static) = err_status;
590
591 loop {
592 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
593 return Some(io_err);
594 }
595
596 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
599 if let Some(io_err) = h2_err.get_io() {
600 return Some(io_err);
601 }
602 }
603
604 err = err.source()?;
605 }
606 }
607}
608
609#[tonic::async_trait]
610impl PubSubService for MessageProcessor {
611 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
612
613 async fn open_channel(
614 &self,
615 request: Request<tonic::Streaming<Message>>,
616 ) -> Result<Response<Self::OpenChannelStream>, Status> {
617 let remote_addr = request.remote_addr();
618 let local_addr = request.local_addr();
619
620 let stream = request.into_inner();
621 let (tx, rx) = mpsc::channel(128);
622
623 let connection = Connection::new(ConnectionType::Remote)
624 .with_remote_addr(remote_addr)
625 .with_local_addr(local_addr)
626 .with_channel(Channel::Server(tx));
627
628 info!(
629 "new connection received from remote: (remote: {:?} - local: {:?})",
630 connection.remote_addr(),
631 connection.local_addr()
632 );
633
634 let conn_index = self.forwarder().on_connection_established(connection);
636
637 self.process_stream(stream, conn_index, false);
638
639 let out_stream = ReceiverStream::new(rx);
640 Ok(Response::new(
641 Box::pin(out_stream) as Self::OpenChannelStream
642 ))
643 }
644}