1use chrono::Utc;
2use reqwest::Client;
3use rocket::fairing::{Fairing, Info, Kind};
4use rocket::http::Status;
5use rocket::request::{FromRequest, Outcome};
6use rocket::{Data, Request, Response};
7use serde::Serialize;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Clone, Serialize)]
12struct RequestData {
13 hostname: String,
14 ip_address: Option<String>,
15 path: String,
16 user_agent: String,
17 method: String,
18 response_time: u32,
19 status: u16,
20 user_id: Option<String>,
21 created_at: String,
22}
23
24type StringMapper = dyn for<'a> Fn(&Request<'a>) -> String + Send + Sync;
25
26struct Config {
27 privacy_level: i32,
28 server_url: String,
29 get_hostname: Box<StringMapper>,
30 get_ip_address: Box<StringMapper>,
31 get_path: Box<StringMapper>,
32 get_user_agent: Box<StringMapper>,
33 get_user_id: Box<StringMapper>,
34}
35
36impl Default for Config {
37 fn default() -> Self {
38 Self {
39 privacy_level: 0,
40 server_url: String::from("https://www.apianalytics-server.com/"),
41 get_hostname: Box::new(get_hostname),
42 get_ip_address: Box::new(get_ip_address),
43 get_path: Box::new(get_path),
44 get_user_agent: Box::new(get_user_agent),
45 get_user_id: Box::new(get_user_id),
46 }
47 }
48}
49
50fn get_hostname(req: &Request) -> String {
51 req.host()
52 .map(|h| h.to_string())
53 .unwrap_or_default()
54}
55
56fn get_ip_address(req: &Request) -> String {
57 if let Some(val) = req
58 .headers()
59 .get_one("cf-connecting-ip")
60 .map(|s| s.trim().to_owned())
61 .filter(|s| !s.is_empty())
62 {
63 return val;
64 }
65
66 if let Some(val) = req
67 .headers()
68 .get_one("x-forwarded-for")
69 .and_then(|s| s.split(',').next())
70 .map(|s| s.trim().to_owned())
71 .filter(|s| !s.is_empty())
72 {
73 return val;
74 }
75
76 if let Some(val) = req
77 .headers()
78 .get_one("x-real-ip")
79 .map(|s| s.trim().to_owned())
80 .filter(|s| !s.is_empty())
81 {
82 return val;
83 }
84
85 req.client_ip()
86 .map(|ip| ip.to_string())
87 .unwrap_or_default()
88}
89
90fn get_path(req: &Request) -> String {
91 req.uri().path().to_string()
92}
93
94fn get_user_agent(req: &Request) -> String {
95 req.headers()
96 .get_one("User-Agent")
97 .unwrap_or_default()
98 .to_owned()
99}
100
101fn get_user_id(_req: &Request) -> String {
102 String::new()
103}
104
105struct RequestBuffer {
106 requests: Vec<RequestData>,
107 last_posted: Instant,
108}
109
110impl RequestBuffer {
111 fn new() -> Self {
112 Self {
113 requests: Vec::new(),
114 last_posted: Instant::now(),
115 }
116 }
117}
118
119pub struct Analytics {
120 api_key: String,
121 config: Config,
122 buffer: Arc<Mutex<RequestBuffer>>,
123 client: Arc<Client>,
124}
125
126impl Analytics {
127 pub fn new(api_key: String) -> Self {
128 let client = Client::builder()
129 .timeout(Duration::from_secs(10))
130 .build()
131 .unwrap_or_else(|_| Client::new());
132
133 Self {
134 api_key,
135 config: Config::default(),
136 buffer: Arc::new(Mutex::new(RequestBuffer::new())),
137 client: Arc::new(client),
138 }
139 }
140
141 pub fn with_privacy_level(mut self, privacy_level: i32) -> Self {
142 self.config.privacy_level = privacy_level;
143 self
144 }
145
146 pub fn with_server_url(mut self, server_url: String) -> Self {
147 self.config.server_url = if server_url.ends_with('/') {
148 server_url
149 } else {
150 server_url + "/"
151 };
152 self
153 }
154
155 pub fn with_hostname_mapper<F>(mut self, mapper: F) -> Self
156 where
157 F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
158 {
159 self.config.get_hostname = Box::new(mapper);
160 self
161 }
162
163 pub fn with_ip_address_mapper<F>(mut self, mapper: F) -> Self
164 where
165 F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
166 {
167 self.config.get_ip_address = Box::new(mapper);
168 self
169 }
170
171 pub fn with_path_mapper<F>(mut self, mapper: F) -> Self
172 where
173 F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
174 {
175 self.config.get_path = Box::new(mapper);
176 self
177 }
178
179 pub fn with_user_agent_mapper<F>(mut self, mapper: F) -> Self
180 where
181 F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
182 {
183 self.config.get_user_agent = Box::new(mapper);
184 self
185 }
186
187 pub fn with_user_id_mapper<F>(mut self, mapper: F) -> Self
188 where
189 F: for<'a> Fn(&Request<'a>) -> String + Send + Sync + 'static,
190 {
191 self.config.get_user_id = Box::new(mapper);
192 self
193 }
194}
195
196#[derive(Clone)]
197pub struct Start<T = Instant>(T);
198
199#[derive(Debug, Clone, Serialize)]
200struct Payload {
201 api_key: String,
202 requests: Vec<RequestData>,
203 framework: String,
204 privacy_level: i32,
205}
206
207impl Payload {
208 pub fn new(api_key: String, requests: Vec<RequestData>, privacy_level: i32) -> Self {
209 Self {
210 api_key,
211 requests,
212 framework: String::from("Rocket"),
213 privacy_level,
214 }
215 }
216}
217
218async fn post_requests(client: &Client, data: Payload, server_url: &str) {
219 let url = format!("{}api/log-request", server_url);
220 let _ = client.post(url).json(&data).send().await;
221}
222
223#[rocket::async_trait]
224impl Fairing for Analytics {
225 fn info(&self) -> Info {
226 Info {
227 name: "API Analytics",
228 kind: Kind::Request | Kind::Response,
229 }
230 }
231
232 async fn on_request(&self, req: &mut Request<'_>, _data: &mut Data<'_>) {
233 req.local_cache(|| Start::<Option<Instant>>(Some(Instant::now())));
234 }
235
236 async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
237 let start = &req.local_cache(|| Start::<Option<Instant>>(None)).0;
238
239 let hostname = (self.config.get_hostname)(req);
240 let ip_address = if self.config.privacy_level >= 2 {
241 None
242 } else {
243 let val = (self.config.get_ip_address)(req);
244 if val.is_empty() { None } else { Some(val) }
245 };
246 let method = req.method().to_string();
247 let user_agent = (self.config.get_user_agent)(req);
248 let path = (self.config.get_path)(req);
249 let user_id = {
250 let val = (self.config.get_user_id)(req);
251 if val.is_empty() { None } else { Some(val) }
252 };
253 let response_time = start
254 .map(|s| s.elapsed().as_millis().min(u32::MAX as u128) as u32)
255 .unwrap_or(0);
256
257 let request_data = RequestData {
258 hostname,
259 ip_address,
260 path,
261 user_agent,
262 method,
263 response_time,
264 status: res.status().code,
265 user_id,
266 created_at: Utc::now().to_rfc3339(),
267 };
268
269 let batch = {
270 let mut buf = self.buffer.lock().unwrap();
271 buf.requests.push(request_data);
272 if buf.last_posted.elapsed().as_secs_f64() > 60.0 {
273 buf.last_posted = Instant::now();
274 std::mem::take(&mut buf.requests)
275 } else {
276 vec![]
277 }
278 };
279
280 if !batch.is_empty() {
281 let payload = Payload::new(self.api_key.clone(), batch, self.config.privacy_level);
282 let server_url = self.config.server_url.clone();
283 let client = Arc::clone(&self.client);
284 tokio::spawn(async move {
285 post_requests(&client, payload, &server_url).await;
286 });
287 }
288 }
289}
290
291#[rocket::async_trait]
293impl<'r> FromRequest<'r> for Start {
294 type Error = ();
295
296 async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
297 match &*request.local_cache(|| Start::<Option<Instant>>(None)) {
298 Start(Some(start)) => Outcome::Success(Start(start.to_owned())),
299 Start(None) => Outcome::Error((Status::InternalServerError, ())),
300 }
301 }
302}