agent_api/
api.rs

1#![allow(warnings)]
2use anyhow::Context;
3use chrono::Local;
4use prost::bytes::BytesMut;
5use std::str::FromStr;
6use std::{ sync::Mutex };
7use tonic::{ Request, Response, Status };
8use tracing::info;
9
10use aya::{ maps::{ MapData, PerfEventArray }, util::online_cpus };
11use std::result::Result::Ok;
12use tonic::async_trait;
13
14use std::collections::HashMap;
15use aya::maps::HashMap as ayaHashMap;
16use tokio::sync::mpsc;
17use tokio::task;
18
19use crate::agent::ConnectionEvent;
20// *  contains agent api configuration
21use crate::agent::{
22    agent_server::Agent,
23    ActiveConnectionResponse,
24    RequestActiveConnections,
25    AddIpToBlocklistRequest,
26    BlocklistResponse,
27    RmIpFromBlocklistRequest,
28    RmIpFromBlocklistResponse,
29};
30use aya::maps::Map;
31use bytemuck_derive::Zeroable;
32use cortexflow_identity::enums::IpProtocols;
33use std::net::Ipv4Addr;
34use tracing::warn;
35
36#[repr(C)]
37#[derive(Clone, Copy, Zeroable)]
38pub struct PacketLog {
39    pub proto: u8,
40    pub src_ip: u32,
41    pub src_port: u16,
42    pub dst_ip: u32,
43    pub dst_port: u16,
44    pub pid: u32,
45}
46unsafe impl aya::Pod for PacketLog {}
47
48pub struct AgentApi {
49    //* event_rx is an istance of a mpsc receiver.
50    //* is used to receive the data from the transmitter (tx)
51    event_rx: Mutex<mpsc::Receiver<Result<Vec<ConnectionEvent>, Status>>>,
52    event_tx: mpsc::Sender<Result<Vec<ConnectionEvent>, Status>>,
53}
54
55//* Event sender trait. Takes an event from a map and send that to the mpsc channel
56//* using the send_map function
57#[async_trait]
58pub trait EventSender: Send + Sync + 'static {
59    async fn send_event(&self, event: Vec<ConnectionEvent>);
60    async fn send_map(
61        &self,
62        map: Vec<ConnectionEvent>,
63        tx: mpsc::Sender<Result<Vec<ConnectionEvent>, Status>>
64    ) {
65        let status = Status::new(tonic::Code::Ok, "success");
66        let event = Ok(map);
67
68        let _ = tx.send(event).await;
69    }
70}
71
72// send event function. takes an HashMap and send that using mpsc event_tx
73#[async_trait]
74impl EventSender for AgentApi {
75    async fn send_event(&self, event: Vec<ConnectionEvent>) {
76        self.send_map(event, self.event_tx.clone()).await;
77    }
78}
79
80const BPF_PATH: &str = "BPF_PATH";
81const PIN_BLOCKLIST_MAP_PATH: &str = "PIN_BLOCKLIST_MAP_PATH";
82
83//initialize a default trait for AgentApi. Loads a name and a bpf istance.
84//this trait is essential for init the Agent.
85impl Default for AgentApi {
86    //TODO:this part needs a better error handling
87    fn default() -> Self {
88        // load maps mapdata
89        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/events_map").expect(
90            "cannot open events_map Mapdata"
91        );
92        let map = Map::PerfEventArray(mapdata); //creates a PerfEventArray from the mapdata
93
94        //init a mpsc channel
95        let (tx, rx) = mpsc::channel(1024);
96        let api = AgentApi {
97            event_rx: rx.into(),
98            event_tx: tx.clone(),
99        };
100
101        let mut events_array = PerfEventArray::try_from(map).expect(
102            "Error while initializing events array"
103        );
104
105        //spawn an event reader
106        task::spawn(async move {
107            let mut net_events_buffer = Vec::new();
108            //scan the cpus to read the data
109
110            for cpu_id in online_cpus()
111                .map_err(|e| anyhow::anyhow!("Error {:?}", e))
112                .unwrap() {
113                let buf = events_array
114                    .open(cpu_id, None)
115                    .expect("Error during the creation of net_events_buf structure");
116
117                let buffers = vec![BytesMut::with_capacity(4096); 8];
118                net_events_buffer.push((buf, buffers));
119            }
120
121            info!("Starting event listener");
122            //send the data through a mpsc channel
123            loop {
124                for (buf, buffers) in net_events_buffer.iter_mut() {
125                    match buf.read_events(buffers) {
126                        Ok(events) => {
127                            //read the events, this function is similar to the one used in identity/helpers.rs/display_events
128                            if events.read > 0 {
129                                for i in 0..events.read {
130                                    let data = &buffers[i];
131                                    if data.len() >= std::mem::size_of::<PacketLog>() {
132                                        let pl: PacketLog = unsafe {
133                                            std::ptr::read(data.as_ptr() as *const _)
134                                        };
135                                        let src = Ipv4Addr::from(u32::from_be(pl.src_ip));
136                                        let dst = Ipv4Addr::from(u32::from_be(pl.dst_ip));
137                                        let src_port = u16::from_be(pl.src_port as u16);
138                                        let dst_port = u16::from_be(pl.dst_port as u16);
139                                        let event_id = pl.pid;
140
141                                        match IpProtocols::try_from(pl.proto) {
142                                            Ok(proto) => {
143                                                info!(
144                                                    "Event Id: {} Protocol: {:?} SRC: {}:{} -> DST: {}:{}",
145                                                    event_id,
146                                                    proto,
147                                                    src,
148                                                    src_port,
149                                                    dst,
150                                                    dst_port
151                                                );
152                                                info!("creating vector for the aggregated data");
153                                                let mut evt = Vec::new();
154                                                // insert event in the vector
155                                                info!("Inserting events into the vector");
156                                                //TODO: use a Arc<str> or Box<str> type instead of String type.
157                                                //The data doesn't need to implement any .copy() or .clone() trait
158                                                // using an Arc<str> type will also waste less resources
159                                                evt.push(ConnectionEvent {
160                                                    event_id: event_id.to_string(),
161                                                    src_ip_port: format!("{}:{}", src, src_port),
162                                                    dst_ip_port: format!("{}:{}", dst, dst_port),
163                                                });
164                                                info!("sending events to the MPSC channel");
165                                                let _ = tx.send(Ok(evt)).await;
166                                            }
167                                            Err(_) => {
168                                                info!(
169                                                    "Event Id: {} Protocol: Unknown ({})",
170                                                    event_id,
171                                                    pl.proto
172                                                );
173                                            }
174                                        };
175                                    } else {
176                                        warn!(
177                                            "Received packet data too small: {} bytes",
178                                            data.len()
179                                        );
180                                    }
181                                }
182                            } else if events.read == 0 {
183                                info!("[Agent/API] 0 Events found");
184                                let mut evt = Vec::new();
185                                evt.push(ConnectionEvent {
186                                    event_id: "0".to_string(),
187                                    src_ip_port: "0:0".to_string(),
188                                    dst_ip_port: "0:0".to_string(),
189                                });
190                                let _ = tx.send(Ok(evt)).await;
191                            }
192                        }
193                        Err(e) => {
194                            eprintln!("Errore nella lettura eventi: {}", e);
195                            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
196                        }
197                    }
198                }
199                // small delay to avoid cpu congestion
200                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
201            }
202        });
203
204        api
205    }
206}
207
208//declare the blocklist hashmap structure
209//TODO: finish the creation of a blocklist hashmap
210#[async_trait]
211impl Agent for AgentApi {
212    // * read the incoming active_connections requests and returns a response with the
213    // * active connections. The data are transformed and sent to the api with a mpsc channel
214    async fn active_connections(
215        &self,
216        request: Request<RequestActiveConnections>
217    ) -> Result<Response<ActiveConnectionResponse>, Status> {
218        //read request
219        let req = request.into_inner();
220
221        //create the hashmap to process events from the mpsc channel queue
222        let mut aggregated_events: Vec<ConnectionEvent> = Vec::new();
223
224        //aggregate events
225        while let Ok(evt) = self.event_rx.lock().unwrap().try_recv() {
226            if let Ok(vec) = evt {
227                aggregated_events.extend(vec);
228            }
229        }
230
231        //if 'exclude' flag is not None exclude the events from the aggregated events
232        //TODO: move this section into the event reader
233        //TODO: transform the block_list parameter in a parameter that the user can pass using the CLI
234        //let block_list = "135.171.168.192".to_string();
235        //if aggregated_events.contains(&block_list) {
236        //    aggregated_events.remove(&block_list);
237        //    info!("Blocked ip from block_list: {:?}", block_list);
238        //}
239
240        //log response for debugging
241        info!("DEBUGGING RESPONSE FROM ACTIVE CONNECTION REQUEST: {:?}", aggregated_events);
242
243        //return response
244        Ok(
245            Response::new(ActiveConnectionResponse {
246                status: "success".to_string(),
247                events: aggregated_events,
248            })
249        )
250    }
251
252    // * creates and add ip to the blocklist
253    async fn add_ip_to_blocklist(
254        &self,
255        request: Request<AddIpToBlocklistRequest>
256    ) -> Result<Response<BlocklistResponse>, Status> {
257        //read request
258        let req = request.into_inner();
259
260        //open blocklist map
261        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/blocklist_map").expect(
262            "cannot open blocklist_map Mapdata"
263        );
264        let blocklist_mapdata = Map::HashMap(mapdata); //load mapdata
265        let mut blocklist_map: ayaHashMap<MapData, [u8; 4], [u8; 4]> = ayaHashMap
266            ::try_from(blocklist_mapdata)
267            .unwrap();
268
269        if req.ip.is_none() {
270            // log blocklist event
271            info!("IP field in request is none");
272            info!("CURRENT BLOCKLIST: {:?}", blocklist_map);
273        } else {
274            // add ip to the blocklist
275            // log blocklist event
276            let datetime = Local::now().to_string();
277            let ip = req.ip.unwrap();
278            //convert ip from string to [u8;4] type and insert into the bpf map
279            let u8_4_ip = Ipv4Addr::from_str(&ip).unwrap().octets();
280            //TODO: convert datetime in a kernel compatible format
281            blocklist_map.insert(u8_4_ip, u8_4_ip, 0);
282            info!("CURRENT BLOCKLIST: {:?}", blocklist_map);
283        }
284        let path = std::env
285            ::var(PIN_BLOCKLIST_MAP_PATH)
286            .context("Blocklist map path not found!")
287            .unwrap();
288
289        //convert the maps with a buffer to match the protobuffer types
290        let mut converted_blocklist_map: HashMap<String, String> = HashMap::new();
291        for item in blocklist_map.iter() {
292            let (k, v) = item.unwrap();
293            // convert keys and values from [u8;4] to String
294            let key = Ipv4Addr::from(k).to_string();
295            let value = Ipv4Addr::from(k).to_string();
296            converted_blocklist_map.insert(key, value);
297        }
298
299        //save ip into the blocklist
300        Ok(
301            Response::new(BlocklistResponse {
302                status: "success".to_string(),
303                events: converted_blocklist_map,
304            })
305        )
306    }
307
308    async fn check_blocklist(
309        &self,
310        request: Request<()>
311    ) -> Result<Response<BlocklistResponse>, Status> {
312        info!("Returning blocklist hashmap");
313        //open blocklist map
314        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/blocklist_map").expect(
315            "cannot open blocklist_map Mapdata"
316        );
317        let blocklist_mapdata = Map::HashMap(mapdata); //load mapdata
318
319        let blocklist_map: ayaHashMap<MapData, [u8; 4], [u8; 4]> = ayaHashMap
320            ::try_from(blocklist_mapdata)
321            .unwrap();
322
323        //convert the maps with a buffer to match the protobuffer types
324
325        let mut converted_blocklist_map: HashMap<String, String> = HashMap::new();
326        for item in blocklist_map.iter() {
327            let (k, v) = item.unwrap();
328            // convert keys and values from [u8;4] to String
329            let key = Ipv4Addr::from(k).to_string();
330            let value = Ipv4Addr::from(k).to_string();
331            converted_blocklist_map.insert(key, value);
332        }
333        Ok(
334            Response::new(BlocklistResponse {
335                status: "success".to_string(),
336                events: converted_blocklist_map,
337            })
338        )
339    }
340    async fn rm_ip_from_blocklist(
341        &self,
342        request: Request<RmIpFromBlocklistRequest>
343    ) -> Result<Response<RmIpFromBlocklistResponse>, Status> {
344        //read request
345        let req = request.into_inner();
346        info!("Removing ip from blocklist map");
347        //open blocklist map
348        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/blocklist_map").expect(
349            "cannot open blocklist_map Mapdata"
350        );
351        let blocklist_mapdata = Map::HashMap(mapdata); //load mapdata
352        let mut blocklist_map: ayaHashMap<MapData, [u8; 4], [u8; 4]> = ayaHashMap
353            ::try_from(blocklist_mapdata)
354            .unwrap();
355        //remove the address
356        let ip_to_remove = req.ip;
357        let u8_4_ip_to_remove = Ipv4Addr::from_str(&ip_to_remove).unwrap().octets();
358        blocklist_map.remove(&u8_4_ip_to_remove);
359
360        //convert the maps with a buffer to match the protobuffer types
361        let mut converted_blocklist_map: HashMap<String, String> = HashMap::new();
362        for item in blocklist_map.iter() {
363            let (k, v) = item.unwrap();
364            // convert keys and values from [u8;4] to String
365            let key = Ipv4Addr::from(k).to_string();
366            let value = Ipv4Addr::from(k).to_string();
367            converted_blocklist_map.insert(key, value);
368        }
369
370        Ok(
371            Response::new(RmIpFromBlocklistResponse {
372                status: "Ip removed from blocklist".to_string(),
373                events: converted_blocklist_map,
374            })
375        )
376    }
377}