Skip to main content

dbrest_core/api_request/
payload.rs

1//! Request body (payload) parsing
2//!
3//! Handles HTTP request payload parsing and validation.
4//! Parses JSON, URL-encoded, and raw payloads based on content type and action.
5
6use bytes::Bytes;
7use compact_str::CompactString;
8use std::collections::HashSet;
9
10use crate::error::Error;
11use crate::types::media::MediaType;
12
13use super::query_params::QueryParams;
14use super::types::{Action, DbAction, InvokeMethod, Mutation, Payload};
15
16/// Parse the request body into a Payload based on content type and action.
17///
18/// Parse and validate the request payload.
19///
20/// Returns: `(Option<Payload>, columns)` where columns is the set of column names
21/// derived from either the payload keys or the &columns parameter.
22pub fn get_payload(
23    body: Bytes,
24    content_type: &MediaType,
25    query_params: &QueryParams,
26    action: &Action,
27) -> Result<(Option<Payload>, HashSet<CompactString>), Error> {
28    if !should_parse_payload(action) {
29        return Ok((None, HashSet::new()));
30    }
31
32    let is_proc = is_procedure(action);
33    let columns_param = &query_params.columns;
34
35    let payload = parse_payload(&body, content_type, is_proc, columns_param)?;
36
37    let cols = match (&payload, get_action_columns(action, &query_params.columns)) {
38        (Some(Payload::ProcessedJSON { keys, .. }), _) => keys.clone(),
39        (Some(Payload::ProcessedUrlEncoded { keys, .. }), _) => keys.clone(),
40        (Some(Payload::RawJSON(_)), Some(cls)) => cls.clone(),
41        _ => HashSet::new(),
42    };
43
44    Ok((payload, cols))
45}
46
47fn should_parse_payload(action: &Action) -> bool {
48    matches!(
49        action,
50        Action::Db(DbAction::RelationMut {
51            mutation: Mutation::MutationCreate
52                | Mutation::MutationUpdate
53                | Mutation::MutationSingleUpsert,
54            ..
55        }) | Action::Db(DbAction::Routine {
56            inv_method: InvokeMethod::Inv,
57            ..
58        })
59    )
60}
61
62fn is_procedure(action: &Action) -> bool {
63    matches!(action, Action::Db(DbAction::Routine { .. }))
64}
65
66fn get_action_columns<'a>(
67    action: &Action,
68    columns: &'a Option<HashSet<CompactString>>,
69) -> Option<&'a HashSet<CompactString>> {
70    match action {
71        Action::Db(DbAction::RelationMut {
72            mutation: Mutation::MutationCreate | Mutation::MutationUpdate,
73            ..
74        })
75        | Action::Db(DbAction::Routine {
76            inv_method: InvokeMethod::Inv,
77            ..
78        }) => columns.as_ref(),
79        _ => None,
80    }
81}
82
83fn parse_payload(
84    body: &Bytes,
85    content_type: &MediaType,
86    is_proc: bool,
87    columns_param: &Option<HashSet<CompactString>>,
88) -> Result<Option<Payload>, Error> {
89    match (content_type, is_proc) {
90        (MediaType::ApplicationJson, _) => {
91            if columns_param.is_some() {
92                // When &columns is specified, pass raw JSON through
93                Ok(Some(Payload::RawJSON(body.clone())))
94            } else {
95                parse_json_payload(body, is_proc)
96            }
97        }
98        (MediaType::ApplicationFormUrlEncoded, true) => {
99            // URL-encoded for RPC
100            let params: Vec<(CompactString, CompactString)> = form_urlencoded::parse(body)
101                .map(|(k, v)| {
102                    (
103                        CompactString::from(k.as_ref()),
104                        CompactString::from(v.as_ref()),
105                    )
106                })
107                .collect();
108            let keys: HashSet<CompactString> = params.iter().map(|(k, _)| k.clone()).collect();
109            Ok(Some(Payload::ProcessedUrlEncoded { params, keys }))
110        }
111        (MediaType::ApplicationFormUrlEncoded, false) => {
112            // URL-encoded for non-RPC: convert to JSON-like structure
113            let params: Vec<(CompactString, CompactString)> = form_urlencoded::parse(body)
114                .map(|(k, v)| {
115                    (
116                        CompactString::from(k.as_ref()),
117                        CompactString::from(v.as_ref()),
118                    )
119                })
120                .collect();
121            let keys: HashSet<CompactString> = params.iter().map(|(k, _)| k.clone()).collect();
122            // Build JSON from params
123            let json_map: serde_json::Map<String, serde_json::Value> = params
124                .iter()
125                .map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string())))
126                .collect();
127            let raw =
128                serde_json::to_vec(&json_map).map_err(|e| Error::InvalidBody(e.to_string()))?;
129            Ok(Some(Payload::ProcessedJSON {
130                raw: Bytes::from(raw),
131                keys,
132            }))
133        }
134        (MediaType::TextPlain, true)
135        | (MediaType::ApplicationXml, true)
136        | (MediaType::ApplicationOctetStream, true) => Ok(Some(Payload::RawPayload(body.clone()))),
137        (ct, _) => Err(Error::InvalidContentType(format!(
138            "Content-Type not acceptable: {}",
139            ct
140        ))),
141    }
142}
143
144fn parse_json_payload(body: &Bytes, is_proc: bool) -> Result<Option<Payload>, Error> {
145    if body.is_empty() && is_proc {
146        // Empty body for RPC is treated as empty object
147        let keys = HashSet::new();
148        return Ok(Some(Payload::ProcessedJSON {
149            raw: Bytes::from_static(b"{}"),
150            keys,
151        }));
152    }
153
154    if body.is_empty() {
155        return Err(Error::InvalidBody("Empty or invalid json".to_string()));
156    }
157
158    let parsed: serde_json::Value = serde_json::from_slice(body)
159        .map_err(|_| Error::InvalidBody("Empty or invalid json".to_string()))?;
160
161    match &parsed {
162        serde_json::Value::Array(arr) => {
163            if arr.is_empty() {
164                return Ok(Some(Payload::ProcessedJSON {
165                    raw: Bytes::from_static(b"[]"),
166                    keys: HashSet::new(),
167                }));
168            }
169
170            // Check that all objects have the same keys
171            if let Some(serde_json::Value::Object(first)) = arr.first() {
172                let canonical_keys: HashSet<CompactString> = first
173                    .keys()
174                    .map(|k| CompactString::from(k.as_str()))
175                    .collect();
176
177                let uniform = arr.iter().all(|item| {
178                    if let serde_json::Value::Object(obj) = item {
179                        let item_keys: HashSet<CompactString> = obj
180                            .keys()
181                            .map(|k| CompactString::from(k.as_str()))
182                            .collect();
183                        item_keys == canonical_keys
184                    } else {
185                        false
186                    }
187                });
188
189                if uniform {
190                    Ok(Some(Payload::ProcessedJSON {
191                        raw: body.clone(),
192                        keys: canonical_keys,
193                    }))
194                } else {
195                    Err(Error::InvalidBody("All object keys must match".to_string()))
196                }
197            } else {
198                Err(Error::InvalidBody("All object keys must match".to_string()))
199            }
200        }
201        serde_json::Value::Object(obj) => {
202            let keys: HashSet<CompactString> = obj
203                .keys()
204                .map(|k| CompactString::from(k.as_str()))
205                .collect();
206            Ok(Some(Payload::ProcessedJSON {
207                raw: body.clone(),
208                keys,
209            }))
210        }
211        _ => {
212            // Non-object, non-array: treat as empty array
213            Ok(Some(Payload::ProcessedJSON {
214                raw: Bytes::from_static(b"[]"),
215                keys: HashSet::new(),
216            }))
217        }
218    }
219}
220
221// ==========================================================================
222// Tests
223// ==========================================================================
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::types::identifiers::QualifiedIdentifier;
229
230    fn create_action() -> Action {
231        Action::Db(DbAction::RelationMut {
232            qi: QualifiedIdentifier::new("public", "items"),
233            mutation: Mutation::MutationCreate,
234        })
235    }
236
237    fn rpc_action() -> Action {
238        Action::Db(DbAction::Routine {
239            qi: QualifiedIdentifier::new("public", "my_func"),
240            inv_method: InvokeMethod::Inv,
241        })
242    }
243
244    fn read_action() -> Action {
245        Action::Db(DbAction::RelationRead {
246            qi: QualifiedIdentifier::new("public", "items"),
247            headers_only: false,
248        })
249    }
250
251    fn default_qp() -> QueryParams {
252        QueryParams::default()
253    }
254
255    #[test]
256    fn test_json_object_payload() {
257        let body = Bytes::from(r#"{"id":1,"name":"test"}"#);
258        let qp = default_qp();
259        let (payload, cols) =
260            get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
261        let payload = payload.unwrap();
262        assert_eq!(cols.len(), 2);
263        assert!(cols.contains("id"));
264        assert!(cols.contains("name"));
265        assert!(matches!(payload, Payload::ProcessedJSON { .. }));
266    }
267
268    #[test]
269    fn test_json_array_payload() {
270        let body = Bytes::from(r#"[{"id":1,"name":"a"},{"id":2,"name":"b"}]"#);
271        let qp = default_qp();
272        let (payload, cols) =
273            get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
274        let payload = payload.unwrap();
275        assert_eq!(cols.len(), 2);
276        assert!(matches!(payload, Payload::ProcessedJSON { .. }));
277    }
278
279    #[test]
280    fn test_json_array_non_uniform_keys() {
281        let body = Bytes::from(r#"[{"id":1},{"name":"b"}]"#);
282        let qp = default_qp();
283        let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
284        assert!(result.is_err());
285    }
286
287    #[test]
288    fn test_empty_json_for_rpc() {
289        let body = Bytes::new();
290        let qp = default_qp();
291        let (payload, _) =
292            get_payload(body, &MediaType::ApplicationJson, &qp, &rpc_action()).unwrap();
293        assert!(payload.is_some());
294    }
295
296    #[test]
297    fn test_empty_json_non_rpc_error() {
298        let body = Bytes::new();
299        let qp = default_qp();
300        let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
301        assert!(result.is_err());
302    }
303
304    #[test]
305    fn test_invalid_json() {
306        let body = Bytes::from("not json");
307        let qp = default_qp();
308        let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
309        assert!(result.is_err());
310    }
311
312    #[test]
313    fn test_url_encoded_rpc() {
314        let body = Bytes::from("id=1&name=test");
315        let qp = default_qp();
316        let (payload, cols) = get_payload(
317            body,
318            &MediaType::ApplicationFormUrlEncoded,
319            &qp,
320            &rpc_action(),
321        )
322        .unwrap();
323        let payload = payload.unwrap();
324        assert_eq!(cols.len(), 2);
325        assert!(matches!(payload, Payload::ProcessedUrlEncoded { .. }));
326    }
327
328    #[test]
329    fn test_url_encoded_non_rpc() {
330        let body = Bytes::from("id=1&name=test");
331        let qp = default_qp();
332        let (payload, cols) = get_payload(
333            body,
334            &MediaType::ApplicationFormUrlEncoded,
335            &qp,
336            &create_action(),
337        )
338        .unwrap();
339        let payload = payload.unwrap();
340        assert_eq!(cols.len(), 2);
341        assert!(matches!(payload, Payload::ProcessedJSON { .. }));
342    }
343
344    #[test]
345    fn test_raw_payload_rpc() {
346        let body = Bytes::from("raw text content");
347        let qp = default_qp();
348        let (payload, _) = get_payload(body, &MediaType::TextPlain, &qp, &rpc_action()).unwrap();
349        assert!(matches!(payload.unwrap(), Payload::RawPayload(_)));
350    }
351
352    #[test]
353    fn test_octet_stream_rpc() {
354        let body = Bytes::from(vec![0u8, 1, 2, 3]);
355        let qp = default_qp();
356        let (payload, _) =
357            get_payload(body, &MediaType::ApplicationOctetStream, &qp, &rpc_action()).unwrap();
358        assert!(matches!(payload.unwrap(), Payload::RawPayload(_)));
359    }
360
361    #[test]
362    fn test_unsupported_content_type() {
363        let body = Bytes::from("data");
364        let qp = default_qp();
365        let result = get_payload(body, &MediaType::TextCsv, &qp, &create_action());
366        assert!(result.is_err());
367    }
368
369    #[test]
370    fn test_no_payload_for_read() {
371        let body = Bytes::from("data");
372        let qp = default_qp();
373        let (payload, cols) =
374            get_payload(body, &MediaType::ApplicationJson, &qp, &read_action()).unwrap();
375        assert!(payload.is_none());
376        assert!(cols.is_empty());
377    }
378
379    #[test]
380    fn test_raw_json_with_columns() {
381        let body = Bytes::from(r#"{"id":1,"name":"test"}"#);
382        let mut qp = default_qp();
383        let mut cols_set = HashSet::new();
384        cols_set.insert(CompactString::from("id"));
385        cols_set.insert(CompactString::from("name"));
386        qp.columns = Some(cols_set.clone());
387
388        let (payload, cols) =
389            get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
390        assert!(matches!(payload.unwrap(), Payload::RawJSON(_)));
391        assert_eq!(cols.len(), 2);
392    }
393
394    #[test]
395    fn test_empty_json_array() {
396        let body = Bytes::from("[]");
397        let qp = default_qp();
398        let (payload, cols) =
399            get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
400        assert!(payload.is_some());
401        assert!(cols.is_empty());
402    }
403
404    #[test]
405    fn test_payload_keys() {
406        let body = Bytes::from(r#"{"a":1,"b":2,"c":3}"#);
407        let qp = default_qp();
408        let (payload, cols) =
409            get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
410        let payload = payload.unwrap();
411        assert_eq!(cols.len(), 3);
412        assert_eq!(payload.keys().len(), 3);
413    }
414}