axum_static_s3/
lib.rs

1//! A simple static file service for AWS S3 using Axum.
2//! 
3//! This will retrieve the files from S3 and serve them as responses using a tower Service.
4//! This is useful for serving static files from S3 in a serverless environment, when the
5//! static files are independent of the application.
6//! 
7//! This provides local fallback for development (local axum invocation) as well.
8//! 
9//! # Basic Usage
10//! 
11//! ```rust
12//! use axum::{Router, routing::get};
13//! use axum_static_s3::S3OriginBuilder;
14//! 
15//! 
16//! #[tokio::main]
17//! async fn main() {
18//!     // Build the S3 origin
19//!     let s3_origin = S3OriginBuilder::new()
20//!         .bucket("my-static-files-bucket")
21//!         .prefix("static/")
22//!         .prune_path(1)      // Remove the first request path component ()
23//!         .max_size(1024 * 1024 * 12) // 12MB
24//!         .build()
25//!         .expect("Failed to build S3 origin");
26//! 
27//!     // Create the router with the S3 static file handler
28//!     let app = Router::new()
29//!         .nest_service("/static/", s3_origin);
30//! 
31//!     // Start the server
32//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
33//!         .await
34//!         .unwrap();
35//!     axum::serve(listener, app).await.unwrap();
36//! }
37//! ```
38//! 
39//! # Features
40//! 
41//! - `trace`: Enable tracing of the S3 requests.
42//! 
43//! 
44//! 
45use std::sync::Arc;
46
47use aws_sdk_s3::{
48    Client as S3Client,
49    error::SdkError,
50    operation::get_object::{
51        GetObjectError, 
52        GetObjectOutput, 
53        builders::GetObjectFluentBuilder
54    },
55};
56use axum::response::IntoResponse;
57use std::{
58    convert::Infallible,
59    future::Future,
60    pin::Pin,
61    task::{Context, Poll},
62};
63use tower_service::Service;
64
65
66#[cfg(feature = "trace")]
67#[allow(unused_imports)]
68use tracing::{info, error};
69#[cfg(feature = "trace")]
70#[allow(unused_imports)]
71use tracing::Instrument;
72
73#[allow(unused_macros)]
74#[cfg(not(feature = "trace"))]
75// Convert to a no-op macro
76macro_rules! info {
77    ($($arg:tt)*) => {
78        ();
79    };
80}
81
82mod adapter;
83use adapter::TryStreamAdapater;
84
85mod builder;
86pub use builder::S3OriginBuilder;
87
88#[derive(Clone)]
89pub(crate) struct S3OriginInner {
90    bucket: String,
91    bucket_prefix: String,
92    s3_client: Arc<S3Client>,
93    prune_path: usize,
94    max_size: Option<i64>,
95}
96
97#[derive(Clone)]
98pub struct S3Origin {
99    inner: Arc<S3OriginInner>,
100}
101
102
103/// Takes a request and trims the paths and creates a new S3 key
104fn request_to_key(bucket_prefix: &str, uri_path: &str, prune_path: usize) -> String {
105    let request_path: String = match prune_path {
106        0 => uri_path.to_string(),
107        _ => uri_path.split('/').skip(prune_path).collect::<Vec<_>>().join("/"),
108    };
109
110    format!("{}{}", bucket_prefix, request_path.trim_start_matches('/'))
111}
112
113
114impl Service<axum::extract::Request> for S3Origin {
115    type Error = Infallible;
116    type Response = axum::response::Response<axum::body::Body>;
117    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static >>;
118
119    /// Always ready to serve; no backpressure.
120    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121        Poll::Ready(Ok(()))
122    }
123
124    /// Serve the request.
125    fn call(&mut self, req: axum::extract::Request) -> Self::Future {
126        #[cfg(feature = "trace")]
127        tracing::info!("S3Origin: Serving request");
128
129        // Only GET requests are supported
130        if req.method() != axum::http::Method::GET {
131            #[cfg(feature = "trace")]
132            tracing::info!("S3Origin: {} method not allowed", req.method());
133
134            return Box::pin(async move {
135                Ok(axum::response::Response::builder().status(axum::http::StatusCode::METHOD_NOT_ALLOWED).body(axum::body::Body::from("Method not allowed")).unwrap())
136            });
137        }
138
139        let this = self.inner.clone();
140        let path = req.uri().path();
141        let path = path.strip_prefix("/").unwrap_or(path);
142
143        let mut path = path.to_string();
144
145        if this.prune_path > 0 {
146            path = path.split('/').skip(this.prune_path).collect::<Vec<_>>().join("/");
147        }
148
149        let client = this.s3_client.clone();
150        let key = request_to_key(&this.bucket_prefix, &path, this.prune_path);
151
152        #[cfg(feature = "trace")]
153        {
154            let current_span = tracing::Span::current();
155            current_span.record("s3_url", &format!("s3://{}/{}", this.bucket, key));
156        }
157
158        let get_s3_fut = async move {
159            let builder = client.get_object()
160                .bucket(&this.bucket)
161                .key(&key);
162            let builder = make_request_builder(&req, builder);
163
164            let response;
165            #[cfg(feature = "trace")]
166            {
167                response = builder.send()
168                    .instrument(
169                        tracing::info_span!("s3_get_object", bucket = %this.bucket, key = %key)
170                    ).await;
171            }
172            #[cfg(not(feature = "trace"))]
173            {
174                response = builder.send().await;
175            }
176            
177            let rv = wrap_create_response(response, this.max_size)
178                .unwrap_or_else(|e| {
179                    e.into_response()
180            });
181
182            Ok(rv)
183        };
184
185        Box::pin(get_s3_fut)
186    }
187}
188
189
190fn make_request_builder(request: &axum::extract::Request, mut builder: GetObjectFluentBuilder) -> GetObjectFluentBuilder {
191    // Check if there is a range header
192    if let Some(range) = request.headers().get(axum::http::header::RANGE) {
193        builder = builder.range(range.to_str().unwrap());
194    }
195    builder
196}
197
198
199fn wrap_create_response<E>(s3_response: Result<GetObjectOutput, SdkError<GetObjectError, E>>, max_size: Option<i64>) -> Result<axum::response::Response, S3Error> {
200    #[cfg(feature = "trace")]
201    {
202        tracing::debug!("S3Origin: Wrapping response: {}",
203            if s3_response.is_ok() { "OK".to_owned() } else { format!("Error: {}", s3_response.as_ref().err().unwrap().to_string()) }
204        );
205    }
206
207    // Unwrap the response from S3, mapping to an S3Error if there is an error
208    let s3_response = s3_response.map_err(S3Error::from)?;
209
210    // Response was successful, so we can collect metadata
211    let content_type = s3_response.content_type().map(|ct| ct.to_owned());
212    let content_length = s3_response.content_length().map(|cl| cl.to_owned());
213
214    if let Some(max_size) = max_size {
215        if let Some(size) = content_length.as_ref() {
216            if size > &max_size {
217                return Err(S3Error::MaxSizeExceeded);
218            }
219        }
220    }
221
222    let body = TryStreamAdapater { stream: s3_response.body.into_async_read()};
223    let body = axum::body::Body::from_stream(body);
224    let mut response = axum::response::Response::builder()
225        .status(200)
226        .body(body)
227        .unwrap(); // Safe to unwrap because we know the response is Ok and no headers are set
228
229    // set Content-Type
230    if let Some(content_type) = content_type {
231        response.headers_mut().insert(
232            axum::http::header::CONTENT_TYPE,
233            content_type
234                .parse()
235                .map_err(|_| S3Error::InternalServerError)?
236                );
237    } else {
238        response.headers_mut().insert(axum::http::header::CONTENT_TYPE, "application/octet-stream".parse().unwrap());  // UNWRAP: Safe value
239    }
240    // set Content-Length
241    if let Some(content_length) = content_length {
242        response.headers_mut().insert(axum::http::header::CONTENT_LENGTH, content_length.to_string().parse().unwrap());  // UNWRAP: Safe value
243    }
244
245    Ok(response)
246}
247
248
249impl<E> From<SdkError<GetObjectError, E>> for S3Error {
250    fn from(error: SdkError<GetObjectError, E>) -> Self {
251        match error {
252            SdkError::ServiceError(error) => {
253                if error.err().is_no_such_key() {
254                    S3Error::NotFound
255                } else {
256                    S3Error::BadGateway
257                }
258            }
259            _ => S3Error::InternalServerError,
260        }
261    }
262}
263
264impl axum::response::IntoResponse for S3Error {
265    fn into_response(self) -> axum::response::Response {
266        #[warn(unreachable_patterns)]
267        match self {
268            S3Error::NotFound => axum::response::Response::builder().status(axum::http::StatusCode::NOT_FOUND).body(axum::body::Body::from("Not found")).unwrap(),
269            S3Error::BadGateway => axum::response::Response::builder().status(axum::http::StatusCode::BAD_GATEWAY).body(axum::body::Body::from("Bad gateway")).unwrap(),
270            S3Error::InternalServerError => axum::response::Response::builder().status(axum::http::StatusCode::INTERNAL_SERVER_ERROR).body(axum::body::Body::from("Internal server error")).unwrap(),
271            S3Error::MaxSizeExceeded => axum::response::Response::builder().status(axum::http::StatusCode::PAYLOAD_TOO_LARGE).body(axum::body::Body::from("Requested file size exceeds the maximum allowed size")).unwrap(),
272        }
273    }
274}
275
276
277pub (crate) enum S3Error {
278    NotFound,
279    BadGateway,
280    InternalServerError,
281    MaxSizeExceeded,
282}
283
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[allow(dead_code)]
290    fn assert_clone<T: Clone>(_: &T) { }
291    #[allow(dead_code)]
292    fn assert_send<T: Send>(_: &T) { }
293    #[allow(dead_code)]
294    fn assert_sync<T: Sync>(_: &T) { }
295    #[allow(dead_code)]
296    fn assert_service<T,R: Service<axum::extract::Request>>(_: T) { }
297
298    #[test]
299    fn can_route_to_s3_origin() {
300        use axum::Router;
301        let origin = S3OriginBuilder::new()
302            .bucket("my-bucket")
303            .prefix("my-prefix")
304            .build()
305            .unwrap();
306        
307        #[allow(dead_code, unused_must_use)]
308        let _app = Router::<()>::new().nest_service("/static", origin);
309    }
310
311    #[test]
312    fn test_nest_route_route() {
313        use axum::{Router, routing::get};
314        let subroute: Router<()> = Router::new().route("/", get(|| async { "Hello, world!" }));
315        let _app = Router::new().nest("/foo", subroute);
316    }
317
318}