axum_analytics/
analytics.rs

1use axum::{body::Body, extract::ConnectInfo, http::Request, response::Response};
2use chrono::Utc;
3use futures::Future;
4use http::{
5    header::{HeaderValue, HOST, USER_AGENT},
6    Extensions, HeaderMap,
7};
8use lazy_static::lazy_static;
9use reqwest::Client;
10use serde::Serialize;
11use std::{
12    net::{IpAddr, SocketAddr},
13    task::{Context, Poll},
14    time::Instant,
15};
16use std::{pin::Pin, sync::Arc};
17use tokio::sync::RwLock;
18use tokio::time::{interval, Duration};
19use tower::{Layer, Service};
20
21#[derive(Debug, Clone, Serialize)]
22struct RequestData {
23    hostname: String,
24    ip_address: String,
25    path: String,
26    user_agent: String,
27    method: String,
28    response_time: u32,
29    status: u16,
30    user_id: String,
31    created_at: String,
32}
33
34impl RequestData {
35    #[allow(clippy::too_many_arguments)]
36    pub fn new(
37        hostname: String,
38        ip_address: String,
39        path: String,
40        user_agent: String,
41        method: String,
42        status: u16,
43        response_time: u32,
44        user_id: String,
45        created_at: String,
46    ) -> Self {
47        Self {
48            hostname,
49            ip_address,
50            path,
51            user_agent,
52            method,
53            response_time,
54            status,
55            user_id,
56            created_at,
57        }
58    }
59}
60
61type StringMapper = dyn for<'a> Fn(&Request<Body>) -> String + Send + Sync;
62
63#[derive(Clone)]
64struct Config {
65    privacy_level: i32,
66    server_url: String,
67    get_hostname: Arc<StringMapper>,
68    get_ip_address: Arc<StringMapper>,
69    get_path: Arc<StringMapper>,
70    get_user_agent: Arc<StringMapper>,
71    get_user_id: Arc<StringMapper>,
72}
73
74impl Default for Config {
75    fn default() -> Self {
76        Self {
77            privacy_level: 0,
78            server_url: "https://www.apianalytics-server.com/".to_string(),
79            get_hostname: Arc::new(get_hostname),
80            get_ip_address: Arc::new(get_ip_address),
81            get_path: Arc::new(get_path),
82            get_user_agent: Arc::new(get_user_agent),
83            get_user_id: Arc::new(get_user_id),
84        }
85    }
86}
87
88pub trait HeaderValueExt {
89    fn to_string(&self) -> String;
90}
91
92impl HeaderValueExt for HeaderValue {
93    fn to_string(&self) -> String {
94        self.to_str().unwrap_or_default().to_string()
95    }
96}
97
98fn get_hostname(req: &Request<Body>) -> String {
99    req.headers()
100        .get(HOST)
101        .map(|x| x.to_string())
102        .unwrap_or_default()
103}
104
105fn get_ip_address(req: &Request<Body>) -> String {
106    let extensions = req.extensions();
107    let headers = req.headers();
108    let mut ip_address = String::new();
109    if let Some(val) = ip_from_x_forwarded_for(headers) {
110        ip_address = val.to_string();
111    } else if let Some(val) = ip_from_x_real_ip(headers) {
112        ip_address = val.to_string();
113    } else if let Some(val) = ip_from_connect_info(extensions) {
114        ip_address = val.to_string();
115    }
116    ip_address
117}
118
119fn ip_from_x_forwarded_for(headers: &HeaderMap) -> Option<IpAddr> {
120    headers
121        .get("x-forwarded-for")
122        .and_then(|hv| hv.to_str().ok())
123        .and_then(|s| {
124            s.split(',')
125                .rev()
126                .find_map(|s| s.trim().parse::<IpAddr>().ok())
127        })
128}
129
130fn ip_from_x_real_ip(headers: &HeaderMap) -> Option<IpAddr> {
131    headers
132        .get("x-real-ip")
133        .and_then(|hv| hv.to_str().ok())
134        .and_then(|s| s.parse::<IpAddr>().ok())
135}
136
137fn ip_from_connect_info(extensions: &Extensions) -> Option<IpAddr> {
138    extensions
139        .get::<ConnectInfo<SocketAddr>>()
140        .map(|ConnectInfo(addr)| addr.ip())
141}
142
143fn get_path(req: &Request<Body>) -> String {
144    req.uri().path().to_owned()
145}
146
147fn get_user_agent(req: &Request<Body>) -> String {
148    req.headers()
149        .get(USER_AGENT)
150        .map(|x| x.to_string())
151        .unwrap_or_default()
152}
153
154fn get_user_id(_req: &Request<Body>) -> String {
155    String::new()
156}
157
158#[derive(Clone)]
159pub struct Analytics {
160    config: Config,
161}
162
163impl Analytics {
164    pub fn new(api_key: String) -> Self {
165        let config = Config::default();
166        let api_key_clone = api_key.clone();
167        let privacy_level = config.privacy_level;
168        let server_url = config.server_url.clone();
169
170        tokio::spawn(async move {
171            let mut interval = interval(Duration::from_secs(60));
172            loop {
173                interval.tick().await; // Wait for the next interval
174
175                let mut requests = REQUESTS.write().await;
176                if !requests.is_empty() {
177                    let payload = Payload::new(
178                        api_key_clone.clone(),
179                        requests.clone(),
180                        privacy_level,
181                    );
182
183                    requests.clear(); // Clear requests after reading
184                    drop(requests); // Explicitly drop the lock
185
186                    let url = server_url.clone();
187
188                    tokio::spawn(async move { post_requests(payload, url).await });
189                }
190            }
191        });
192
193        Self { config }
194    }
195
196    pub fn with_privacy_level(mut self, privacy_level: i32) -> Self {
197        self.config.privacy_level = privacy_level;
198        self
199    }
200
201    pub fn with_server_url(mut self, server_url: String) -> Self {
202        if server_url.ends_with("/") {
203            self.config.server_url = server_url;
204        } else {
205            self.config.server_url = server_url + "/";
206        }
207        self
208    }
209
210    pub fn with_hostname_mapper<F>(mut self, mapper: F) -> Self
211    where
212        F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
213    {
214        self.config.get_hostname = Arc::new(mapper);
215        self
216    }
217
218    pub fn with_ip_address_mapper<F>(mut self, mapper: F) -> Self
219    where
220        F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
221    {
222        self.config.get_ip_address = Arc::new(mapper);
223        self
224    }
225
226    pub fn with_path_mapper<F>(mut self, mapper: F) -> Self
227    where
228        F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
229    {
230        self.config.get_path = Arc::new(mapper);
231        self
232    }
233
234    pub fn with_user_agent_mapper<F>(mut self, mapper: F) -> Self
235    where
236        F: Fn(&Request<Body>) -> String + Send + Sync + 'static,
237    {
238        self.config.get_user_agent = Arc::new(mapper);
239        self
240    }
241}
242
243impl<S> Layer<S> for Analytics {
244    type Service = AnalyticsMiddleware<S>;
245
246    fn layer(&self, inner: S) -> Self::Service {
247        AnalyticsMiddleware {
248            config: Arc::new(self.config.clone()),
249            inner,
250        }
251    }
252}
253
254#[derive(Clone)]
255pub struct AnalyticsMiddleware<S> {
256    config: Arc<Config>,
257    inner: S,
258}
259
260lazy_static! {
261    static ref REQUESTS: Arc<RwLock<Vec<RequestData>>> = Arc::new(RwLock::new(vec![]));
262}
263
264#[derive(Debug, Clone, Serialize)]
265struct Payload {
266    api_key: String,
267    requests: Vec<RequestData>,
268    framework: String,
269    privacy_level: i32,
270}
271
272impl Payload {
273    pub fn new(api_key: String, requests: Vec<RequestData>, privacy_level: i32) -> Self {
274        Self {
275            api_key,
276            requests,
277            framework: "Axum".to_string(),
278            privacy_level,
279        }
280    }
281}
282
283async fn post_requests(data: Payload, server_url: String) {
284    let client = Client::new();
285    let response = client
286        .post(server_url + "api/log-request")
287        .json(&data)
288        .send()
289        .await;
290
291    match response {
292        Ok(resp) if resp.status().is_success() => {
293            return;
294        }
295        Ok(resp) => {
296            eprintln!(
297                "Failed to send analytics data. Server responded with status: {}",
298                resp.status()
299            );
300        }
301        Err(err) => {
302            eprintln!("Error sending analytics data: {}", err);
303        }
304    }
305}
306
307async fn log_request(request_data: RequestData) {
308    let mut requests = REQUESTS.write().await;
309    requests.push(request_data);
310}
311
312impl<S> Service<Request<Body>> for AnalyticsMiddleware<S>
313where
314    S: Service<Request<Body>, Response = Response> + Clone + Send + Sync + 'static,
315    S::Future: Send + 'static,
316{
317    type Response = S::Response;
318    type Error = S::Error;
319    type Future =
320        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
321
322    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
323        self.inner.poll_ready(cx)
324    }
325
326    fn call(&mut self, req: Request<Body>) -> Self::Future {
327        let now = Instant::now();
328
329        let hostname = (self.config.get_hostname)(&req);
330        let ip_address = (self.config.get_ip_address)(&req);
331        let path = (self.config.get_path)(&req);
332        let method = req.method().to_string();
333        let user_agent = (self.config.get_user_agent)(&req);
334        let user_id = (self.config.get_user_id)(&req);
335
336        let future = self.inner.call(req);
337
338        Box::pin(async move {
339            let res: Response = future.await?;
340
341            let response_time = now.elapsed().as_millis().min(u32::MAX as u128) as u32;
342
343            let request_data = RequestData::new(
344                hostname,
345                ip_address,
346                path,
347                user_agent,
348                method,
349                res.status().as_u16(),
350                response_time,
351                user_id,
352                Utc::now().to_rfc3339(),
353            );
354
355            tokio::spawn(log_request(request_data));
356
357            Ok(res)
358        })
359    }
360}