1#[cfg(not(feature = "tokio"))]
43compile_error!("Only tokio runtime is supported, and it is required to use `tower-embed`.");
44
45use std::{
46 convert::Infallible,
47 marker::PhantomData,
48 pin::Pin,
49 sync::Arc,
50 task::{Context, Poll},
51};
52
53#[doc(inline)]
54pub use tower_embed_impl::Embed;
55
56#[doc(inline)]
57pub use tower_embed_core as core;
58
59#[doc(inline)]
60pub use tower_embed_core::{Embed, http::Body};
61
62#[doc(hidden)]
63pub mod file;
64
65pub struct ResponseFuture(ResponseFutureInner);
67
68type ResponseFutureInner =
69 Pin<Box<dyn Future<Output = Result<http::Response<Body>, Infallible>> + Send>>;
70
71impl ResponseFuture {
72 fn new<F>(future: F) -> Self
73 where
74 F: Future<Output = Result<http::Response<Body>, Infallible>> + Send + 'static,
75 {
76 ResponseFuture(Box::pin(future))
77 }
78}
79
80impl Future for ResponseFuture {
81 type Output = Result<http::Response<Body>, Infallible>;
82
83 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 self.0.as_mut().poll(cx)
85 }
86}
87
88pub struct ServeEmbed<E = ()> {
90 _embed: PhantomData<E>,
91 not_found_service: Option<NotFoundService>,
93}
94
95type NotFoundService =
96 tower::util::BoxCloneSyncService<http::Request<()>, http::Response<Body>, Infallible>;
97
98impl<E> Clone for ServeEmbed<E> {
99 fn clone(&self) -> Self {
100 Self {
101 _embed: PhantomData,
102 not_found_service: self.not_found_service.clone(),
103 }
104 }
105}
106
107impl<E: Embed> Default for ServeEmbed<E> {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl<E: Embed> ServeEmbed<E> {
114 pub fn new() -> Self {
116 ServeEmbedBuilder::new().build::<E>()
117 }
118}
119
120impl ServeEmbed<()> {
121 pub fn builder() -> ServeEmbedBuilder {
123 ServeEmbedBuilder::new()
124 }
125}
126
127impl<E, ReqBody> tower::Service<http::Request<ReqBody>> for ServeEmbed<E>
128where
129 E: Embed + Send + 'static,
130{
131 type Response = http::Response<Body>;
132 type Error = std::convert::Infallible;
133 type Future = ResponseFuture;
134
135 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
136 Poll::Ready(Ok(()))
137 }
138
139 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
140 let req = req.map(|_| ());
141 let not_found_service = self.not_found_service.clone();
142 ResponseFuture::new(async move {
143 let response =
144 if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
145 method_not_allowed()
146 } else {
147 let path = req.uri().path().trim_start_matches('/');
148 tracing::trace!("Serving embedded resource '{path}'");
149 handle_request(E::get(path), req, not_found_service).await
150 };
151 Ok(response)
152 })
153 }
154}
155
156#[derive(Default)]
158pub struct ServeEmbedBuilder {
159 not_found_service: Option<NotFoundService>,
160}
161
162impl ServeEmbedBuilder {
163 pub fn new() -> Self {
165 Self::default()
166 }
167
168 pub fn not_found_service<S>(mut self, service: S) -> Self
170 where
171 S: tower::Service<http::Request<()>, Response = http::Response<Body>, Error = Infallible>
172 + Send
173 + Sync
174 + Clone
175 + 'static,
176 S::Future: Send + 'static,
177 {
178 self.not_found_service = Some(tower::util::BoxCloneSyncService::new(service));
179 self
180 }
181
182 pub fn build<E: Embed>(self) -> ServeEmbed<E> {
184 ServeEmbed {
185 _embed: PhantomData,
186 not_found_service: self.not_found_service,
187 }
188 }
189}
190
191pub trait EmbedExt: Embed + Sized {
193 fn not_found_page(path: &str) -> NotFoundPage<Self> {
195 NotFoundPage::new(path.to_string())
196 }
197}
198
199impl<T> EmbedExt for T where T: Embed + Sized {}
200
201pub struct NotFoundPage<E>(Arc<NotFoundPageInner<E>>);
203
204impl<E> Clone for NotFoundPage<E> {
205 fn clone(&self) -> Self {
206 Self(Arc::clone(&self.0))
207 }
208}
209
210struct NotFoundPageInner<E> {
211 _embed: PhantomData<E>,
212 page: String,
213}
214
215impl<E> NotFoundPage<E> {
216 fn new(page: String) -> Self {
217 Self(Arc::new(NotFoundPageInner {
218 _embed: PhantomData,
219 page,
220 }))
221 }
222}
223
224impl<E> tower::Service<http::Request<()>> for NotFoundPage<E>
225where
226 E: Embed,
227{
228 type Response = http::Response<Body>;
229 type Error = Infallible;
230 type Future = ResponseFuture;
231
232 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233 Poll::Ready(Ok(()))
234 }
235
236 fn call(&mut self, req: http::Request<()>) -> Self::Future {
237 let embedded = E::get(&self.0.page);
238 ResponseFuture::new(async move { Ok(handle_request(embedded, req, None).await) })
239 }
240}
241
242async fn handle_request<F>(
243 embedded: F,
244 request: http::Request<()>,
245 not_found_service: Option<NotFoundService>,
246) -> http::Response<Body>
247where
248 F: Future<Output = std::io::Result<core::Embedded>> + Send,
249{
250 use core::headers::{self, HeaderMapExt};
251
252 let path = request.uri().path().trim_start_matches('/');
253 let core::Embedded { content, metadata } = match embedded.await {
254 Ok(embedded) => embedded,
255 Err(err)
256 if err.kind() == std::io::ErrorKind::NotFound
257 || err.kind() == std::io::ErrorKind::NotADirectory =>
258 {
259 tracing::trace!("Embedded resource not found: '{path}'");
260 return not_found_response(request, not_found_service).await;
261 }
262 Err(err) => {
263 tracing::error!("Failed to get embedded resource '{path}': {err}");
264 return server_error_response(err);
265 }
266 };
267
268 let if_none_match = request.headers().typed_get::<headers::IfNoneMatch>();
269 if let Some(if_none_match) = if_none_match
270 && let Some(etag) = &metadata.etag
271 && !if_none_match.condition_passes(etag)
272 {
273 tracing::trace!("ETag match for embedded resource '{path}'");
274 return not_modified_response();
275 }
276
277 let if_modified_since = request.headers().typed_get::<headers::IfModifiedSince>();
278 if let Some(if_modified_since) = if_modified_since
279 && let Some(last_modified) = &metadata.last_modified
280 && !if_modified_since.condition_passes(last_modified)
281 {
282 tracing::trace!("Last-Modified match for embedded resource '{path}'");
283 return not_modified_response();
284 }
285
286 let mut response = http::Response::builder()
287 .status(http::StatusCode::OK)
288 .body(Body::stream(content))
289 .unwrap();
290
291 response.headers_mut().typed_insert(metadata.content_type);
292 if let Some(etag) = metadata.etag {
293 response.headers_mut().typed_insert(etag);
294 }
295 if let Some(last_modified) = metadata.last_modified {
296 response.headers_mut().typed_insert(last_modified);
297 }
298
299 response
300}
301
302async fn not_found_response(
303 request: http::Request<()>,
304 mut not_found_service: Option<NotFoundService>,
305) -> http::Response<Body> {
306 use tower::ServiceExt;
307
308 let mut response = match not_found_service.take() {
309 Some(service) => {
310 let service = service.ready_oneshot().await.unwrap();
311 service.oneshot(request).await.unwrap()
312 }
313 None => http::Response::builder()
314 .status(http::StatusCode::NOT_FOUND)
315 .body(Body::empty())
316 .unwrap(),
317 };
318 response.headers_mut().insert(
319 http::header::CACHE_CONTROL,
320 http::HeaderValue::from_static("no-store"),
321 );
322 response
323}
324
325fn not_modified_response() -> http::Response<Body> {
326 http::Response::builder()
327 .status(http::StatusCode::NOT_MODIFIED)
328 .body(Body::empty())
329 .unwrap()
330}
331
332fn method_not_allowed() -> http::Response<Body> {
333 http::Response::builder()
334 .header(
335 http::header::ALLOW,
336 http::HeaderValue::from_static("GET, HEAD"),
337 )
338 .status(http::StatusCode::METHOD_NOT_ALLOWED)
339 .body(Body::empty())
340 .unwrap()
341}
342
343fn server_error_response(_err: std::io::Error) -> http::Response<Body> {
344 http::Response::builder()
345 .status(http::StatusCode::INTERNAL_SERVER_ERROR)
346 .header(http::header::CACHE_CONTROL, "no-store")
347 .body(Body::empty())
348 .unwrap()
349}