1use std::marker::PhantomData;
2
3use motore::{
4 layer::{Identity, Layer, Stack},
5 service::Service,
6};
7
8use super::NamedService;
9use crate::{
10 Request, Response, Status,
11 body::{Body, BoxBody, boxed},
12 codec::{
13 compression::{CompressionEncoding, ENCODING_HEADER},
14 decode::Kind,
15 },
16 context::{Config, ServerContext},
17 message::{RecvEntryMessage, SendEntryMessage},
18 metadata::MetadataValue,
19};
20
21#[derive(Clone)]
22pub struct ServiceBuilder<S, L> {
23 service: S,
24 layer: L,
25 rpc_config: Config,
26}
27
28impl<S> ServiceBuilder<S, Identity> {
29 pub fn new(service: S) -> Self {
30 Self {
31 service,
32 layer: Identity::new(),
33 rpc_config: Config::default(),
34 }
35 }
36}
37
38impl<S, L> ServiceBuilder<S, L> {
39 pub fn send_compressions(mut self, config: Vec<CompressionEncoding>) -> Self {
44 self.rpc_config.send_compressions = Some(config);
45 self
46 }
47
48 pub fn accept_compressions(mut self, config: Vec<CompressionEncoding>) -> Self {
53 self.rpc_config.accept_compressions = Some(config);
54 self
55 }
56
57 pub fn layer<O>(self, layer: O) -> ServiceBuilder<S, Stack<O, L>> {
58 ServiceBuilder {
59 layer: Stack::new(layer, self.layer),
60 service: self.service,
61 rpc_config: self.rpc_config,
62 }
63 }
64
65 pub fn layer_front<Front>(self, layer: Front) -> ServiceBuilder<S, Stack<L, Front>> {
66 ServiceBuilder {
67 layer: Stack::new(self.layer, layer),
68 service: self.service,
69 rpc_config: self.rpc_config,
70 }
71 }
72
73 pub fn build<T, U>(self) -> CodecService<<L as volo::Layer<S>>::Service, T, U>
74 where
75 L: Layer<S>,
76 {
77 let service = motore::builder::ServiceBuilder::new()
78 .layer(self.layer)
79 .service(self.service);
80
81 CodecService::new(service, self.rpc_config)
82 }
83}
84
85pub struct CodecService<S, T, U> {
86 inner: S,
87 rpc_config: Config,
88 _marker: PhantomData<fn(T, U)>,
89}
90
91impl<S, T, U> Clone for CodecService<S, T, U>
92where
93 S: Clone,
94{
95 fn clone(&self) -> Self {
96 Self {
97 inner: self.inner.clone(),
98 rpc_config: self.rpc_config.clone(),
99 _marker: PhantomData,
100 }
101 }
102}
103
104impl<S, T, U> CodecService<S, T, U> {
105 pub fn new(inner: S, rpc_config: Config) -> Self {
106 Self {
107 inner,
108 rpc_config,
109 _marker: PhantomData,
110 }
111 }
112}
113
114impl<S, T, U> Service<ServerContext, Request<BoxBody>> for CodecService<S, T, U>
115where
116 S: Service<ServerContext, Request<T>, Response = Response<U>> + Sync,
117 S::Error: Into<Status>,
118 T: RecvEntryMessage,
119 U: SendEntryMessage,
120{
121 type Response = Response<BoxBody>;
122 type Error = Status;
123
124 async fn call(
125 &self,
126 cx: &mut ServerContext,
127 req: Request<BoxBody>,
128 ) -> Result<Self::Response, Self::Error> {
129 let (metadata, extensions, body) = req.into_parts();
130 #[cfg(not(feature = "compress"))]
131 let send_compression = None;
132 #[cfg(feature = "compress")]
133 let send_compression = CompressionEncoding::from_accept_encoding_header(
134 metadata.headers(),
135 &self.rpc_config.send_compressions,
136 );
137
138 #[cfg(not(feature = "compress"))]
139 let recv_compression = None;
140 #[cfg(feature = "compress")]
141 let recv_compression = CompressionEncoding::from_encoding_header(
142 metadata.headers(),
143 &self.rpc_config.accept_compressions,
144 )?;
145
146 let message = T::from_body(
147 Some(cx.rpc_info.method().as_str()),
148 body,
149 Kind::Request,
150 recv_compression,
151 )?;
152
153 let volo_req = Request::from_parts(metadata, extensions, message);
154
155 cx.stats.record_process_start_at();
156
157 let volo_resp = self.inner.call(cx, volo_req).await.map_err(Into::into)?;
158
159 cx.stats.record_process_end_at();
160
161 let mut resp =
162 volo_resp.map(|message| boxed(Body::new(message.into_body(send_compression))));
163
164 if let Some(encoding) = send_compression {
165 resp.metadata_mut().insert(
166 ENCODING_HEADER,
167 MetadataValue::unchecked_from_header_value(encoding.into_header_value()),
168 );
169 };
170
171 Ok(resp)
172 }
173}
174
175impl<S: NamedService, T, U> NamedService for CodecService<S, T, U> {
176 const NAME: &'static str = S::NAME;
177}