agp_controller/
service.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7
8use agp_config::tls::client::TlsClientConfig;
9use tokio::sync::mpsc;
10use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
11use tokio_util::sync::CancellationToken;
12use tonic::codegen::{Body, StdError};
13use tonic::{Request, Response, Status};
14use tracing::{debug, error, info};
15
16use crate::api::proto::api::v1::{
17    Ack, ControlMessage, controller_service_client::ControllerServiceClient,
18    controller_service_server::ControllerService as GrpcControllerService,
19};
20use crate::errors::ControllerError;
21
22use agp_config::grpc::client::ClientConfig;
23use agp_datapath::message_processing::MessageProcessor;
24use agp_datapath::messages::utils::AgpHeaderFlags;
25use agp_datapath::messages::{Agent, AgentType};
26use agp_datapath::pubsub::proto::pubsub::v1::Message as PubsubMessage;
27
28#[derive(Debug, Clone)]
29pub struct ControllerService {
30    /// underlying message processor
31    message_processor: Arc<MessageProcessor>,
32
33    /// channel to send messages into the datapath
34    tx_gw: OnceLock<mpsc::Sender<Result<PubsubMessage, Status>>>,
35
36    /// map of connection IDs to their configuration
37    connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
38}
39
40impl ControllerService {
41    pub fn new(message_processor: Arc<MessageProcessor>) -> Self {
42        ControllerService {
43            message_processor,
44            tx_gw: OnceLock::new(),
45            connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
46        }
47    }
48
49    async fn handle_new_message(
50        &self,
51        msg: ControlMessage,
52        tx: mpsc::Sender<Result<ControlMessage, Status>>,
53    ) -> Result<(), ControllerError> {
54        match msg.payload {
55            Some(ref payload) => {
56                match payload {
57                    crate::api::proto::api::v1::control_message::Payload::ConfigCommand(config) => {
58                        for conn in &config.connections_to_create {
59                            let client_endpoint =
60                                format!("{}:{}", conn.remote_address, conn.remote_port);
61
62                            // connect to an endpoint if it's not already connected
63                            if !self.connections.read().contains_key(&client_endpoint) {
64                                let client_config = ClientConfig {
65                                    endpoint: format!("http://{}", client_endpoint),
66                                    tls_setting: TlsClientConfig::default().with_insecure(true),
67                                    ..ClientConfig::default()
68                                };
69
70                                match client_config.to_channel() {
71                                    Err(e) => {
72                                        error!("error reading channel config {:?}", e);
73                                    }
74                                    Ok(channel) => {
75                                        let ret = self
76                                            .message_processor
77                                            .connect(
78                                                channel,
79                                                Some(client_config.clone()),
80                                                None,
81                                                None,
82                                            )
83                                            .await
84                                            .map_err(|e| {
85                                                ControllerError::ConnectionError(e.to_string())
86                                            });
87
88                                        let conn_id = match ret {
89                                            Err(e) => {
90                                                error!("connection error: {:?}", e);
91                                                return Err(ControllerError::ConnectionError(
92                                                    e.to_string(),
93                                                ));
94                                            }
95                                            Ok(conn_id) => conn_id.1,
96                                        };
97
98                                        self.connections.write().insert(client_endpoint, conn_id);
99                                    }
100                                }
101                            }
102                        }
103
104                        for route in &config.routes_to_set {
105                            if !self.connections.read().contains_key(&route.connection_id) {
106                                error!("connection {} not found", route.connection_id);
107                                continue;
108                            }
109
110                            let conn = self
111                                .connections
112                                .read()
113                                .get(&route.connection_id)
114                                .cloned()
115                                .unwrap();
116                            let source = Agent::from_strings(
117                                route.company.as_str(),
118                                route.namespace.as_str(),
119                                route.agent_name.as_str(),
120                                0,
121                            );
122                            let agent_type = AgentType::from_strings(
123                                route.company.as_str(),
124                                route.namespace.as_str(),
125                                route.agent_name.as_str(),
126                            );
127
128                            let msg = PubsubMessage::new_subscribe(
129                                &source,
130                                &agent_type,
131                                route.agent_id,
132                                Some(AgpHeaderFlags::default().with_recv_from(conn)),
133                            );
134
135                            if let Err(e) = self.send_message(msg).await {
136                                error!("failed to subscribe: {}", e);
137                            }
138                        }
139
140                        for route in &config.routes_to_delete {
141                            if !self.connections.read().contains_key(&route.connection_id) {
142                                error!("connection {} not found", route.connection_id);
143                                continue;
144                            }
145
146                            let conn = self
147                                .connections
148                                .read()
149                                .get(&route.connection_id)
150                                .cloned()
151                                .unwrap();
152                            let source = Agent::from_strings(
153                                route.company.as_str(),
154                                route.namespace.as_str(),
155                                route.agent_name.as_str(),
156                                0,
157                            );
158                            let agent_type = AgentType::from_strings(
159                                route.company.as_str(),
160                                route.namespace.as_str(),
161                                route.agent_name.as_str(),
162                            );
163
164                            let msg = PubsubMessage::new_unsubscribe(
165                                &source,
166                                &agent_type,
167                                route.agent_id,
168                                Some(AgpHeaderFlags::default().with_recv_from(conn)),
169                            );
170
171                            if let Err(e) = self.send_message(msg).await {
172                                error!("failed to unsubscribe: {}", e);
173                            }
174                        }
175
176                        let ack = Ack {
177                            original_message_id: msg.message_id.clone(),
178                            success: true,
179                            messages: vec![],
180                        };
181
182                        let reply = ControlMessage {
183                            message_id: uuid::Uuid::new_v4().to_string(),
184                            payload: Some(
185                                crate::api::proto::api::v1::control_message::Payload::Ack(ack),
186                            ),
187                        };
188
189                        if let Err(e) = tx.send(Ok(reply)).await {
190                            eprintln!("failed to send ACK: {}", e);
191                        }
192                    }
193                    crate::api::proto::api::v1::control_message::Payload::Ack(_ack) => {
194                        // received an ack, do nothing - this should not happen
195                    }
196                }
197            }
198            None => {
199                println!(
200                    "received control message {} with no payload",
201                    msg.message_id
202                );
203            }
204        }
205
206        Ok(())
207    }
208
209    async fn send_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
210        let sender = self.tx_gw.get_or_init(|| {
211            let (_, tx_gw, _) = self.message_processor.register_local_connection();
212            tx_gw
213        });
214
215        sender.send(Ok(msg)).await.map_err(|e| {
216            error!("error sending message into datapath: {}", e);
217            ControllerError::DatapathError(e.to_string())
218        })
219    }
220
221    async fn process_stream(
222        &self,
223        cancellation_token: CancellationToken,
224        mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
225        tx: mpsc::Sender<Result<ControlMessage, Status>>,
226    ) -> tokio::task::JoinHandle<()> {
227        let svc = self.clone();
228        let token = cancellation_token.clone();
229
230        tokio::spawn(async move {
231            loop {
232                tokio::select! {
233                    next = stream.next() => {
234                        match next {
235                            Some(Ok(msg)) => {
236                                if let Err(e) = svc.handle_new_message(msg, tx.clone()).await {
237                                    error!("error processing incoming control message: {:?}", e);
238                                }
239                            }
240                            Some(Err(e)) => {
241                                if let Some(io_err) = ControllerService::match_for_io_error(&e) {
242                                    if io_err.kind() == std::io::ErrorKind::BrokenPipe {
243                                        info!("connection closed by peer");
244                                    }
245                                } else {
246                                    error!("error receiving control messages: {:?}", e);
247                                }
248                                break;
249                            }
250                            None => {
251                                debug!("end of stream");
252                                break;
253                            }
254                        }
255                    }
256                    _ = token.cancelled() => {
257                        debug!("shutting down stream on cancellation token");
258                        break;
259                    }
260                }
261            }
262        })
263    }
264
265    pub async fn connect<C>(
266        &self,
267        channel: C,
268    ) -> Result<tokio::task::JoinHandle<()>, ControllerError>
269    where
270        C: tonic::client::GrpcService<tonic::body::Body>,
271        C::Error: Into<StdError>,
272        C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
273        <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
274    {
275        //TODO(zkacsand): make this a constant or make it configurable?
276        let max_retry = 10;
277
278        let mut client: ControllerServiceClient<C> = ControllerServiceClient::new(channel);
279        let mut i = 0;
280        while i < max_retry {
281            let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
282            let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
283
284            match client.open_control_channel(Request::new(out_stream)).await {
285                Ok(stream) => {
286                    let ret = self
287                        .process_stream(CancellationToken::new(), stream.into_inner(), tx)
288                        .await;
289                    return Ok(ret);
290                }
291                Err(e) => {
292                    error!("connection error: {:?}.", e.to_string());
293                }
294            };
295
296            i += 1;
297
298            // sleep 1 sec between each connection retry
299            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
300        }
301
302        error!("unable to connect to the endpoint");
303        Err(ControllerError::ConnectionError(
304            "reached max connection retries".to_string(),
305        ))
306    }
307
308    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
309        let mut err: &(dyn std::error::Error + 'static) = err_status;
310
311        loop {
312            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
313                return Some(io_err);
314            }
315
316            // h2::Error do not expose std::io::Error with `source()`
317            // https://github.com/hyperium/h2/pull/462
318            if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
319                if let Some(io_err) = h2_err.get_io() {
320                    return Some(io_err);
321                }
322            }
323
324            err = err.source()?;
325        }
326    }
327}
328
329#[tonic::async_trait]
330impl GrpcControllerService for ControllerService {
331    type OpenControlChannelStream =
332        Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
333
334    async fn open_control_channel(
335        &self,
336        request: Request<tonic::Streaming<ControlMessage>>,
337    ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
338        let stream = request.into_inner();
339        let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
340
341        self.process_stream(CancellationToken::new(), stream, tx.clone())
342            .await;
343
344        let out_stream = ReceiverStream::new(rx);
345        Ok(Response::new(
346            Box::pin(out_stream) as Self::OpenControlChannelStream
347        ))
348    }
349}