groundwork/
call.rs

1use std::{
2    ops::DerefMut,
3    sync::{Arc, Mutex},
4    task::Poll,
5    time::{Duration, Instant, SystemTime, UNIX_EPOCH},
6};
7
8use poem::{Body, Endpoint, IntoResponse, Middleware, Response};
9use serde::Serialize;
10use tokio::io::AsyncRead;
11
12#[derive(Serialize, Debug)]
13#[serde(rename_all = "camelCase")]
14pub struct Call {
15    pub timestamp_ms: u64,
16    pub duration_us: u64,
17    pub path: String,
18    pub response: CallResponse,
19}
20
21#[derive(Serialize, Debug)]
22#[serde(rename_all = "camelCase")]
23pub enum CallResponse {
24    Ok { length: usize },
25    Error { code: u16 },
26}
27
28pub type Buffer<const SIZE: usize> = circular_buffer::CircularBuffer<SIZE, Call>;
29
30pub type BufferRef<const SIZE: usize> = Arc<Mutex<Buffer<SIZE>>>;
31
32#[derive(Clone)]
33pub struct CallMiddleware<const SIZE: usize> {
34    calls: BufferRef<SIZE>,
35}
36
37impl<const SIZE: usize> Default for CallMiddleware<SIZE> {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl<const SIZE: usize> CallMiddleware<SIZE> {
44    pub fn new() -> Self {
45        Self {
46            calls: Arc::new(Mutex::new(Buffer::new())),
47        }
48    }
49
50    pub fn get(&self) -> BufferRef<SIZE> {
51        self.calls.clone()
52    }
53}
54
55impl Call {
56    pub fn successfull(
57        timestamp_ms: u64,
58        duration: Duration,
59        path: String,
60        reponse_length: usize,
61    ) -> Self {
62        Self {
63            timestamp_ms,
64            duration_us: duration.as_micros() as u64,
65            path,
66            response: CallResponse::Ok {
67                length: reponse_length,
68            },
69        }
70    }
71
72    pub fn error(timestamp_ms: u64, duration: Duration, path: String, response_code: u16) -> Self {
73        Self {
74            timestamp_ms,
75            duration_us: duration.as_micros() as u64,
76            path,
77            response: CallResponse::Error {
78                code: response_code,
79            },
80        }
81    }
82}
83
84pub struct CallMiddlewareImpl<const SIZE: usize, E: Endpoint> {
85    endpoint: E,
86    calls: BufferRef<SIZE>,
87}
88
89impl<const SIZE: usize, E: Endpoint> Middleware<E> for CallMiddleware<SIZE> {
90    type Output = CallMiddlewareImpl<SIZE, E>;
91
92    fn transform(&self, ep: E) -> Self::Output {
93        CallMiddlewareImpl {
94            endpoint: ep,
95            calls: self.calls.clone(),
96        }
97    }
98}
99
100impl<const SIZE: usize, E: Endpoint> Endpoint for CallMiddlewareImpl<SIZE, E> {
101    type Output = Response;
102
103    async fn call(&self, request: poem::Request) -> poem::Result<Self::Output> {
104        let path = request.original_uri().to_string();
105        let timestamp_ms = SystemTime::now()
106            .duration_since(UNIX_EPOCH)
107            .as_ref()
108            .map(Duration::as_millis)
109            .unwrap_or(0) as u64;
110        let now = Instant::now();
111        let res = self.endpoint.call(request).await;
112        let duration = now.elapsed();
113        match res {
114            Ok(response) => {
115                let r = response.into_response();
116                let (parts, body) = r.into_parts();
117                let async_read = body.into_async_read();
118                Ok(Response::from_parts(
119                    parts,
120                    Body::from_async_read(BodyReader {
121                        timestamp_ms,
122                        duration,
123                        path,
124                        calls: self.calls.clone(),
125                        wrapped: async_read,
126                        length: 0,
127                    }),
128                ))
129                // let mut guard = self.calls.lock().expect("can lock mutex");
130                // guard.push_back(Call::successfull(timestamp_ms, duration, path, body.len()));
131                // Ok(Response::from_parts(parts, Body::from_bytes(body)))
132            }
133            Err(err) => {
134                let mut guard = self.calls.lock().expect("can lock mutex");
135                guard.push_back(Call::error(
136                    timestamp_ms,
137                    duration,
138                    path,
139                    err.status().as_u16(),
140                ));
141                Err(err)
142            }
143        }
144    }
145}
146
147struct BodyReader<const SIZE: usize, T: AsyncRead + Unpin> {
148    wrapped: T,
149    timestamp_ms: u64,
150    duration: Duration,
151    path: String,
152    calls: BufferRef<SIZE>,
153    length: usize,
154}
155
156impl<const SIZE: usize, T: AsyncRead + Unpin> AsyncRead for BodyReader<SIZE, T> {
157    fn poll_read(
158        mut self: std::pin::Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160        buf: &mut tokio::io::ReadBuf<'_>,
161    ) -> Poll<std::io::Result<()>> {
162        let initial = buf.filled().len();
163        let r = unsafe { std::pin::Pin::new_unchecked(&mut self.deref_mut().wrapped) }
164            .poll_read(cx, buf);
165        if let Poll::Ready(Err(_)) = &r {
166            let mut path = String::new();
167            std::mem::swap(&mut path, &mut self.path);
168            self.calls.lock().expect("can lock").push_back(Call::error(
169                self.timestamp_ms,
170                self.duration,
171                path,
172                u16::MAX,
173            ));
174            return r;
175        }
176        match buf.filled().len() - initial {
177            0 => {
178                let mut path = String::new();
179                std::mem::swap(&mut path, &mut self.path);
180                self.calls
181                    .lock()
182                    .expect("can lock")
183                    .push_back(Call::successfull(
184                        self.timestamp_ms,
185                        self.duration,
186                        path,
187                        self.length,
188                    ));
189            }
190            v => self.length += v,
191        }
192        r
193    }
194}