wasm_runner_sdk/
extract.rs

1//! Request extractors for handler functions.
2//!
3//! Extractors allow you to extract data from requests in a type-safe way.
4//! They implement the `FromRequest` trait which is used by the handler system.
5
6use crate::request::Request;
7use crate::response::{IntoResponse, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use std::collections::HashMap;
10
11/// Trait for types that can be extracted from a request.
12///
13/// Implement this trait to create custom extractors.
14pub trait FromRequest: Sized {
15    /// The error type returned when extraction fails.
16    type Error: IntoResponse;
17
18    /// Attempts to extract this type from the request.
19    fn from_request(req: &mut Request) -> Result<Self, Self::Error>;
20}
21
22/// Error type for extraction failures.
23#[derive(Debug, Clone)]
24pub struct ExtractError {
25    status: StatusCode,
26    message: String,
27}
28
29impl ExtractError {
30    /// Creates a new extraction error.
31    pub fn new(status: StatusCode, message: impl Into<String>) -> Self {
32        Self {
33            status,
34            message: message.into(),
35        }
36    }
37
38    /// Creates a bad request error.
39    pub fn bad_request(message: impl Into<String>) -> Self {
40        Self::new(StatusCode::BAD_REQUEST, message)
41    }
42
43    /// Creates an unprocessable entity error.
44    pub fn unprocessable(message: impl Into<String>) -> Self {
45        Self::new(StatusCode::UNPROCESSABLE_ENTITY, message)
46    }
47}
48
49impl IntoResponse for ExtractError {
50    fn into_response(self) -> Response {
51        Response::new(self.status)
52            .with_content_type("application/json")
53            .with_body(format!(r#"{{"error":"{}"}}"#, self.message).into_bytes())
54    }
55}
56
57impl std::fmt::Display for ExtractError {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        write!(f, "{}: {}", self.status, self.message)
60    }
61}
62
63impl std::error::Error for ExtractError {}
64
65/// Extracts query parameters into a deserializable type.
66///
67/// # Example
68/// ```ignore
69/// #[derive(Deserialize)]
70/// struct Pagination {
71///     page: Option<u32>,
72///     limit: Option<u32>,
73/// }
74///
75/// fn handler(Query(params): Query<Pagination>) -> impl IntoResponse {
76///     format!("Page: {:?}, Limit: {:?}", params.page, params.limit)
77/// }
78/// ```
79#[derive(Debug, Clone)]
80pub struct Query<T>(pub T);
81
82impl<T: DeserializeOwned> FromRequest for Query<T> {
83    type Error = ExtractError;
84
85    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
86        let pairs = req.query_all();
87
88        // Build a simple object for serde to deserialize
89        // This handles both flat structures and optional fields
90        let query_string: String = pairs
91            .iter()
92            .map(|(k, v)| format!("{}={}", k, v))
93            .collect::<Vec<_>>()
94            .join("&");
95
96        serde_urlencoded::from_str(&query_string)
97            .map(Query)
98            .map_err(|e| ExtractError::bad_request(format!("Invalid query parameters: {}", e)))
99    }
100}
101
102/// Extracts path segments into a tuple or struct.
103///
104/// # Example
105/// ```ignore
106/// // For path "/users/123/posts/456"
107/// fn handler(Path((user_id, post_id)): Path<(String, String)>) -> impl IntoResponse {
108///     format!("User: {}, Post: {}", user_id, post_id)
109/// }
110/// ```
111#[derive(Debug, Clone)]
112pub struct Path<T>(pub T);
113
114// Implement for single value
115impl FromRequest for Path<String> {
116    type Error = ExtractError;
117
118    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
119        let segments = req.path_segments();
120        if segments.is_empty() {
121            return Err(ExtractError::bad_request("No path segments available"));
122        }
123        Ok(Path(segments[0].clone()))
124    }
125}
126
127// Implement for tuples up to 6 elements
128impl FromRequest for Path<(String,)> {
129    type Error = ExtractError;
130
131    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
132        let segments = req.path_segments();
133        if segments.is_empty() {
134            return Err(ExtractError::bad_request("Expected 1 path segment, got 0"));
135        }
136        Ok(Path((segments[0].clone(),)))
137    }
138}
139
140impl FromRequest for Path<(String, String)> {
141    type Error = ExtractError;
142
143    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
144        let segments = req.path_segments();
145        if segments.len() < 2 {
146            return Err(ExtractError::bad_request(format!(
147                "Expected 2 path segments, got {}",
148                segments.len()
149            )));
150        }
151        Ok(Path((segments[0].clone(), segments[1].clone())))
152    }
153}
154
155impl FromRequest for Path<(String, String, String)> {
156    type Error = ExtractError;
157
158    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
159        let segments = req.path_segments();
160        if segments.len() < 3 {
161            return Err(ExtractError::bad_request(format!(
162                "Expected 3 path segments, got {}",
163                segments.len()
164            )));
165        }
166        Ok(Path((
167            segments[0].clone(),
168            segments[1].clone(),
169            segments[2].clone(),
170        )))
171    }
172}
173
174impl FromRequest for Path<(String, String, String, String)> {
175    type Error = ExtractError;
176
177    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
178        let segments = req.path_segments();
179        if segments.len() < 4 {
180            return Err(ExtractError::bad_request(format!(
181                "Expected 4 path segments, got {}",
182                segments.len()
183            )));
184        }
185        Ok(Path((
186            segments[0].clone(),
187            segments[1].clone(),
188            segments[2].clone(),
189            segments[3].clone(),
190        )))
191    }
192}
193
194/// Extracts all path segments as a vector.
195#[derive(Debug, Clone)]
196pub struct PathSegments(pub Vec<String>);
197
198impl FromRequest for PathSegments {
199    type Error = ExtractError;
200
201    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
202        Ok(PathSegments(req.path_segments().to_vec()))
203    }
204}
205
206/// Extracts and deserializes JSON from the request body.
207///
208/// # Example
209/// ```ignore
210/// #[derive(Deserialize)]
211/// struct CreateUser {
212///     name: String,
213///     email: String,
214/// }
215///
216/// fn handler(Json(user): Json<CreateUser>) -> impl IntoResponse {
217///     format!("Creating user: {}", user.name)
218/// }
219/// ```
220#[derive(Debug, Clone)]
221pub struct Json<T>(pub T);
222
223impl<T: DeserializeOwned> FromRequest for Json<T> {
224    type Error = ExtractError;
225
226    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
227        let body = req.body();
228        serde_json::from_slice(body)
229            .map(Json)
230            .map_err(|e| ExtractError::unprocessable(format!("Invalid JSON: {}", e)))
231    }
232}
233
234/// Extracts the raw request body as bytes.
235#[derive(Debug, Clone)]
236pub struct Body(pub Vec<u8>);
237
238impl FromRequest for Body {
239    type Error = ExtractError;
240
241    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
242        Ok(Body(req.body().to_vec()))
243    }
244}
245
246/// Extracts the request body as a string.
247#[derive(Debug, Clone)]
248pub struct BodyString(pub String);
249
250impl FromRequest for BodyString {
251    type Error = ExtractError;
252
253    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
254        let body = req.body();
255        String::from_utf8(body.to_vec())
256            .map(BodyString)
257            .map_err(|e| ExtractError::bad_request(format!("Invalid UTF-8 body: {}", e)))
258    }
259}
260
261/// Extracts all headers as a HashMap.
262#[derive(Debug, Clone)]
263pub struct Headers(pub HashMap<String, String>);
264
265impl FromRequest for Headers {
266    type Error = ExtractError;
267
268    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
269        Ok(Headers(req.headers().clone()))
270    }
271}
272
273/// Extracts the HTTP method.
274#[derive(Debug, Clone)]
275pub struct MethodExtractor(pub crate::request::Method);
276
277impl FromRequest for MethodExtractor {
278    type Error = ExtractError;
279
280    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
281        Ok(MethodExtractor(req.method()))
282    }
283}
284
285/// Extracts the Content-Type header.
286#[derive(Debug, Clone)]
287pub struct ContentType(pub Option<String>);
288
289impl FromRequest for ContentType {
290    type Error = ExtractError;
291
292    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
293        Ok(ContentType(req.content_type().map(|s| s.to_string())))
294    }
295}
296
297/// Extracts the Authorization header.
298#[derive(Debug, Clone)]
299pub struct Authorization(pub Option<String>);
300
301impl FromRequest for Authorization {
302    type Error = ExtractError;
303
304    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
305        Ok(Authorization(req.header("authorization").map(|s| s.to_string())))
306    }
307}
308
309/// Extracts a Bearer token from the Authorization header.
310#[derive(Debug, Clone)]
311pub struct BearerToken(pub String);
312
313impl FromRequest for BearerToken {
314    type Error = ExtractError;
315
316    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
317        let auth = req
318            .header("authorization")
319            .ok_or_else(|| ExtractError::new(StatusCode::UNAUTHORIZED, "Missing Authorization header"))?;
320
321        if !auth.starts_with("Bearer ") {
322            return Err(ExtractError::new(
323                StatusCode::UNAUTHORIZED,
324                "Invalid Authorization header format, expected Bearer token",
325            ));
326        }
327
328        Ok(BearerToken(auth[7..].to_string()))
329    }
330}
331
332/// Extracts all cookies as a HashMap.
333#[derive(Debug, Clone)]
334pub struct Cookies(pub HashMap<String, String>);
335
336impl FromRequest for Cookies {
337    type Error = ExtractError;
338
339    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
340        Ok(Cookies(req.cookies().clone()))
341    }
342}
343
344/// Extracts the full request path.
345#[derive(Debug, Clone)]
346pub struct FullPath(pub String);
347
348impl FromRequest for FullPath {
349    type Error = ExtractError;
350
351    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
352        Ok(FullPath(req.path().to_string()))
353    }
354}
355
356/// Extracts all query parameters as a HashMap.
357#[derive(Debug, Clone)]
358pub struct QueryMap(pub HashMap<String, String>);
359
360impl FromRequest for QueryMap {
361    type Error = ExtractError;
362
363    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
364        Ok(QueryMap(req.query().clone()))
365    }
366}
367
368/// Optional extractor - wraps another extractor and returns None on failure.
369impl<T: FromRequest> FromRequest for Option<T> {
370    type Error = ExtractError;
371
372    fn from_request(req: &mut Request) -> Result<Self, Self::Error> {
373        Ok(T::from_request(req).ok())
374    }
375}