1use std::collections::HashMap;
5use std::net::ToSocketAddrs;
6use std::pin::Pin;
7use std::sync::{Arc, OnceLock};
8
9use slim_config::tls::client::TlsClientConfig;
10use tokio::sync::mpsc;
11use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
12use tokio_util::sync::CancellationToken;
13use tonic::codegen::{Body, StdError};
14use tonic::{Request, Response, Status};
15use tracing::{debug, error, info};
16
17use crate::api::proto::api::v1::{
18 Ack, ConnectionEntry, ControlMessage, SubscriptionEntry,
19 controller_service_client::ControllerServiceClient,
20 controller_service_server::ControllerService as GrpcControllerService,
21};
22use crate::api::proto::api::v1::{
23 ConnectionListResponse, ConnectionType, SubscriptionListResponse,
24};
25use crate::errors::ControllerError;
26
27use slim_config::grpc::client::ClientConfig;
28use slim_datapath::api::proto::pubsub::v1::Message as PubsubMessage;
29use slim_datapath::message_processing::MessageProcessor;
30use slim_datapath::messages::utils::SlimHeaderFlags;
31use slim_datapath::messages::{Agent, AgentType};
32use slim_datapath::tables::SubscriptionTable;
33
34#[derive(Debug, Clone)]
35pub struct ControllerService {
36 message_processor: Arc<MessageProcessor>,
38
39 tx_slim: OnceLock<mpsc::Sender<Result<PubsubMessage, Status>>>,
41
42 connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
44}
45
46impl ControllerService {
47 pub fn new(message_processor: Arc<MessageProcessor>) -> Self {
48 ControllerService {
49 message_processor,
50 tx_slim: OnceLock::new(),
51 connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
52 }
53 }
54
55 async fn handle_new_control_message(
56 &self,
57 msg: ControlMessage,
58 tx: mpsc::Sender<Result<ControlMessage, Status>>,
59 ) -> Result<(), ControllerError> {
60 match msg.payload {
61 Some(ref payload) => {
62 match payload {
63 crate::api::proto::api::v1::control_message::Payload::ConfigCommand(config) => {
64 for conn in &config.connections_to_create {
65 let client_endpoint =
66 format!("{}:{}", conn.remote_address, conn.remote_port);
67
68 let mut addrs_iter = client_endpoint
69 .as_str()
70 .to_socket_addrs()
71 .map_err(|e| ControllerError::ConnectionError(e.to_string()))?;
72 let remote_sock = addrs_iter
73 .next()
74 .ok_or_else(|| ControllerError::ConnectionError(format!("could not resolve {}", client_endpoint)))?;
75
76 if !self.connections.read().contains_key(&client_endpoint) {
78 let client_config = ClientConfig {
79 endpoint: format!("http://{}", client_endpoint),
80 tls_setting: TlsClientConfig::default().with_insecure(true),
81 ..ClientConfig::default()
82 };
83
84 match client_config.to_channel() {
85 Err(e) => {
86 error!("error reading channel config {:?}", e);
87 }
88 Ok(channel) => {
89 let ret = self
90 .message_processor
91 .connect(
92 channel,
93 Some(client_config.clone()),
94 None,
95 Some(remote_sock),
96 )
97 .await
98 .map_err(|e| {
99 ControllerError::ConnectionError(e.to_string())
100 });
101
102 let conn_id = match ret {
103 Err(e) => {
104 error!("connection error: {:?}", e);
105 return Err(ControllerError::ConnectionError(
106 e.to_string(),
107 ));
108 }
109 Ok(conn_id) => conn_id.1,
110 };
111
112 self.connections.write().insert(client_endpoint, conn_id);
113 }
114 }
115 }
116 }
117
118 for subscription in &config.subscriptions_to_set {
119 if !self.connections.read().contains_key(&subscription.connection_id) {
120 error!("connection {} not found", subscription.connection_id);
121 continue;
122 }
123
124 let conn = self
125 .connections
126 .read()
127 .get(&subscription.connection_id)
128 .cloned()
129 .unwrap();
130 let source = Agent::from_strings(
131 subscription.organization.as_str(),
132 subscription.namespace.as_str(),
133 subscription.agent_type.as_str(),
134 0,
135 );
136 let agent_type = AgentType::from_strings(
137 subscription.organization.as_str(),
138 subscription.namespace.as_str(),
139 subscription.agent_type.as_str(),
140 );
141
142 let msg = PubsubMessage::new_subscribe(
143 &source,
144 &agent_type,
145 subscription.agent_id,
146 Some(SlimHeaderFlags::default().with_recv_from(conn)),
147 );
148
149 if let Err(e) = self.send_control_message(msg).await {
150 error!("failed to subscribe: {}", e);
151 }
152 }
153
154 for subscription in &config.subscriptions_to_delete {
155 if !self.connections.read().contains_key(&subscription.connection_id) {
156 error!("connection {} not found", subscription.connection_id);
157 continue;
158 }
159
160 let conn = self
161 .connections
162 .read()
163 .get(&subscription.connection_id)
164 .cloned()
165 .unwrap();
166 let source = Agent::from_strings(
167 subscription.organization.as_str(),
168 subscription.namespace.as_str(),
169 subscription.agent_type.as_str(),
170 0,
171 );
172 let agent_type = AgentType::from_strings(
173 subscription.organization.as_str(),
174 subscription.namespace.as_str(),
175 subscription.agent_type.as_str(),
176 );
177
178 let msg = PubsubMessage::new_unsubscribe(
179 &source,
180 &agent_type,
181 subscription.agent_id,
182 Some(SlimHeaderFlags::default().with_recv_from(conn)),
183 );
184
185 if let Err(e) = self.send_control_message(msg).await {
186 error!("failed to unsubscribe: {}", e);
187 }
188 }
189
190 let ack = Ack {
191 original_message_id: msg.message_id.clone(),
192 success: true,
193 messages: vec![],
194 };
195
196 let reply = ControlMessage {
197 message_id: uuid::Uuid::new_v4().to_string(),
198 payload: Some(
199 crate::api::proto::api::v1::control_message::Payload::Ack(ack),
200 ),
201 };
202
203 if let Err(e) = tx.send(Ok(reply)).await {
204 eprintln!("failed to send ACK: {}", e);
205 }
206 }
207 crate::api::proto::api::v1::control_message::Payload::SubscriptionListRequest(_) => {
208 const CHUNK_SIZE: usize = 100;
209
210 let conn_table = self.message_processor.connection_table();
211 let mut entries = Vec::new();
212
213 self
214 .message_processor
215 .subscription_table()
216 .for_each(|agent_type, agent_id, local, remote| {
217 let mut entry = SubscriptionEntry {
218 organization: agent_type.organization_string().unwrap_or_else(|| agent_type.organization().to_string()),
219 namespace: agent_type.namespace_string().unwrap_or_else(|| agent_type.organization().to_string()),
220 agent_type: agent_type.agent_type_string().unwrap_or_else(|| agent_type.organization().to_string()),
221 agent_id: Some(agent_id),
222 ..Default::default()
223 };
224
225 for &cid in local {
226 entry.local_connections.push(ConnectionEntry {
227 id: cid,
228 connection_type: ConnectionType::Local as i32,
229 ip: String::new(),
230 port: 0,
231 });
232 }
233
234 for &cid in remote {
235 if let Some(conn) = conn_table.get(cid as usize) {
236 if let Some(sock) = conn.remote_addr() {
237 entry.remote_connections.push(ConnectionEntry {
238 id: cid,
239 connection_type: ConnectionType::Remote as i32,
240 ip: sock.ip().to_string(),
241 port: sock.port() as u32,
242 });
243 } else {
244 entry.remote_connections.push(ConnectionEntry {
245 id: cid,
246 connection_type: ConnectionType::Remote as i32,
247 ip: String::new(),
248 port: 0,
249 });
250 }
251 } else {
252 error!("no connection entry for id {}", cid);
253 entry.remote_connections.push(ConnectionEntry {
254 id: cid,
255 connection_type: ConnectionType::Remote as i32,
256 ip: String::new(),
257 port: 0,
258 });
259 }
260 }
261
262 entries.push(entry);
263 });
264
265 for chunk in entries.chunks(CHUNK_SIZE) {
266 let resp = ControlMessage {
267 message_id: uuid::Uuid::new_v4().to_string(),
268 payload: Some(
269 crate::api::proto::api::v1::control_message::Payload::SubscriptionListResponse(
270 SubscriptionListResponse {
271 entries: chunk.to_vec(),
272 }
273 )
274 ),
275 };
276
277 if let Err(e) = tx.try_send(Ok(resp)) {
278 error!("failed to send subscription batch: {}", e);
279 }
280 }
281 }
282 crate::api::proto::api::v1::control_message::Payload::ConnectionListRequest(_) => {
283 let mut all_entries = Vec::new();
284 self.message_processor
285 .connection_table()
286 .for_each(|id, conn| {
287 let (ip, port) = conn
288 .remote_addr()
289 .map(|sock| (sock.ip().to_string(), sock.port() as u32))
290 .unwrap_or_else(|| ("".into(), 0));
291
292 all_entries.push(ConnectionEntry {
293 id: id as u64,
294 connection_type: ConnectionType::Remote as i32,
295 ip,
296 port,
297 });
298 });
299
300 const CHUNK_SIZE: usize = 100;
301 for chunk in all_entries.chunks(CHUNK_SIZE) {
302 let resp = ControlMessage {
303 message_id: uuid::Uuid::new_v4().to_string(),
304 payload: Some(
305 crate::api::proto::api::v1::control_message::Payload::ConnectionListResponse(
306 ConnectionListResponse {
307 entries: chunk.to_vec(),
308 },
309 ),
310 ),
311 };
312
313 if let Err(e) = tx.try_send(Ok(resp)) {
314 error!("failed to send connection list batch: {}", e);
315 }
316 }
317 }
318 crate::api::proto::api::v1::control_message::Payload::Ack(_ack) => {
319 }
321 crate::api::proto::api::v1::control_message::Payload::SubscriptionListResponse(_) => {
322 }
324 crate::api::proto::api::v1::control_message::Payload::ConnectionListResponse(_) => {
325 }
327 }
328 }
329 None => {
330 println!(
331 "received control message {} with no payload",
332 msg.message_id
333 );
334 }
335 }
336
337 Ok(())
338 }
339
340 async fn send_control_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
341 let sender = self.tx_slim.get_or_init(|| {
342 let (_, tx_slim, _) = self.message_processor.register_local_connection();
343 tx_slim
344 });
345
346 sender.send(Ok(msg)).await.map_err(|e| {
347 error!("error sending message into datapath: {}", e);
348 ControllerError::DatapathError(e.to_string())
349 })
350 }
351
352 async fn process_control_message_stream(
353 &self,
354 cancellation_token: CancellationToken,
355 mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
356 tx: mpsc::Sender<Result<ControlMessage, Status>>,
357 ) -> tokio::task::JoinHandle<()> {
358 let svc = self.clone();
359 let token = cancellation_token.clone();
360
361 tokio::spawn(async move {
362 loop {
363 tokio::select! {
364 next = stream.next() => {
365 match next {
366 Some(Ok(msg)) => {
367 if let Err(e) = svc.handle_new_control_message(msg, tx.clone()).await {
368 error!("error processing incoming control message: {:?}", e);
369 }
370 }
371 Some(Err(e)) => {
372 if let Some(io_err) = ControllerService::match_for_io_error(&e) {
373 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
374 info!("connection closed by peer");
375 }
376 } else {
377 error!("error receiving control messages: {:?}", e);
378 }
379 break;
380 }
381 None => {
382 debug!("end of stream");
383 break;
384 }
385 }
386 }
387 _ = token.cancelled() => {
388 debug!("shutting down stream on cancellation token");
389 break;
390 }
391 }
392 }
393 })
394 }
395
396 pub async fn connect<C>(
397 &self,
398 channel: C,
399 ) -> Result<tokio::task::JoinHandle<()>, ControllerError>
400 where
401 C: tonic::client::GrpcService<tonic::body::Body>,
402 C::Error: Into<StdError>,
403 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
404 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
405 {
406 let max_retry = 10;
408
409 let mut client: ControllerServiceClient<C> = ControllerServiceClient::new(channel);
410 let mut i = 0;
411 while i < max_retry {
412 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
413 let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
414
415 match client.open_control_channel(Request::new(out_stream)).await {
416 Ok(stream) => {
417 let ret = self
418 .process_control_message_stream(
419 CancellationToken::new(),
420 stream.into_inner(),
421 tx,
422 )
423 .await;
424 return Ok(ret);
425 }
426 Err(e) => {
427 error!("connection error: {:?}.", e.to_string());
428 }
429 };
430
431 i += 1;
432
433 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
435 }
436
437 error!("unable to connect to the endpoint");
438 Err(ControllerError::ConnectionError(
439 "reached max connection retries".to_string(),
440 ))
441 }
442
443 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
444 let mut err: &(dyn std::error::Error + 'static) = err_status;
445
446 loop {
447 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
448 return Some(io_err);
449 }
450
451 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
454 if let Some(io_err) = h2_err.get_io() {
455 return Some(io_err);
456 }
457 }
458
459 err = err.source()?;
460 }
461 }
462}
463
464#[tonic::async_trait]
465impl GrpcControllerService for ControllerService {
466 type OpenControlChannelStream =
467 Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
468
469 async fn open_control_channel(
470 &self,
471 request: Request<tonic::Streaming<ControlMessage>>,
472 ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
473 let stream = request.into_inner();
474 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
475
476 self.process_control_message_stream(CancellationToken::new(), stream, tx.clone())
477 .await;
478
479 let out_stream = ReceiverStream::new(rx);
480 Ok(Response::new(
481 Box::pin(out_stream) as Self::OpenControlChannelStream
482 ))
483 }
484}