cat_dev/net/additions/
request_id.rs1use crate::{
12 errors::CatBridgeError,
13 net::{
14 errors::CommonNetAPIError,
15 models::{FromRequest, FromRequestParts, Request, Response},
16 },
17};
18use rand::{TryRngCore, rng};
19use std::{
20 convert::Infallible,
21 fmt::{Debug, Display, Formatter, Result as FmtResult, Write},
22 ops::Deref,
23 sync::Arc,
24};
25use tower::{Layer, Service};
26use tracing::{
27 Id as TracingId, error_span,
28 field::valuable,
29 instrument::{Instrument, Instrumented},
30};
31use valuable::{Valuable, Value, Visit};
32
33#[derive(Clone, PartialEq, Eq)]
34pub struct RequestID(Arc<String>);
35
36impl RequestID {
37 #[must_use]
41 pub fn generate() -> Self {
42 let mut buff = [0_u8; 16];
43 _ = rng().try_fill_bytes(&mut buff);
46 let mut id = String::with_capacity(32);
47 for byte in buff {
48 _ = write!(&mut id, "{byte:02x}");
49 }
50 Self(Arc::new(id))
51 }
52
53 #[must_use]
57 pub fn from_existing(id: String) -> Self {
58 Self(Arc::new(id))
59 }
60
61 #[must_use]
66 pub fn fatal_unknown() -> Self {
67 Self(Arc::new("<unknown>".to_owned()))
68 }
69
70 #[must_use]
72 pub fn str(&self) -> &str {
73 self.0.as_str()
74 }
75}
76
77impl<State: Clone + Send + Sync + 'static> FromRequestParts<State> for RequestID {
78 async fn from_request_parts(parts: &mut Request<State>) -> Result<Self, CatBridgeError> {
79 parts
80 .extensions()
81 .get::<RequestID>()
82 .cloned()
83 .ok_or(CommonNetAPIError::ExtensionNotPresent.into())
84 }
85}
86
87impl<State: Clone + Send + Sync + 'static> FromRequest<State> for RequestID {
88 async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
89 req.extensions()
90 .get::<RequestID>()
91 .cloned()
92 .ok_or(CommonNetAPIError::ExtensionNotPresent.into())
93 }
94}
95
96impl Debug for RequestID {
97 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
98 fmt.debug_struct("RequestID").field("id", &self.0).finish()
99 }
100}
101
102impl Display for RequestID {
103 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
104 write!(fmt, "{}", self.0)
105 }
106}
107
108impl Deref for RequestID {
109 type Target = str;
110
111 fn deref(&self) -> &Self::Target {
112 self.str()
113 }
114}
115
116impl Valuable for RequestID {
117 fn as_value(&self) -> Value<'_> {
118 Value::String(self.0.as_str())
119 }
120
121 fn visit(&self, visitor: &mut dyn Visit) {
122 visitor.visit_value(self.as_value());
123 }
124}
125
126#[derive(Clone, Debug)]
127pub struct RequestIDLayer(String);
128
129impl RequestIDLayer {
130 #[must_use]
131 pub const fn new(service_name: String) -> Self {
132 Self(service_name)
133 }
134}
135
136impl<Layered> Layer<Layered> for RequestIDLayer
137where
138 Layered: Clone,
139{
140 type Service = LayeredRequestID<Layered>;
141
142 fn layer(&self, inner: Layered) -> Self::Service {
143 LayeredRequestID {
144 inner,
145 service_name: self.0.clone(),
146 }
147 }
148}
149
150#[derive(Clone)]
151pub struct LayeredRequestID<Layered> {
152 inner: Layered,
153 service_name: String,
154}
155
156impl<Layered, State: Clone + Send + Sync + 'static> Service<Request<State>>
157 for LayeredRequestID<Layered>
158where
159 Layered:
160 Service<Request<State>, Response = Response, Error = Infallible> + Clone + Send + 'static,
161 Layered::Future: Send + 'static,
162{
163 type Response = Layered::Response;
164 type Error = Layered::Error;
165 type Future = Instrumented<Layered::Future>;
166
167 #[inline]
168 fn poll_ready(
169 &mut self,
170 ctx: &mut std::task::Context<'_>,
171 ) -> std::task::Poll<Result<(), Self::Error>> {
172 self.inner.poll_ready(ctx)
173 }
174
175 fn call(&mut self, mut req: Request<State>) -> Self::Future {
176 let parent_span = req
177 .extensions()
178 .get::<Option<TracingId>>()
179 .cloned()
180 .unwrap_or(None);
181 let req_id = RequestID::generate();
182
183 let span = error_span!(
184 parent: parent_span,
185 "WithRequestID",
186 lisa.subsystem = %self.service_name,
187 request.id = valuable(&req_id),
188 );
189 req.extensions_mut().insert::<RequestID>(req_id);
190 req.extensions_mut().insert::<Option<TracingId>>(span.id());
191 self.inner.call(req).instrument(span.or_current())
192 }
193}
194
195#[cfg(test)]
196mod unit_tests {
197 use super::*;
198
199 fn only_accept<Ty: Clone + Send + Sync + 'static>(_unused: Option<Ty>) {}
200
201 #[test]
202 pub fn assert_is_extensionable() {
203 only_accept::<RequestID>(None);
204 }
205}