1use std::sync::Arc;
46
47use aws_sdk_s3::{
48 Client as S3Client,
49 error::SdkError,
50 operation::get_object::{
51 GetObjectError,
52 GetObjectOutput,
53 builders::GetObjectFluentBuilder
54 },
55};
56use axum::response::IntoResponse;
57use std::{
58 convert::Infallible,
59 future::Future,
60 pin::Pin,
61 task::{Context, Poll},
62};
63use tower_service::Service;
64
65
66#[cfg(feature = "trace")]
67#[allow(unused_imports)]
68use tracing::{info, error};
69#[cfg(feature = "trace")]
70#[allow(unused_imports)]
71use tracing::Instrument;
72
73#[allow(unused_macros)]
74#[cfg(not(feature = "trace"))]
75macro_rules! info {
77 ($($arg:tt)*) => {
78 ();
79 };
80}
81
82mod adapter;
83use adapter::TryStreamAdapater;
84
85mod builder;
86pub use builder::S3OriginBuilder;
87
88#[derive(Clone)]
89pub(crate) struct S3OriginInner {
90 bucket: String,
91 bucket_prefix: String,
92 s3_client: Arc<S3Client>,
93 prune_path: usize,
94 max_size: Option<i64>,
95}
96
97#[derive(Clone)]
98pub struct S3Origin {
99 inner: Arc<S3OriginInner>,
100}
101
102
103fn request_to_key(bucket_prefix: &str, uri_path: &str, prune_path: usize) -> String {
105 let request_path: String = match prune_path {
106 0 => uri_path.to_string(),
107 _ => uri_path.split('/').skip(prune_path).collect::<Vec<_>>().join("/"),
108 };
109
110 format!("{}{}", bucket_prefix, request_path.trim_start_matches('/'))
111}
112
113
114impl Service<axum::extract::Request> for S3Origin {
115 type Error = Infallible;
116 type Response = axum::response::Response<axum::body::Body>;
117 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static >>;
118
119 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 Poll::Ready(Ok(()))
122 }
123
124 fn call(&mut self, req: axum::extract::Request) -> Self::Future {
126 #[cfg(feature = "trace")]
127 tracing::info!("S3Origin: Serving request");
128
129 if req.method() != axum::http::Method::GET {
131 #[cfg(feature = "trace")]
132 tracing::info!("S3Origin: {} method not allowed", req.method());
133
134 return Box::pin(async move {
135 Ok(axum::response::Response::builder().status(axum::http::StatusCode::METHOD_NOT_ALLOWED).body(axum::body::Body::from("Method not allowed")).unwrap())
136 });
137 }
138
139 let this = self.inner.clone();
140 let path = req.uri().path();
141 let path = path.strip_prefix("/").unwrap_or(path);
142
143 let mut path = path.to_string();
144
145 if this.prune_path > 0 {
146 path = path.split('/').skip(this.prune_path).collect::<Vec<_>>().join("/");
147 }
148
149 let client = this.s3_client.clone();
150 let key = request_to_key(&this.bucket_prefix, &path, this.prune_path);
151
152 #[cfg(feature = "trace")]
153 {
154 let current_span = tracing::Span::current();
155 current_span.record("s3_url", &format!("s3://{}/{}", this.bucket, key));
156 }
157
158 let get_s3_fut = async move {
159 let builder = client.get_object()
160 .bucket(&this.bucket)
161 .key(&key);
162 let builder = make_request_builder(&req, builder);
163
164 let response;
165 #[cfg(feature = "trace")]
166 {
167 response = builder.send()
168 .instrument(
169 tracing::info_span!("s3_get_object", bucket = %this.bucket, key = %key)
170 ).await;
171 }
172 #[cfg(not(feature = "trace"))]
173 {
174 response = builder.send().await;
175 }
176
177 let rv = wrap_create_response(response, this.max_size)
178 .unwrap_or_else(|e| {
179 e.into_response()
180 });
181
182 Ok(rv)
183 };
184
185 Box::pin(get_s3_fut)
186 }
187}
188
189
190fn make_request_builder(request: &axum::extract::Request, mut builder: GetObjectFluentBuilder) -> GetObjectFluentBuilder {
191 if let Some(range) = request.headers().get(axum::http::header::RANGE) {
193 builder = builder.range(range.to_str().unwrap());
194 }
195 builder
196}
197
198
199fn wrap_create_response<E>(s3_response: Result<GetObjectOutput, SdkError<GetObjectError, E>>, max_size: Option<i64>) -> Result<axum::response::Response, S3Error> {
200 #[cfg(feature = "trace")]
201 {
202 tracing::debug!("S3Origin: Wrapping response: {}",
203 if s3_response.is_ok() { "OK".to_owned() } else { format!("Error: {}", s3_response.as_ref().err().unwrap().to_string()) }
204 );
205 }
206
207 let s3_response = s3_response.map_err(S3Error::from)?;
209
210 let content_type = s3_response.content_type().map(|ct| ct.to_owned());
212 let content_length = s3_response.content_length().map(|cl| cl.to_owned());
213
214 if let Some(max_size) = max_size {
215 if let Some(size) = content_length.as_ref() {
216 if size > &max_size {
217 return Err(S3Error::MaxSizeExceeded);
218 }
219 }
220 }
221
222 let body = TryStreamAdapater { stream: s3_response.body.into_async_read()};
223 let body = axum::body::Body::from_stream(body);
224 let mut response = axum::response::Response::builder()
225 .status(200)
226 .body(body)
227 .unwrap(); if let Some(content_type) = content_type {
231 response.headers_mut().insert(
232 axum::http::header::CONTENT_TYPE,
233 content_type
234 .parse()
235 .map_err(|_| S3Error::InternalServerError)?
236 );
237 } else {
238 response.headers_mut().insert(axum::http::header::CONTENT_TYPE, "application/octet-stream".parse().unwrap()); }
240 if let Some(content_length) = content_length {
242 response.headers_mut().insert(axum::http::header::CONTENT_LENGTH, content_length.to_string().parse().unwrap()); }
244
245 Ok(response)
246}
247
248
249impl<E> From<SdkError<GetObjectError, E>> for S3Error {
250 fn from(error: SdkError<GetObjectError, E>) -> Self {
251 match error {
252 SdkError::ServiceError(error) => {
253 if error.err().is_no_such_key() {
254 S3Error::NotFound
255 } else {
256 S3Error::BadGateway
257 }
258 }
259 _ => S3Error::InternalServerError,
260 }
261 }
262}
263
264impl axum::response::IntoResponse for S3Error {
265 fn into_response(self) -> axum::response::Response {
266 #[warn(unreachable_patterns)]
267 match self {
268 S3Error::NotFound => axum::response::Response::builder().status(axum::http::StatusCode::NOT_FOUND).body(axum::body::Body::from("Not found")).unwrap(),
269 S3Error::BadGateway => axum::response::Response::builder().status(axum::http::StatusCode::BAD_GATEWAY).body(axum::body::Body::from("Bad gateway")).unwrap(),
270 S3Error::InternalServerError => axum::response::Response::builder().status(axum::http::StatusCode::INTERNAL_SERVER_ERROR).body(axum::body::Body::from("Internal server error")).unwrap(),
271 S3Error::MaxSizeExceeded => axum::response::Response::builder().status(axum::http::StatusCode::PAYLOAD_TOO_LARGE).body(axum::body::Body::from("Requested file size exceeds the maximum allowed size")).unwrap(),
272 }
273 }
274}
275
276
277pub (crate) enum S3Error {
278 NotFound,
279 BadGateway,
280 InternalServerError,
281 MaxSizeExceeded,
282}
283
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[allow(dead_code)]
290 fn assert_clone<T: Clone>(_: &T) { }
291 #[allow(dead_code)]
292 fn assert_send<T: Send>(_: &T) { }
293 #[allow(dead_code)]
294 fn assert_sync<T: Sync>(_: &T) { }
295 #[allow(dead_code)]
296 fn assert_service<T,R: Service<axum::extract::Request>>(_: T) { }
297
298 #[test]
299 fn can_route_to_s3_origin() {
300 use axum::Router;
301 let origin = S3OriginBuilder::new()
302 .bucket("my-bucket")
303 .prefix("my-prefix")
304 .build()
305 .unwrap();
306
307 #[allow(dead_code, unused_must_use)]
308 let _app = Router::<()>::new().nest_service("/static", origin);
309 }
310
311 #[test]
312 fn test_nest_route_route() {
313 use axum::{Router, routing::get};
314 let subroute: Router<()> = Router::new().route("/", get(|| async { "Hello, world!" }));
315 let _app = Router::new().nest("/foo", subroute);
316 }
317
318}