1use super::frame::MessageKind;
8use serde_json::{Map as JsonMap, Value as JsonValue};
9use std::fmt;
10
11#[derive(Debug, Clone, PartialEq)]
12pub struct InsertDispatchPayload {
13 pub collection: String,
14 pub payload: Option<JsonValue>,
15 pub payloads: Option<Vec<JsonValue>>,
16 pub idempotency_key: Option<String>,
17 pub batch: bool,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct KeyPayload {
22 pub collection: String,
23 pub id: String,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct BulkOkPayload {
28 pub affected: u64,
29 pub rids: Vec<String>,
30 pub ids: Vec<String>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub enum OperationPayloadError {
35 InvalidJson { op: &'static str, message: String },
36 ExpectedObject { op: &'static str },
37 MissingCollection { op: &'static str },
38 MissingId { op: &'static str },
39 TruncatedBulkOkCount,
40}
41
42impl fmt::Display for OperationPayloadError {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 Self::InvalidJson { op, message } => write!(f, "{op}: invalid JSON: {message}"),
46 Self::ExpectedObject { op } => write!(f, "{op}: payload must be a JSON object"),
47 Self::MissingCollection { op } => {
48 write!(f, "{op}: missing 'collection' string")
49 }
50 Self::MissingId { op } => write!(f, "{op}: missing 'id' string"),
51 Self::TruncatedBulkOkCount => write!(f, "BulkOk truncated: expected 8-byte count"),
52 }
53 }
54}
55
56impl std::error::Error for OperationPayloadError {}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum OperationReplyError {
60 Engine(String),
61 UnexpectedKind {
62 expected: &'static str,
63 actual: MessageKind,
64 },
65}
66
67impl fmt::Display for OperationReplyError {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 match self {
70 Self::Engine(message) => write!(f, "{message}"),
71 Self::UnexpectedKind { expected, actual } => {
72 write!(f, "expected {expected}, got {actual:?}")
73 }
74 }
75 }
76}
77
78impl std::error::Error for OperationReplyError {}
79
80pub fn encode_insert_payload(collection: &str, payload: JsonValue) -> Vec<u8> {
81 let mut obj = JsonMap::new();
82 obj.insert(
83 "collection".into(),
84 JsonValue::String(collection.to_string()),
85 );
86 obj.insert("payload".into(), payload);
87 serde_json::to_vec(&JsonValue::Object(obj)).expect("insert payload JSON is serializable")
88}
89
90pub fn encode_bulk_insert_payload(collection: &str, payloads: Vec<JsonValue>) -> Vec<u8> {
91 let mut obj = JsonMap::new();
92 obj.insert(
93 "collection".into(),
94 JsonValue::String(collection.to_string()),
95 );
96 obj.insert("payloads".into(), JsonValue::Array(payloads));
97 serde_json::to_vec(&JsonValue::Object(obj)).expect("bulk insert payload JSON is serializable")
98}
99
100pub fn decode_insert_dispatch_payload(
101 bytes: &[u8],
102) -> Result<InsertDispatchPayload, OperationPayloadError> {
103 let obj = object_from_payload("Insert", bytes)?;
104 let collection = required_collection("Insert", &obj)?;
105 let payload = obj.get("payload").cloned();
106 let payloads = obj
107 .get("payloads")
108 .and_then(JsonValue::as_array)
109 .map(|items| items.to_vec());
110 let idempotency_key = obj
111 .get("idempotency_key")
112 .and_then(JsonValue::as_str)
113 .map(String::from);
114 let batch = obj
115 .get("batch")
116 .and_then(JsonValue::as_bool)
117 .unwrap_or(false);
118 Ok(InsertDispatchPayload {
119 collection,
120 payload,
121 payloads,
122 idempotency_key,
123 batch,
124 })
125}
126
127pub fn encode_key_payload(collection: &str, id: &str) -> Vec<u8> {
128 let mut obj = JsonMap::new();
129 obj.insert(
130 "collection".into(),
131 JsonValue::String(collection.to_string()),
132 );
133 obj.insert("id".into(), JsonValue::String(id.to_string()));
134 serde_json::to_vec(&JsonValue::Object(obj)).expect("key payload JSON is serializable")
135}
136
137pub fn decode_get_payload(bytes: &[u8]) -> Result<KeyPayload, OperationPayloadError> {
138 decode_key_payload("Get", bytes)
139}
140
141pub fn decode_delete_payload(bytes: &[u8]) -> Result<KeyPayload, OperationPayloadError> {
142 decode_key_payload("Delete", bytes)
143}
144
145pub fn encode_query_result_summary_payload(statement: &str, affected: u64) -> Vec<u8> {
146 let mut obj = JsonMap::new();
147 obj.insert("ok".into(), JsonValue::Bool(true));
148 obj.insert("statement".into(), JsonValue::String(statement.to_string()));
149 obj.insert("affected".into(), JsonValue::Number(affected.into()));
150 serde_json::to_vec(&JsonValue::Object(obj)).expect("query result payload JSON is serializable")
151}
152
153pub fn decode_query_result_payload(bytes: &[u8]) -> Result<JsonValue, OperationPayloadError> {
154 json_value_from_payload("QueryResult", bytes)
155}
156
157pub fn encode_get_result_payload(found: bool) -> Vec<u8> {
158 let mut obj = JsonMap::new();
159 obj.insert("ok".into(), JsonValue::Bool(true));
160 obj.insert("found".into(), JsonValue::Bool(found));
161 serde_json::to_vec(&JsonValue::Object(obj)).expect("get result payload JSON is serializable")
162}
163
164pub fn decode_get_result_payload(bytes: &[u8]) -> Result<JsonValue, OperationPayloadError> {
165 json_value_from_payload("GetResult", bytes)
166}
167
168pub fn decode_text_payload(bytes: &[u8]) -> String {
169 String::from_utf8_lossy(bytes).into_owned()
170}
171
172pub fn decode_error_payload(bytes: &[u8]) -> String {
173 decode_text_payload(bytes)
174}
175
176pub fn expect_result_or_error(
177 kind: MessageKind,
178 payload: &[u8],
179) -> Result<&[u8], OperationReplyError> {
180 expect_payload_or_error(kind, payload, MessageKind::Result, "Result/Error")
181}
182
183pub fn expect_bulk_ok_or_error(
184 kind: MessageKind,
185 payload: &[u8],
186) -> Result<&[u8], OperationReplyError> {
187 expect_payload_or_error(kind, payload, MessageKind::BulkOk, "BulkOk/Error")
188}
189
190pub fn expect_delete_ok_or_error(
191 kind: MessageKind,
192 payload: &[u8],
193) -> Result<&[u8], OperationReplyError> {
194 expect_payload_or_error(kind, payload, MessageKind::DeleteOk, "DeleteOk/Error")
195}
196
197pub fn expect_pong_reply(kind: MessageKind) -> Result<(), OperationReplyError> {
198 if kind == MessageKind::Pong {
199 Ok(())
200 } else {
201 Err(OperationReplyError::UnexpectedKind {
202 expected: "Pong",
203 actual: kind,
204 })
205 }
206}
207
208pub fn encode_bulk_ok_payload(affected: u64, ids: Vec<JsonValue>) -> Vec<u8> {
209 let mut obj = JsonMap::new();
210 obj.insert("affected".into(), JsonValue::Number(affected.into()));
211 obj.insert("ids".into(), JsonValue::Array(ids));
212 serde_json::to_vec(&JsonValue::Object(obj)).expect("bulk ok payload JSON is serializable")
213}
214
215pub fn encode_bulk_ok_payload_from_json_ids_bytes(affected: u64, ids: &[u8]) -> Vec<u8> {
216 let ids = match serde_json::from_slice::<JsonValue>(ids) {
217 Ok(JsonValue::Array(items)) => items,
218 _ => Vec::new(),
219 };
220 encode_bulk_ok_payload(affected, ids)
221}
222
223pub fn encode_bulk_ok_payload_from_json_id_literals<I, S>(affected: u64, ids: I) -> Vec<u8>
224where
225 I: IntoIterator<Item = S>,
226 S: AsRef<str>,
227{
228 let ids = ids
229 .into_iter()
230 .map(|id| {
231 serde_json::from_str::<JsonValue>(id.as_ref())
232 .unwrap_or_else(|_| JsonValue::String(id.as_ref().to_string()))
233 })
234 .collect();
235 encode_bulk_ok_payload(affected, ids)
236}
237
238pub fn decode_bulk_ok_payload(bytes: &[u8]) -> Result<BulkOkPayload, OperationPayloadError> {
239 let obj = object_from_payload("BulkOk", bytes)?;
240 let affected = obj.get("affected").and_then(JsonValue::as_u64).unwrap_or(0);
241 let rids: Vec<String> = obj
242 .get("rids")
243 .or_else(|| obj.get("ids"))
244 .and_then(JsonValue::as_array)
245 .map(|items| items.iter().filter_map(json_id_to_string).collect())
246 .unwrap_or_default();
247 let ids: Vec<String> = obj
248 .get("ids")
249 .and_then(JsonValue::as_array)
250 .map(|items| items.iter().filter_map(json_id_to_string).collect())
251 .unwrap_or_else(|| rids.clone());
252 Ok(BulkOkPayload {
253 affected,
254 rids,
255 ids,
256 })
257}
258
259pub fn encode_bulk_ok_count_payload(count: u64) -> Vec<u8> {
260 count.to_le_bytes().to_vec()
261}
262
263pub fn decode_bulk_ok_count_payload(bytes: &[u8]) -> Result<u64, OperationPayloadError> {
264 if bytes.len() < 8 {
265 return Err(OperationPayloadError::TruncatedBulkOkCount);
266 }
267 let mut count = [0u8; 8];
268 count.copy_from_slice(&bytes[..8]);
269 Ok(u64::from_le_bytes(count))
270}
271
272pub fn decode_delete_ok_affected(bytes: &[u8]) -> Result<u64, OperationPayloadError> {
273 let obj = object_from_payload("DeleteOk", bytes)?;
274 Ok(obj.get("affected").and_then(JsonValue::as_u64).unwrap_or(0))
275}
276
277pub fn encode_delete_ok_payload(affected: u64) -> Vec<u8> {
278 let mut obj = JsonMap::new();
279 obj.insert("affected".into(), JsonValue::Number(affected.into()));
280 serde_json::to_vec(&JsonValue::Object(obj)).expect("delete ok payload JSON is serializable")
281}
282
283fn expect_payload_or_error<'a>(
284 actual: MessageKind,
285 payload: &'a [u8],
286 ok: MessageKind,
287 expected: &'static str,
288) -> Result<&'a [u8], OperationReplyError> {
289 match actual {
290 kind if kind == ok => Ok(payload),
291 MessageKind::Error => Err(OperationReplyError::Engine(decode_error_payload(payload))),
292 other => Err(OperationReplyError::UnexpectedKind {
293 expected,
294 actual: other,
295 }),
296 }
297}
298
299fn decode_key_payload(op: &'static str, bytes: &[u8]) -> Result<KeyPayload, OperationPayloadError> {
300 let obj = object_from_payload(op, bytes)?;
301 let collection = required_collection(op, &obj)?;
302 let id = match obj.get("id").and_then(JsonValue::as_str) {
303 Some(value) if !value.is_empty() => value.to_string(),
304 _ => return Err(OperationPayloadError::MissingId { op }),
305 };
306 Ok(KeyPayload { collection, id })
307}
308
309fn json_value_from_payload(
310 op: &'static str,
311 bytes: &[u8],
312) -> Result<JsonValue, OperationPayloadError> {
313 let value: JsonValue =
314 serde_json::from_slice(bytes).map_err(|err| OperationPayloadError::InvalidJson {
315 op,
316 message: err.to_string(),
317 })?;
318 match value {
319 JsonValue::Object(_) => Ok(value),
320 _ => Err(OperationPayloadError::ExpectedObject { op }),
321 }
322}
323
324fn object_from_payload(
325 op: &'static str,
326 bytes: &[u8],
327) -> Result<JsonMap<String, JsonValue>, OperationPayloadError> {
328 let value: JsonValue =
329 serde_json::from_slice(bytes).map_err(|err| OperationPayloadError::InvalidJson {
330 op,
331 message: err.to_string(),
332 })?;
333 match value {
334 JsonValue::Object(obj) => Ok(obj),
335 _ => Err(OperationPayloadError::ExpectedObject { op }),
336 }
337}
338
339fn required_collection(
340 op: &'static str,
341 obj: &JsonMap<String, JsonValue>,
342) -> Result<String, OperationPayloadError> {
343 match obj.get("collection").and_then(JsonValue::as_str) {
344 Some(value) if !value.is_empty() => Ok(value.to_string()),
345 _ => Err(OperationPayloadError::MissingCollection { op }),
346 }
347}
348
349fn json_id_to_string(value: &JsonValue) -> Option<String> {
350 value
351 .as_str()
352 .map(String::from)
353 .or_else(|| value.as_u64().map(|n| n.to_string()))
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn insert_payload_round_trips_single_and_bulk_shapes() {
362 let single = decode_insert_dispatch_payload(&encode_insert_payload(
363 "users",
364 serde_json::json!({"name":"Ada"}),
365 ))
366 .unwrap();
367 assert_eq!(single.collection, "users");
368 assert_eq!(single.payload.unwrap(), serde_json::json!({"name":"Ada"}));
369 assert!(single.payloads.is_none());
370
371 let bulk = decode_insert_dispatch_payload(&encode_bulk_insert_payload(
372 "users",
373 vec![serde_json::json!({"name":"Ada"})],
374 ))
375 .unwrap();
376 assert_eq!(bulk.collection, "users");
377 assert_eq!(bulk.payloads.unwrap().len(), 1);
378 assert!(bulk.payload.is_none());
379 }
380
381 #[test]
382 fn key_payload_round_trips_get_and_delete_contracts() {
383 let bytes = encode_key_payload("users", "42");
384 assert_eq!(
385 decode_get_payload(&bytes).unwrap(),
386 KeyPayload {
387 collection: "users".into(),
388 id: "42".into(),
389 }
390 );
391 assert_eq!(
392 decode_delete_payload(&bytes).unwrap(),
393 KeyPayload {
394 collection: "users".into(),
395 id: "42".into(),
396 }
397 );
398 }
399
400 #[test]
401 fn bulk_ok_decodes_ids_and_affected_count() {
402 let payload = encode_bulk_ok_payload(2, vec![JsonValue::Number(1.into()), "2".into()]);
403 assert_eq!(
404 decode_bulk_ok_payload(&payload).unwrap(),
405 BulkOkPayload {
406 affected: 2,
407 rids: vec!["1".into(), "2".into()],
408 ids: vec!["1".into(), "2".into()],
409 }
410 );
411
412 let payload = encode_bulk_ok_payload_from_json_ids_bytes(2, br#"[1,"2"]"#);
413 assert_eq!(decode_bulk_ok_payload(&payload).unwrap().ids.len(), 2);
414
415 let payload = encode_bulk_ok_payload_from_json_id_literals(2, ["1", r#""2""#]);
416 assert_eq!(
417 decode_bulk_ok_payload(&payload).unwrap().ids,
418 vec!["1".to_string(), "2".to_string()]
419 );
420 }
421
422 #[test]
423 fn operation_reply_payloads_encode_wire_visible_json_contracts() {
424 let query =
425 decode_query_result_payload(&encode_query_result_summary_payload("INSERT", 3)).unwrap();
426 assert_eq!(query["ok"], JsonValue::Bool(true));
427 assert_eq!(query["statement"], JsonValue::String("INSERT".into()));
428 assert_eq!(query["affected"], JsonValue::Number(3.into()));
429
430 let get = decode_get_result_payload(&encode_get_result_payload(false)).unwrap();
431 assert_eq!(get["ok"], JsonValue::Bool(true));
432 assert_eq!(get["found"], JsonValue::Bool(false));
433 assert_eq!(decode_text_payload(b"raw result"), "raw result");
434 assert_eq!(decode_error_payload(b"engine failed"), "engine failed");
435 assert_eq!(
436 expect_result_or_error(MessageKind::Result, b"ok").unwrap(),
437 b"ok"
438 );
439 assert_eq!(
440 expect_bulk_ok_or_error(MessageKind::Error, b"failed").unwrap_err(),
441 OperationReplyError::Engine("failed".to_string())
442 );
443 assert_eq!(
444 expect_delete_ok_or_error(MessageKind::Pong, b"").unwrap_err(),
445 OperationReplyError::UnexpectedKind {
446 expected: "DeleteOk/Error",
447 actual: MessageKind::Pong
448 }
449 );
450 assert!(expect_pong_reply(MessageKind::Pong).is_ok());
451
452 assert_eq!(
453 decode_delete_ok_affected(&encode_delete_ok_payload(7)).unwrap(),
454 7
455 );
456 }
457
458 #[test]
459 fn bulk_ok_count_payload_round_trips_legacy_binary_shape() {
460 let payload = encode_bulk_ok_count_payload(42);
461 assert_eq!(payload.len(), 8);
462 assert_eq!(decode_bulk_ok_count_payload(&payload).unwrap(), 42);
463 assert_eq!(
464 decode_bulk_ok_count_payload(&payload[..7])
465 .unwrap_err()
466 .to_string(),
467 "BulkOk truncated: expected 8-byte count"
468 );
469 }
470}