agent_api/
api.rs

1#![allow(warnings)]
2use anyhow::Context;
3use chrono::Local;
4use prost::bytes::BytesMut;
5use std::str::FromStr;
6use std::{sync::Mutex };
7use tonic::{ Request, Response, Status };
8use tracing::info;
9
10use aya::{ maps::{ MapData, PerfEventArray }, util::online_cpus };
11use std::result::Result::Ok;
12use tonic::async_trait;
13
14use std::collections::HashMap;
15use aya::maps::HashMap as ayaHashMap;
16use tokio::sync::mpsc;
17use tokio::task;
18
19// *  contains agent api configuration
20use crate::agent::{
21    agent_server::Agent,
22    ActiveConnectionResponse,
23    RequestActiveConnections,
24    AddIpToBlocklistRequest,
25    BlocklistResponse,
26};
27use aya::maps::Map;
28use bytemuck_derive::Zeroable;
29use cortexflow_identity::enums::IpProtocols;
30use std::net::Ipv4Addr;
31use tracing::warn;
32
33#[repr(C)]
34#[derive(Clone, Copy, Zeroable)]
35pub struct PacketLog {
36    pub proto: u8,
37    pub src_ip: u32,
38    pub src_port: u16,
39    pub dst_ip: u32,
40    pub dst_port: u16,
41    pub pid: u32,
42}
43unsafe impl aya::Pod for PacketLog {}
44
45pub struct AgentApi {
46    //* event_rx is an istance of a mpsc receiver.
47    //* is used to receive the data from the transmitter (tx)
48    event_rx: Mutex<mpsc::Receiver<Result<HashMap<String, String>, Status>>>,
49    event_tx: mpsc::Sender<Result<HashMap<String, String>, Status>>,
50}
51
52//* Event sender trait. Takes an event from a map and send that to the mpsc channel
53//* using the send_map function
54#[async_trait]
55pub trait EventSender: Send + Sync + 'static {
56    async fn send_event(&self, event: HashMap<String, String>);
57    async fn send_map(
58        &self,
59        map: HashMap<String, String>,
60        tx: mpsc::Sender<Result<HashMap<String, String>, Status>>
61    ) {
62        let status = Status::new(tonic::Code::Ok, "success");
63        let event = Ok(map);
64
65        let _ = tx.send(event).await;
66    }
67}
68// send event function. takes an HashMap and send that using mpsc event_tx
69#[async_trait]
70impl EventSender for AgentApi {
71    async fn send_event(&self, event: HashMap<String, String>) {
72        self.send_map(event, self.event_tx.clone()).await;
73    }
74}
75
76const BPF_PATH: &str = "BPF_PATH";
77const PIN_BLOCKLIST_MAP_PATH: &str = "PIN_BLOCKLIST_MAP_PATH";
78
79//initialize a default trait for AgentApi. Loads a name and a bpf istance.
80//this trait is essential for init the Agent.
81impl Default for AgentApi {
82    //TODO:this part needs a better error handling
83    fn default() -> Self {
84        // load maps mapdata
85        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/events_map").expect(
86            "cannot open events_map Mapdata"
87        );
88        let map = Map::PerfEventArray(mapdata); //creates a PerfEventArray from the mapdata
89
90        //init a mpsc channel
91        let (tx, rx) = mpsc::channel(1024);
92        let api = AgentApi {
93            event_rx: rx.into(),
94            event_tx: tx.clone(),
95        };
96
97        let mut events_array = PerfEventArray::try_from(map).expect(
98            "Error while initializing events array"
99        );
100
101        //spawn an event reader
102        task::spawn(async move {
103            let mut net_events_buffer = Vec::new();
104            //scan the cpus to read the data
105
106            for cpu_id in online_cpus()
107                .map_err(|e| anyhow::anyhow!("Error {:?}", e))
108                .unwrap() {
109                let buf = events_array
110                    .open(cpu_id, None)
111                    .expect("Error during the creation of net_events_buf structure");
112
113                let buffers = vec![BytesMut::with_capacity(4096); 8];
114                net_events_buffer.push((buf, buffers));
115            }
116
117            info!("Starting event listener");
118            //send the data through a mpsc channel
119            loop {
120                for (buf, buffers) in net_events_buffer.iter_mut() {
121                    match buf.read_events(buffers) {
122                        Ok(events) => {
123                            //read the events, this function is similar to the one used in identity/helpers.rs/display_events
124                            if events.read > 0 {
125                                for i in 0..events.read {
126                                    let data = &buffers[i];
127                                    if data.len() >= std::mem::size_of::<PacketLog>() {
128                                        let pl: PacketLog = unsafe {
129                                            std::ptr::read(data.as_ptr() as *const _)
130                                        };
131                                        let src = Ipv4Addr::from(u32::from_be(pl.src_ip));
132                                        let dst = Ipv4Addr::from(u32::from_be(pl.dst_ip));
133                                        let src_port = u16::from_be(pl.src_port as u16);
134                                        let dst_port = u16::from_be(pl.dst_port as u16);
135                                        let event_id = pl.pid;
136
137                                        match IpProtocols::try_from(pl.proto) {
138                                            Ok(proto) => {
139                                                info!(
140                                                    "Event Id: {} Protocol: {:?} SRC: {}:{} -> DST: {}:{}",
141                                                    event_id,
142                                                    proto,
143                                                    src,
144                                                    src_port,
145                                                    dst,
146                                                    dst_port
147                                                );
148                                                info!("creating hashmap for the aggregated data");
149                                                let mut evt = HashMap::new();
150                                                // insert event in the hashmap
151                                                info!("Inserting events into the hashmap");
152                                                //TODO: use a Arc<str> or Box<str> type instead of String type.
153                                                //The data doesn't need to implement any .copy() or .clone() trait
154                                                // using an Arc<str> type will also waste less resources
155                                                evt.insert(
156                                                    format!("{:?}", src.to_string()),
157                                                    format!("{:?}", event_id.to_string())
158                                                );
159                                                info!("sending events to the MPSC channel");
160                                                let _ = tx.send(Ok(evt)).await;
161                                            }
162                                            Err(_) => {
163                                                info!(
164                                                    "Event Id: {} Protocol: Unknown ({})",
165                                                    event_id,
166                                                    pl.proto
167                                                );
168                                            }
169                                        };
170                                    } else {
171                                        warn!(
172                                            "Received packet data too small: {} bytes",
173                                            data.len()
174                                        );
175                                    }
176                                }
177                            } else if events.read == 0 {
178                                info!("[Agent/API] 0 Events found");
179                                let mut evt = HashMap::new();
180                                evt.insert("0".to_string(), "0".to_string());
181                                let _ = tx.send(Ok(evt)).await;
182                            }
183                        }
184                        Err(e) => {
185                            eprintln!("Errore nella lettura eventi: {}", e);
186                            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
187                        }
188                    }
189                }
190                // small delay to avoid cpu congestion
191                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
192            }
193        });
194
195        api
196    }
197}
198
199//declare the blocklist hashmap structure
200//TODO: finish the creation of a blocklist hashmap
201#[async_trait]
202impl Agent for AgentApi {
203    // * read the incoming active_connections requests and returns a response with the
204    // * active connections. The data are transformed and sent to the api with a mpsc channel
205    async fn active_connections(
206        &self,
207        request: Request<RequestActiveConnections>
208    ) -> Result<Response<ActiveConnectionResponse>, Status> {
209        //read request
210        let req = request.into_inner();
211
212        //create the hashmap to process events from the mpsc channel queue
213        let mut aggregated_events: HashMap<String, String> = HashMap::new();
214
215        //aggregate events
216        while let Ok(evt) = self.event_rx.lock().unwrap().try_recv() {
217            if let Ok(map) = evt {
218                aggregated_events.extend(map);
219            }
220        }
221
222        //if 'exclude' flag is not None exclude the events from the aggregated events
223        //TODO: move this section into the event reader
224        //TODO: transform the block_list parameter in a parameter that the user can pass using the CLI
225        let block_list = "135.171.168.192".to_string();
226        if aggregated_events.contains_key(&block_list) {
227            aggregated_events.remove(&block_list);
228            info!("Blocked ip from block_list: {:?}", block_list);
229        }
230
231        //log response for debugging
232        info!("DEBUGGING RESPONSE FROM ACTIVE CONNECTION REQUEST: {:?}", aggregated_events);
233
234        //return response
235        Ok(
236            Response::new(ActiveConnectionResponse {
237                status: "success".to_string(),
238                events: aggregated_events,
239            })
240        )
241    }
242
243    // * creates and add ip to the blocklist
244    async fn add_ip_to_blocklist(
245        &self,
246        request: Request<AddIpToBlocklistRequest>
247    ) -> Result<Response<BlocklistResponse>, Status> {
248        //read request
249        let req = request.into_inner();
250
251        //open blocklist map
252        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/blocklist_map").expect(
253            "cannot open blocklist_map Mapdata"
254        );
255        let blocklist_mapdata = Map::HashMap(mapdata); //load mapdata
256        let mut blocklist_map: ayaHashMap<MapData, [u8; 4], [u8; 4]> = ayaHashMap
257            ::try_from(blocklist_mapdata)
258            .unwrap();
259
260        if req.ip.is_none() {
261            // log blocklist event
262            info!("IP field in request is none");
263            info!("CURRENT BLOCKLIST: {:?}", blocklist_map);
264        } else {
265            // add ip to the blocklist
266            // log blocklist event
267            let datetime = Local::now().to_string();
268            let ip = req.ip.unwrap();
269            //convert ip from string to [u8;4] type and insert into the bpf map
270            let u8_4_ip = Ipv4Addr::from_str(&ip).unwrap().octets();
271            //TODO: convert datetime in a kernel compatible format
272            blocklist_map.insert(u8_4_ip, u8_4_ip, 0);
273            info!("CURRENT BLOCKLIST: {:?}", blocklist_map);
274        }
275        let path = std::env
276            ::var(PIN_BLOCKLIST_MAP_PATH)
277            .context("Blocklist map path not found!")
278            .unwrap();
279
280        //convert the maps with a buffer to match the protobuffer types
281        let mut converted_blocklist_map: HashMap<String, String> = HashMap::new();
282        for item in blocklist_map.iter() {
283            let (k, v) = item.unwrap();
284            // convert keys and values from [u8;4] to String
285            let key = String::from_utf8(k.to_vec()).unwrap();
286            let value = String::from_utf8(v.to_vec()).unwrap();
287            converted_blocklist_map.insert(key, value);
288        }
289
290        //save ip into the blocklist
291        Ok(
292            Response::new(BlocklistResponse {
293                status: "success".to_string(),
294                events: converted_blocklist_map,
295            })
296        )
297    }
298
299    async fn check_blocklist(
300        &self,
301        request: Request<()>
302    ) -> Result<Response<BlocklistResponse>, Status> {
303        info!("Returning blocklist hashmap");
304        //open blocklist map
305        let mapdata = MapData::from_pin("/sys/fs/bpf/maps/blocklist_map").expect(
306            "cannot open blocklist_map Mapdata"
307        );
308        let blocklist_mapdata = Map::HashMap(mapdata); //load mapdata
309
310        let blocklist_map: ayaHashMap<MapData, [u8; 4], [u8; 4]> = ayaHashMap
311            ::try_from(blocklist_mapdata)
312            .unwrap();
313
314        //convert the maps with a buffer to match the protobuffer types
315
316        let mut converted_blocklist_map: HashMap<String, String> = HashMap::new();
317        for item in blocklist_map.iter() {
318            let (k, v) = item.unwrap();
319            // convert keys and values from [u8;4] to String
320            let key = String::from_utf8(k.to_vec()).unwrap();
321            let value = String::from_utf8(v.to_vec()).unwrap();
322            converted_blocklist_map.insert(key, value);
323        }
324        Ok(
325            Response::new(BlocklistResponse {
326                status: "success".to_string(),
327                events: converted_blocklist_map,
328            })
329        )
330    }
331}