1use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use base64::Engine;
9
10use slim_config::component::id::ID;
11use slim_config::grpc::server::ServerConfig;
12use slim_config::metadata::MetadataValue;
13use tokio::sync::mpsc;
14use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
15use tokio_util::sync::CancellationToken;
16use tonic::{Request, Response, Status};
17use tracing::{debug, error, info};
18
19use crate::api::proto::api::v1::control_message::Payload;
20use crate::api::proto::api::v1::controller_service_server::ControllerServiceServer;
21use crate::api::proto::api::v1::{
22 self, ConnectionListResponse, ConnectionType, SubscriptionListResponse,
23};
24use crate::api::proto::api::v1::{
25 Ack, ConnectionDetails, ConnectionEntry, ControlMessage, SubscriptionEntry,
26 controller_service_client::ControllerServiceClient,
27 controller_service_server::ControllerService as GrpcControllerService,
28};
29use crate::errors::ControllerError;
30use slim_auth::auth_provider::{AuthProvider, AuthVerifier};
31use slim_auth::traits::TokenProvider;
32use slim_config::grpc::client::ClientConfig;
33use slim_datapath::api::ProtoMessage as DataPlaneMessage;
34use slim_datapath::api::{ProtoSessionMessageType, ProtoSessionType, SessionHeader, SlimHeader};
35use slim_datapath::message_processing::MessageProcessor;
36use slim_datapath::messages::Name;
37use slim_datapath::messages::encoder::calculate_hash;
38use slim_datapath::messages::utils::{SLIM_IDENTITY, SlimHeaderFlags};
39use slim_datapath::tables::SubscriptionTable;
40
41type TxChannel = mpsc::Sender<Result<ControlMessage, Status>>;
42type TxChannels = HashMap<String, TxChannel>;
43
44pub static CONTROLLER_SOURCE_NAME: std::sync::LazyLock<slim_datapath::messages::Name> =
46 std::sync::LazyLock::new(|| {
47 slim_datapath::messages::Name::from_strings(["controller", "controller", "controller"])
48 .with_id(0)
49 });
50
51#[derive(Clone)]
53pub struct ControlPlaneSettings {
54 pub id: ID,
56 pub group_name: Option<String>,
58 pub servers: Vec<ServerConfig>,
60 pub clients: Vec<ClientConfig>,
62 pub drain_rx: drain::Watch,
64 pub message_processor: Arc<MessageProcessor>,
66 pub pubsub_servers: Vec<ServerConfig>,
68 pub auth_provider: Option<AuthProvider>,
70 pub auth_verifier: Option<AuthVerifier>,
72}
73
74struct ControllerServiceInternal {
79 id: ID,
81
82 group_name: Option<String>,
84
85 message_processor: Arc<MessageProcessor>,
87
88 connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
90
91 tx_slim: mpsc::Sender<Result<DataPlaneMessage, Status>>,
93
94 tx_channels: parking_lot::RwLock<TxChannels>,
96
97 cancellation_tokens: parking_lot::RwLock<HashMap<String, CancellationToken>>,
99
100 drain_rx: drain::Watch,
102
103 connection_details: Vec<ConnectionDetails>,
105
106 auth_provider: Option<AuthProvider>,
108
109 _auth_verifier: Option<AuthVerifier>,
111}
112
113#[derive(Clone)]
114struct ControllerService {
115 inner: Arc<ControllerServiceInternal>,
117}
118
119pub struct ControlPlane {
121 servers: Vec<ServerConfig>,
123
124 clients: Vec<ClientConfig>,
126
127 controller: ControllerService,
129
130 rx_slim_option: Option<mpsc::Receiver<Result<DataPlaneMessage, Status>>>,
133}
134
135impl Drop for ControlPlane {
138 fn drop(&mut self) {
139 for (_endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
141 token.cancel();
142 }
143 }
144}
145
146fn from_server_config(server_config: &ServerConfig) -> ConnectionDetails {
147 let group_name = server_config
148 .metadata
149 .as_ref()
150 .and_then(|m| m.get("group_name"))
151 .and_then(|v| match v {
152 MetadataValue::String(s) => Some(s.clone()),
153 _ => None,
154 });
155 let local_endpoint = server_config
156 .metadata
157 .as_ref()
158 .and_then(|m| m.get("local_endpoint"))
159 .and_then(|v| match v {
160 MetadataValue::String(s) => Some(s.clone()),
161 _ => None,
162 });
163 let external_endpoint = server_config
164 .metadata
165 .as_ref()
166 .and_then(|m| m.get("external_endpoint"))
167 .and_then(|v| match v {
168 MetadataValue::String(s) => Some(s.clone()),
169 _ => None,
170 });
171 ConnectionDetails {
172 endpoint: server_config.endpoint.clone(),
173 mtls_required: !server_config.tls_setting.insecure,
174 group_name,
175 local_endpoint,
176 external_endpoint,
177 }
178}
179
180impl ControlPlane {
182 pub fn new(config: ControlPlaneSettings) -> Self {
195 let (_, tx_slim, rx_slim) = config.message_processor.register_local_connection(true);
197
198 let connection_details = config
199 .pubsub_servers
200 .iter()
201 .map(from_server_config)
202 .collect();
203
204 ControlPlane {
205 servers: config.servers,
206 clients: config.clients,
207 controller: ControllerService {
208 inner: Arc::new(ControllerServiceInternal {
209 id: config.id,
210 group_name: config.group_name,
211 message_processor: config.message_processor,
212 connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
213 tx_slim,
214 tx_channels: parking_lot::RwLock::new(HashMap::new()),
215 cancellation_tokens: parking_lot::RwLock::new(HashMap::new()),
216 drain_rx: config.drain_rx,
217 connection_details,
218 auth_provider: config.auth_provider,
219 _auth_verifier: config.auth_verifier,
220 }),
221 },
222 rx_slim_option: Some(rx_slim),
223 }
224 }
225
226 pub fn with_clients(mut self, clients: Vec<ClientConfig>) -> Self {
228 self.clients = clients;
229 self
230 }
231
232 pub fn with_servers(mut self, servers: Vec<ServerConfig>) -> Self {
234 self.servers = servers;
235 self
236 }
237
238 pub async fn run(&mut self) -> Result<(), ControllerError> {
245 info!("starting controller service");
246
247 let servers = self.servers.clone();
249 let clients = self.clients.clone();
250
251 for server in servers {
253 self.run_server(server)?;
254 }
255
256 for client in clients {
258 self.run_client(client).await?;
259 }
260
261 let rx = self.rx_slim_option.take();
262 self.listen_from_data_plane(rx.unwrap()).await;
263
264 Ok(())
265 }
266
267 async fn listen_from_data_plane(
268 &mut self,
269 mut rx: mpsc::Receiver<Result<DataPlaneMessage, Status>>,
270 ) {
271 let cancellation_token = CancellationToken::new();
272 let cancellation_token_clone = cancellation_token.clone();
273 let drain = self.controller.inner.drain_rx.clone();
274
275 self.controller
276 .inner
277 .cancellation_tokens
278 .write()
279 .insert("DATA_PLANE".to_string(), cancellation_token_clone);
280
281 let clients = self.clients.clone();
282 let inner = self.controller.inner.clone();
283
284 let subscribe_msg =
286 DataPlaneMessage::new_subscribe(&CONTROLLER_SOURCE_NAME, &CONTROLLER_SOURCE_NAME, None);
287
288 if let Err(e) = inner.tx_slim.send(Ok(subscribe_msg)).await {
290 error!("failed to send subscribe message to data plane: {}", e);
291 }
292
293 tokio::spawn(async move {
294 loop {
295 tokio::select! {
296 next = rx.recv() => {
297 match next {
298 Some(res) => {
299 match res {
300 Ok(msg) => {
301 debug!("Send sub/unsub to control plane for message: {:?}", msg);
302
303 let mut sub_vec = vec![];
304 let mut unsub_vec = vec![];
305
306 let dst = msg.get_dst();
307 let components = dst.components_strings().unwrap();
308 let cmd = v1::Subscription {
309 component_0: components[0].to_string(),
310 component_1: components[1].to_string(),
311 component_2: components[2].to_string(),
312 id: Some(dst.id()),
313 connection_id: "n/a".to_string(),
314 };
315 match msg.get_type() {
316 slim_datapath::api::MessageType::Subscribe(_) => {
317 sub_vec.push(cmd);
318 },
319 slim_datapath::api::MessageType::Unsubscribe(_) => {
320 unsub_vec.push(cmd);
321 }
322 slim_datapath::api::MessageType::Publish(_) => {
323 continue;
325 },
326 }
327
328 let ctrl = ControlMessage {
329 message_id: uuid::Uuid::new_v4().to_string(),
330 payload: Some(Payload::ConfigCommand(
331 v1::ConfigurationCommand {
332 connections_to_create: vec![],
333 subscriptions_to_set: sub_vec,
334 subscriptions_to_delete: unsub_vec
335 })),
336 };
337
338 for c in &clients {
339 let tx = match inner.tx_channels.read().get(&c.endpoint) {
340 Some(tx) => tx.clone(),
341 None => continue,
342 };
343 if (tx.send(Ok(ctrl.clone())).await).is_err() {
344 error!("error while notifiyng the control plane");
345 };
346
347 }
348 }
349 Err(e) => {
350 error!("received error from the data plane {}", e.to_string());
351 continue;
352 }
353 }
354 }
355 None => {
356 debug!("Data plane receiver channel closed.");
357 break;
358 }
359 }
360 }
361 _ = cancellation_token.cancelled() => {
362 debug!("shutting down stream on cancellation token");
363 break;
364 }
365 _ = drain.clone().signaled() => {
366 debug!("shutting down stream on drain");
367 break;
368 }
369 }
370 }
371 });
372 }
373
374 pub fn stop(&mut self) {
378 info!("stopping controller service");
379
380 for (endpoint, token) in self.controller.inner.cancellation_tokens.write().drain() {
382 info!(%endpoint, "stopping");
383 token.cancel();
384 }
385 }
386
387 async fn run_client(&mut self, client: ClientConfig) -> Result<(), ControllerError> {
391 if self
392 .controller
393 .inner
394 .cancellation_tokens
395 .read()
396 .contains_key(&client.endpoint)
397 {
398 return Err(ControllerError::ConfigError(format!(
399 "client {} is already running",
400 client.endpoint
401 )));
402 }
403
404 let cancellation_token = CancellationToken::new();
405
406 let tx = self
407 .controller
408 .connect(client.clone(), cancellation_token.clone())
409 .await?;
410
411 self.controller
413 .inner
414 .cancellation_tokens
415 .write()
416 .insert(client.endpoint.clone(), cancellation_token);
417
418 self.controller
420 .inner
421 .tx_channels
422 .write()
423 .insert(client.endpoint.clone(), tx);
424
425 Ok(())
427 }
428
429 pub fn run_server(&mut self, config: ServerConfig) -> Result<(), ControllerError> {
433 info!(%config.endpoint, "starting control plane server");
434
435 if self
437 .controller
438 .inner
439 .cancellation_tokens
440 .read()
441 .contains_key(&config.endpoint)
442 {
443 error!("server {} is already running", config.endpoint);
444 return Err(ControllerError::ConfigError(format!(
445 "server {} is already running",
446 config.endpoint
447 )));
448 }
449
450 let token = config
451 .run_server(
452 &[ControllerServiceServer::new(self.controller.clone())],
453 self.controller.inner.drain_rx.clone(),
454 )
455 .map_err(|e| {
456 error!("failed to run server {}: {}", config.endpoint, e);
457 ControllerError::ConfigError(e.to_string())
458 })?;
459
460 self.controller
462 .inner
463 .cancellation_tokens
464 .write()
465 .insert(config.endpoint.clone(), token.clone());
466
467 info!(%config.endpoint, "control plane server started");
468
469 Ok(())
470 }
471}
472
473fn generate_session_id(moderator: &Name, channel: &Name) -> u32 {
474 let mut all: [u64; 8] = [0; 8];
477 let m = moderator.components();
478 let c = channel.components();
479 all[..4].copy_from_slice(m);
480 all[4..].copy_from_slice(c);
481
482 let hash = calculate_hash(&all);
483 (hash ^ (hash >> 32)) as u32
484}
485
486fn get_name_from_string(string_name: &String) -> Result<Name, ControllerError> {
487 let parts: Vec<&str> = string_name.split('/').collect();
488 if parts.len() < 3 {
489 return Err(ControllerError::ConfigError(format!(
490 "invalid name format: {}",
491 string_name
492 )));
493 }
494
495 if parts.len() == 4 {
496 let id = parts[3].parse::<u64>().map_err(|_| {
497 ControllerError::ConfigError(format!("invalid moderator ID: {}", parts[3]))
498 })?;
499 Ok(Name::from_strings([parts[0], parts[1], parts[2]]).with_id(id))
500 } else {
501 Ok(Name::from_strings([parts[0], parts[1], parts[2]]))
502 }
503}
504
505#[allow(clippy::too_many_arguments)]
506fn create_channel_message(
507 source: &Name,
508 destination: &Name,
509 channel: Option<&Name>,
513 request_type: ProtoSessionMessageType,
514 session_id: u32,
515 message_id: u32,
516 mut metadata: HashMap<String, String>,
517 payload: Vec<u8>,
518 auth_provider: &Option<AuthProvider>,
519) -> DataPlaneMessage {
520 let slim_header = Some(SlimHeader::new(source, destination, None));
521 let dest = channel.unwrap_or(destination);
522
523 let session_header = Some(SessionHeader::new(
524 ProtoSessionType::SessionMulticast.into(),
525 request_type.into(),
526 session_id,
527 message_id,
528 &None,
529 &Some(dest.clone()),
530 ));
531
532 if let Some(auth) = auth_provider {
533 let identity_token = auth
534 .get_token()
535 .map_err(|e| {
536 error!("failed to generate identity token: {}", e);
537 ControllerError::DatapathError(e.to_string())
538 })
539 .unwrap();
540
541 metadata.insert(SLIM_IDENTITY.to_string(), identity_token);
542 }
543 let mut msg =
544 DataPlaneMessage::new_publish_with_headers(slim_header, session_header, "", payload);
545
546 msg.set_metadata_map(metadata);
547
548 msg
549}
550
551fn new_channel_message(
552 controller: &Name,
553 moderator: &Name,
554 channel: &Name,
555 auth_provider: &Option<AuthProvider>,
556) -> DataPlaneMessage {
557 let session_id = generate_session_id(moderator, channel);
558
559 #[derive(Debug, Clone, bincode::Encode, bincode::Decode)]
563 struct JoinMessagePayloadLocal {
564 channel_name: Name,
565 moderator_name: Name,
566 }
567 let p = JoinMessagePayloadLocal {
568 channel_name: channel.clone(),
569 moderator_name: moderator.clone(),
570 };
571 let invite_payload: Vec<u8> = bincode::encode_to_vec(p, bincode::config::standard())
572 .expect("unable to encode channel join payload");
573
574 let mut metadata = HashMap::new();
575
576 metadata.insert("IS_MODERATOR".to_string(), "true".to_string());
577
578 metadata.insert("MLS_ENABLED".to_string(), "true".to_string());
582
583 create_channel_message(
584 controller,
585 moderator,
586 Some(channel),
587 ProtoSessionMessageType::ChannelJoinRequest,
588 session_id,
589 rand::random::<u32>(),
590 metadata,
591 invite_payload,
592 auth_provider,
593 )
594}
595
596fn delete_channel_message(
597 controller: &Name,
598 moderator: &Name,
599 channel_name: &Name,
600 auth_provider: &Option<AuthProvider>,
601) -> DataPlaneMessage {
602 let session_id = generate_session_id(moderator, channel_name);
603
604 let mut metadata = HashMap::new();
605 metadata.insert("DELETE_GROUP".to_string(), "true".to_string());
606
607 create_channel_message(
608 controller,
609 moderator,
610 None,
611 ProtoSessionMessageType::ChannelLeaveRequest,
612 session_id,
613 rand::random::<u32>(),
614 metadata,
615 vec![],
616 auth_provider,
617 )
618}
619
620fn invite_participant_message(
621 controller: &Name,
622 moderator: &Name,
623 participant: &Name,
624 channel_name: &Name,
625 auth_provider: &Option<AuthProvider>,
626) -> DataPlaneMessage {
627 let session_id = generate_session_id(moderator, channel_name);
628 let mut metadata = HashMap::new();
629
630 let encoded_participant: Vec<u8> =
631 bincode::encode_to_vec(participant, bincode::config::standard())
632 .expect("unable to encode channel join payload");
633 let encoded_participant_str =
634 base64::engine::general_purpose::STANDARD.encode(&encoded_participant);
635
636 metadata.insert("PARTICIPANT_NAME".to_string(), encoded_participant_str);
637
638 create_channel_message(
639 controller,
640 moderator,
641 None,
642 ProtoSessionMessageType::ChannelDiscoveryRequest,
643 session_id,
644 rand::random::<u32>(),
645 metadata,
646 vec![],
647 auth_provider,
648 )
649}
650
651fn remove_participant_message(
652 controller: &Name,
653 moderator: &Name,
654 participant: &Name,
655 channel_name: &Name,
656 auth_provider: &Option<AuthProvider>,
657) -> DataPlaneMessage {
658 let session_id = generate_session_id(moderator, channel_name);
659
660 let mut metadata = HashMap::new();
661 let encoded_participant: Vec<u8> =
662 bincode::encode_to_vec(participant, bincode::config::standard())
663 .expect("unable to encode channel join payload");
664 let encoded_participant_str =
665 base64::engine::general_purpose::STANDARD.encode(&encoded_participant);
666 metadata.insert("PARTICIPANT_NAME".to_string(), encoded_participant_str);
667
668 create_channel_message(
669 controller,
670 moderator,
671 None,
672 ProtoSessionMessageType::ChannelLeaveRequest,
673 session_id,
674 rand::random::<u32>(),
675 metadata,
676 vec![],
677 auth_provider,
678 )
679}
680
681impl ControllerService {
682 const MAX_RETRIES: i32 = 10;
683
684 async fn handle_new_control_message(
686 &self,
687 msg: ControlMessage,
688 tx: &mpsc::Sender<Result<ControlMessage, Status>>,
689 ) -> Result<(), ControllerError> {
690 match msg.payload {
691 Some(ref payload) => {
692 match payload {
693 Payload::ConfigCommand(config) => {
694 for conn in &config.connections_to_create {
695 info!("received a connection to create: {:?}", conn);
696 let client_config =
697 serde_json::from_str::<ClientConfig>(&conn.config_data)
698 .map_err(|e| ControllerError::ConfigError(e.to_string()))?;
699 let client_endpoint = &client_config.endpoint;
700
701 if !self.inner.connections.read().contains_key(client_endpoint) {
703 match client_config.to_channel() {
704 Err(e) => {
705 error!("error reading channel config {:?}", e);
706 }
707 Ok(channel) => {
708 let ret = self
709 .inner
710 .message_processor
711 .connect(
712 channel,
713 Some(client_config.clone()),
714 None,
715 None,
716 )
717 .await
718 .map_err(|e| {
719 ControllerError::ConnectionError(e.to_string())
720 });
721
722 let conn_id = match ret {
723 Err(e) => {
724 error!("connection error: {:?}", e);
725 return Err(ControllerError::ConnectionError(
726 e.to_string(),
727 ));
728 }
729 Ok(conn_id) => conn_id.1,
730 };
731
732 self.inner
733 .connections
734 .write()
735 .insert(client_endpoint.clone(), conn_id);
736 }
737 }
738 }
739 }
740
741 for subscription in &config.subscriptions_to_set {
742 if !self
743 .inner
744 .connections
745 .read()
746 .contains_key(&subscription.connection_id)
747 {
748 error!("connection {} not found", subscription.connection_id);
749 continue;
750 }
751
752 let conn = self
753 .inner
754 .connections
755 .read()
756 .get(&subscription.connection_id)
757 .cloned()
758 .unwrap();
759 let source = Name::from_strings([
760 subscription.component_0.as_str(),
761 subscription.component_1.as_str(),
762 subscription.component_2.as_str(),
763 ])
764 .with_id(0);
765 let name = Name::from_strings([
766 subscription.component_0.as_str(),
767 subscription.component_1.as_str(),
768 subscription.component_2.as_str(),
769 ])
770 .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
771
772 let msg = DataPlaneMessage::new_subscribe(
773 &source,
774 &name,
775 Some(SlimHeaderFlags::default().with_recv_from(conn)),
776 );
777
778 if let Err(e) = self.send_control_message(msg).await {
779 error!("failed to subscribe: {}", e);
780 }
781 }
782
783 for subscription in &config.subscriptions_to_delete {
784 if !self
785 .inner
786 .connections
787 .read()
788 .contains_key(&subscription.connection_id)
789 {
790 error!("connection {} not found", subscription.connection_id);
791 continue;
792 }
793
794 let conn = self
795 .inner
796 .connections
797 .read()
798 .get(&subscription.connection_id)
799 .cloned()
800 .unwrap();
801 let source = Name::from_strings([
802 subscription.component_0.as_str(),
803 subscription.component_1.as_str(),
804 subscription.component_2.as_str(),
805 ])
806 .with_id(0);
807 let name = Name::from_strings([
808 subscription.component_0.as_str(),
809 subscription.component_1.as_str(),
810 subscription.component_2.as_str(),
811 ])
812 .with_id(subscription.id.unwrap_or(Name::NULL_COMPONENT));
813
814 let msg = DataPlaneMessage::new_unsubscribe(
815 &source,
816 &name,
817 Some(SlimHeaderFlags::default().with_recv_from(conn)),
818 );
819
820 if let Err(e) = self.send_control_message(msg).await {
821 error!("failed to unsubscribe: {}", e);
822 }
823 }
824
825 let ack = Ack {
826 original_message_id: msg.message_id.clone(),
827 success: true,
828 messages: vec![],
829 };
830
831 let reply = ControlMessage {
832 message_id: uuid::Uuid::new_v4().to_string(),
833 payload: Some(Payload::Ack(ack)),
834 };
835
836 if let Err(e) = tx.send(Ok(reply)).await {
837 error!("failed to send ACK: {}", e);
838 }
839 }
840 Payload::SubscriptionListRequest(_) => {
841 const CHUNK_SIZE: usize = 100;
842
843 let conn_table = self.inner.message_processor.connection_table();
844 let mut entries = Vec::new();
845
846 self.inner.message_processor.subscription_table().for_each(
847 |name, id, local, remote| {
848 let mut entry = SubscriptionEntry {
849 component_0: name.components_strings().unwrap()[0].to_string(),
850 component_1: name.components_strings().unwrap()[1].to_string(),
851 component_2: name.components_strings().unwrap()[2].to_string(),
852 id: Some(id),
853 ..Default::default()
854 };
855
856 for &cid in local {
857 entry.local_connections.push(ConnectionEntry {
858 id: cid,
859 connection_type: ConnectionType::Local as i32,
860 config_data: "{}".to_string(),
861 });
862 }
863
864 for &cid in remote {
865 if let Some(conn) = conn_table.get(cid as usize) {
866 entry.remote_connections.push(ConnectionEntry {
867 id: cid,
868 connection_type: ConnectionType::Remote as i32,
869 config_data: match conn.config_data() {
870 Some(data) => serde_json::to_string(data)
871 .unwrap_or_else(|_| "{}".to_string()),
872 None => "{}".to_string(),
873 },
874 });
875 } else {
876 error!("no connection entry for id {}", cid);
877 }
878 }
879 entries.push(entry);
880 },
881 );
882
883 for chunk in entries.chunks(CHUNK_SIZE) {
884 let resp = ControlMessage {
885 message_id: uuid::Uuid::new_v4().to_string(),
886 payload: Some(Payload::SubscriptionListResponse(
887 SubscriptionListResponse {
888 entries: chunk.to_vec(),
889 },
890 )),
891 };
892
893 if let Err(e) = tx.try_send(Ok(resp)) {
894 error!("failed to send subscription batch: {}", e);
895 }
896 }
897 }
898 Payload::ConnectionListRequest(_) => {
899 let mut all_entries = Vec::new();
900 self.inner
901 .message_processor
902 .connection_table()
903 .for_each(|id, conn| {
904 all_entries.push(ConnectionEntry {
905 id: id as u64,
906 connection_type: ConnectionType::Remote as i32,
907 config_data: match conn.config_data() {
908 Some(data) => serde_json::to_string(data)
909 .unwrap_or_else(|_| "{}".to_string()),
910 None => "{}".to_string(),
911 },
912 });
913 });
914
915 const CHUNK_SIZE: usize = 100;
916 for chunk in all_entries.chunks(CHUNK_SIZE) {
917 let resp = ControlMessage {
918 message_id: uuid::Uuid::new_v4().to_string(),
919 payload: Some(Payload::ConnectionListResponse(
920 ConnectionListResponse {
921 entries: chunk.to_vec(),
922 },
923 )),
924 };
925
926 if let Err(e) = tx.try_send(Ok(resp)) {
927 error!("failed to send connection list batch: {}", e);
928 }
929 }
930 }
931 Payload::Ack(_ack) => {
932 }
934 Payload::SubscriptionListResponse(_) => {
935 }
937 Payload::ConnectionListResponse(_) => {
938 }
940 Payload::RegisterNodeRequest(_) => {
941 error!("received a register node request");
942 }
943 Payload::RegisterNodeResponse(_) => {
944 }
946 Payload::DeregisterNodeRequest(_) => {
947 error!("received a deregister node request");
948 }
949 Payload::DeregisterNodeResponse(_) => {
950 }
952 Payload::CreateChannelRequest(req) => {
953 info!("received a create channel request");
954
955 let mut success = true;
956 if let Some(first_moderator) = req.moderators.first() {
958 let moderator_name = get_name_from_string(first_moderator)?;
959 if !moderator_name.has_id() {
960 error!("invalid moderator ID");
961 success = false;
962 } else {
963 let channel_name = get_name_from_string(&req.channel_name)?;
964
965 let creation_msg = new_channel_message(
966 &CONTROLLER_SOURCE_NAME,
967 &moderator_name,
968 &channel_name,
969 &self.inner.auth_provider,
970 );
971
972 debug!("Send session creation message: {:?}", creation_msg);
973 if let Err(e) = self.send_control_message(creation_msg).await {
974 error!("failed to send channel creation: {}", e);
975 success = false;
976 }
977 }
978 } else {
979 error!("no moderators specified create channel request");
980 success = false;
981 };
982
983 let ack = Ack {
984 original_message_id: msg.message_id.clone(),
985 success,
986 messages: vec![msg.message_id.clone()],
987 };
988
989 let reply = ControlMessage {
990 message_id: uuid::Uuid::new_v4().to_string(),
991 payload: Some(Payload::Ack(ack)),
992 };
993
994 if let Err(e) = tx.send(Ok(reply)).await {
995 error!("failed to send Ack: {}", e);
996 }
997 }
998 Payload::DeleteChannelRequest(req) => {
999 info!("received a channel delete request");
1000 let mut success = true;
1001
1002 if let Some(first_moderator) = req.moderators.first() {
1004 let moderator_name = get_name_from_string(first_moderator)?;
1005 if !moderator_name.has_id() {
1006 error!("invalid moderator ID");
1007 success = false;
1008 } else {
1009 let channel_name = get_name_from_string(&req.channel_name)?;
1010
1011 let delete_msg = delete_channel_message(
1012 &CONTROLLER_SOURCE_NAME,
1013 &moderator_name,
1014 &channel_name,
1015 &self.inner.auth_provider,
1016 );
1017
1018 debug!("Send delete session message: {:?}", delete_msg);
1019 if let Err(e) = self.send_control_message(delete_msg).await {
1020 error!("failed to send delete channel: {}", e);
1021 success = false;
1022 }
1023 }
1024 } else {
1025 error!("no moderators specified in delete channel request");
1026 success = false;
1027 };
1028
1029 let ack = Ack {
1030 original_message_id: msg.message_id.clone(),
1031 success,
1032 messages: vec![msg.message_id.clone()],
1033 };
1034
1035 let reply = ControlMessage {
1036 message_id: uuid::Uuid::new_v4().to_string(),
1037 payload: Some(Payload::Ack(ack)),
1038 };
1039
1040 if let Err(e) = tx.send(Ok(reply)).await {
1041 error!("failed to send Ack: {}", e);
1042 }
1043 }
1044 Payload::AddParticipantRequest(req) => {
1045 info!(
1046 "received a participant add request for channel: {}, participant: {}",
1047 req.channel_name, req.participant_name
1048 );
1049
1050 let mut success = true;
1051
1052 if let Some(first_moderator) = req.moderators.first() {
1053 let moderator_name = get_name_from_string(first_moderator)?;
1054 if !moderator_name.has_id() {
1055 error!("invalid moderator ID");
1056 success = false;
1057 } else {
1058 let channel_name = get_name_from_string(&req.channel_name)?;
1059 let participant_name = get_name_from_string(&req.participant_name)?;
1060
1061 let invite_msg = invite_participant_message(
1062 &CONTROLLER_SOURCE_NAME,
1063 &moderator_name,
1064 &participant_name,
1065 &channel_name,
1066 &self.inner.auth_provider,
1067 );
1068
1069 debug!("Send invite participant: {:?}", invite_msg);
1070
1071 if let Err(e) = self.send_control_message(invite_msg).await {
1072 error!("failed to send channel creation: {}", e);
1073 success = false;
1074 }
1075 }
1076 } else {
1077 error!("no moderators specified in add participant request");
1078 };
1079
1080 let ack = Ack {
1081 original_message_id: msg.message_id.clone(),
1082 success,
1083 messages: vec![msg.message_id.clone()],
1084 };
1085
1086 let reply = ControlMessage {
1087 message_id: uuid::Uuid::new_v4().to_string(),
1088 payload: Some(Payload::Ack(ack)),
1089 };
1090
1091 if let Err(e) = tx.send(Ok(reply)).await {
1092 error!("failed to send Ack: {}", e);
1093 }
1094 }
1095 Payload::DeleteParticipantRequest(req) => {
1096 info!("received a participant delete request");
1097
1098 let mut success = true;
1099
1100 if let Some(first_moderator) = req.moderators.first() {
1101 let moderator_name = get_name_from_string(first_moderator)?;
1102 if !moderator_name.has_id() {
1103 error!("invalid moderator ID");
1104 success = false;
1105 } else {
1106 let channel_name = get_name_from_string(&req.channel_name)?;
1107 let participant_name = get_name_from_string(&req.participant_name)?;
1108
1109 let remove_msg = remove_participant_message(
1110 &CONTROLLER_SOURCE_NAME,
1111 &moderator_name,
1112 &participant_name,
1113 &channel_name,
1114 &self.inner.auth_provider,
1115 );
1116
1117 if let Err(e) = self.send_control_message(remove_msg).await {
1118 error!("failed to send channel creation: {}", e);
1119 success = false;
1120 }
1121 }
1122 } else {
1123 error!("no moderators specified in remove participant request");
1124 success = false;
1125 };
1126
1127 let ack = Ack {
1128 original_message_id: msg.message_id.clone(),
1129 success,
1130 messages: vec![msg.message_id.clone()],
1131 };
1132
1133 let reply = ControlMessage {
1134 message_id: uuid::Uuid::new_v4().to_string(),
1135 payload: Some(Payload::Ack(ack)),
1136 };
1137
1138 if let Err(e) = tx.send(Ok(reply)).await {
1139 error!("failed to send Ack: {}", e);
1140 }
1141 }
1142 Payload::ListChannelRequest(_) => {}
1143 Payload::ListChannelResponse(_) => {}
1144 Payload::ListParticipantsRequest(_) => {}
1145 Payload::ListParticipantsResponse(_) => {}
1146 }
1147 }
1148 None => {
1149 error!(
1150 "received control message {} with no payload",
1151 msg.message_id
1152 );
1153 }
1154 }
1155
1156 Ok(())
1157 }
1158
1159 async fn send_control_message(&self, msg: DataPlaneMessage) -> Result<(), ControllerError> {
1161 self.inner.tx_slim.send(Ok(msg)).await.map_err(|e| {
1162 error!("error sending message into datapath: {}", e);
1163 ControllerError::DatapathError(e.to_string())
1164 })
1165 }
1166
1167 fn process_control_message_stream(
1169 &self,
1170 config: Option<ClientConfig>,
1171 mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
1172 tx: mpsc::Sender<Result<ControlMessage, Status>>,
1173 cancellation_token: CancellationToken,
1174 ) -> tokio::task::JoinHandle<()> {
1175 let this = self.clone();
1176 let drain = this.inner.drain_rx.clone();
1177 tokio::spawn(async move {
1178 let endpoint = config
1180 .as_ref()
1181 .map(|c| c.endpoint.clone())
1182 .unwrap_or_else(|| "unknown".to_string());
1183 info!(%endpoint, "connected to control plane");
1184
1185 let mut retry_connect = false;
1186
1187 let register_request = ControlMessage {
1188 message_id: uuid::Uuid::new_v4().to_string(),
1189 payload: Some(Payload::RegisterNodeRequest(v1::RegisterNodeRequest {
1190 node_id: this.inner.id.to_string(),
1191 group_name: this.inner.group_name.clone(),
1192 connection_details: this.inner.connection_details.clone(),
1193 })),
1194 };
1195
1196 if config.is_some()
1198 && let Err(e) = tx.send(Ok(register_request)).await
1199 {
1200 error!("failed to send register request: {}", e);
1201 return;
1202 }
1203
1204 loop {
1207 tokio::select! {
1208 next = stream.next() => {
1209 match next {
1210 Some(Ok(msg)) => {
1211 if let Err(e) = this.handle_new_control_message(msg, &tx).await {
1212 error!("error processing incoming control message: {:?}", e);
1213 }
1214 }
1215 Some(Err(e)) => {
1216 if let Some(io_err) = Self::match_for_io_error(&e) {
1217 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
1218 info!("connection closed by peer");
1219 retry_connect = true;
1220 }
1221 } else {
1222 error!(%e, "error receiving control messages");
1223 }
1224
1225 break;
1226 }
1227 None => {
1228 debug!("end of stream");
1229 retry_connect = true;
1230 break;
1231 }
1232 }
1233 }
1234 _ = cancellation_token.cancelled() => {
1235 debug!("shutting down stream on cancellation token");
1236 break;
1237 }
1238 _ = drain.clone().signaled() => {
1239 debug!("shutting down stream on drain");
1240 break;
1241 }
1242 }
1243 }
1244
1245 info!(%endpoint, "control plane stream closed");
1246
1247 if retry_connect && let Some(config) = config {
1248 info!(%config.endpoint, "retrying connection to control plane");
1249 this.connect(config.clone(), cancellation_token)
1250 .await
1251 .map_or_else(
1252 |e| {
1253 error!("failed to reconnect to control plane: {}", e);
1254 },
1255 |tx| {
1256 info!(%config.endpoint, "reconnected to control plane");
1257
1258 this.inner
1259 .tx_channels
1260 .write()
1261 .insert(config.endpoint.clone(), tx);
1262 },
1263 )
1264 }
1265 })
1266 }
1267
1268 async fn connect(
1272 &self,
1273 config: ClientConfig,
1274 cancellation_token: CancellationToken,
1275 ) -> Result<mpsc::Sender<Result<ControlMessage, Status>>, ControllerError> {
1276 info!(%config.endpoint, "connecting to control plane");
1277
1278 let channel = config.to_channel().map_err(|e| {
1279 error!("error reading channel config: {}", e);
1280 ControllerError::ConfigError(e.to_string())
1281 })?;
1282
1283 let mut client = ControllerServiceClient::new(channel);
1284 for i in 0..Self::MAX_RETRIES {
1285 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
1286 let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
1287 match client.open_control_channel(Request::new(out_stream)).await {
1288 Ok(stream) => {
1289 self.process_control_message_stream(
1291 Some(config),
1292 stream.into_inner(),
1293 tx.clone(),
1294 cancellation_token.clone(),
1295 );
1296
1297 return Ok(tx);
1298 }
1299 Err(e) => {
1300 error!(%e, "connection error, retrying {}/{}", i + 1, Self::MAX_RETRIES);
1301 }
1302 };
1303
1304 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
1306 }
1307
1308 Err(ControllerError::ConfigError(format!(
1309 "failed to connect to control plane after {} retries",
1310 Self::MAX_RETRIES
1311 )))
1312 }
1313
1314 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
1315 let mut err: &(dyn std::error::Error + 'static) = err_status;
1316
1317 loop {
1318 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
1319 return Some(io_err);
1320 }
1321
1322 if let Some(h2_err) = err.downcast_ref::<h2::Error>()
1325 && let Some(io_err) = h2_err.get_io()
1326 {
1327 return Some(io_err);
1328 }
1329
1330 err = err.source()?;
1331 }
1332 }
1333}
1334
1335#[tonic::async_trait]
1336impl GrpcControllerService for ControllerService {
1337 type OpenControlChannelStream =
1338 Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
1339
1340 async fn open_control_channel(
1341 &self,
1342 request: Request<tonic::Streaming<ControlMessage>>,
1343 ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
1344 let remote_endpoint = request
1346 .remote_addr()
1347 .map(|addr| addr.to_string())
1348 .unwrap_or_else(|| "unknown".to_string());
1349
1350 let stream = request.into_inner();
1351 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
1352
1353 let cancellation_token = CancellationToken::new();
1354
1355 self.process_control_message_stream(None, stream, tx.clone(), cancellation_token.clone());
1356
1357 self.inner
1359 .tx_channels
1360 .write()
1361 .insert(remote_endpoint.clone(), tx);
1362
1363 self.inner
1365 .cancellation_tokens
1366 .write()
1367 .insert(remote_endpoint.clone(), cancellation_token);
1368
1369 let out_stream = ReceiverStream::new(rx);
1370 Ok(Response::new(
1371 Box::pin(out_stream) as Self::OpenControlChannelStream
1372 ))
1373 }
1374}
1375
1376#[cfg(test)]
1377mod tests {
1378 use super::*;
1379 use slim_config::component::id::Kind;
1380 use tracing_test::traced_test;
1381
1382 #[tokio::test]
1383 #[traced_test]
1384 async fn test_end_to_end() {
1385 let id_server =
1387 ID::new_with_name(Kind::new("slim").unwrap(), "test-server-instance").unwrap();
1388 let id_client =
1389 ID::new_with_name(Kind::new("slim").unwrap(), "test-client-instance").unwrap();
1390
1391 let server_config = ServerConfig::with_endpoint("127.0.0.1:50051")
1393 .with_tls_settings(slim_config::tls::server::TlsServerConfig::insecure());
1394
1395 let client_config = ClientConfig::with_endpoint("http://127.0.0.1:50051")
1397 .with_tls_setting(slim_config::tls::client::TlsClientConfig::insecure());
1398
1399 let (signal_server, watch_server) = drain::channel();
1401 let (signal_client, watch_client) = drain::channel();
1402
1403 let message_processor_client = MessageProcessor::with_drain_channel(watch_client.clone());
1405 let message_processor_server = MessageProcessor::with_drain_channel(watch_server.clone());
1406
1407 let pubsub_servers = [server_config.clone()];
1409 let mut control_plane_server = ControlPlane::new(ControlPlaneSettings {
1410 id: id_server,
1411 group_name: None,
1412 servers: vec![server_config],
1413 clients: vec![],
1414 drain_rx: watch_server,
1415 message_processor: Arc::new(message_processor_server),
1416 pubsub_servers: pubsub_servers.to_vec(),
1417 auth_provider: None,
1418 auth_verifier: None,
1419 });
1420
1421 let mut control_plane_client = ControlPlane::new(ControlPlaneSettings {
1422 id: id_client,
1423 group_name: None,
1424 servers: vec![],
1425 clients: vec![client_config],
1426 drain_rx: watch_client,
1427 message_processor: Arc::new(message_processor_client),
1428 pubsub_servers: pubsub_servers.to_vec(),
1429 auth_provider: None,
1430 auth_verifier: None,
1431 });
1432
1433 control_plane_server.run().await.unwrap();
1435
1436 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1438
1439 control_plane_client.run().await.unwrap();
1441
1442 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
1444
1445 assert!(logs_contain("received a register node request"));
1447
1448 drop(control_plane_server);
1451 drop(control_plane_client);
1452
1453 signal_server.drain().await;
1455 signal_client.drain().await;
1456 }
1457
1458 #[test]
1459 fn test_generate_session_id() {
1460 let moderator_a = Name::from_strings(["Org", "Ns", "Moderator"]).with_id(42);
1461 let moderator_b = Name::from_strings(["Org", "Ns", "Moderator"]).with_id(43); let channel_x = Name::from_strings(["Org", "Ns", "ChannelX"]).with_id(7);
1463 let channel_y = Name::from_strings(["Org", "Ns", "ChannelY"]).with_id(7); let id1 = generate_session_id(&moderator_a, &channel_x);
1466 let id2 = generate_session_id(&moderator_a, &channel_x);
1467 assert_eq!(id1, id2, "hash must be deterministic for same inputs");
1468
1469 let id3 = generate_session_id(&moderator_b, &channel_x);
1470 assert_ne!(id1, id3, "changing moderator id should change session id");
1471
1472 let id4 = generate_session_id(&moderator_a, &channel_y);
1473 assert_ne!(id1, id4, "changing channel name should change session id");
1474
1475 assert!(
1477 id1 != 0 && id3 != 0 && id4 != 0,
1478 "session ids should not be zero"
1479 );
1480 }
1481}