1use axum::{body::Body, extract::ConnectInfo, http::Request, response::Response};
2use chrono::Utc;
3use futures::Future;
4use http::{
5 header::{HeaderValue, HOST, USER_AGENT},
6 Extensions, HeaderMap,
7};
8use lazy_static::lazy_static;
9use reqwest::Client;
10use serde::Serialize;
11use std::{
12 net::{IpAddr, SocketAddr},
13 task::{Context, Poll},
14 time::Instant,
15};
16use std::{pin::Pin, sync::Arc};
17use tokio::sync::RwLock;
18use tokio::time::{interval, Duration};
19use tower::{Layer, Service};
20
21#[derive(Debug, Clone, Serialize)]
22struct RequestData {
23 hostname: String,
24 ip_address: String,
25 path: String,
26 user_agent: String,
27 method: String,
28 response_time: u32,
29 status: u16,
30 user_id: String,
31 created_at: String,
32}
33
34impl RequestData {
35 #[allow(clippy::too_many_arguments)]
36 pub fn new(
37 hostname: String,
38 ip_address: String,
39 path: String,
40 user_agent: String,
41 method: String,
42 status: u16,
43 response_time: u32,
44 user_id: String,
45 created_at: String,
46 ) -> Self {
47 Self {
48 hostname,
49 ip_address,
50 path,
51 user_agent,
52 method,
53 response_time,
54 status,
55 user_id,
56 created_at,
57 }
58 }
59}
60
61type StringMapper = dyn for<'a> Fn(&Request<Body>) -> String + Send + Sync;
62
63#[derive(Clone)]
64struct Config {
65 privacy_level: i32,
66 server_url: String,
67 get_hostname: Arc<StringMapper>,
68 get_ip_address: Arc<StringMapper>,
69 get_path: Arc<StringMapper>,
70 get_user_agent: Arc<StringMapper>,
71 get_user_id: Arc<StringMapper>,
72}
73
74impl Default for Config {
75 fn default() -> Self {
76 Self {
77 privacy_level: 0,
78 server_url: "https://www.apianalytics-server.com/".to_string(),
79 get_hostname: Arc::new(get_hostname),
80 get_ip_address: Arc::new(get_ip_address),
81 get_path: Arc::new(get_path),
82 get_user_agent: Arc::new(get_user_agent),
83 get_user_id: Arc::new(get_user_id),
84 }
85 }
86}
87
88pub trait HeaderValueExt {
89 fn to_string(&self) -> String;
90}
91
92impl HeaderValueExt for HeaderValue {
93 fn to_string(&self) -> String {
94 self.to_str().unwrap_or_default().to_string()
95 }
96}
97
98fn get_hostname(req: &Request<Body>) -> String {
99 req.headers()
100 .get(HOST)
101 .map(|x| x.to_string())
102 .unwrap_or_default()
103}
104
105fn get_ip_address(req: &Request<Body>) -> String {
106 let extensions = req.extensions();
107 let headers = req.headers();
108 let mut ip_address = String::new();
109 if let Some(val) = ip_from_x_forwarded_for(headers) {
110 ip_address = val.to_string();
111 } else if let Some(val) = ip_from_x_real_ip(headers) {
112 ip_address = val.to_string();
113 } else if let Some(val) = ip_from_connect_info(extensions) {
114 ip_address = val.to_string();
115 }
116 ip_address
117}
118
119fn ip_from_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
120 headers
121 .get("x-forwarded-for")
122 .and_then(|hv| hv.to_str().ok())
123 .and_then(|s| {
124 s.split(',')
125 .rev()
126 .find_map(|s| s.trim().parse::<IpAddr>().ok())
127 })
128}
129
130fn ip_from_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
131 headers
132 .get("x-real-ip")
133 .and_then(|hv| hv.to_str().ok())
134 .and_then(|s| s.parse::<IpAddr>().ok())
135}
136
137fn ip_from_connect_info(extensions: &Extensions) -> Option<IpAddr> {
138 extensions
139 .get::<ConnectInfo<SocketAddr>>()
140 .map(|ConnectInfo(addr)| addr.ip())
141}
142
143fn get_path(req: &Request<Body>) -> String {
144 req.uri().path().to_owned()
145}
146
147fn get_user_agent(req: &Request<Body>) -> String {
148 req.headers()
149 .get(USER_AGENT)
150 .map(|x| x.to_string())
151 .unwrap_or_default()
152}
153
154fn get_user_id(_req: &Request<Body>) -> String {
155 String::new()
156}
157
158#[derive(Clone)]
159pub struct Analytics {
160 config: Config,
161}
162
163impl Analytics {
164 pub fn new(api_key: String) -> Self {
165 let config = Config::default();
166 let api_key_clone = api_key.clone();
167 let privacy_level = config.privacy_level;
168 let server_url = config.server_url.clone();
169
170 tokio::spawn(async move {
171 let mut interval = interval(Duration::from_secs(60));
172 loop {
173 interval.tick().await; let mut requests = REQUESTS.write().await;
176 if !requests.is_empty() {
177 let payload = Payload::new(
178 api_key_clone.clone(),
179 requests.clone(),
180 privacy_level,
181 );
182
183 requests.clear(); drop(requests); let url = server_url.clone();
187
188 tokio::spawn(async move { post_requests(payload, url).await });
189 }
190 }
191 });
192
193 Self { config }
194 }
195
196 pub fn with_privacy_level(mut self, privacy_level: i32) -> Self {
197 self.config.privacy_level = privacy_level;
198 self
199 }
200
201 pub fn with_server_url(mut self, server_url: String) -> Self {
202 if server_url.ends_with("/") {
203 self.config.server_url = server_url;
204 } else {
205 self.config.server_url = server_url + "/";
206 }
207 self
208 }
209
210 pub fn with_hostname_mapper<F>(mut self, mapper: F) -> Self
211 where
212 F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
213 {
214 self.config.get_hostname = Arc::new(mapper);
215 self
216 }
217
218 pub fn with_ip_address_mapper<F>(mut self, mapper: F) -> Self
219 where
220 F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
221 {
222 self.config.get_ip_address = Arc::new(mapper);
223 self
224 }
225
226 pub fn with_path_mapper<F>(mut self, mapper: F) -> Self
227 where
228 F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
229 {
230 self.config.get_path = Arc::new(mapper);
231 self
232 }
233
234 pub fn with_user_agent_mapper<F>(mut self, mapper: F) -> Self
235 where
236 F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
237 {
238 self.config.get_user_agent = Arc::new(mapper);
239 self
240 }
241}
242
243impl<S> Layer<S> for Analytics {
244 type Service = AnalyticsMiddleware<S>;
245
246 fn layer(&self, inner: S) -> Self::Service {
247 AnalyticsMiddleware {
248 config: Arc::new(self.config.clone()),
249 inner,
250 }
251 }
252}
253
254#[derive(Clone)]
255pub struct AnalyticsMiddleware<S> {
256 config: Arc<Config>,
257 inner: S,
258}
259
260lazy_static! {
261 static ref REQUESTS: Arc<RwLock<Vec<RequestData>>> = Arc::new(RwLock::new(vec![]));
262}
263
264#[derive(Debug, Clone, Serialize)]
265struct Payload {
266 api_key: String,
267 requests: Vec<RequestData>,
268 framework: String,
269 privacy_level: i32,
270}
271
272impl Payload {
273 pub fn new(api_key: String, requests: Vec<RequestData>, privacy_level: i32) -> Self {
274 Self {
275 api_key,
276 requests,
277 framework: "Axum".to_string(),
278 privacy_level,
279 }
280 }
281}
282
283async fn post_requests(data: Payload, server_url: String) {
284 let client = Client::new();
285 let response = client
286 .post(server_url + "api/log-request")
287 .json(&data)
288 .send()
289 .await;
290
291 match response {
292 Ok(resp) if resp.status().is_success() => {
293 return;
294 }
295 Ok(resp) => {
296 eprintln!(
297 "Failed to send analytics data. Server responded with status: {}",
298 resp.status()
299 );
300 }
301 Err(err) => {
302 eprintln!("Error sending analytics data: {}", err);
303 }
304 }
305}
306
307async fn log_request(request_data: RequestData) {
308 let mut requests = REQUESTS.write().await;
309 requests.push(request_data);
310}
311
312impl<S> Service<Request<Body>> for AnalyticsMiddleware<S>
313where
314 S: Service<Request<Body>, Response = Response> + Clone + Send + Sync + 'static,
315 S::Future: Send + 'static,
316{
317 type Response = S::Response;
318 type Error = S::Error;
319 type Future =
320 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
321
322 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
323 self.inner.poll_ready(cx)
324 }
325
326 fn call(&mut self, req: Request<Body>) -> Self::Future {
327 let now = Instant::now();
328
329 let hostname = (self.config.get_hostname)(&req);
330 let ip_address = (self.config.get_ip_address)(&req);
331 let path = (self.config.get_path)(&req);
332 let method = req.method().to_string();
333 let user_agent = (self.config.get_user_agent)(&req);
334 let user_id = (self.config.get_user_id)(&req);
335
336 let future = self.inner.call(req);
337
338 Box::pin(async move {
339 let res: Response = future.await?;
340
341 let response_time = now.elapsed().as_millis().min(u32::MAX as u128) as u32;
342
343 let request_data = RequestData::new(
344 hostname,
345 ip_address,
346 path,
347 user_agent,
348 method,
349 res.status().as_u16(),
350 response_time,
351 user_id,
352 Utc::now().to_rfc3339(),
353 );
354
355 tokio::spawn(log_request(request_data));
356
357 Ok(res)
358 })
359 }
360}