cat_dev/net/server/requestable/
extension.rs1use crate::{
7 errors::CatBridgeError,
8 net::{
9 errors::CommonNetAPIError,
10 models::{FromRef, FromRequest, FromRequestParts, Request},
11 server::models::{FromResponseStreamEvent, ResponseStreamEvent},
12 },
13};
14use std::{
15 fmt::{Debug, Formatter, Result as FmtResult},
16 ops::{Deref, DerefMut},
17 task::{Context, Poll},
18};
19use tower::{Layer, Service};
20use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
21
22pub struct Extension<Ty>(pub Ty);
24
25impl<Ty, State: Clone + Send + Sync + 'static> FromRef<Request<State>> for Option<Extension<Ty>>
26where
27 Ty: Clone + Send + Sync + 'static,
28{
29 fn from_ref(input: &Request<State>) -> Self {
30 input
31 .extensions()
32 .get::<Ty>()
33 .map(|ext| Extension(ext.clone()))
34 }
35}
36
37impl<Ty, State: Clone + Send + Sync + 'static> FromRequestParts<State> for Extension<Ty>
38where
39 Ty: Clone + Send + Sync + 'static,
40{
41 async fn from_request_parts(req: &mut Request<State>) -> Result<Self, CatBridgeError> {
42 req.extensions()
43 .get::<Ty>()
44 .map(|ext| Extension(ext.clone()))
45 .ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
46 }
47}
48
49impl<Ty, State: Clone + Send + Sync + 'static> FromRequest<State> for Extension<Ty>
50where
51 Ty: Clone + Send + Sync + 'static,
52{
53 async fn from_request(req: Request<State>) -> Result<Self, CatBridgeError> {
54 req.extensions()
55 .get::<Ty>()
56 .map(|ext| Extension(ext.clone()))
57 .ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
58 }
59}
60
61impl<Ty, State: Clone + Send + Sync + 'static> FromResponseStreamEvent<State> for Extension<Ty>
62where
63 Ty: Clone + Send + Sync + 'static,
64{
65 async fn from_stream_event(
66 event: &mut ResponseStreamEvent<State>,
67 ) -> Result<Self, CatBridgeError> {
68 event
69 .extensions()
70 .get::<Ty>()
71 .map(|ext| Extension(ext.clone()))
72 .ok_or_else(|| CommonNetAPIError::ExtensionNotPresent.into())
73 }
74}
75
76impl<ServiceTy, ExtensionTy> Layer<ServiceTy> for Extension<ExtensionTy>
77where
78 ExtensionTy: Clone + Send + Sync + 'static,
79{
80 type Service = AddExtension<ServiceTy, ExtensionTy>;
81
82 fn layer(&self, inner: ServiceTy) -> Self::Service {
83 AddExtension {
84 value: self.0.clone(),
85 wrapped: inner,
86 }
87 }
88}
89
90impl<Ty: Clone + Send + Sync + 'static> Deref for Extension<Ty> {
91 type Target = Ty;
92
93 fn deref(&self) -> &Self::Target {
94 &self.0
95 }
96}
97
98impl<Ty: Clone + Send + Sync + 'static> DerefMut for Extension<Ty> {
99 fn deref_mut(&mut self) -> &mut Self::Target {
100 &mut self.0
101 }
102}
103
104impl<Ty> Debug for Extension<Ty>
105where
106 Ty: Debug,
107{
108 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
109 fmt.debug_struct("Extension")
110 .field("inner", &self.0)
111 .finish()
112 }
113}
114
115const EXTENSION_FIELDS: &[NamedField<'static>] = &[NamedField::new("inner")];
116
117impl<Ty> Structable for Extension<Ty>
118where
119 Ty: Valuable,
120{
121 fn definition(&self) -> StructDef<'_> {
122 StructDef::new_static("Extension", Fields::Named(EXTENSION_FIELDS))
123 }
124}
125
126impl<Ty> Valuable for Extension<Ty>
127where
128 Ty: Valuable,
129{
130 fn as_value(&self) -> Value<'_> {
131 Value::Structable(self)
132 }
133
134 fn visit(&self, visitor: &mut dyn Visit) {
135 visitor.visit_named_fields(&NamedValues::new(EXTENSION_FIELDS, &[self.0.as_value()]));
136 }
137}
138
139#[derive(Clone)]
143pub struct AddExtension<ServiceTy, ExtensionTy> {
144 value: ExtensionTy,
145 wrapped: ServiceTy,
146}
147
148impl<ServiceTy, ExtensionTy, State> Service<Request<State>> for AddExtension<ServiceTy, ExtensionTy>
149where
150 ServiceTy: Service<Request<State>>,
151 ExtensionTy: Clone + Send + Sync + 'static,
152 State: Clone + Send + Sync + 'static,
153{
154 type Response = ServiceTy::Response;
155 type Error = ServiceTy::Error;
156 type Future = ServiceTy::Future;
157
158 #[inline]
159 fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 self.wrapped.poll_ready(ctx)
161 }
162
163 fn call(&mut self, mut req: Request<State>) -> Self::Future {
164 req.extensions_mut().insert(self.value.clone());
165 self.wrapped.call(req)
166 }
167}
168
169impl<ServiceTy, ExtensionTy, State> Service<ResponseStreamEvent<State>>
170 for AddExtension<ServiceTy, ExtensionTy>
171where
172 ServiceTy: Service<ResponseStreamEvent<State>>,
173 ExtensionTy: Clone + Send + Sync + 'static,
174 State: Clone + Send + Sync + 'static,
175{
176 type Response = ServiceTy::Response;
177 type Error = ServiceTy::Error;
178 type Future = ServiceTy::Future;
179
180 #[inline]
181 fn poll_ready(&mut self, ctx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182 self.wrapped.poll_ready(ctx)
183 }
184
185 fn call(&mut self, mut evt: ResponseStreamEvent<State>) -> Self::Future {
186 evt.extensions_mut().insert(self.value.clone());
187 self.wrapped.call(evt)
188 }
189}
190
191#[cfg(test)]
192mod unit_tests {
193 use super::*;
194 use crate::net::server::{Router, test_helpers::router_body_no_close};
195 use std::sync::{Arc, atomic::AtomicU8};
196
197 #[tokio::test]
198 pub async fn test_extension() {
199 async fn echo_extension(
200 Extension(data): Extension<String>,
201 Extension(_unused): Extension<Arc<AtomicU8>>,
202 ) -> String {
203 data
204 }
205
206 let mut router = Router::new();
207 router
208 .add_route(&[0x1], echo_extension)
209 .expect("Failed to add route to router!");
210 router.layer(Extension::<String>("Hey from an extension!".to_owned()));
211 router.layer(Extension::<Arc<AtomicU8>>(Arc::new(AtomicU8::new(0))));
212 assert_eq!(
213 router_body_no_close(&mut router, &[0x1, 0x2, 0x3, 0x4]).await,
214 b"Hey from an extension!",
215 );
216 }
217}