Skip to main content

rocket_analytics/
analytics.rs

1use chrono::Utc;
2use reqwest::Client;
3use rocket::fairing::{Fairing, Info, Kind};
4use rocket::http::Status;
5use rocket::request::{FromRequest, Outcome};
6use rocket::{Data, Request, Response};
7use serde::Serialize;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Serialize)]
12struct RequestData {
13    hostname: String,
14    ip_address: Option<String>,
15    path: String,
16    user_agent: String,
17    method: String,
18    response_time: u32,
19    status: u16,
20    user_id: Option<String>,
21    created_at: String,
22}
23
24type StringMapper = dyn for<'a> Fn(&Request<'a>) -> String + Send + Sync;
25
26struct Config {
27    privacy_level: i32,
28    server_url: String,
29    get_hostname: Box<StringMapper>,
30    get_ip_address: Box<StringMapper>,
31    get_path: Box<StringMapper>,
32    get_user_agent: Box<StringMapper>,
33    get_user_id: Box<StringMapper>,
34}
35
36impl Default for Config {
37    fn default() -> Self {
38        Self {
39            privacy_level: 0,
40            server_url: String::from("https://www.apianalytics-server.com/"),
41            get_hostname: Box::new(get_hostname),
42            get_ip_address: Box::new(get_ip_address),
43            get_path: Box::new(get_path),
44            get_user_agent: Box::new(get_user_agent),
45            get_user_id: Box::new(get_user_id),
46        }
47    }
48}
49
50fn get_hostname(req: &Request) -> String {
51    req.host()
52        .map(|h| h.to_string())
53        .unwrap_or_default()
54}
55
56fn get_ip_address(req: &Request) -> String {
57    if let Some(val) = req
58        .headers()
59        .get_one("cf-connecting-ip")
60        .map(|s| s.trim().to_owned())
61        .filter(|s| !s.is_empty())
62    {
63        return val;
64    }
65
66    if let Some(val) = req
67        .headers()
68        .get_one("x-forwarded-for")
69        .and_then(|s| s.split(',').next())
70        .map(|s| s.trim().to_owned())
71        .filter(|s| !s.is_empty())
72    {
73        return val;
74    }
75
76    if let Some(val) = req
77        .headers()
78        .get_one("x-real-ip")
79        .map(|s| s.trim().to_owned())
80        .filter(|s| !s.is_empty())
81    {
82        return val;
83    }
84
85    req.client_ip()
86        .map(|ip| ip.to_string())
87        .unwrap_or_default()
88}
89
90fn get_path(req: &Request) -> String {
91    req.uri().path().to_string()
92}
93
94fn get_user_agent(req: &Request) -> String {
95    req.headers()
96        .get_one("User-Agent")
97        .unwrap_or_default()
98        .to_owned()
99}
100
101fn get_user_id(_req: &Request) -> String {
102    String::new()
103}
104
105struct RequestBuffer {
106    requests: Vec<RequestData>,
107    last_posted: Instant,
108}
109
110impl RequestBuffer {
111    fn new() -> Self {
112        Self {
113            requests: Vec::new(),
114            last_posted: Instant::now(),
115        }
116    }
117}
118
119pub struct Analytics {
120    api_key: String,
121    config: Config,
122    buffer: Arc<Mutex<RequestBuffer>>,
123    client: Arc<Client>,
124}
125
126impl Analytics {
127    pub fn new(api_key: String) -> Self {
128        let client = Client::builder()
129            .timeout(Duration::from_secs(10))
130            .build()
131            .unwrap_or_else(|_| Client::new());
132
133        Self {
134            api_key,
135            config: Config::default(),
136            buffer: Arc::new(Mutex::new(RequestBuffer::new())),
137            client: Arc::new(client),
138        }
139    }
140
141    pub fn with_privacy_level(mut self, privacy_level: i32) -> Self {
142        self.config.privacy_level = privacy_level;
143        self
144    }
145
146    pub fn with_server_url(mut self, server_url: String) -> Self {
147        self.config.server_url = if server_url.ends_with('/') {
148            server_url
149        } else {
150            server_url + "/"
151        };
152        self
153    }
154
155    pub fn with_hostname_mapper<F>(mut self, mapper: F) -> Self
156    where
157        F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
158    {
159        self.config.get_hostname = Box::new(mapper);
160        self
161    }
162
163    pub fn with_ip_address_mapper<F>(mut self, mapper: F) -> Self
164    where
165        F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
166    {
167        self.config.get_ip_address = Box::new(mapper);
168        self
169    }
170
171    pub fn with_path_mapper<F>(mut self, mapper: F) -> Self
172    where
173        F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
174    {
175        self.config.get_path = Box::new(mapper);
176        self
177    }
178
179    pub fn with_user_agent_mapper<F>(mut self, mapper: F) -> Self
180    where
181        F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
182    {
183        self.config.get_user_agent = Box::new(mapper);
184        self
185    }
186
187    pub fn with_user_id_mapper<F>(mut self, mapper: F) -> Self
188    where
189        F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
190    {
191        self.config.get_user_id = Box::new(mapper);
192        self
193    }
194}
195
196#[derive(Clone)]
197pub struct Start<T = Instant>(T);
198
199#[derive(Debug, Clone, Serialize)]
200struct Payload {
201    api_key: String,
202    requests: Vec<RequestData>,
203    framework: String,
204    privacy_level: i32,
205}
206
207impl Payload {
208    pub fn new(api_key: String, requests: Vec<RequestData>, privacy_level: i32) -> Self {
209        Self {
210            api_key,
211            requests,
212            framework: String::from("Rocket"),
213            privacy_level,
214        }
215    }
216}
217
218async fn post_requests(client: &Client, data: Payload, server_url: &str) {
219    let url = format!("{}api/log-request", server_url);
220    let _ = client.post(url).json(&data).send().await;
221}
222
223#[rocket::async_trait]
224impl Fairing for Analytics {
225    fn info(&self) -> Info {
226        Info {
227            name: "API Analytics",
228            kind: Kind::Request | Kind::Response,
229        }
230    }
231
232    async fn on_request(&self, req: &mut Request<'_>, _data: &mut Data<'_>) {
233        req.local_cache(|| Start::<Option<Instant>>(Some(Instant::now())));
234    }
235
236    async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
237        let start = &req.local_cache(|| Start::<Option<Instant>>(None)).0;
238
239        let hostname = (self.config.get_hostname)(req);
240        let ip_address = if self.config.privacy_level >= 2 {
241            None
242        } else {
243            let val = (self.config.get_ip_address)(req);
244            if val.is_empty() { None } else { Some(val) }
245        };
246        let method = req.method().to_string();
247        let user_agent = (self.config.get_user_agent)(req);
248        let path = (self.config.get_path)(req);
249        let user_id = {
250            let val = (self.config.get_user_id)(req);
251            if val.is_empty() { None } else { Some(val) }
252        };
253        let response_time = start
254            .map(|s| s.elapsed().as_millis().min(u32::MAX as u128) as u32)
255            .unwrap_or(0);
256
257        let request_data = RequestData {
258            hostname,
259            ip_address,
260            path,
261            user_agent,
262            method,
263            response_time,
264            status: res.status().code,
265            user_id,
266            created_at: Utc::now().to_rfc3339(),
267        };
268
269        let batch = {
270            let mut buf = self.buffer.lock().unwrap();
271            buf.requests.push(request_data);
272            if buf.last_posted.elapsed().as_secs_f64() > 60.0 {
273                buf.last_posted = Instant::now();
274                std::mem::take(&mut buf.requests)
275            } else {
276                vec![]
277            }
278        };
279
280        if !batch.is_empty() {
281            let payload = Payload::new(self.api_key.clone(), batch, self.config.privacy_level);
282            let server_url = self.config.server_url.clone();
283            let client = Arc::clone(&self.client);
284            tokio::spawn(async move {
285                post_requests(&client, payload, &server_url).await;
286            });
287        }
288    }
289}
290
291// Allows a route to access the start time
292#[rocket::async_trait]
293impl<'r> FromRequest<'r> for Start {
294    type Error = ();
295
296    async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
297        match &*request.local_cache(|| Start::<Option<Instant>>(None)) {
298            Start(Some(start)) => Outcome::Success(Start(start.to_owned())),
299            Start(None) => Outcome::Error((Status::InternalServerError, ())),
300        }
301    }
302}