1use std::{
2 ops::DerefMut,
3 sync::{Arc, Mutex},
4 task::Poll,
5 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
6};
7
8use poem::{Body, Endpoint, IntoResponse, Middleware, Response};
9use serde::Serialize;
10use tokio::io::AsyncRead;
11
12#[derive(Serialize, Debug)]
13#[serde(rename_all = "camelCase")]
14pub struct Call {
15 pub timestamp_ms: u64,
16 pub duration_us: u64,
17 pub path: String,
18 pub response: CallResponse,
19}
20
21#[derive(Serialize, Debug)]
22#[serde(rename_all = "camelCase")]
23pub enum CallResponse {
24 Ok { length: usize },
25 Error { code: u16 },
26}
27
28pub type Buffer<const SIZE: usize> = circular_buffer::CircularBuffer<SIZE, Call>;
29
30pub type BufferRef<const SIZE: usize> = Arc<Mutex<Buffer<SIZE>>>;
31
32#[derive(Clone)]
33pub struct CallMiddleware<const SIZE: usize> {
34 calls: BufferRef<SIZE>,
35}
36
37impl<const SIZE: usize> Default for CallMiddleware<SIZE> {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl<const SIZE: usize> CallMiddleware<SIZE> {
44 pub fn new() -> Self {
45 Self {
46 calls: Arc::new(Mutex::new(Buffer::new())),
47 }
48 }
49
50 pub fn get(&self) -> BufferRef<SIZE> {
51 self.calls.clone()
52 }
53}
54
55impl Call {
56 pub fn successfull(
57 timestamp_ms: u64,
58 duration: Duration,
59 path: String,
60 reponse_length: usize,
61 ) -> Self {
62 Self {
63 timestamp_ms,
64 duration_us: duration.as_micros() as u64,
65 path,
66 response: CallResponse::Ok {
67 length: reponse_length,
68 },
69 }
70 }
71
72 pub fn error(timestamp_ms: u64, duration: Duration, path: String, response_code: u16) -> Self {
73 Self {
74 timestamp_ms,
75 duration_us: duration.as_micros() as u64,
76 path,
77 response: CallResponse::Error {
78 code: response_code,
79 },
80 }
81 }
82}
83
84pub struct CallMiddlewareImpl<const SIZE: usize, E: Endpoint> {
85 endpoint: E,
86 calls: BufferRef<SIZE>,
87}
88
89impl<const SIZE: usize, E: Endpoint> Middleware<E> for CallMiddleware<SIZE> {
90 type Output = CallMiddlewareImpl<SIZE, E>;
91
92 fn transform(&self, ep: E) -> Self::Output {
93 CallMiddlewareImpl {
94 endpoint: ep,
95 calls: self.calls.clone(),
96 }
97 }
98}
99
100impl<const SIZE: usize, E: Endpoint> Endpoint for CallMiddlewareImpl<SIZE, E> {
101 type Output = Response;
102
103 async fn call(&self, request: poem::Request) -> poem::Result<Self::Output> {
104 let path = request.original_uri().to_string();
105 let timestamp_ms = SystemTime::now()
106 .duration_since(UNIX_EPOCH)
107 .as_ref()
108 .map(Duration::as_millis)
109 .unwrap_or(0) as u64;
110 let now = Instant::now();
111 let res = self.endpoint.call(request).await;
112 let duration = now.elapsed();
113 match res {
114 Ok(response) => {
115 let r = response.into_response();
116 let (parts, body) = r.into_parts();
117 let async_read = body.into_async_read();
118 Ok(Response::from_parts(
119 parts,
120 Body::from_async_read(BodyReader {
121 timestamp_ms,
122 duration,
123 path,
124 calls: self.calls.clone(),
125 wrapped: async_read,
126 length: 0,
127 }),
128 ))
129 }
133 Err(err) => {
134 let mut guard = self.calls.lock().expect("can lock mutex");
135 guard.push_back(Call::error(
136 timestamp_ms,
137 duration,
138 path,
139 err.status().as_u16(),
140 ));
141 Err(err)
142 }
143 }
144 }
145}
146
147struct BodyReader<const SIZE: usize, T: AsyncRead + Unpin> {
148 wrapped: T,
149 timestamp_ms: u64,
150 duration: Duration,
151 path: String,
152 calls: BufferRef<SIZE>,
153 length: usize,
154}
155
156impl<const SIZE: usize, T: AsyncRead + Unpin> AsyncRead for BodyReader<SIZE, T> {
157 fn poll_read(
158 mut self: std::pin::Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 buf: &mut tokio::io::ReadBuf<'_>,
161 ) -> Poll<std::io::Result<()>> {
162 let initial = buf.filled().len();
163 let r = unsafe { std::pin::Pin::new_unchecked(&mut self.deref_mut().wrapped) }
164 .poll_read(cx, buf);
165 if let Poll::Ready(Err(_)) = &r {
166 let mut path = String::new();
167 std::mem::swap(&mut path, &mut self.path);
168 self.calls.lock().expect("can lock").push_back(Call::error(
169 self.timestamp_ms,
170 self.duration,
171 path,
172 u16::MAX,
173 ));
174 return r;
175 }
176 match buf.filled().len() - initial {
177 0 => {
178 let mut path = String::new();
179 std::mem::swap(&mut path, &mut self.path);
180 self.calls
181 .lock()
182 .expect("can lock")
183 .push_back(Call::successfull(
184 self.timestamp_ms,
185 self.duration,
186 path,
187 self.length,
188 ));
189 }
190 v => self.length += v,
191 }
192 r
193 }
194}