cat_dev/net/additions/
stream_id.rs1use crate::{
8 errors::CatBridgeError,
9 net::models::{FromRequest, FromRequestParts, Request, Response},
10};
11use std::{
12 convert::Infallible,
13 fmt::{Display, Formatter, Result as FmtResult},
14 ops::Deref,
15};
16use tower::{Layer, Service};
17use tracing::{
18 Id as TracingId, error_span,
19 instrument::{Instrument, Instrumented},
20};
21use valuable::Valuable;
22
23#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Valuable)]
25pub struct StreamID(u64);
26
27impl StreamID {
28 #[must_use]
29 pub const fn from_existing(id: u64) -> Self {
30 Self(id)
31 }
32
33 #[must_use]
35 pub fn to_raw(&self) -> u64 {
36 self.0
37 }
38}
39
40impl Display for StreamID {
41 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
42 write!(fmt, "{}", self.0)
43 }
44}
45
46impl Deref for StreamID {
47 type Target = u64;
48
49 fn deref(&self) -> &Self::Target {
50 &self.0
51 }
52}
53
54impl<State: Clone + Send + Sync + 'static> FromRequestParts<State> for StreamID {
55 async fn from_request_parts(parts: &mut Request<State>) -> Result<Self, CatBridgeError> {
56 Ok(Self::from_existing(parts.stream_id()))
57 }
58}
59
60impl<State: Clone + Send + Sync + 'static> FromRequest<State> for StreamID {
61 async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
62 Ok(Self::from_existing(req.stream_id()))
63 }
64}
65
66#[derive(Clone, Debug)]
67pub struct StreamIDLayer;
68
69impl<Layered> Layer<Layered> for StreamIDLayer
70where
71 Layered: Clone,
72{
73 type Service = LayeredStreamID<Layered>;
74
75 fn layer(&self, inner: Layered) -> Self::Service {
76 LayeredStreamID { inner }
77 }
78}
79
80#[derive(Clone)]
81pub struct LayeredStreamID<Layered> {
82 inner: Layered,
83}
84
85impl<Layered, State: Clone + Send + Sync + 'static> Service<Request<State>>
86 for LayeredStreamID<Layered>
87where
88 Layered:
89 Service<Request<State>, Response = Response, Error = Infallible> + Clone + Send + 'static,
90 Layered::Future: Send + 'static,
91{
92 type Response = Layered::Response;
93 type Error = Layered::Error;
94 type Future = Instrumented<Layered::Future>;
95
96 #[inline]
97 fn poll_ready(
98 &mut self,
99 ctx: &mut std::task::Context<'_>,
100 ) -> std::task::Poll<Result<(), Self::Error>> {
101 self.inner.poll_ready(ctx)
102 }
103
104 fn call(&mut self, mut req: Request<State>) -> Self::Future {
105 let parent_span = req
106 .extensions()
107 .get::<Option<TracingId>>()
108 .cloned()
109 .unwrap_or(None);
110 let stream_id = StreamID::from_existing(req.stream_id());
111
112 let span = error_span!(
113 parent: parent_span,
114 "WithStreamID",
115 request.stream_id = %stream_id,
116 );
117 req.extensions_mut().insert::<StreamID>(stream_id);
118 req.extensions_mut().insert::<Option<TracingId>>(span.id());
119 self.inner.call(req).instrument(span.or_current())
120 }
121}