cat_dev/net/server/requestable/
extension.rs

1//! Handles fetching state that is dynamic per request (e.g. a request id) that
2//! can be attached.
3//!
4//! If you're dealing with app/server wide state, prefer [`crate::net::requestable::State`].
5
6use crate::{
7	errors::CatBridgeError,
8	net::{
9		errors::CommonNetAPIError,
10		models::{FromRef, FromRequest, FromRequestParts, Request},
11		server::models::{FromResponseStreamEvent, ResponseStreamEvent},
12	},
13};
14use std::{
15	fmt::{Debug, Formatter, Result as FmtResult},
16	ops::{Deref, DerefMut},
17	task::{Context, Poll},
18};
19use tower::{Layer, Service};
20use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
21
22/// A type that has been cloned, and extracted out of an extension map.
23pub struct Extension<Ty>(pub Ty);
24
25impl<Ty, State: Clone + Send + Sync + 'static> FromRef<Request<State>> for Option<Extension<Ty>>
26where
27	Ty: Clone + Send + Sync + 'static,
28{
29	fn from_ref(input: &Request<State>) -> Self {
30		input
31			.extensions()
32			.get::<Ty>()
33			.map(|ext| Extension(ext.clone()))
34	}
35}
36
37impl<Ty, State: Clone + Send + Sync + 'static> FromRequestParts<State> for Extension<Ty>
38where
39	Ty: Clone + Send + Sync + 'static,
40{
41	async fn from_request_parts(req: &mut Request<State>) -> Result<Self, CatBridgeError> {
42		req.extensions()
43			.get::<Ty>()
44			.map(|ext| Extension(ext.clone()))
45			.ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
46	}
47}
48
49impl<Ty, State: Clone + Send + Sync + 'static> FromRequest<State> for Extension<Ty>
50where
51	Ty: Clone + Send + Sync + 'static,
52{
53	async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
54		req.extensions()
55			.get::<Ty>()
56			.map(|ext| Extension(ext.clone()))
57			.ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
58	}
59}
60
61impl<Ty, State: Clone + Send + Sync + 'static> FromResponseStreamEvent<State> for Extension<Ty>
62where
63	Ty: Clone + Send + Sync + 'static,
64{
65	async fn from_stream_event(
66		event: &mut ResponseStreamEvent<State>,
67	) -> Result<Self, CatBridgeError> {
68		event
69			.extensions()
70			.get::<Ty>()
71			.map(|ext| Extension(ext.clone()))
72			.ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
73	}
74}
75
76impl<ServiceTy, ExtensionTy> Layer<ServiceTy> for Extension<ExtensionTy>
77where
78	ExtensionTy: Clone + Send + Sync + 'static,
79{
80	type Service = AddExtension<ServiceTy, ExtensionTy>;
81
82	fn layer(&self, inner: ServiceTy) -> Self::Service {
83		AddExtension {
84			value: self.0.clone(),
85			wrapped: inner,
86		}
87	}
88}
89
90impl<Ty: Clone + Send + Sync + 'static> Deref for Extension<Ty> {
91	type Target = Ty;
92
93	fn deref(&self) -> &Self::Target {
94		&self.0
95	}
96}
97
98impl<Ty: Clone + Send + Sync + 'static> DerefMut for Extension<Ty> {
99	fn deref_mut(&mut self) -> &mut Self::Target {
100		&mut self.0
101	}
102}
103
104impl<Ty> Debug for Extension<Ty>
105where
106	Ty: Debug,
107{
108	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
109		fmt.debug_struct("Extension")
110			.field("inner", &self.0)
111			.finish()
112	}
113}
114
115const EXTENSION_FIELDS: &[NamedField<'static>] = &[NamedField::new("inner")];
116
117impl<Ty> Structable for Extension<Ty>
118where
119	Ty: Valuable,
120{
121	fn definition(&self) -> StructDef<'_> {
122		StructDef::new_static("Extension", Fields::Named(EXTENSION_FIELDS))
123	}
124}
125
126impl<Ty> Valuable for Extension<Ty>
127where
128	Ty: Valuable,
129{
130	fn as_value(&self) -> Value<'_> {
131		Value::Structable(self)
132	}
133
134	fn visit(&self, visitor: &mut dyn Visit) {
135		visitor.visit_named_fields(&NamedValues::new(EXTENSION_FIELDS, &[self.0.as_value()]));
136	}
137}
138
139/// Middleware for adding some shareable value to a request extensions field.
140///
141/// This is the actual tower service that gets processed during requests.
142#[derive(Clone)]
143pub struct AddExtension<ServiceTy, ExtensionTy> {
144	value: ExtensionTy,
145	wrapped: ServiceTy,
146}
147
148impl<ServiceTy, ExtensionTy, State> Service<Request<State>> for AddExtension<ServiceTy, ExtensionTy>
149where
150	ServiceTy: Service<Request<State>>,
151	ExtensionTy: Clone + Send + Sync + 'static,
152	State: Clone + Send + Sync + 'static,
153{
154	type Response = ServiceTy::Response;
155	type Error = ServiceTy::Error;
156	type Future = ServiceTy::Future;
157
158	#[inline]
159	fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160		self.wrapped.poll_ready(ctx)
161	}
162
163	fn call(&mut self, mut req: Request<State>) -> Self::Future {
164		req.extensions_mut().insert(self.value.clone());
165		self.wrapped.call(req)
166	}
167}
168
169impl<ServiceTy, ExtensionTy, State> Service<ResponseStreamEvent<State>>
170	for AddExtension<ServiceTy, ExtensionTy>
171where
172	ServiceTy: Service<ResponseStreamEvent<State>>,
173	ExtensionTy: Clone + Send + Sync + 'static,
174	State: Clone + Send + Sync + 'static,
175{
176	type Response = ServiceTy::Response;
177	type Error = ServiceTy::Error;
178	type Future = ServiceTy::Future;
179
180	#[inline]
181	fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182		self.wrapped.poll_ready(ctx)
183	}
184
185	fn call(&mut self, mut evt: ResponseStreamEvent<State>) -> Self::Future {
186		evt.extensions_mut().insert(self.value.clone());
187		self.wrapped.call(evt)
188	}
189}
190
191#[cfg(test)]
192mod unit_tests {
193	use super::*;
194	use crate::net::server::{Router, test_helpers::router_body_no_close};
195	use std::sync::{Arc, atomic::AtomicU8};
196
197	#[tokio::test]
198	pub async fn test_extension() {
199		async fn echo_extension(
200			Extension(data): Extension<String>,
201			Extension(_unused): Extension<Arc<AtomicU8>>,
202		) -> String {
203			data
204		}
205
206		let mut router = Router::new();
207		router
208			.add_route(&[0x1], echo_extension)
209			.expect("Failed to add route to router!");
210		router.layer(Extension::<String>("Hey from an extension!".to_owned()));
211		router.layer(Extension::<Arc<AtomicU8>>(Arc::new(AtomicU8::new(0))));
212		assert_eq!(
213			router_body_no_close(&mut router, &[0x1, 0x2, 0x3, 0x4]).await,
214			b"Hey from an extension!",
215		);
216	}
217}