slim_controller/
service.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use slim_config::component::id::ID;
9use slim_config::grpc::server::ServerConfig;
10use tokio::sync::mpsc;
11use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
12use tokio_util::sync::CancellationToken;
13use tonic::{Request, Response, Status};
14use tracing::{debug, error, info};
15
16use crate::api::proto::api::v1::control_message::Payload;
17use crate::api::proto::api::v1::controller_service_server::ControllerServiceServer;
18use crate::api::proto::api::v1::{
19    self, ConnectionListResponse, ConnectionType, SubscriptionListResponse,
20};
21use crate::api::proto::api::v1::{
22    Ack, ConnectionEntry, ControlMessage, SubscriptionEntry,
23    controller_service_client::ControllerServiceClient,
24    controller_service_server::ControllerService as GrpcControllerService,
25};
26use crate::errors::ControllerError;
27use slim_config::grpc::client::ClientConfig;
28use slim_datapath::api::ProtoMessage as PubsubMessage;
29use slim_datapath::message_processing::MessageProcessor;
30use slim_datapath::messages::Name;
31use slim_datapath::messages::utils::SlimHeaderFlags;
32use slim_datapath::tables::SubscriptionTable;
33
34type TxChannel = mpsc::Sender<Result<ControlMessage, Status>>;
35type TxChannels = HashMap<String, TxChannel>;
36
37/// Inner structure for the controller service
38/// This structure holds the internal state of the controller service,
39/// including the ID, message processor, connections, and channels.
40/// It is normally wrapped in an Arc to allow shared ownership across multiple threads.
41#[derive(Debug)]
42struct ControllerServiceInternal {
43    /// ID of this SLIM instance
44    id: ID,
45
46    /// underlying message processor
47    message_processor: Arc<MessageProcessor>,
48
49    /// map of connection IDs to their configuration
50    connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
51
52    /// channel to send messages into the datapath
53    tx_slim: mpsc::Sender<Result<PubsubMessage, Status>>,
54
55    /// channels to send control messages
56    tx_channels: parking_lot::RwLock<TxChannels>,
57
58    /// cancellation token for graceful shutdown
59    cancellation_tokens: parking_lot::RwLock<HashMap<String, CancellationToken>>,
60
61    /// drain watch channel
62    drain_rx: drain::Watch,
63}
64
65#[derive(Debug, Clone)]
66struct ControllerService {
67    /// internal service state
68    inner: Arc<ControllerServiceInternal>,
69}
70
71/// The ControlPlane service is the main entry point for the controller service.
72#[derive(Debug)]
73pub struct ControlPlane {
74    /// servers
75    servers: Vec<ServerConfig>,
76
77    /// clients
78    clients: Vec<ClientConfig>,
79
80    /// controller
81    controller: ControllerService,
82
83    /// channel to receive message from the datapath
84    /// to be used in listen_from_data_plan
85    rx_slim_option: Option<mpsc::Receiver<Result<PubsubMessage, Status>>>,
86}
87
88/// ControllerServiceInternal implements Drop trait to cancel all running listeners and
89/// clean up resources.
90impl Drop for ControlPlane {
91    fn drop(&mut self) {
92        // cancel all running listeners
93        for (_endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
94            token.cancel();
95        }
96    }
97}
98
99/// ControlPlane implements the service trait for the controller service.
100impl ControlPlane {
101    /// Create a new ControlPlane service instance
102    /// This function initializes the ControlPlane with the given ID, servers, clients, and message processor.
103    /// It also sets up the internal state, including the connections and channels.
104    /// # Arguments
105    /// * `id` - The ID of the SLIM instance.
106    /// * `servers` - A vector of server configurations.
107    /// * `clients` - A vector of client configurations.
108    /// * `drain_rx` - A drain watch channel for graceful shutdown.
109    /// * `message_processor` - An Arc to the message processor instance.
110    /// # Returns
111    /// A new instance of ControlPlane.
112    pub fn new(
113        id: ID,
114        servers: Vec<ServerConfig>,
115        clients: Vec<ClientConfig>,
116        drain_rx: drain::Watch,
117        message_processor: Arc<MessageProcessor>,
118    ) -> Self {
119        let (_, tx_slim, rx_slim) = message_processor.register_local_connection(true);
120
121        ControlPlane {
122            servers,
123            clients,
124            controller: ControllerService {
125                inner: Arc::new(ControllerServiceInternal {
126                    id,
127                    message_processor,
128                    connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
129                    tx_slim,
130                    tx_channels: parking_lot::RwLock::new(HashMap::new()),
131                    cancellation_tokens: parking_lot::RwLock::new(HashMap::new()),
132                    drain_rx,
133                }),
134            },
135            rx_slim_option: Some(rx_slim),
136        }
137    }
138
139    /// Take an existing ControlPlane instance and return a new one with the provided clients.
140    pub fn with_clients(mut self, clients: Vec<ClientConfig>) -> Self {
141        self.clients = clients;
142        self
143    }
144
145    /// Take an existing ControlPlane instance and return a new one with the provided servers.
146    pub fn with_servers(mut self, servers: Vec<ServerConfig>) -> Self {
147        self.servers = servers;
148        self
149    }
150
151    /// Run the clients and servers of the ControlPlane service.
152    /// This function starts all the servers and clients defined in the ControlPlane.
153    /// # Returns
154    /// A Result indicating success or failure of the operation.
155    /// # Errors
156    /// If there is an error starting any of the servers or clients, it will return a ControllerError.
157    pub async fn run(&mut self) -> Result<(), ControllerError> {
158        info!("starting controller service");
159
160        // Collect servers to avoid borrowing self both mutably and immutably
161        let servers = self.servers.clone();
162        let clients = self.clients.clone();
163
164        // run all servers
165        for server in servers {
166            self.run_server(server)?;
167        }
168
169        // run all clients
170        for client in clients {
171            self.run_client(client).await?;
172        }
173
174        let rx = self.rx_slim_option.take();
175        self.listen_from_data_plane(rx.unwrap()).await;
176
177        Ok(())
178    }
179
180    async fn listen_from_data_plane(
181        &mut self,
182        mut rx: mpsc::Receiver<Result<PubsubMessage, Status>>,
183    ) {
184        let cancellation_token = CancellationToken::new();
185        let cancellation_token_clone = cancellation_token.clone();
186        let drain = self.controller.inner.drain_rx.clone();
187
188        self.controller
189            .inner
190            .cancellation_tokens
191            .write()
192            .insert("DATA_PLANE".to_string(), cancellation_token_clone);
193
194        let clients = self.clients.clone();
195        let inner = self.controller.inner.clone();
196
197        tokio::spawn(async move {
198            loop {
199                tokio::select! {
200                    next = rx.recv() => {
201                        match next {
202                            Some(res) => {
203                                match res {
204                                    Ok(msg) => {
205                                        debug!("Send sub/unsub to control plane for message: {:?}", msg);
206
207                                        let mut sub_vec = vec![];
208                                        let mut unsub_vec = vec![];
209
210                                        let dst = msg.get_dst();
211                                        let components = dst.components_strings().unwrap();
212                                        let cmd = v1::Subscription {
213                                                    component_0: components[0].to_string(),
214                                                    component_1: components[1].to_string(),
215                                                    component_2: components[2].to_string(),
216                                                    id: Some(dst.id()),
217                                                    connection_id: "n/a".to_string(),
218                                        };
219                                        match msg.get_type() {
220                                            slim_datapath::api::MessageType::Subscribe(_) => {
221                                                sub_vec.push(cmd);
222                                            },
223                                            slim_datapath::api::MessageType::Unsubscribe(_) => {
224                                                unsub_vec.push(cmd);
225                                            }
226                                            slim_datapath::api::MessageType::Publish(_) => {
227                                                // drop publication messages
228                                                continue;
229                                            },
230                                        }
231
232                                        let ctrl = ControlMessage {
233                                            message_id: uuid::Uuid::new_v4().to_string(),
234                                            payload: Some(Payload::ConfigCommand(
235                                                v1::ConfigurationCommand {
236                                                    connections_to_create: vec![],
237                                                    subscriptions_to_set: sub_vec,
238                                                    subscriptions_to_delete: unsub_vec
239                                                })),
240                                        };
241
242                                        for c in &clients {
243                                            let tx = match inner.tx_channels.read().get(&c.endpoint) {
244                                                Some(tx) => tx.clone(),
245                                                None => continue,
246                                            };
247                                            if (tx.send(Ok(ctrl.clone())).await).is_err() {
248                                                error!("error while notifiyng the control plane");
249                                            };
250
251                                        }
252                                    }
253                                    Err(e) => {
254                                        error!("received error from the data plane {}", e.to_string());
255                                        continue;
256                                    }
257                                }
258                            }
259                            None => {
260                                debug!("Data plane receiver channel closed.");
261                                break;
262                            }
263                        }
264                    }
265                    _ = cancellation_token.cancelled() => {
266                        debug!("shutting down stream on cancellation token");
267                        break;
268                    }
269                    _ = drain.clone().signaled() => {
270                        debug!("shutting down stream on drain");
271                        break;
272                    }
273                }
274            }
275        });
276    }
277
278    /// Stop the ControlPlane service.
279    /// This function stops all running listeners and cancels any ongoing operations.
280    /// It cleans up the internal state and ensures that all resources are released properly.
281    pub fn stop(&mut self) {
282        info!("stopping controller service");
283
284        // cancel all running listeners
285        for (endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
286            info!(%endpoint, "stopping");
287            token.cancel();
288        }
289    }
290
291    /// Run a client configuration.
292    /// This function connects to the control plane using the provided client configuration.
293    /// It checks if the client is already running and if not, it starts a new connection.
294    async fn run_client(&mut self, client: ClientConfig) -> Result<(), ControllerError> {
295        if self
296            .controller
297            .inner
298            .cancellation_tokens
299            .read()
300            .contains_key(&client.endpoint)
301        {
302            return Err(ControllerError::ConfigError(format!(
303                "client {} is already running",
304                client.endpoint
305            )));
306        }
307
308        let cancellation_token = CancellationToken::new();
309
310        let tx = self
311            .controller
312            .connect(client.clone(), cancellation_token.clone())
313            .await?;
314
315        // Store the cancellation token in the controller service
316        self.controller
317            .inner
318            .cancellation_tokens
319            .write()
320            .insert(client.endpoint.clone(), cancellation_token);
321
322        // Store the sender in the tx_channels map
323        self.controller
324            .inner
325            .tx_channels
326            .write()
327            .insert(client.endpoint.clone(), tx);
328
329        // return the sender for control messages
330        Ok(())
331    }
332
333    /// Run a server configuration.
334    /// This function starts a server using the provided server configuration.
335    /// It checks if the server is already running and if not, it starts a new server.
336    pub fn run_server(&mut self, config: ServerConfig) -> Result<(), ControllerError> {
337        info!(%config.endpoint, "starting control plane server");
338
339        // Check if the server is already running
340        if self
341            .controller
342            .inner
343            .cancellation_tokens
344            .read()
345            .contains_key(&config.endpoint)
346        {
347            error!("server {} is already running", config.endpoint);
348            return Err(ControllerError::ConfigError(format!(
349                "server {} is already running",
350                config.endpoint
351            )));
352        }
353
354        let token = config
355            .run_server(
356                &[ControllerServiceServer::new(self.controller.clone())],
357                self.controller.inner.drain_rx.clone(),
358            )
359            .map_err(|e| {
360                error!("failed to run server {}: {}", config.endpoint, e);
361                ControllerError::ConfigError(e.to_string())
362            })?;
363
364        // Store the cancellation token in the controller service
365        self.controller
366            .inner
367            .cancellation_tokens
368            .write()
369            .insert(config.endpoint.clone(), token.clone());
370
371        info!(%config.endpoint, "control plane server started");
372
373        Ok(())
374    }
375}
376
377impl ControllerService {
378    const MAX_RETRIES: i32 = 10;
379
380    /// Handle new control messages.
381    async fn handle_new_control_message(
382        &self,
383        msg: ControlMessage,
384        tx: &mpsc::Sender<Result<ControlMessage, Status>>,
385    ) -> Result<(), ControllerError> {
386        match msg.payload {
387            Some(ref payload) => {
388                match payload {
389                    Payload::ConfigCommand(config) => {
390                        for conn in &config.connections_to_create {
391                            info!("received a connection to create: {:?}", conn);
392                            let client_config =
393                                serde_json::from_str::<ClientConfig>(&conn.config_data)
394                                    .map_err(|e| ControllerError::ConfigError(e.to_string()))?;
395                            let client_endpoint = &client_config.endpoint;
396
397                            // connect to an endpoint if it's not already connected
398                            if !self.inner.connections.read().contains_key(client_endpoint) {
399                                match client_config.to_channel() {
400                                    Err(e) => {
401                                        error!("error reading channel config {:?}", e);
402                                    }
403                                    Ok(channel) => {
404                                        let ret = self
405                                            .inner
406                                            .message_processor
407                                            .connect(
408                                                channel,
409                                                Some(client_config.clone()),
410                                                None,
411                                                None,
412                                            )
413                                            .await
414                                            .map_err(|e| {
415                                                ControllerError::ConnectionError(e.to_string())
416                                            });
417
418                                        let conn_id = match ret {
419                                            Err(e) => {
420                                                error!("connection error: {:?}", e);
421                                                return Err(ControllerError::ConnectionError(
422                                                    e.to_string(),
423                                                ));
424                                            }
425                                            Ok(conn_id) => conn_id.1,
426                                        };
427
428                                        self.inner
429                                            .connections
430                                            .write()
431                                            .insert(client_endpoint.clone(), conn_id);
432                                    }
433                                }
434                            }
435                        }
436
437                        for subscription in &config.subscriptions_to_set {
438                            if !self
439                                .inner
440                                .connections
441                                .read()
442                                .contains_key(&subscription.connection_id)
443                            {
444                                error!("connection {} not found", subscription.connection_id);
445                                continue;
446                            }
447
448                            let conn = self
449                                .inner
450                                .connections
451                                .read()
452                                .get(&subscription.connection_id)
453                                .cloned()
454                                .unwrap();
455                            let source = Name::from_strings([
456                                subscription.component_0.as_str(),
457                                subscription.component_1.as_str(),
458                                subscription.component_2.as_str(),
459                            ])
460                            .with_id(0);
461                            let name = Name::from_strings([
462                                subscription.component_0.as_str(),
463                                subscription.component_1.as_str(),
464                                subscription.component_2.as_str(),
465                            ])
466                            .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
467
468                            let msg = PubsubMessage::new_subscribe(
469                                &source,
470                                &name,
471                                Some(SlimHeaderFlags::default().with_recv_from(conn)),
472                            );
473
474                            if let Err(e) = self.send_control_message(msg).await {
475                                error!("failed to subscribe: {}", e);
476                            }
477                        }
478
479                        for subscription in &config.subscriptions_to_delete {
480                            if !self
481                                .inner
482                                .connections
483                                .read()
484                                .contains_key(&subscription.connection_id)
485                            {
486                                error!("connection {} not found", subscription.connection_id);
487                                continue;
488                            }
489
490                            let conn = self
491                                .inner
492                                .connections
493                                .read()
494                                .get(&subscription.connection_id)
495                                .cloned()
496                                .unwrap();
497                            let source = Name::from_strings([
498                                subscription.component_0.as_str(),
499                                subscription.component_1.as_str(),
500                                subscription.component_2.as_str(),
501                            ])
502                            .with_id(0);
503                            let name = Name::from_strings([
504                                subscription.component_0.as_str(),
505                                subscription.component_1.as_str(),
506                                subscription.component_2.as_str(),
507                            ])
508                            .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
509
510                            let msg = PubsubMessage::new_unsubscribe(
511                                &source,
512                                &name,
513                                Some(SlimHeaderFlags::default().with_recv_from(conn)),
514                            );
515
516                            if let Err(e) = self.send_control_message(msg).await {
517                                error!("failed to unsubscribe: {}", e);
518                            }
519                        }
520
521                        let ack = Ack {
522                            original_message_id: msg.message_id.clone(),
523                            success: true,
524                            messages: vec![],
525                        };
526
527                        let reply = ControlMessage {
528                            message_id: uuid::Uuid::new_v4().to_string(),
529                            payload: Some(Payload::Ack(ack)),
530                        };
531
532                        if let Err(e) = tx.send(Ok(reply)).await {
533                            error!("failed to send ACK: {}", e);
534                        }
535                    }
536                    Payload::SubscriptionListRequest(_) => {
537                        const CHUNK_SIZE: usize = 100;
538
539                        let conn_table = self.inner.message_processor.connection_table();
540                        let mut entries = Vec::new();
541
542                        self.inner.message_processor.subscription_table().for_each(
543                            |name, id, local, remote| {
544                                let mut entry = SubscriptionEntry {
545                                    component_0: name.components_strings().unwrap()[0].to_string(),
546                                    component_1: name.components_strings().unwrap()[1].to_string(),
547                                    component_2: name.components_strings().unwrap()[2].to_string(),
548                                    id: Some(id),
549                                    ..Default::default()
550                                };
551
552                                for &cid in local {
553                                    entry.local_connections.push(ConnectionEntry {
554                                        id: cid,
555                                        connection_type: ConnectionType::Local as i32,
556                                        config_data: "{}".to_string(),
557                                    });
558                                }
559
560                                for &cid in remote {
561                                    if let Some(conn) = conn_table.get(cid as usize) {
562                                        entry.remote_connections.push(ConnectionEntry {
563                                            id: cid,
564                                            connection_type: ConnectionType::Remote as i32,
565                                            config_data: match conn.config_data() {
566                                                Some(data) => serde_json::to_string(data)
567                                                    .unwrap_or_else(|_| "{}".to_string()),
568                                                None => "{}".to_string(),
569                                            },
570                                        });
571                                    } else {
572                                        error!("no connection entry for id {}", cid);
573                                    }
574                                }
575                                entries.push(entry);
576                            },
577                        );
578
579                        for chunk in entries.chunks(CHUNK_SIZE) {
580                            let resp = ControlMessage {
581                                message_id: uuid::Uuid::new_v4().to_string(),
582                                payload: Some(Payload::SubscriptionListResponse(
583                                    SubscriptionListResponse {
584                                        entries: chunk.to_vec(),
585                                    },
586                                )),
587                            };
588
589                            if let Err(e) = tx.try_send(Ok(resp)) {
590                                error!("failed to send subscription batch: {}", e);
591                            }
592                        }
593                    }
594                    Payload::ConnectionListRequest(_) => {
595                        let mut all_entries = Vec::new();
596                        self.inner
597                            .message_processor
598                            .connection_table()
599                            .for_each(|id, conn| {
600                                all_entries.push(ConnectionEntry {
601                                    id: id as u64,
602                                    connection_type: ConnectionType::Remote as i32,
603                                    config_data: match conn.config_data() {
604                                        Some(data) => serde_json::to_string(data)
605                                            .unwrap_or_else(|_| "{}".to_string()),
606                                        None => "{}".to_string(),
607                                    },
608                                });
609                            });
610
611                        const CHUNK_SIZE: usize = 100;
612                        for chunk in all_entries.chunks(CHUNK_SIZE) {
613                            let resp = ControlMessage {
614                                message_id: uuid::Uuid::new_v4().to_string(),
615                                payload: Some(Payload::ConnectionListResponse(
616                                    ConnectionListResponse {
617                                        entries: chunk.to_vec(),
618                                    },
619                                )),
620                            };
621
622                            if let Err(e) = tx.try_send(Ok(resp)) {
623                                error!("failed to send connection list batch: {}", e);
624                            }
625                        }
626                    }
627                    Payload::Ack(_ack) => {
628                        // received an ack, do nothing - this should not happen
629                    }
630                    Payload::SubscriptionListResponse(_) => {
631                        // received a subscription list response, do nothing - this should not happen
632                    }
633                    Payload::ConnectionListResponse(_) => {
634                        // received a connection list response, do nothing - this should not happen
635                    }
636                    Payload::RegisterNodeRequest(_) => {
637                        error!("received a register node request, this should not happen");
638                    }
639                    Payload::RegisterNodeResponse(_) => {
640                        // received a register node response, do nothing
641                    }
642                    Payload::DeregisterNodeRequest(_) => {
643                        error!("received a deregister node request, this should not happen");
644                    }
645                    Payload::DeregisterNodeResponse(_) => {
646                        // received a deregister node response, do nothing
647                    }
648                    Payload::CreateChannelRequest(_) => {}
649                    Payload::CreateChannelResponse(_) => {}
650                    Payload::DeleteChannelRequest(_) => {}
651                    Payload::AddParticipantRequest(_) => {}
652                    Payload::DeleteParticipantRequest(_) => {}
653                    Payload::ListChannelRequest(_) => {}
654                    Payload::ListChannelResponse(_) => {}
655                    Payload::ListParticipantsRequest(_) => {}
656                    Payload::ListParticipantsResponse(_) => {}
657                }
658            }
659            None => {
660                error!(
661                    "received control message {} with no payload",
662                    msg.message_id
663                );
664            }
665        }
666
667        Ok(())
668    }
669
670    /// Send a control message to SLIM.
671    async fn send_control_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
672        self.inner.tx_slim.send(Ok(msg)).await.map_err(|e| {
673            error!("error sending message into datapath: {}", e);
674            ControllerError::DatapathError(e.to_string())
675        })
676    }
677
678    /// Process the control message stream.
679    fn process_control_message_stream(
680        &self,
681        config: Option<ClientConfig>,
682        mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
683        tx: mpsc::Sender<Result<ControlMessage, Status>>,
684        cancellation_token: CancellationToken,
685    ) -> tokio::task::JoinHandle<()> {
686        let this = self.clone();
687        let drain = this.inner.drain_rx.clone();
688        tokio::spawn(async move {
689            // Send a register message to the control plane
690            let endpoint = config
691                .as_ref()
692                .map(|c| c.endpoint.clone())
693                .unwrap_or_else(|| "unknown".to_string());
694            info!(%endpoint, "connected to control plane");
695
696            let mut retry_connect = false;
697
698            let register_request = ControlMessage {
699                message_id: uuid::Uuid::new_v4().to_string(),
700                payload: Some(Payload::RegisterNodeRequest(v1::RegisterNodeRequest {
701                    node_id: this.inner.id.to_string(),
702                })),
703            };
704
705            // send register request if client
706            if config.is_some() {
707                if let Err(e) = tx.send(Ok(register_request)).await {
708                    error!("failed to send register request: {}", e);
709                    return;
710                }
711            }
712
713            // TODO; here we should wait for an ack
714
715            loop {
716                tokio::select! {
717                    next = stream.next() => {
718                        match next {
719                            Some(Ok(msg)) => {
720                                if let Err(e) = this.handle_new_control_message(msg, &tx).await {
721                                    error!("error processing incoming control message: {:?}", e);
722                                }
723                            }
724                            Some(Err(e)) => {
725                                if let Some(io_err) = Self::match_for_io_error(&e) {
726                                    if io_err.kind() == std::io::ErrorKind::BrokenPipe {
727                                        info!("connection closed by peer");
728                                        retry_connect = true;
729                                    }
730                                } else {
731                                    error!(%e, "error receiving control messages");
732                                }
733
734                                break;
735                            }
736                            None => {
737                                debug!("end of stream");
738                                retry_connect = true;
739                                break;
740                            }
741                        }
742                    }
743                    _ = cancellation_token.cancelled() => {
744                        debug!("shutting down stream on cancellation token");
745                        break;
746                    }
747                    _ = drain.clone().signaled() => {
748                        debug!("shutting down stream on drain");
749                        break;
750                    }
751                }
752            }
753
754            info!(%endpoint, "control plane stream closed");
755
756            if retry_connect {
757                if let Some(config) = config {
758                    info!(%config.endpoint, "retrying connection to control plane");
759                    this.connect(config.clone(), cancellation_token)
760                        .await
761                        .map_or_else(
762                            |e| {
763                                error!("failed to reconnect to control plane: {}", e);
764                            },
765                            |tx| {
766                                info!(%config.endpoint, "reconnected to control plane");
767
768                                this.inner
769                                    .tx_channels
770                                    .write()
771                                    .insert(config.endpoint.clone(), tx);
772                            },
773                        )
774                }
775            }
776        })
777    }
778
779    /// Connect to the control plane using the provided client configuration.
780    /// This function attempts to establish a connection to the control plane and returns a sender for control messages.
781    /// It retries the connection a specified number of times if it fails.
782    async fn connect(
783        &self,
784        config: ClientConfig,
785        cancellation_token: CancellationToken,
786    ) -> Result<mpsc::Sender<Result<ControlMessage, Status>>, ControllerError> {
787        info!(%config.endpoint, "connecting to control plane");
788
789        let channel = config.to_channel().map_err(|e| {
790            error!("error reading channel config: {}", e);
791            ControllerError::ConfigError(e.to_string())
792        })?;
793
794        let mut client = ControllerServiceClient::new(channel);
795        for i in 0..Self::MAX_RETRIES {
796            let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
797            let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
798            match client.open_control_channel(Request::new(out_stream)).await {
799                Ok(stream) => {
800                    // process the control message stream
801                    self.process_control_message_stream(
802                        Some(config),
803                        stream.into_inner(),
804                        tx.clone(),
805                        cancellation_token.clone(),
806                    );
807
808                    return Ok(tx);
809                }
810                Err(e) => {
811                    error!(%e, "connection error, retrying {}/{}", i + 1, Self::MAX_RETRIES);
812                }
813            };
814
815            // sleep 1 sec between each connection retry
816            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
817        }
818
819        Err(ControllerError::ConfigError(format!(
820            "failed to connect to control plane after {} retries",
821            Self::MAX_RETRIES
822        )))
823    }
824
825    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
826        let mut err: &(dyn std::error::Error + 'static) = err_status;
827
828        loop {
829            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
830                return Some(io_err);
831            }
832
833            // h2::Error do not expose std::io::Error with `source()`
834            // https://github.com/hyperium/h2/pull/462
835            if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
836                if let Some(io_err) = h2_err.get_io() {
837                    return Some(io_err);
838                }
839            }
840
841            err = err.source()?;
842        }
843    }
844}
845
846#[tonic::async_trait]
847impl GrpcControllerService for ControllerService {
848    type OpenControlChannelStream =
849        Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
850
851    async fn open_control_channel(
852        &self,
853        request: Request<tonic::Streaming<ControlMessage>>,
854    ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
855        // Get the remote endpoint from the request metadata
856        let remote_endpoint = request
857            .remote_addr()
858            .map(|addr| addr.to_string())
859            .unwrap_or_else(|| "unknown".to_string());
860
861        let stream = request.into_inner();
862        let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
863
864        let cancellation_token = CancellationToken::new();
865
866        self.process_control_message_stream(None, stream, tx.clone(), cancellation_token.clone());
867
868        // store the sender in the tx_channels map
869        self.inner
870            .tx_channels
871            .write()
872            .insert(remote_endpoint.clone(), tx);
873
874        // store the cancellation token in the controller service
875        self.inner
876            .cancellation_tokens
877            .write()
878            .insert(remote_endpoint.clone(), cancellation_token);
879
880        let out_stream = ReceiverStream::new(rx);
881        Ok(Response::new(
882            Box::pin(out_stream) as Self::OpenControlChannelStream
883        ))
884    }
885}
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890    use slim_config::component::id::Kind;
891    use tracing_test::traced_test;
892
893    #[tokio::test]
894    #[traced_test]
895    async fn test_end_to_end() {
896        // Create an ID for slim instance
897        let id_server =
898            ID::new_with_name(Kind::new("slim").unwrap(), "test-server-instance").unwrap();
899        let id_client =
900            ID::new_with_name(Kind::new("slim").unwrap(), "test-client-instance").unwrap();
901
902        // Create a server configuration
903        let server_config = ServerConfig::with_endpoint("127.0.0.1:50051")
904            .with_tls_settings(slim_config::tls::server::TlsServerConfig::insecure());
905
906        // create a client configuration
907        let client_config = ClientConfig::with_endpoint("http://127.0.0.1:50051")
908            .with_tls_setting(slim_config::tls::client::TlsClientConfig::insecure());
909
910        // create drain channels
911        let (signal_server, watch_server) = drain::channel();
912        let (signal_client, watch_client) = drain::channel();
913
914        // Create a message processor
915        let message_processor_client = MessageProcessor::with_drain_channel(watch_client.clone());
916        let message_processor_server = MessageProcessor::with_drain_channel(watch_server.clone());
917
918        // Create a control plane instance for server
919        let mut control_plane_server = ControlPlane::new(
920            id_server,
921            vec![server_config],
922            vec![],
923            watch_server,
924            Arc::new(message_processor_server),
925        );
926
927        let mut control_plane_client = ControlPlane::new(
928            id_client,
929            vec![],
930            vec![client_config],
931            watch_client,
932            Arc::new(message_processor_client),
933        );
934
935        // Start the server
936        control_plane_server.run().await.unwrap();
937
938        // Sleep for a short duration to ensure the server is ready
939        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
940
941        // Start the client
942        control_plane_client.run().await.unwrap();
943
944        // Sleep for a short duration to ensure the client is ready
945        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
946
947        // Check if the server received the connection
948        assert!(logs_contain(
949            "received a register node request, this should not happen"
950        ));
951
952        // drop the server and the client. This should also cancel the running listeners
953        // and close the connections gracefully.
954        drop(control_plane_server);
955        drop(control_plane_client);
956
957        // Make sure there is nothing left to drain (this should not block)
958        signal_server.drain().await;
959        signal_client.drain().await;
960    }
961}