1use bytes::Bytes;
7use compact_str::CompactString;
8use std::collections::HashSet;
9
10use crate::error::Error;
11use crate::types::media::MediaType;
12
13use super::query_params::QueryParams;
14use super::types::{Action, DbAction, InvokeMethod, Mutation, Payload};
15
16pub fn get_payload(
23 body: Bytes,
24 content_type: &MediaType,
25 query_params: &QueryParams,
26 action: &Action,
27) -> Result<(Option<Payload>, HashSet<CompactString>), Error> {
28 if !should_parse_payload(action) {
29 return Ok((None, HashSet::new()));
30 }
31
32 let is_proc = is_procedure(action);
33 let columns_param = &query_params.columns;
34
35 let payload = parse_payload(&body, content_type, is_proc, columns_param)?;
36
37 let cols = match (&payload, get_action_columns(action, &query_params.columns)) {
38 (Some(Payload::ProcessedJSON { keys, .. }), _) => keys.clone(),
39 (Some(Payload::ProcessedUrlEncoded { keys, .. }), _) => keys.clone(),
40 (Some(Payload::RawJSON(_)), Some(cls)) => cls.clone(),
41 _ => HashSet::new(),
42 };
43
44 Ok((payload, cols))
45}
46
47fn should_parse_payload(action: &Action) -> bool {
48 matches!(
49 action,
50 Action::Db(DbAction::RelationMut {
51 mutation: Mutation::MutationCreate
52 | Mutation::MutationUpdate
53 | Mutation::MutationSingleUpsert,
54 ..
55 }) | Action::Db(DbAction::Routine {
56 inv_method: InvokeMethod::Inv,
57 ..
58 })
59 )
60}
61
62fn is_procedure(action: &Action) -> bool {
63 matches!(action, Action::Db(DbAction::Routine { .. }))
64}
65
66fn get_action_columns<'a>(
67 action: &Action,
68 columns: &'a Option<HashSet<CompactString>>,
69) -> Option<&'a HashSet<CompactString>> {
70 match action {
71 Action::Db(DbAction::RelationMut {
72 mutation: Mutation::MutationCreate | Mutation::MutationUpdate,
73 ..
74 })
75 | Action::Db(DbAction::Routine {
76 inv_method: InvokeMethod::Inv,
77 ..
78 }) => columns.as_ref(),
79 _ => None,
80 }
81}
82
83fn parse_payload(
84 body: &Bytes,
85 content_type: &MediaType,
86 is_proc: bool,
87 columns_param: &Option<HashSet<CompactString>>,
88) -> Result<Option<Payload>, Error> {
89 match (content_type, is_proc) {
90 (MediaType::ApplicationJson, _) => {
91 if columns_param.is_some() {
92 Ok(Some(Payload::RawJSON(body.clone())))
94 } else {
95 parse_json_payload(body, is_proc)
96 }
97 }
98 (MediaType::ApplicationFormUrlEncoded, true) => {
99 let params: Vec<(CompactString, CompactString)> = form_urlencoded::parse(body)
101 .map(|(k, v)| {
102 (
103 CompactString::from(k.as_ref()),
104 CompactString::from(v.as_ref()),
105 )
106 })
107 .collect();
108 let keys: HashSet<CompactString> = params.iter().map(|(k, _)| k.clone()).collect();
109 Ok(Some(Payload::ProcessedUrlEncoded { params, keys }))
110 }
111 (MediaType::ApplicationFormUrlEncoded, false) => {
112 let params: Vec<(CompactString, CompactString)> = form_urlencoded::parse(body)
114 .map(|(k, v)| {
115 (
116 CompactString::from(k.as_ref()),
117 CompactString::from(v.as_ref()),
118 )
119 })
120 .collect();
121 let keys: HashSet<CompactString> = params.iter().map(|(k, _)| k.clone()).collect();
122 let json_map: serde_json::Map<String, serde_json::Value> = params
124 .iter()
125 .map(|(k, v)| (k.to_string(), serde_json::Value::String(v.to_string())))
126 .collect();
127 let raw =
128 serde_json::to_vec(&json_map).map_err(|e| Error::InvalidBody(e.to_string()))?;
129 Ok(Some(Payload::ProcessedJSON {
130 raw: Bytes::from(raw),
131 keys,
132 }))
133 }
134 (MediaType::TextPlain, true)
135 | (MediaType::ApplicationXml, true)
136 | (MediaType::ApplicationOctetStream, true) => Ok(Some(Payload::RawPayload(body.clone()))),
137 (ct, _) => Err(Error::InvalidContentType(format!(
138 "Content-Type not acceptable: {}",
139 ct
140 ))),
141 }
142}
143
144fn parse_json_payload(body: &Bytes, is_proc: bool) -> Result<Option<Payload>, Error> {
145 if body.is_empty() && is_proc {
146 let keys = HashSet::new();
148 return Ok(Some(Payload::ProcessedJSON {
149 raw: Bytes::from_static(b"{}"),
150 keys,
151 }));
152 }
153
154 if body.is_empty() {
155 return Err(Error::InvalidBody("Empty or invalid json".to_string()));
156 }
157
158 let parsed: serde_json::Value = serde_json::from_slice(body)
159 .map_err(|_| Error::InvalidBody("Empty or invalid json".to_string()))?;
160
161 match &parsed {
162 serde_json::Value::Array(arr) => {
163 if arr.is_empty() {
164 return Ok(Some(Payload::ProcessedJSON {
165 raw: Bytes::from_static(b"[]"),
166 keys: HashSet::new(),
167 }));
168 }
169
170 if let Some(serde_json::Value::Object(first)) = arr.first() {
172 let canonical_keys: HashSet<CompactString> = first
173 .keys()
174 .map(|k| CompactString::from(k.as_str()))
175 .collect();
176
177 let uniform = arr.iter().all(|item| {
178 if let serde_json::Value::Object(obj) = item {
179 let item_keys: HashSet<CompactString> = obj
180 .keys()
181 .map(|k| CompactString::from(k.as_str()))
182 .collect();
183 item_keys == canonical_keys
184 } else {
185 false
186 }
187 });
188
189 if uniform {
190 Ok(Some(Payload::ProcessedJSON {
191 raw: body.clone(),
192 keys: canonical_keys,
193 }))
194 } else {
195 Err(Error::InvalidBody("All object keys must match".to_string()))
196 }
197 } else {
198 Err(Error::InvalidBody("All object keys must match".to_string()))
199 }
200 }
201 serde_json::Value::Object(obj) => {
202 let keys: HashSet<CompactString> = obj
203 .keys()
204 .map(|k| CompactString::from(k.as_str()))
205 .collect();
206 Ok(Some(Payload::ProcessedJSON {
207 raw: body.clone(),
208 keys,
209 }))
210 }
211 _ => {
212 Ok(Some(Payload::ProcessedJSON {
214 raw: Bytes::from_static(b"[]"),
215 keys: HashSet::new(),
216 }))
217 }
218 }
219}
220
221#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::types::identifiers::QualifiedIdentifier;
229
230 fn create_action() -> Action {
231 Action::Db(DbAction::RelationMut {
232 qi: QualifiedIdentifier::new("public", "items"),
233 mutation: Mutation::MutationCreate,
234 })
235 }
236
237 fn rpc_action() -> Action {
238 Action::Db(DbAction::Routine {
239 qi: QualifiedIdentifier::new("public", "my_func"),
240 inv_method: InvokeMethod::Inv,
241 })
242 }
243
244 fn read_action() -> Action {
245 Action::Db(DbAction::RelationRead {
246 qi: QualifiedIdentifier::new("public", "items"),
247 headers_only: false,
248 })
249 }
250
251 fn default_qp() -> QueryParams {
252 QueryParams::default()
253 }
254
255 #[test]
256 fn test_json_object_payload() {
257 let body = Bytes::from(r#"{"id":1,"name":"test"}"#);
258 let qp = default_qp();
259 let (payload, cols) =
260 get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
261 let payload = payload.unwrap();
262 assert_eq!(cols.len(), 2);
263 assert!(cols.contains("id"));
264 assert!(cols.contains("name"));
265 assert!(matches!(payload, Payload::ProcessedJSON { .. }));
266 }
267
268 #[test]
269 fn test_json_array_payload() {
270 let body = Bytes::from(r#"[{"id":1,"name":"a"},{"id":2,"name":"b"}]"#);
271 let qp = default_qp();
272 let (payload, cols) =
273 get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
274 let payload = payload.unwrap();
275 assert_eq!(cols.len(), 2);
276 assert!(matches!(payload, Payload::ProcessedJSON { .. }));
277 }
278
279 #[test]
280 fn test_json_array_non_uniform_keys() {
281 let body = Bytes::from(r#"[{"id":1},{"name":"b"}]"#);
282 let qp = default_qp();
283 let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
284 assert!(result.is_err());
285 }
286
287 #[test]
288 fn test_empty_json_for_rpc() {
289 let body = Bytes::new();
290 let qp = default_qp();
291 let (payload, _) =
292 get_payload(body, &MediaType::ApplicationJson, &qp, &rpc_action()).unwrap();
293 assert!(payload.is_some());
294 }
295
296 #[test]
297 fn test_empty_json_non_rpc_error() {
298 let body = Bytes::new();
299 let qp = default_qp();
300 let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
301 assert!(result.is_err());
302 }
303
304 #[test]
305 fn test_invalid_json() {
306 let body = Bytes::from("not json");
307 let qp = default_qp();
308 let result = get_payload(body, &MediaType::ApplicationJson, &qp, &create_action());
309 assert!(result.is_err());
310 }
311
312 #[test]
313 fn test_url_encoded_rpc() {
314 let body = Bytes::from("id=1&name=test");
315 let qp = default_qp();
316 let (payload, cols) = get_payload(
317 body,
318 &MediaType::ApplicationFormUrlEncoded,
319 &qp,
320 &rpc_action(),
321 )
322 .unwrap();
323 let payload = payload.unwrap();
324 assert_eq!(cols.len(), 2);
325 assert!(matches!(payload, Payload::ProcessedUrlEncoded { .. }));
326 }
327
328 #[test]
329 fn test_url_encoded_non_rpc() {
330 let body = Bytes::from("id=1&name=test");
331 let qp = default_qp();
332 let (payload, cols) = get_payload(
333 body,
334 &MediaType::ApplicationFormUrlEncoded,
335 &qp,
336 &create_action(),
337 )
338 .unwrap();
339 let payload = payload.unwrap();
340 assert_eq!(cols.len(), 2);
341 assert!(matches!(payload, Payload::ProcessedJSON { .. }));
342 }
343
344 #[test]
345 fn test_raw_payload_rpc() {
346 let body = Bytes::from("raw text content");
347 let qp = default_qp();
348 let (payload, _) = get_payload(body, &MediaType::TextPlain, &qp, &rpc_action()).unwrap();
349 assert!(matches!(payload.unwrap(), Payload::RawPayload(_)));
350 }
351
352 #[test]
353 fn test_octet_stream_rpc() {
354 let body = Bytes::from(vec![0u8, 1, 2, 3]);
355 let qp = default_qp();
356 let (payload, _) =
357 get_payload(body, &MediaType::ApplicationOctetStream, &qp, &rpc_action()).unwrap();
358 assert!(matches!(payload.unwrap(), Payload::RawPayload(_)));
359 }
360
361 #[test]
362 fn test_unsupported_content_type() {
363 let body = Bytes::from("data");
364 let qp = default_qp();
365 let result = get_payload(body, &MediaType::TextCsv, &qp, &create_action());
366 assert!(result.is_err());
367 }
368
369 #[test]
370 fn test_no_payload_for_read() {
371 let body = Bytes::from("data");
372 let qp = default_qp();
373 let (payload, cols) =
374 get_payload(body, &MediaType::ApplicationJson, &qp, &read_action()).unwrap();
375 assert!(payload.is_none());
376 assert!(cols.is_empty());
377 }
378
379 #[test]
380 fn test_raw_json_with_columns() {
381 let body = Bytes::from(r#"{"id":1,"name":"test"}"#);
382 let mut qp = default_qp();
383 let mut cols_set = HashSet::new();
384 cols_set.insert(CompactString::from("id"));
385 cols_set.insert(CompactString::from("name"));
386 qp.columns = Some(cols_set.clone());
387
388 let (payload, cols) =
389 get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
390 assert!(matches!(payload.unwrap(), Payload::RawJSON(_)));
391 assert_eq!(cols.len(), 2);
392 }
393
394 #[test]
395 fn test_empty_json_array() {
396 let body = Bytes::from("[]");
397 let qp = default_qp();
398 let (payload, cols) =
399 get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
400 assert!(payload.is_some());
401 assert!(cols.is_empty());
402 }
403
404 #[test]
405 fn test_payload_keys() {
406 let body = Bytes::from(r#"{"a":1,"b":2,"c":3}"#);
407 let qp = default_qp();
408 let (payload, cols) =
409 get_payload(body, &MediaType::ApplicationJson, &qp, &create_action()).unwrap();
410 let payload = payload.unwrap();
411 assert_eq!(cols.len(), 3);
412 assert_eq!(payload.keys().len(), 3);
413 }
414}