slim_controller/
service.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::net::ToSocketAddrs;
6use std::pin::Pin;
7use std::sync::{Arc, OnceLock};
8
9use slim_config::tls::client::TlsClientConfig;
10use tokio::sync::mpsc;
11use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
12use tokio_util::sync::CancellationToken;
13use tonic::codegen::{Body, StdError};
14use tonic::{Request, Response, Status};
15use tracing::{debug, error, info};
16
17use crate::api::proto::api::v1::{
18    Ack, ConnectionEntry, ControlMessage, SubscriptionEntry,
19    controller_service_client::ControllerServiceClient,
20    controller_service_server::ControllerService as GrpcControllerService,
21};
22use crate::api::proto::api::v1::{
23    ConnectionListResponse, ConnectionType, SubscriptionListResponse,
24};
25use crate::errors::ControllerError;
26
27use slim_config::grpc::client::ClientConfig;
28use slim_datapath::api::proto::pubsub::v1::Message as PubsubMessage;
29use slim_datapath::message_processing::MessageProcessor;
30use slim_datapath::messages::utils::SlimHeaderFlags;
31use slim_datapath::messages::{Agent, AgentType};
32use slim_datapath::tables::SubscriptionTable;
33
34#[derive(Debug, Clone)]
35pub struct ControllerService {
36    /// underlying message processor
37    message_processor: Arc<MessageProcessor>,
38
39    /// channel to send messages into the datapath
40    tx_slim: OnceLock<mpsc::Sender<Result<PubsubMessage, Status>>>,
41
42    /// map of connection IDs to their configuration
43    connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
44}
45
46impl ControllerService {
47    pub fn new(message_processor: Arc<MessageProcessor>) -> Self {
48        ControllerService {
49            message_processor,
50            tx_slim: OnceLock::new(),
51            connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
52        }
53    }
54
55    async fn handle_new_control_message(
56        &self,
57        msg: ControlMessage,
58        tx: mpsc::Sender<Result<ControlMessage, Status>>,
59    ) -> Result<(), ControllerError> {
60        match msg.payload {
61            Some(ref payload) => {
62                match payload {
63                    crate::api::proto::api::v1::control_message::Payload::ConfigCommand(config) => {
64                        for conn in &config.connections_to_create {
65                            let client_endpoint =
66                                format!("{}:{}", conn.remote_address, conn.remote_port);
67
68                            let mut addrs_iter = client_endpoint
69                                .as_str()
70                                .to_socket_addrs()
71                                .map_err(|e| ControllerError::ConnectionError(e.to_string()))?;
72                            let remote_sock = addrs_iter
73                                .next()
74                                .ok_or_else(|| ControllerError::ConnectionError(format!("could not resolve {}", client_endpoint)))?;
75
76                            // connect to an endpoint if it's not already connected
77                            if !self.connections.read().contains_key(&client_endpoint) {
78                                let client_config = ClientConfig {
79                                    endpoint: format!("http://{}", client_endpoint),
80                                    tls_setting: TlsClientConfig::default().with_insecure(true),
81                                    ..ClientConfig::default()
82                                };
83
84                                match client_config.to_channel() {
85                                    Err(e) => {
86                                        error!("error reading channel config {:?}", e);
87                                    }
88                                    Ok(channel) => {
89                                        let ret = self
90                                            .message_processor
91                                            .connect(
92                                                channel,
93                                                Some(client_config.clone()),
94                                                None,
95                                                Some(remote_sock),
96                                            )
97                                            .await
98                                            .map_err(|e| {
99                                                ControllerError::ConnectionError(e.to_string())
100                                            });
101
102                                        let conn_id = match ret {
103                                            Err(e) => {
104                                                error!("connection error: {:?}", e);
105                                                return Err(ControllerError::ConnectionError(
106                                                    e.to_string(),
107                                                ));
108                                            }
109                                            Ok(conn_id) => conn_id.1,
110                                        };
111
112                                        self.connections.write().insert(client_endpoint, conn_id);
113                                    }
114                                }
115                            }
116                        }
117
118                        for subscription in &config.subscriptions_to_set {
119                            if !self.connections.read().contains_key(&subscription.connection_id) {
120                                error!("connection {} not found", subscription.connection_id);
121                                continue;
122                            }
123
124                            let conn = self
125                                .connections
126                                .read()
127                                .get(&subscription.connection_id)
128                                .cloned()
129                                .unwrap();
130                            let source = Agent::from_strings(
131                                subscription.organization.as_str(),
132                                subscription.namespace.as_str(),
133                                subscription.agent_type.as_str(),
134                                0,
135                            );
136                            let agent_type = AgentType::from_strings(
137                                subscription.organization.as_str(),
138                                subscription.namespace.as_str(),
139                                subscription.agent_type.as_str(),
140                            );
141
142                            let msg = PubsubMessage::new_subscribe(
143                                &source,
144                                &agent_type,
145                                subscription.agent_id,
146                                Some(SlimHeaderFlags::default().with_recv_from(conn)),
147                            );
148
149                            if let Err(e) = self.send_control_message(msg).await {
150                                error!("failed to subscribe: {}", e);
151                            }
152                        }
153
154                        for subscription in &config.subscriptions_to_delete {
155                            if !self.connections.read().contains_key(&subscription.connection_id) {
156                                error!("connection {} not found", subscription.connection_id);
157                                continue;
158                            }
159
160                            let conn = self
161                                .connections
162                                .read()
163                                .get(&subscription.connection_id)
164                                .cloned()
165                                .unwrap();
166                            let source = Agent::from_strings(
167                                subscription.organization.as_str(),
168                                subscription.namespace.as_str(),
169                                subscription.agent_type.as_str(),
170                                0,
171                            );
172                            let agent_type = AgentType::from_strings(
173                                subscription.organization.as_str(),
174                                subscription.namespace.as_str(),
175                                subscription.agent_type.as_str(),
176                            );
177
178                            let msg = PubsubMessage::new_unsubscribe(
179                                &source,
180                                &agent_type,
181                                subscription.agent_id,
182                                Some(SlimHeaderFlags::default().with_recv_from(conn)),
183                            );
184
185                            if let Err(e) = self.send_control_message(msg).await {
186                                error!("failed to unsubscribe: {}", e);
187                            }
188                        }
189
190                        let ack = Ack {
191                            original_message_id: msg.message_id.clone(),
192                            success: true,
193                            messages: vec![],
194                        };
195
196                        let reply = ControlMessage {
197                            message_id: uuid::Uuid::new_v4().to_string(),
198                            payload: Some(
199                                crate::api::proto::api::v1::control_message::Payload::Ack(ack),
200                            ),
201                        };
202
203                        if let Err(e) = tx.send(Ok(reply)).await {
204                            eprintln!("failed to send ACK: {}", e);
205                        }
206                    }
207                    crate::api::proto::api::v1::control_message::Payload::SubscriptionListRequest(_) => {
208                        const CHUNK_SIZE: usize = 100;
209
210                        let conn_table = self.message_processor.connection_table();
211                        let mut entries = Vec::new();
212
213                        self
214                            .message_processor
215                            .subscription_table()
216                            .for_each(|agent_type, agent_id, local, remote| {
217                                let mut entry = SubscriptionEntry {
218                                    organization: agent_type.organization_string().unwrap_or_else(|| agent_type.organization().to_string()),
219                                    namespace: agent_type.namespace_string().unwrap_or_else(|| agent_type.organization().to_string()),
220                                    agent_type: agent_type.agent_type_string().unwrap_or_else(|| agent_type.organization().to_string()),
221                                    agent_id: Some(agent_id),
222                                    ..Default::default()
223                                };
224
225                                for &cid in local {
226                                    entry.local_connections.push(ConnectionEntry {
227                                        id:   cid,
228                                        connection_type: ConnectionType::Local as i32,
229                                        ip:   String::new(),
230                                        port: 0,
231                                    });
232                                }
233
234                                for &cid in remote {
235                                    if let Some(conn) = conn_table.get(cid as usize) {
236                                        if let Some(sock) = conn.remote_addr() {
237                                            entry.remote_connections.push(ConnectionEntry {
238                                                id:   cid,
239                                                connection_type: ConnectionType::Remote as i32,
240                                                ip:   sock.ip().to_string(),
241                                                port: sock.port() as u32,
242                                            });
243                                        } else {
244                                            entry.remote_connections.push(ConnectionEntry {
245                                                id:   cid,
246                                                connection_type: ConnectionType::Remote as i32,
247                                                ip:   String::new(),
248                                                port: 0,
249                                            });
250                                        }
251                                    } else {
252                                        error!("no connection entry for id {}", cid);
253                                        entry.remote_connections.push(ConnectionEntry {
254                                            id:   cid,
255                                            connection_type: ConnectionType::Remote as i32,
256                                            ip:   String::new(),
257                                            port: 0,
258                                        });
259                                    }
260                                }
261
262                                entries.push(entry);
263                            });
264
265                        for chunk in entries.chunks(CHUNK_SIZE) {
266                            let resp = ControlMessage {
267                                message_id: uuid::Uuid::new_v4().to_string(),
268                                payload: Some(
269                                    crate::api::proto::api::v1::control_message::Payload::SubscriptionListResponse(
270                                        SubscriptionListResponse {
271                                            entries: chunk.to_vec(),
272                                        }
273                                    )
274                                ),
275                            };
276
277                            if let Err(e) = tx.try_send(Ok(resp)) {
278                                error!("failed to send subscription batch: {}", e);
279                            }
280                        }
281                    }
282                    crate::api::proto::api::v1::control_message::Payload::ConnectionListRequest(_) => {
283                        let mut all_entries = Vec::new();
284                        self.message_processor
285                            .connection_table()
286                            .for_each(|id, conn| {
287                                let (ip, port) = conn
288                                    .remote_addr()
289                                    .map(|sock| (sock.ip().to_string(), sock.port() as u32))
290                                    .unwrap_or_else(|| ("".into(), 0));
291
292                                all_entries.push(ConnectionEntry {
293                                    id: id as u64,
294                                    connection_type: ConnectionType::Remote as i32,
295                                    ip,
296                                    port,
297                                });
298                            });
299
300                        const CHUNK_SIZE: usize = 100;
301                        for chunk in all_entries.chunks(CHUNK_SIZE) {
302                            let resp = ControlMessage {
303                                message_id: uuid::Uuid::new_v4().to_string(),
304                                payload: Some(
305                                    crate::api::proto::api::v1::control_message::Payload::ConnectionListResponse(
306                                        ConnectionListResponse {
307                                            entries: chunk.to_vec(),
308                                        },
309                                    ),
310                                ),
311                            };
312
313                            if let Err(e) = tx.try_send(Ok(resp)) {
314                                error!("failed to send connection list batch: {}", e);
315                            }
316                        }
317                    }
318                    crate::api::proto::api::v1::control_message::Payload::Ack(_ack) => {
319                        // received an ack, do nothing - this should not happen
320                    }
321                    crate::api::proto::api::v1::control_message::Payload::SubscriptionListResponse(_) => {
322                        // received a subscription list response, do nothing - this should not happen
323                    }
324                    crate::api::proto::api::v1::control_message::Payload::ConnectionListResponse(_) => {
325                        // received a connection list response, do nothing - this should not happen
326                    }
327                }
328            }
329            None => {
330                println!(
331                    "received control message {} with no payload",
332                    msg.message_id
333                );
334            }
335        }
336
337        Ok(())
338    }
339
340    async fn send_control_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
341        let sender = self.tx_slim.get_or_init(|| {
342            let (_, tx_slim, _) = self.message_processor.register_local_connection();
343            tx_slim
344        });
345
346        sender.send(Ok(msg)).await.map_err(|e| {
347            error!("error sending message into datapath: {}", e);
348            ControllerError::DatapathError(e.to_string())
349        })
350    }
351
352    async fn process_control_message_stream(
353        &self,
354        cancellation_token: CancellationToken,
355        mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
356        tx: mpsc::Sender<Result<ControlMessage, Status>>,
357    ) -> tokio::task::JoinHandle<()> {
358        let svc = self.clone();
359        let token = cancellation_token.clone();
360
361        tokio::spawn(async move {
362            loop {
363                tokio::select! {
364                    next = stream.next() => {
365                        match next {
366                            Some(Ok(msg)) => {
367                                if let Err(e) = svc.handle_new_control_message(msg, tx.clone()).await {
368                                    error!("error processing incoming control message: {:?}", e);
369                                }
370                            }
371                            Some(Err(e)) => {
372                                if let Some(io_err) = ControllerService::match_for_io_error(&e) {
373                                    if io_err.kind() == std::io::ErrorKind::BrokenPipe {
374                                        info!("connection closed by peer");
375                                    }
376                                } else {
377                                    error!("error receiving control messages: {:?}", e);
378                                }
379                                break;
380                            }
381                            None => {
382                                debug!("end of stream");
383                                break;
384                            }
385                        }
386                    }
387                    _ = token.cancelled() => {
388                        debug!("shutting down stream on cancellation token");
389                        break;
390                    }
391                }
392            }
393        })
394    }
395
396    pub async fn connect<C>(
397        &self,
398        channel: C,
399    ) -> Result<tokio::task::JoinHandle<()>, ControllerError>
400    where
401        C: tonic::client::GrpcService<tonic::body::Body>,
402        C::Error: Into<StdError>,
403        C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
404        <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
405    {
406        //TODO(zkacsand): make this a constant or make it configurable?
407        let max_retry = 10;
408
409        let mut client: ControllerServiceClient<C> = ControllerServiceClient::new(channel);
410        let mut i = 0;
411        while i < max_retry {
412            let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
413            let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
414
415            match client.open_control_channel(Request::new(out_stream)).await {
416                Ok(stream) => {
417                    let ret = self
418                        .process_control_message_stream(
419                            CancellationToken::new(),
420                            stream.into_inner(),
421                            tx,
422                        )
423                        .await;
424                    return Ok(ret);
425                }
426                Err(e) => {
427                    error!("connection error: {:?}.", e.to_string());
428                }
429            };
430
431            i += 1;
432
433            // sleep 1 sec between each connection retry
434            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
435        }
436
437        error!("unable to connect to the endpoint");
438        Err(ControllerError::ConnectionError(
439            "reached max connection retries".to_string(),
440        ))
441    }
442
443    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
444        let mut err: &(dyn std::error::Error + 'static) = err_status;
445
446        loop {
447            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
448                return Some(io_err);
449            }
450
451            // h2::Error do not expose std::io::Error with `source()`
452            // https://github.com/hyperium/h2/pull/462
453            if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
454                if let Some(io_err) = h2_err.get_io() {
455                    return Some(io_err);
456                }
457            }
458
459            err = err.source()?;
460        }
461    }
462}
463
464#[tonic::async_trait]
465impl GrpcControllerService for ControllerService {
466    type OpenControlChannelStream =
467        Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
468
469    async fn open_control_channel(
470        &self,
471        request: Request<tonic::Streaming<ControlMessage>>,
472    ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
473        let stream = request.into_inner();
474        let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
475
476        self.process_control_message_stream(CancellationToken::new(), stream, tx.clone())
477            .await;
478
479        let out_stream = ReceiverStream::new(rx);
480        Ok(Response::new(
481            Box::pin(out_stream) as Self::OpenControlChannelStream
482        ))
483    }
484}