cat_dev/net/additions/
stream_id.rs

1//! A unique identifier for a specific 'live' connection, or TCP Stream.
2//!
3//! On UDP this will still have a pretty strong guarantee of being unique.
4//! Still a lot less guarantees can be made, but the hope it is still decently
5//! unique.
6
7use crate::{
8	errors::CatBridgeError,
9	net::models::{FromRequest, FromRequestParts, Request, Response},
10};
11use std::{
12	convert::Infallible,
13	fmt::{Display, Formatter, Result as FmtResult},
14	ops::Deref,
15};
16use tower::{Layer, Service};
17use tracing::{
18	Id as TracingId, error_span,
19	instrument::{Instrument, Instrumented},
20};
21use valuable::Valuable;
22
23/// A unique identifier for a particular TCP Stream, or connection.
24#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Valuable)]
25pub struct StreamID(u64);
26
27impl StreamID {
28	#[must_use]
29	pub const fn from_existing(id: u64) -> Self {
30		Self(id)
31	}
32
33	/// Convert this stream id to a raw value.
34	#[must_use]
35	pub fn to_raw(&self) -> u64 {
36		self.0
37	}
38}
39
40impl Display for StreamID {
41	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
42		write!(fmt, "{}", self.0)
43	}
44}
45
46impl Deref for StreamID {
47	type Target = u64;
48
49	fn deref(&self) -> &Self::Target {
50		&self.0
51	}
52}
53
54impl<State: Clone + Send + Sync + 'static> FromRequestParts<State> for StreamID {
55	async fn from_request_parts(parts: &mut Request<State>) -> Result<Self, CatBridgeError> {
56		Ok(Self::from_existing(parts.stream_id()))
57	}
58}
59
60impl<State: Clone + Send + Sync + 'static> FromRequest<State> for StreamID {
61	async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
62		Ok(Self::from_existing(req.stream_id()))
63	}
64}
65
66#[derive(Clone, Debug)]
67pub struct StreamIDLayer;
68
69impl<Layered> Layer<Layered> for StreamIDLayer
70where
71	Layered: Clone,
72{
73	type Service = LayeredStreamID<Layered>;
74
75	fn layer(&self, inner: Layered) -> Self::Service {
76		LayeredStreamID { inner }
77	}
78}
79
80#[derive(Clone)]
81pub struct LayeredStreamID<Layered> {
82	inner: Layered,
83}
84
85impl<Layered, State: Clone + Send + Sync + 'static> Service<Request<State>>
86	for LayeredStreamID<Layered>
87where
88	Layered:
89		Service<Request<State>, Response = Response, Error = Infallible> + Clone + Send + 'static,
90	Layered::Future: Send + 'static,
91{
92	type Response = Layered::Response;
93	type Error = Layered::Error;
94	type Future = Instrumented<Layered::Future>;
95
96	#[inline]
97	fn poll_ready(
98		&mut self,
99		ctx: &mut std::task::Context<'_>,
100	) -> std::task::Poll<Result<(), Self::Error>> {
101		self.inner.poll_ready(ctx)
102	}
103
104	fn call(&mut self, mut req: Request<State>) -> Self::Future {
105		let parent_span = req
106			.extensions()
107			.get::<Option<TracingId>>()
108			.cloned()
109			.unwrap_or(None);
110		let stream_id = StreamID::from_existing(req.stream_id());
111
112		let span = error_span!(
113		  parent: parent_span,
114		  "WithStreamID",
115		  request.stream_id = %stream_id,
116		);
117		req.extensions_mut().insert::<StreamID>(stream_id);
118		req.extensions_mut().insert::<Option<TracingId>>(span.id());
119		self.inner.call(req).instrument(span.or_current())
120	}
121}