1use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use slim_config::component::id::ID;
9use slim_config::grpc::server::ServerConfig;
10use tokio::sync::mpsc;
11use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
12use tokio_util::sync::CancellationToken;
13use tonic::{Request, Response, Status};
14use tracing::{debug, error, info};
15
16use crate::api::proto::api::v1::control_message::Payload;
17use crate::api::proto::api::v1::controller_service_server::ControllerServiceServer;
18use crate::api::proto::api::v1::{
19 self, ConnectionListResponse, ConnectionType, SubscriptionListResponse,
20};
21use crate::api::proto::api::v1::{
22 Ack, ConnectionEntry, ControlMessage, SubscriptionEntry,
23 controller_service_client::ControllerServiceClient,
24 controller_service_server::ControllerService as GrpcControllerService,
25};
26use crate::errors::ControllerError;
27use slim_config::grpc::client::ClientConfig;
28use slim_datapath::api::ProtoMessage as PubsubMessage;
29use slim_datapath::message_processing::MessageProcessor;
30use slim_datapath::messages::Name;
31use slim_datapath::messages::utils::SlimHeaderFlags;
32use slim_datapath::tables::SubscriptionTable;
33
34type TxChannel = mpsc::Sender<Result<ControlMessage, Status>>;
35type TxChannels = HashMap<String, TxChannel>;
36
37#[derive(Debug)]
42struct ControllerServiceInternal {
43 id: ID,
45
46 message_processor: Arc<MessageProcessor>,
48
49 connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
51
52 tx_slim: mpsc::Sender<Result<PubsubMessage, Status>>,
54
55 _rx_slim: mpsc::Receiver<Result<PubsubMessage, Status>>,
57
58 tx_channels: parking_lot::RwLock<TxChannels>,
60
61 cancellation_tokens: parking_lot::RwLock<HashMap<String, CancellationToken>>,
63
64 drain_rx: drain::Watch,
66}
67
68#[derive(Debug, Clone)]
69struct ControllerService {
70 inner: Arc<ControllerServiceInternal>,
72}
73
74#[derive(Debug)]
76pub struct ControlPlane {
77 servers: Vec<ServerConfig>,
79
80 clients: Vec<ClientConfig>,
82
83 controller: ControllerService,
85}
86
87impl Drop for ControlPlane {
90 fn drop(&mut self) {
91 for (_endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
93 token.cancel();
94 }
95 }
96}
97
98impl ControlPlane {
100 pub fn new(
112 id: ID,
113 servers: Vec<ServerConfig>,
114 clients: Vec<ClientConfig>,
115 drain_rx: drain::Watch,
116 message_processor: Arc<MessageProcessor>,
117 ) -> Self {
118 let (_, tx_slim, rx_slim) = message_processor.register_local_connection();
120
121 ControlPlane {
122 servers,
123 clients,
124 controller: ControllerService {
125 inner: Arc::new(ControllerServiceInternal {
126 id,
127 message_processor,
128 connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
129 tx_slim,
130 _rx_slim: rx_slim,
131 tx_channels: parking_lot::RwLock::new(HashMap::new()),
132 cancellation_tokens: parking_lot::RwLock::new(HashMap::new()),
133 drain_rx,
134 }),
135 },
136 }
137 }
138
139 pub fn with_clients(mut self, clients: Vec<ClientConfig>) -> Self {
141 self.clients = clients;
142 self
143 }
144
145 pub fn with_servers(mut self, servers: Vec<ServerConfig>) -> Self {
147 self.servers = servers;
148 self
149 }
150
151 pub async fn run(&mut self) -> Result<(), ControllerError> {
158 info!("starting controller service");
159
160 let servers = self.servers.clone();
162 let clients = self.clients.clone();
163
164 for server in servers {
166 self.run_server(server)?;
167 }
168
169 for client in clients {
171 self.run_client(client).await?;
172 }
173
174 Ok(())
175 }
176
177 pub fn stop(&mut self) {
181 info!("stopping controller service");
182
183 for (endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
185 info!(%endpoint, "stopping");
186 token.cancel();
187 }
188 }
189
190 async fn run_client(&mut self, client: ClientConfig) -> Result<(), ControllerError> {
194 if self
195 .controller
196 .inner
197 .cancellation_tokens
198 .read()
199 .contains_key(&client.endpoint)
200 {
201 return Err(ControllerError::ConfigError(format!(
202 "client {} is already running",
203 client.endpoint
204 )));
205 }
206
207 let cancellation_token = CancellationToken::new();
208
209 let tx = self
210 .controller
211 .connect(client.clone(), cancellation_token.clone())
212 .await?;
213
214 self.controller
216 .inner
217 .cancellation_tokens
218 .write()
219 .insert(client.endpoint.clone(), cancellation_token);
220
221 self.controller
223 .inner
224 .tx_channels
225 .write()
226 .insert(client.endpoint.clone(), tx);
227
228 Ok(())
230 }
231
232 pub fn run_server(&mut self, config: ServerConfig) -> Result<(), ControllerError> {
236 info!(%config.endpoint, "starting control plane server");
237
238 if self
240 .controller
241 .inner
242 .cancellation_tokens
243 .read()
244 .contains_key(&config.endpoint)
245 {
246 error!("server {} is already running", config.endpoint);
247 return Err(ControllerError::ConfigError(format!(
248 "server {} is already running",
249 config.endpoint
250 )));
251 }
252
253 let token = config
254 .run_server(
255 &[ControllerServiceServer::new(self.controller.clone())],
256 self.controller.inner.drain_rx.clone(),
257 )
258 .map_err(|e| {
259 error!("failed to run server {}: {}", config.endpoint, e);
260 ControllerError::ConfigError(e.to_string())
261 })?;
262
263 self.controller
265 .inner
266 .cancellation_tokens
267 .write()
268 .insert(config.endpoint.clone(), token.clone());
269
270 info!(%config.endpoint, "control plane server started");
271
272 Ok(())
273 }
274}
275
276impl ControllerService {
277 const MAX_RETRIES: i32 = 10;
278
279 async fn handle_new_control_message(
281 &self,
282 msg: ControlMessage,
283 tx: &mpsc::Sender<Result<ControlMessage, Status>>,
284 ) -> Result<(), ControllerError> {
285 match msg.payload {
286 Some(ref payload) => {
287 match payload {
288 Payload::ConfigCommand(config) => {
289 for conn in &config.connections_to_create {
290 info!("received a connection to create: {:?}", conn);
291 let client_config =
292 serde_json::from_str::<ClientConfig>(&conn.config_data)
293 .map_err(|e| ControllerError::ConfigError(e.to_string()))?;
294 let client_endpoint = &client_config.endpoint;
295
296 if !self.inner.connections.read().contains_key(client_endpoint) {
298 match client_config.to_channel() {
299 Err(e) => {
300 error!("error reading channel config {:?}", e);
301 }
302 Ok(channel) => {
303 let ret = self
304 .inner
305 .message_processor
306 .connect(
307 channel,
308 Some(client_config.clone()),
309 None,
310 None,
311 )
312 .await
313 .map_err(|e| {
314 ControllerError::ConnectionError(e.to_string())
315 });
316
317 let conn_id = match ret {
318 Err(e) => {
319 error!("connection error: {:?}", e);
320 return Err(ControllerError::ConnectionError(
321 e.to_string(),
322 ));
323 }
324 Ok(conn_id) => conn_id.1,
325 };
326
327 self.inner
328 .connections
329 .write()
330 .insert(client_endpoint.clone(), conn_id);
331 }
332 }
333 }
334 }
335
336 for subscription in &config.subscriptions_to_set {
337 if !self
338 .inner
339 .connections
340 .read()
341 .contains_key(&subscription.connection_id)
342 {
343 error!("connection {} not found", subscription.connection_id);
344 continue;
345 }
346
347 let conn = self
348 .inner
349 .connections
350 .read()
351 .get(&subscription.connection_id)
352 .cloned()
353 .unwrap();
354 let source = Name::from_strings([
355 subscription.component_0.as_str(),
356 subscription.component_1.as_str(),
357 subscription.component_2.as_str(),
358 ])
359 .with_id(0);
360 let name = Name::from_strings([
361 subscription.component_0.as_str(),
362 subscription.component_1.as_str(),
363 subscription.component_2.as_str(),
364 ])
365 .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
366
367 let msg = PubsubMessage::new_subscribe(
368 &source,
369 &name,
370 Some(SlimHeaderFlags::default().with_recv_from(conn)),
371 );
372
373 if let Err(e) = self.send_control_message(msg).await {
374 error!("failed to subscribe: {}", e);
375 }
376 }
377
378 for subscription in &config.subscriptions_to_delete {
379 if !self
380 .inner
381 .connections
382 .read()
383 .contains_key(&subscription.connection_id)
384 {
385 error!("connection {} not found", subscription.connection_id);
386 continue;
387 }
388
389 let conn = self
390 .inner
391 .connections
392 .read()
393 .get(&subscription.connection_id)
394 .cloned()
395 .unwrap();
396 let source = Name::from_strings([
397 subscription.component_0.as_str(),
398 subscription.component_1.as_str(),
399 subscription.component_2.as_str(),
400 ])
401 .with_id(0);
402 let name = Name::from_strings([
403 subscription.component_0.as_str(),
404 subscription.component_1.as_str(),
405 subscription.component_2.as_str(),
406 ])
407 .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
408
409 let msg = PubsubMessage::new_unsubscribe(
410 &source,
411 &name,
412 Some(SlimHeaderFlags::default().with_recv_from(conn)),
413 );
414
415 if let Err(e) = self.send_control_message(msg).await {
416 error!("failed to unsubscribe: {}", e);
417 }
418 }
419
420 let ack = Ack {
421 original_message_id: msg.message_id.clone(),
422 success: true,
423 messages: vec![],
424 };
425
426 let reply = ControlMessage {
427 message_id: uuid::Uuid::new_v4().to_string(),
428 payload: Some(Payload::Ack(ack)),
429 };
430
431 if let Err(e) = tx.send(Ok(reply)).await {
432 error!("failed to send ACK: {}", e);
433 }
434 }
435 Payload::SubscriptionListRequest(_) => {
436 const CHUNK_SIZE: usize = 100;
437
438 let conn_table = self.inner.message_processor.connection_table();
439 let mut entries = Vec::new();
440
441 self.inner.message_processor.subscription_table().for_each(
442 |name, id, local, remote| {
443 let mut entry = SubscriptionEntry {
444 component_0: name.components_strings().unwrap()[0].to_string(),
445 component_1: name.components_strings().unwrap()[1].to_string(),
446 component_2: name.components_strings().unwrap()[2].to_string(),
447 id: Some(id),
448 ..Default::default()
449 };
450
451 for &cid in local {
452 entry.local_connections.push(ConnectionEntry {
453 id: cid,
454 connection_type: ConnectionType::Local as i32,
455 config_data: "{}".to_string(),
456 });
457 }
458
459 for &cid in remote {
460 if let Some(conn) = conn_table.get(cid as usize) {
461 entry.remote_connections.push(ConnectionEntry {
462 id: cid,
463 connection_type: ConnectionType::Remote as i32,
464 config_data: match conn.config_data() {
465 Some(data) => serde_json::to_string(data)
466 .unwrap_or_else(|_| "{}".to_string()),
467 None => "{}".to_string(),
468 },
469 });
470 } else {
471 error!("no connection entry for id {}", cid);
472 }
473 }
474 entries.push(entry);
475 },
476 );
477
478 for chunk in entries.chunks(CHUNK_SIZE) {
479 let resp = ControlMessage {
480 message_id: uuid::Uuid::new_v4().to_string(),
481 payload: Some(Payload::SubscriptionListResponse(
482 SubscriptionListResponse {
483 entries: chunk.to_vec(),
484 },
485 )),
486 };
487
488 if let Err(e) = tx.try_send(Ok(resp)) {
489 error!("failed to send subscription batch: {}", e);
490 }
491 }
492 }
493 Payload::ConnectionListRequest(_) => {
494 let mut all_entries = Vec::new();
495 self.inner
496 .message_processor
497 .connection_table()
498 .for_each(|id, conn| {
499 all_entries.push(ConnectionEntry {
500 id: id as u64,
501 connection_type: ConnectionType::Remote as i32,
502 config_data: match conn.config_data() {
503 Some(data) => serde_json::to_string(data)
504 .unwrap_or_else(|_| "{}".to_string()),
505 None => "{}".to_string(),
506 },
507 });
508 });
509
510 const CHUNK_SIZE: usize = 100;
511 for chunk in all_entries.chunks(CHUNK_SIZE) {
512 let resp = ControlMessage {
513 message_id: uuid::Uuid::new_v4().to_string(),
514 payload: Some(Payload::ConnectionListResponse(
515 ConnectionListResponse {
516 entries: chunk.to_vec(),
517 },
518 )),
519 };
520
521 if let Err(e) = tx.try_send(Ok(resp)) {
522 error!("failed to send connection list batch: {}", e);
523 }
524 }
525 }
526 Payload::Ack(_ack) => {
527 }
529 Payload::SubscriptionListResponse(_) => {
530 }
532 Payload::ConnectionListResponse(_) => {
533 }
535 Payload::RegisterNodeRequest(_) => {
536 error!("received a register node request, this should not happen");
537 }
538 Payload::RegisterNodeResponse(_) => {
539 }
541 Payload::DeregisterNodeRequest(_) => {
542 error!("received a deregister node request, this should not happen");
543 }
544 Payload::DeregisterNodeResponse(_) => {
545 }
547 }
548 }
549 None => {
550 error!(
551 "received control message {} with no payload",
552 msg.message_id
553 );
554 }
555 }
556
557 Ok(())
558 }
559
560 async fn send_control_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
562 self.inner.tx_slim.send(Ok(msg)).await.map_err(|e| {
563 error!("error sending message into datapath: {}", e);
564 ControllerError::DatapathError(e.to_string())
565 })
566 }
567
568 fn process_control_message_stream(
570 &self,
571 config: Option<ClientConfig>,
572 mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
573 tx: mpsc::Sender<Result<ControlMessage, Status>>,
574 cancellation_token: CancellationToken,
575 ) -> tokio::task::JoinHandle<()> {
576 let this = self.clone();
577 let drain = this.inner.drain_rx.clone();
578 tokio::spawn(async move {
579 let endpoint = config
581 .as_ref()
582 .map(|c| c.endpoint.clone())
583 .unwrap_or_else(|| "unknown".to_string());
584 info!(%endpoint, "connected to control plane");
585
586 let mut retry_connect = false;
587
588 let register_request = ControlMessage {
589 message_id: uuid::Uuid::new_v4().to_string(),
590 payload: Some(Payload::RegisterNodeRequest(v1::RegisterNodeRequest {
591 node_id: this.inner.id.to_string(),
592 })),
593 };
594
595 if config.is_some() {
597 if let Err(e) = tx.send(Ok(register_request)).await {
598 error!("failed to send register request: {}", e);
599 return;
600 }
601 }
602
603 loop {
606 tokio::select! {
607 next = stream.next() => {
608 match next {
609 Some(Ok(msg)) => {
610 if let Err(e) = this.handle_new_control_message(msg, &tx).await {
611 error!("error processing incoming control message: {:?}", e);
612 }
613 }
614 Some(Err(e)) => {
615 if let Some(io_err) = Self::match_for_io_error(&e) {
616 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
617 info!("connection closed by peer");
618 retry_connect = true;
619 }
620 } else {
621 error!(%e, "error receiving control messages");
622 }
623
624 break;
625 }
626 None => {
627 debug!("end of stream");
628 retry_connect = true;
629 break;
630 }
631 }
632 }
633 _ = cancellation_token.cancelled() => {
634 debug!("shutting down stream on cancellation token");
635 break;
636 }
637 _ = drain.clone().signaled() => {
638 debug!("shutting down stream on drain");
639 break;
640 }
641 }
642 }
643
644 info!(%endpoint, "control plane stream closed");
645
646 if retry_connect {
647 if let Some(config) = config {
648 info!(%config.endpoint, "retrying connection to control plane");
649 this.connect(config.clone(), cancellation_token)
650 .await
651 .map_or_else(
652 |e| {
653 error!("failed to reconnect to control plane: {}", e);
654 },
655 |tx| {
656 info!(%config.endpoint, "reconnected to control plane");
657
658 this.inner
659 .tx_channels
660 .write()
661 .insert(config.endpoint.clone(), tx);
662 },
663 )
664 }
665 }
666 })
667 }
668
669 async fn connect(
673 &self,
674 config: ClientConfig,
675 cancellation_token: CancellationToken,
676 ) -> Result<mpsc::Sender<Result<ControlMessage, Status>>, ControllerError> {
677 info!(%config.endpoint, "connecting to control plane");
678
679 let channel = config.to_channel().map_err(|e| {
680 error!("error reading channel config: {}", e);
681 ControllerError::ConfigError(e.to_string())
682 })?;
683
684 let mut client = ControllerServiceClient::new(channel);
685 for i in 0..Self::MAX_RETRIES {
686 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
687 let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
688 match client.open_control_channel(Request::new(out_stream)).await {
689 Ok(stream) => {
690 self.process_control_message_stream(
692 Some(config),
693 stream.into_inner(),
694 tx.clone(),
695 cancellation_token.clone(),
696 );
697
698 return Ok(tx);
699 }
700 Err(e) => {
701 error!(%e, "connection error, retrying {}/{}", i + 1, Self::MAX_RETRIES);
702 }
703 };
704
705 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
707 }
708
709 Err(ControllerError::ConfigError(format!(
710 "failed to connect to control plane after {} retries",
711 Self::MAX_RETRIES
712 )))
713 }
714
715 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
716 let mut err: &(dyn std::error::Error + 'static) = err_status;
717
718 loop {
719 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
720 return Some(io_err);
721 }
722
723 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
726 if let Some(io_err) = h2_err.get_io() {
727 return Some(io_err);
728 }
729 }
730
731 err = err.source()?;
732 }
733 }
734}
735
736#[tonic::async_trait]
737impl GrpcControllerService for ControllerService {
738 type OpenControlChannelStream =
739 Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
740
741 async fn open_control_channel(
742 &self,
743 request: Request<tonic::Streaming<ControlMessage>>,
744 ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
745 let remote_endpoint = request
747 .remote_addr()
748 .map(|addr| addr.to_string())
749 .unwrap_or_else(|| "unknown".to_string());
750
751 let stream = request.into_inner();
752 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
753
754 let cancellation_token = CancellationToken::new();
755
756 self.process_control_message_stream(None, stream, tx.clone(), cancellation_token.clone());
757
758 self.inner
760 .tx_channels
761 .write()
762 .insert(remote_endpoint.clone(), tx);
763
764 self.inner
766 .cancellation_tokens
767 .write()
768 .insert(remote_endpoint.clone(), cancellation_token);
769
770 let out_stream = ReceiverStream::new(rx);
771 Ok(Response::new(
772 Box::pin(out_stream) as Self::OpenControlChannelStream
773 ))
774 }
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780 use slim_config::component::id::Kind;
781 use tracing_test::traced_test;
782
783 #[tokio::test]
784 #[traced_test]
785 async fn test_end_to_end() {
786 let id_server =
788 ID::new_with_name(Kind::new("slim").unwrap(), "test_server_instance").unwrap();
789 let id_client =
790 ID::new_with_name(Kind::new("slim").unwrap(), "test_client_instance").unwrap();
791
792 let server_config = ServerConfig::with_endpoint("127.0.0.1:50051")
794 .with_tls_settings(slim_config::tls::server::TlsServerConfig::insecure());
795
796 let client_config = ClientConfig::with_endpoint("http://127.0.0.1:50051")
798 .with_tls_setting(slim_config::tls::client::TlsClientConfig::insecure());
799
800 let (signal_server, watch_server) = drain::channel();
802 let (signal_client, watch_client) = drain::channel();
803
804 let message_processor_client = MessageProcessor::with_drain_channel(watch_client.clone());
806 let message_processor_server = MessageProcessor::with_drain_channel(watch_server.clone());
807
808 let mut control_plane_server = ControlPlane::new(
810 id_server,
811 vec![server_config],
812 vec![],
813 watch_server,
814 Arc::new(message_processor_server),
815 );
816
817 let mut control_plane_client = ControlPlane::new(
818 id_client,
819 vec![],
820 vec![client_config],
821 watch_client,
822 Arc::new(message_processor_client),
823 );
824
825 control_plane_server.run().await.unwrap();
827
828 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
830
831 control_plane_client.run().await.unwrap();
833
834 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
836
837 assert!(logs_contain(
839 "received a register node request, this should not happen"
840 ));
841
842 drop(control_plane_server);
845 drop(control_plane_client);
846
847 signal_server.drain().await;
849 signal_client.drain().await;
850 }
851}