1use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::{Arc, OnceLock};
7
8use agp_config::tls::client::TlsClientConfig;
9use tokio::sync::mpsc;
10use tokio_stream::{Stream, StreamExt, wrappers::ReceiverStream};
11use tokio_util::sync::CancellationToken;
12use tonic::codegen::{Body, StdError};
13use tonic::{Request, Response, Status};
14use tracing::{debug, error, info};
15
16use crate::api::proto::api::v1::{
17 Ack, ControlMessage, controller_service_client::ControllerServiceClient,
18 controller_service_server::ControllerService as GrpcControllerService,
19};
20use crate::errors::ControllerError;
21
22use agp_config::grpc::client::ClientConfig;
23use agp_datapath::message_processing::MessageProcessor;
24use agp_datapath::messages::utils::AgpHeaderFlags;
25use agp_datapath::messages::{Agent, AgentType};
26use agp_datapath::pubsub::proto::pubsub::v1::Message as PubsubMessage;
27
28#[derive(Debug, Clone)]
29pub struct ControllerService {
30 message_processor: Arc<MessageProcessor>,
32
33 tx_gw: OnceLock<mpsc::Sender<Result<PubsubMessage, Status>>>,
35
36 connections: Arc<parking_lot::RwLock<HashMap<String, u64>>>,
38}
39
40impl ControllerService {
41 pub fn new(message_processor: Arc<MessageProcessor>) -> Self {
42 ControllerService {
43 message_processor,
44 tx_gw: OnceLock::new(),
45 connections: Arc::new(parking_lot::RwLock::new(HashMap::new())),
46 }
47 }
48
49 async fn handle_new_message(
50 &self,
51 msg: ControlMessage,
52 tx: mpsc::Sender<Result<ControlMessage, Status>>,
53 ) -> Result<(), ControllerError> {
54 match msg.payload {
55 Some(ref payload) => {
56 match payload {
57 crate::api::proto::api::v1::control_message::Payload::ConfigCommand(config) => {
58 for conn in &config.connections_to_create {
59 let client_endpoint =
60 format!("{}:{}", conn.remote_address, conn.remote_port);
61
62 if !self.connections.read().contains_key(&client_endpoint) {
64 let client_config = ClientConfig {
65 endpoint: format!("http://{}", client_endpoint),
66 tls_setting: TlsClientConfig::default().with_insecure(true),
67 ..ClientConfig::default()
68 };
69
70 match client_config.to_channel() {
71 Err(e) => {
72 error!("error reading channel config {:?}", e);
73 }
74 Ok(channel) => {
75 let ret = self
76 .message_processor
77 .connect(
78 channel,
79 Some(client_config.clone()),
80 None,
81 None,
82 )
83 .await
84 .map_err(|e| {
85 ControllerError::ConnectionError(e.to_string())
86 });
87
88 let conn_id = match ret {
89 Err(e) => {
90 error!("connection error: {:?}", e);
91 return Err(ControllerError::ConnectionError(
92 e.to_string(),
93 ));
94 }
95 Ok(conn_id) => conn_id.1,
96 };
97
98 self.connections.write().insert(client_endpoint, conn_id);
99 }
100 }
101 }
102 }
103
104 for route in &config.routes_to_set {
105 if !self.connections.read().contains_key(&route.connection_id) {
106 error!("connection {} not found", route.connection_id);
107 continue;
108 }
109
110 let conn = self
111 .connections
112 .read()
113 .get(&route.connection_id)
114 .cloned()
115 .unwrap();
116 let source = Agent::from_strings(
117 route.company.as_str(),
118 route.namespace.as_str(),
119 route.agent_name.as_str(),
120 0,
121 );
122 let agent_type = AgentType::from_strings(
123 route.company.as_str(),
124 route.namespace.as_str(),
125 route.agent_name.as_str(),
126 );
127
128 let msg = PubsubMessage::new_subscribe(
129 &source,
130 &agent_type,
131 route.agent_id,
132 Some(AgpHeaderFlags::default().with_recv_from(conn)),
133 );
134
135 if let Err(e) = self.send_message(msg).await {
136 error!("failed to subscribe: {}", e);
137 }
138 }
139
140 for route in &config.routes_to_delete {
141 if !self.connections.read().contains_key(&route.connection_id) {
142 error!("connection {} not found", route.connection_id);
143 continue;
144 }
145
146 let conn = self
147 .connections
148 .read()
149 .get(&route.connection_id)
150 .cloned()
151 .unwrap();
152 let source = Agent::from_strings(
153 route.company.as_str(),
154 route.namespace.as_str(),
155 route.agent_name.as_str(),
156 0,
157 );
158 let agent_type = AgentType::from_strings(
159 route.company.as_str(),
160 route.namespace.as_str(),
161 route.agent_name.as_str(),
162 );
163
164 let msg = PubsubMessage::new_unsubscribe(
165 &source,
166 &agent_type,
167 route.agent_id,
168 Some(AgpHeaderFlags::default().with_recv_from(conn)),
169 );
170
171 if let Err(e) = self.send_message(msg).await {
172 error!("failed to unsubscribe: {}", e);
173 }
174 }
175
176 let ack = Ack {
177 original_message_id: msg.message_id.clone(),
178 success: true,
179 messages: vec![],
180 };
181
182 let reply = ControlMessage {
183 message_id: uuid::Uuid::new_v4().to_string(),
184 payload: Some(
185 crate::api::proto::api::v1::control_message::Payload::Ack(ack),
186 ),
187 };
188
189 if let Err(e) = tx.send(Ok(reply)).await {
190 eprintln!("failed to send ACK: {}", e);
191 }
192 }
193 crate::api::proto::api::v1::control_message::Payload::Ack(_ack) => {
194 }
196 }
197 }
198 None => {
199 println!(
200 "received control message {} with no payload",
201 msg.message_id
202 );
203 }
204 }
205
206 Ok(())
207 }
208
209 async fn send_message(&self, msg: PubsubMessage) -> Result<(), ControllerError> {
210 let sender = self.tx_gw.get_or_init(|| {
211 let (_, tx_gw, _) = self.message_processor.register_local_connection();
212 tx_gw
213 });
214
215 sender.send(Ok(msg)).await.map_err(|e| {
216 error!("error sending message into datapath: {}", e);
217 ControllerError::DatapathError(e.to_string())
218 })
219 }
220
221 async fn process_stream(
222 &self,
223 cancellation_token: CancellationToken,
224 mut stream: impl Stream<Item = Result<ControlMessage, Status>> + Unpin + Send + 'static,
225 tx: mpsc::Sender<Result<ControlMessage, Status>>,
226 ) -> tokio::task::JoinHandle<()> {
227 let svc = self.clone();
228 let token = cancellation_token.clone();
229
230 tokio::spawn(async move {
231 loop {
232 tokio::select! {
233 next = stream.next() => {
234 match next {
235 Some(Ok(msg)) => {
236 if let Err(e) = svc.handle_new_message(msg, tx.clone()).await {
237 error!("error processing incoming control message: {:?}", e);
238 }
239 }
240 Some(Err(e)) => {
241 if let Some(io_err) = ControllerService::match_for_io_error(&e) {
242 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
243 info!("connection closed by peer");
244 }
245 } else {
246 error!("error receiving control messages: {:?}", e);
247 }
248 break;
249 }
250 None => {
251 debug!("end of stream");
252 break;
253 }
254 }
255 }
256 _ = token.cancelled() => {
257 debug!("shutting down stream on cancellation token");
258 break;
259 }
260 }
261 }
262 })
263 }
264
265 pub async fn connect<C>(
266 &self,
267 channel: C,
268 ) -> Result<tokio::task::JoinHandle<()>, ControllerError>
269 where
270 C: tonic::client::GrpcService<tonic::body::Body>,
271 C::Error: Into<StdError>,
272 C::ResponseBody: Body<Data = bytes::Bytes> + std::marker::Send + 'static,
273 <C::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
274 {
275 let max_retry = 10;
277
278 let mut client: ControllerServiceClient<C> = ControllerServiceClient::new(channel);
279 let mut i = 0;
280 while i < max_retry {
281 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
282 let out_stream = ReceiverStream::new(rx).map(|res| res.expect("mapping error"));
283
284 match client.open_control_channel(Request::new(out_stream)).await {
285 Ok(stream) => {
286 let ret = self
287 .process_stream(CancellationToken::new(), stream.into_inner(), tx)
288 .await;
289 return Ok(ret);
290 }
291 Err(e) => {
292 error!("connection error: {:?}.", e.to_string());
293 }
294 };
295
296 i += 1;
297
298 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
300 }
301
302 error!("unable to connect to the endpoint");
303 Err(ControllerError::ConnectionError(
304 "reached max connection retries".to_string(),
305 ))
306 }
307
308 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
309 let mut err: &(dyn std::error::Error + 'static) = err_status;
310
311 loop {
312 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
313 return Some(io_err);
314 }
315
316 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
319 if let Some(io_err) = h2_err.get_io() {
320 return Some(io_err);
321 }
322 }
323
324 err = err.source()?;
325 }
326 }
327}
328
329#[tonic::async_trait]
330impl GrpcControllerService for ControllerService {
331 type OpenControlChannelStream =
332 Pin<Box<dyn Stream<Item = Result<ControlMessage, Status>> + Send + 'static>>;
333
334 async fn open_control_channel(
335 &self,
336 request: Request<tonic::Streaming<ControlMessage>>,
337 ) -> Result<Response<Self::OpenControlChannelStream>, Status> {
338 let stream = request.into_inner();
339 let (tx, rx) = mpsc::channel::<Result<ControlMessage, Status>>(128);
340
341 self.process_stream(CancellationToken::new(), stream, tx.clone())
342 .await;
343
344 let out_stream = ReceiverStream::new(rx);
345 Ok(Response::new(
346 Box::pin(out_stream) as Self::OpenControlChannelStream
347 ))
348 }
349}