1use std::collections::BTreeMap;
17use std::io::Write;
18
19use anyhow::{anyhow, bail, Context, Result};
20use chrono::{DateTime, TimeZone, Utc};
21use edn::symbols::Keyword;
22use rmpv::Value;
23use uuid::Uuid;
24
25use crate::ops::{DataType, EntityRef, QueryArg, TxOp};
26use crate::protocol::ColumnDescription;
27use crate::transaction::TxKey;
28
29pub const EXT_TIMESTAMP: i8 = -1;
34pub const EXT_BIGINT: i8 = 1;
35pub const EXT_UUID: i8 = 2;
36pub const EXT_KEYWORD: i8 = 3;
37
38fn keyword_to_wire(kw: &Keyword) -> String {
43 match kw.namespace() {
44 Some(ns) => format!("{}/{}", ns, kw.name()),
45 None => kw.name().to_string(),
46 }
47}
48
49fn keyword_from_wire(s: &str) -> Result<Keyword> {
50 match s.split_once('/') {
51 Some((ns, name)) if !ns.is_empty() && !name.is_empty() => Ok(Keyword::namespaced(ns, name)),
52 Some(_) => bail!("invalid keyword wire format: {:?}", s),
53 None if s.is_empty() => bail!("empty keyword"),
54 None => Ok(Keyword::plain(s)),
55 }
56}
57
58fn write_ext<W: Write>(w: &mut W, ty: i8, payload: &[u8]) -> Result<()> {
63 rmp::encode::write_ext_meta(w, payload.len() as u32, ty)?;
64 w.write_all(payload)?;
65 Ok(())
66}
67
68fn write_timestamp<W: Write>(w: &mut W, dt: &DateTime<Utc>) -> Result<()> {
69 let secs = dt.timestamp();
70 let nanos = dt.timestamp_subsec_nanos();
71 if nanos == 0 && secs >= 0 && secs <= u32::MAX as i64 {
72 let bytes = (secs as u32).to_be_bytes();
73 write_ext(w, EXT_TIMESTAMP, &bytes)
74 } else if secs >= 0 && (secs as u64) < (1u64 << 34) {
75 let data: u64 = ((nanos as u64) << 34) | (secs as u64);
76 let bytes = data.to_be_bytes();
77 write_ext(w, EXT_TIMESTAMP, &bytes)
78 } else {
79 let mut bytes = [0u8; 12];
80 bytes[..4].copy_from_slice(&nanos.to_be_bytes());
81 bytes[4..].copy_from_slice(&secs.to_be_bytes());
82 write_ext(w, EXT_TIMESTAMP, &bytes)
83 }
84}
85
86fn read_timestamp(payload: &[u8]) -> Result<DateTime<Utc>> {
87 let (secs, nanos): (i64, u32) = match payload.len() {
88 4 => {
89 let secs = u32::from_be_bytes(payload.try_into().unwrap());
90 (secs as i64, 0)
91 }
92 8 => {
93 let data = u64::from_be_bytes(payload.try_into().unwrap());
94 let nanos = (data >> 34) as u32;
95 let secs = (data & 0x0003_ffff_ffff) as i64;
96 (secs, nanos)
97 }
98 12 => {
99 let nanos = u32::from_be_bytes(payload[..4].try_into().unwrap());
100 let secs = i64::from_be_bytes(payload[4..].try_into().unwrap());
101 (secs, nanos)
102 }
103 n => bail!("invalid msgpack Timestamp payload length {n}"),
104 };
105 Utc.timestamp_opt(secs, nanos)
106 .single()
107 .context("invalid timestamp value")
108}
109
110fn read_bigint(payload: &[u8]) -> Result<i128> {
111 let bytes: [u8; 16] = payload
112 .try_into()
113 .map_err(|_| anyhow!("BigInt ext payload must be 16 bytes"))?;
114 Ok(i128::from_be_bytes(bytes))
115}
116
117fn read_uuid(payload: &[u8]) -> Result<Uuid> {
118 let bytes: [u8; 16] = payload
119 .try_into()
120 .map_err(|_| anyhow!("Uuid ext payload must be 16 bytes"))?;
121 Ok(Uuid::from_bytes(bytes))
122}
123
124fn read_keyword(payload: &[u8]) -> Result<Keyword> {
125 let s = std::str::from_utf8(payload).context("keyword payload is not valid UTF-8")?;
126 keyword_from_wire(s)
127}
128
129pub fn write_data_type<W: Write>(w: &mut W, dt: &DataType) -> Result<()> {
134 match dt {
135 DataType::Boolean(v) => {
136 rmp::encode::write_bool(w, *v)?;
137 }
138 DataType::Long(v) => {
139 rmp::encode::write_sint(w, *v)?;
140 }
141 DataType::Float(v) => {
142 rmp::encode::write_f32(w, *v)?;
143 }
144 DataType::Double(v) => {
145 rmp::encode::write_f64(w, *v)?;
146 }
147 DataType::String(v) => {
148 rmp::encode::write_str(w, v)?;
149 }
150 DataType::Bytes(v) => {
151 rmp::encode::write_bin(w, v)?;
152 }
153 DataType::Vector(v) => {
154 rmp::encode::write_array_len(w, v.len() as u32)?;
155 for item in v {
156 write_data_type(w, item)?;
157 }
158 }
159 DataType::Map(m) => {
160 rmp::encode::write_map_len(w, m.len() as u32)?;
161 for (k, v) in m {
162 rmp::encode::write_str(w, k)?;
163 write_data_type(w, v)?;
164 }
165 }
166 DataType::BigInt(v) => write_ext(w, EXT_BIGINT, &v.to_be_bytes())?,
167 DataType::Uuid(v) => write_ext(w, EXT_UUID, v.as_bytes())?,
168 DataType::Keyword(v) => {
169 let s = keyword_to_wire(v);
170 write_ext(w, EXT_KEYWORD, s.as_bytes())?;
171 }
172 DataType::Instant(v) => write_timestamp(w, v)?,
173 }
174 Ok(())
175}
176
177pub fn data_type_from_value(v: Value) -> Result<DataType> {
178 match v {
179 Value::Boolean(b) => Ok(DataType::Boolean(b)),
180 Value::Integer(n) => n
181 .as_i64()
182 .map(DataType::Long)
183 .ok_or_else(|| anyhow!("integer out of i64 range: {n}")),
184 Value::F32(f) => Ok(DataType::Float(f)),
185 Value::F64(f) => Ok(DataType::Double(f)),
186 Value::String(s) => s
187 .into_str()
188 .map(DataType::String)
189 .ok_or_else(|| anyhow!("string is not valid UTF-8")),
190 Value::Binary(b) => Ok(DataType::Bytes(b)),
191 Value::Array(items) => {
192 let mut out = Vec::with_capacity(items.len());
193 for item in items {
194 out.push(data_type_from_value(item)?);
195 }
196 Ok(DataType::Vector(out))
197 }
198 Value::Map(entries) => {
199 let mut m = BTreeMap::new();
200 for (k, v) in entries {
201 let key = match k {
202 Value::String(s) => s
203 .into_str()
204 .ok_or_else(|| anyhow!("map key is not valid UTF-8"))?,
205 other => bail!("DataType::Map keys must be strings, got {other:?}"),
206 };
207 m.insert(key, data_type_from_value(v)?);
208 }
209 Ok(DataType::Map(m))
210 }
211 Value::Ext(ty, payload) => match ty {
212 EXT_TIMESTAMP => Ok(DataType::Instant(read_timestamp(&payload)?)),
213 EXT_BIGINT => Ok(DataType::BigInt(read_bigint(&payload)?)),
214 EXT_UUID => Ok(DataType::Uuid(read_uuid(&payload)?)),
215 EXT_KEYWORD => Ok(DataType::Keyword(read_keyword(&payload)?)),
216 _ => bail!("unknown msgpack ext type {ty}"),
217 },
218 Value::Nil => bail!("DataType cannot be nil"),
219 }
220}
221
222pub fn read_data_type(buf: &[u8]) -> Result<(DataType, &[u8])> {
223 let mut cursor = buf;
224 let value =
225 rmpv::decode::read_value(&mut cursor).map_err(|e| anyhow!("msgpack decode error: {e}"))?;
226 let dt = data_type_from_value(value)?;
227 Ok((dt, cursor))
228}
229
230fn map_from_value(v: Value) -> Result<BTreeMap<String, Value>> {
235 let entries = match v {
236 Value::Map(entries) => entries,
237 other => bail!("expected map, got {other:?}"),
238 };
239 let mut out = BTreeMap::new();
240 for (k, v) in entries {
241 let key = match k {
242 Value::String(s) => s
243 .into_str()
244 .ok_or_else(|| anyhow!("map key is not valid UTF-8"))?,
245 other => bail!("map key must be string, got {other:?}"),
246 };
247 out.insert(key, v);
248 }
249 Ok(out)
250}
251
252fn take_field(map: &mut BTreeMap<String, Value>, name: &str) -> Result<Value> {
253 map.remove(name)
254 .ok_or_else(|| anyhow!("missing field {name:?}"))
255}
256
257fn take_string(map: &mut BTreeMap<String, Value>, name: &str) -> Result<String> {
258 match take_field(map, name)? {
259 Value::String(s) => s
260 .into_str()
261 .ok_or_else(|| anyhow!("field {name:?} is not valid UTF-8")),
262 other => bail!("field {name:?} expected string, got {other:?}"),
263 }
264}
265
266fn take_i64(map: &mut BTreeMap<String, Value>, name: &str) -> Result<i64> {
267 match take_field(map, name)? {
268 Value::Integer(n) => n
269 .as_i64()
270 .ok_or_else(|| anyhow!("field {name:?} integer out of i64 range")),
271 other => bail!("field {name:?} expected integer, got {other:?}"),
272 }
273}
274
275fn take_data_type(map: &mut BTreeMap<String, Value>, name: &str) -> Result<DataType> {
276 data_type_from_value(take_field(map, name)?)
277}
278
279fn write_str_field<W: Write>(w: &mut W, name: &str, value: &str) -> Result<()> {
280 rmp::encode::write_str(w, name)?;
281 rmp::encode::write_str(w, value)?;
282 Ok(())
283}
284
285pub fn write_entity_ref<W: Write>(w: &mut W, er: &EntityRef) -> Result<()> {
290 match er {
291 EntityRef::Id(id) => {
292 rmp::encode::write_map_len(w, 2)?;
293 write_str_field(w, "kind", "id")?;
294 rmp::encode::write_str(w, "id")?;
295 rmp::encode::write_sint(w, *id)?;
296 }
297 EntityRef::TempId(s) => {
298 rmp::encode::write_map_len(w, 2)?;
299 write_str_field(w, "kind", "temp")?;
300 write_str_field(w, "temp", s)?;
301 }
302 EntityRef::Ident(kw) => {
303 rmp::encode::write_map_len(w, 2)?;
304 write_str_field(w, "kind", "ident")?;
305 write_str_field(w, "ident", &keyword_to_wire(kw))?;
306 }
307 EntityRef::LookupRef(attr, value) => {
308 rmp::encode::write_map_len(w, 3)?;
309 write_str_field(w, "kind", "lookup")?;
310 write_str_field(w, "attr", &keyword_to_wire(attr))?;
311 rmp::encode::write_str(w, "value")?;
312 write_data_type(w, value)?;
313 }
314 }
315 Ok(())
316}
317
318pub fn entity_ref_from_value(v: Value) -> Result<EntityRef> {
319 let mut map = map_from_value(v)?;
320 let kind = take_string(&mut map, "kind")?;
321 match kind.as_str() {
322 "id" => Ok(EntityRef::Id(take_i64(&mut map, "id")?)),
323 "temp" => Ok(EntityRef::TempId(take_string(&mut map, "temp")?)),
324 "ident" => Ok(EntityRef::Ident(keyword_from_wire(&take_string(
325 &mut map, "ident",
326 )?)?)),
327 "lookup" => {
328 let attr = keyword_from_wire(&take_string(&mut map, "attr")?)?;
329 let value = take_data_type(&mut map, "value")?;
330 Ok(EntityRef::LookupRef(attr, value))
331 }
332 other => bail!("unknown EntityRef kind: {other:?}"),
333 }
334}
335
336pub fn write_tx_op<W: Write>(w: &mut W, op: &TxOp) -> Result<()> {
341 match op {
342 TxOp::Put(doc) => {
343 rmp::encode::write_map_len(w, 2)?;
344 write_str_field(w, "kind", "put")?;
345 rmp::encode::write_str(w, "doc")?;
346 rmp::encode::write_map_len(w, doc.len() as u32)?;
347 for (k, v) in doc {
348 rmp::encode::write_str(w, &keyword_to_wire(k))?;
349 write_data_type(w, v)?;
350 }
351 }
352 TxOp::Add {
353 entity,
354 attribute,
355 value,
356 } => {
357 rmp::encode::write_map_len(w, 4)?;
358 write_str_field(w, "kind", "add")?;
359 rmp::encode::write_str(w, "entity")?;
360 write_entity_ref(w, entity)?;
361 write_str_field(w, "attr", &keyword_to_wire(attribute))?;
362 rmp::encode::write_str(w, "value")?;
363 write_data_type(w, value)?;
364 }
365 TxOp::Retract {
366 entity,
367 attribute,
368 value,
369 } => {
370 rmp::encode::write_map_len(w, 4)?;
371 write_str_field(w, "kind", "retract")?;
372 rmp::encode::write_str(w, "entity")?;
373 write_entity_ref(w, entity)?;
374 write_str_field(w, "attr", &keyword_to_wire(attribute))?;
375 rmp::encode::write_str(w, "value")?;
376 write_data_type(w, value)?;
377 }
378 TxOp::Delete(entity) => {
379 rmp::encode::write_map_len(w, 2)?;
380 write_str_field(w, "kind", "delete")?;
381 rmp::encode::write_str(w, "entity")?;
382 write_entity_ref(w, entity)?;
383 }
384 TxOp::Erase(entity) => {
385 rmp::encode::write_map_len(w, 2)?;
386 write_str_field(w, "kind", "erase")?;
387 rmp::encode::write_str(w, "entity")?;
388 write_entity_ref(w, entity)?;
389 }
390 }
391 Ok(())
392}
393
394pub fn tx_op_from_value(v: Value) -> Result<TxOp> {
395 let mut map = map_from_value(v)?;
396 let kind = take_string(&mut map, "kind")?;
397 match kind.as_str() {
398 "put" => {
399 let doc_value = take_field(&mut map, "doc")?;
400 let entries = match doc_value {
401 Value::Map(entries) => entries,
402 other => bail!("Put.doc must be a map, got {other:?}"),
403 };
404 let mut doc = BTreeMap::new();
405 for (k, v) in entries {
406 let key_str = match k {
407 Value::String(s) => s
408 .into_str()
409 .ok_or_else(|| anyhow!("Put.doc key is not valid UTF-8"))?,
410 other => bail!("Put.doc key must be string, got {other:?}"),
411 };
412 doc.insert(keyword_from_wire(&key_str)?, data_type_from_value(v)?);
413 }
414 Ok(TxOp::Put(doc))
415 }
416 "add" => {
417 let entity = entity_ref_from_value(take_field(&mut map, "entity")?)?;
418 let attribute = keyword_from_wire(&take_string(&mut map, "attr")?)?;
419 let value = take_data_type(&mut map, "value")?;
420 Ok(TxOp::Add {
421 entity,
422 attribute,
423 value,
424 })
425 }
426 "retract" => {
427 let entity = entity_ref_from_value(take_field(&mut map, "entity")?)?;
428 let attribute = keyword_from_wire(&take_string(&mut map, "attr")?)?;
429 let value = take_data_type(&mut map, "value")?;
430 Ok(TxOp::Retract {
431 entity,
432 attribute,
433 value,
434 })
435 }
436 "delete" => Ok(TxOp::Delete(entity_ref_from_value(take_field(
437 &mut map, "entity",
438 )?)?)),
439 "erase" => Ok(TxOp::Erase(entity_ref_from_value(take_field(
440 &mut map, "entity",
441 )?)?)),
442 other => bail!("unknown TxOp kind: {other:?}"),
443 }
444}
445
446pub fn write_query_arg<W: Write>(w: &mut W, arg: &QueryArg) -> Result<()> {
451 match arg {
452 QueryArg::Scalar(dt) => {
453 rmp::encode::write_map_len(w, 2)?;
454 write_str_field(w, "kind", "scalar")?;
455 rmp::encode::write_str(w, "value")?;
456 write_data_type(w, dt)?;
457 }
458 QueryArg::Collection(items) => {
459 rmp::encode::write_map_len(w, 2)?;
460 write_str_field(w, "kind", "collection")?;
461 rmp::encode::write_str(w, "values")?;
462 rmp::encode::write_array_len(w, items.len() as u32)?;
463 for item in items {
464 write_data_type(w, item)?;
465 }
466 }
467 QueryArg::Tuple(items) => {
468 rmp::encode::write_map_len(w, 2)?;
469 write_str_field(w, "kind", "tuple")?;
470 rmp::encode::write_str(w, "values")?;
471 rmp::encode::write_array_len(w, items.len() as u32)?;
472 for item in items {
473 write_data_type(w, item)?;
474 }
475 }
476 QueryArg::Relation(rows) => {
477 rmp::encode::write_map_len(w, 2)?;
478 write_str_field(w, "kind", "relation")?;
479 rmp::encode::write_str(w, "rows")?;
480 rmp::encode::write_array_len(w, rows.len() as u32)?;
481 for row in rows {
482 rmp::encode::write_array_len(w, row.len() as u32)?;
483 for v in row {
484 write_data_type(w, v)?;
485 }
486 }
487 }
488 }
489 Ok(())
490}
491
492pub fn query_arg_from_value(v: Value) -> Result<QueryArg> {
493 let mut map = map_from_value(v)?;
494 let kind = take_string(&mut map, "kind")?;
495 match kind.as_str() {
496 "scalar" => Ok(QueryArg::Scalar(take_data_type(&mut map, "value")?)),
497 "collection" => Ok(QueryArg::Collection(take_data_type_array(
498 &mut map, "values",
499 )?)),
500 "tuple" => Ok(QueryArg::Tuple(take_data_type_array(&mut map, "values")?)),
501 "relation" => {
502 let rows_value = take_field(&mut map, "rows")?;
503 let rows = match rows_value {
504 Value::Array(arr) => arr,
505 other => bail!("Relation.rows must be an array, got {other:?}"),
506 };
507 let mut out = Vec::with_capacity(rows.len());
508 for row in rows {
509 let row = match row {
510 Value::Array(arr) => arr,
511 other => bail!("Relation row must be an array, got {other:?}"),
512 };
513 let mut typed = Vec::with_capacity(row.len());
514 for v in row {
515 typed.push(data_type_from_value(v)?);
516 }
517 out.push(typed);
518 }
519 Ok(QueryArg::Relation(out))
520 }
521 other => bail!("unknown QueryArg kind: {other:?}"),
522 }
523}
524
525fn take_data_type_array(map: &mut BTreeMap<String, Value>, name: &str) -> Result<Vec<DataType>> {
526 let v = take_field(map, name)?;
527 let arr = match v {
528 Value::Array(arr) => arr,
529 other => bail!("field {name:?} expected array, got {other:?}"),
530 };
531 let mut out = Vec::with_capacity(arr.len());
532 for item in arr {
533 out.push(data_type_from_value(item)?);
534 }
535 Ok(out)
536}
537
538fn take_optional_string(map: &mut BTreeMap<String, Value>, name: &str) -> Result<Option<String>> {
539 let Some(v) = map.remove(name) else {
540 return Ok(None);
541 };
542 match v {
543 Value::Nil => Ok(None),
544 Value::String(s) => s
545 .into_str()
546 .map(Some)
547 .ok_or_else(|| anyhow!("field {name:?} is not valid UTF-8")),
548 other => bail!("field {name:?} expected string or nil, got {other:?}"),
549 }
550}
551
552fn take_optional_i64(map: &mut BTreeMap<String, Value>, name: &str) -> Result<Option<i64>> {
553 let Some(v) = map.remove(name) else {
554 return Ok(None);
555 };
556 match v {
557 Value::Nil => Ok(None),
558 Value::Integer(n) => n
559 .as_i64()
560 .map(Some)
561 .ok_or_else(|| anyhow!("field {name:?} integer out of i64 range")),
562 other => bail!("field {name:?} expected integer or nil, got {other:?}"),
563 }
564}
565
566fn take_optional_timestamp(
567 map: &mut BTreeMap<String, Value>,
568 name: &str,
569) -> Result<Option<DateTime<Utc>>> {
570 let Some(v) = map.remove(name) else {
571 return Ok(None);
572 };
573 match v {
574 Value::Nil => Ok(None),
575 Value::Ext(EXT_TIMESTAMP, payload) => Ok(Some(read_timestamp(&payload)?)),
576 other => bail!("field {name:?} expected timestamp ext or nil, got {other:?}"),
577 }
578}
579
580fn take_timestamp(map: &mut BTreeMap<String, Value>, name: &str) -> Result<DateTime<Utc>> {
581 match take_field(map, name)? {
582 Value::Ext(EXT_TIMESTAMP, payload) => read_timestamp(&payload),
583 other => bail!("field {name:?} expected timestamp ext, got {other:?}"),
584 }
585}
586
587fn write_optional_string<W: Write>(w: &mut W, opt: &Option<String>) -> Result<()> {
588 match opt {
589 Some(s) => {
590 rmp::encode::write_str(w, s)?;
591 }
592 None => {
593 rmp::encode::write_nil(w)?;
594 }
595 }
596 Ok(())
597}
598
599fn write_optional_i64<W: Write>(w: &mut W, opt: Option<i64>) -> Result<()> {
600 match opt {
601 Some(v) => {
602 rmp::encode::write_sint(w, v)?;
603 }
604 None => {
605 rmp::encode::write_nil(w)?;
606 }
607 }
608 Ok(())
609}
610
611fn write_optional_timestamp<W: Write>(w: &mut W, opt: Option<DateTime<Utc>>) -> Result<()> {
612 match opt {
613 Some(t) => write_timestamp(w, &t)?,
614 None => {
615 rmp::encode::write_nil(w)?;
616 }
617 }
618 Ok(())
619}
620
621fn read_body_value(data: &[u8]) -> Result<Value> {
622 let mut cursor = data;
623 let v =
624 rmpv::decode::read_value(&mut cursor).map_err(|e| anyhow!("msgpack decode error: {e}"))?;
625 if !cursor.is_empty() {
626 bail!("trailing bytes after msgpack body");
627 }
628 Ok(v)
629}
630
631#[derive(Debug, Clone, PartialEq)]
636pub struct OpenDbRequest {
637 pub tx_id: Option<i64>,
638 pub system_time: Option<DateTime<Utc>>,
639}
640
641#[derive(Debug, Clone, PartialEq)]
642pub struct QueryRequest {
643 pub tx_key: TxKey,
644 pub query: String,
645 pub args: Vec<QueryArg>,
646}
647
648#[derive(Debug, Clone, PartialEq)]
649pub struct QueryResponse {
650 pub columns: Vec<ColumnDescription>,
651 pub rows: Vec<Vec<DataType>>,
652}
653
654#[derive(Debug, Clone, PartialEq)]
655pub struct ExecuteRequest {
656 pub ops: Vec<TxOp>,
657}
658
659#[derive(Debug, Clone, PartialEq)]
660pub struct TxResultResponse {
661 pub status: u8,
662 pub tx_id: i64,
663 pub system_time: DateTime<Utc>,
664 pub error_message: Option<String>,
665}
666
667#[derive(Debug, Clone, PartialEq)]
668pub struct ErrorResponseBody {
669 pub severity: u8,
670 pub code: u16,
671 pub message: String,
672 pub detail: Option<String>,
673 pub hint: Option<String>,
674}
675
676pub fn encode_open_db_request(req: &OpenDbRequest) -> Result<Vec<u8>> {
678 let mut buf = Vec::new();
679 rmp::encode::write_map_len(&mut buf, 2)?;
680 rmp::encode::write_str(&mut buf, "tx_id")?;
681 write_optional_i64(&mut buf, req.tx_id)?;
682 rmp::encode::write_str(&mut buf, "system_time")?;
683 write_optional_timestamp(&mut buf, req.system_time)?;
684 Ok(buf)
685}
686
687pub fn decode_open_db_request(data: &[u8]) -> Result<OpenDbRequest> {
688 let mut map = map_from_value(read_body_value(data)?)?;
689 let tx_id = take_optional_i64(&mut map, "tx_id")?;
690 let system_time = take_optional_timestamp(&mut map, "system_time")?;
691 Ok(OpenDbRequest { tx_id, system_time })
692}
693
694pub fn encode_tx_key(tx_key: &TxKey) -> Result<Vec<u8>> {
697 let mut buf = Vec::new();
698 write_tx_key(&mut buf, tx_key)?;
699 Ok(buf)
700}
701
702pub fn decode_tx_key(data: &[u8]) -> Result<TxKey> {
703 tx_key_from_value(read_body_value(data)?)
704}
705
706fn write_tx_key<W: Write>(w: &mut W, tx_key: &TxKey) -> Result<()> {
707 rmp::encode::write_map_len(w, 2)?;
708 rmp::encode::write_str(w, "tx_id")?;
709 rmp::encode::write_sint(w, tx_key.tx_id)?;
710 rmp::encode::write_str(w, "system_time")?;
711 write_timestamp(w, &tx_key.system_time)?;
712 Ok(())
713}
714
715fn tx_key_from_value(value: Value) -> Result<TxKey> {
716 let mut map = map_from_value(value)?;
717 Ok(TxKey {
718 tx_id: take_i64(&mut map, "tx_id")?,
719 system_time: take_timestamp(&mut map, "system_time")?,
720 })
721}
722
723pub fn encode_query_request(req: &QueryRequest) -> Result<Vec<u8>> {
725 let mut buf = Vec::new();
726 rmp::encode::write_map_len(&mut buf, 3)?;
727 rmp::encode::write_str(&mut buf, "tx_key")?;
728 write_tx_key(&mut buf, &req.tx_key)?;
729 rmp::encode::write_str(&mut buf, "query")?;
730 rmp::encode::write_str(&mut buf, &req.query)?;
731 rmp::encode::write_str(&mut buf, "args")?;
732 rmp::encode::write_array_len(&mut buf, req.args.len() as u32)?;
733 for arg in &req.args {
734 write_query_arg(&mut buf, arg)?;
735 }
736 Ok(buf)
737}
738
739pub fn decode_query_request(data: &[u8]) -> Result<QueryRequest> {
740 let mut map = map_from_value(read_body_value(data)?)?;
741 let tx_key = tx_key_from_value(take_field(&mut map, "tx_key")?)?;
742 let query = take_string(&mut map, "query")?;
743 let arr = match take_field(&mut map, "args")? {
744 Value::Array(arr) => arr,
745 other => bail!("field \"args\" expected array, got {other:?}"),
746 };
747 let mut args = Vec::with_capacity(arr.len());
748 for item in arr {
749 args.push(query_arg_from_value(item)?);
750 }
751 Ok(QueryRequest {
752 tx_key,
753 query,
754 args,
755 })
756}
757
758pub fn encode_query_response(resp: &QueryResponse) -> Result<Vec<u8>> {
760 let mut buf = Vec::new();
761 rmp::encode::write_map_len(&mut buf, 2)?;
762 rmp::encode::write_str(&mut buf, "columns")?;
763 rmp::encode::write_array_len(&mut buf, resp.columns.len() as u32)?;
764 for col in &resp.columns {
765 write_column_description(&mut buf, col)?;
766 }
767 rmp::encode::write_str(&mut buf, "rows")?;
768 rmp::encode::write_array_len(&mut buf, resp.rows.len() as u32)?;
769 for row in &resp.rows {
770 rmp::encode::write_array_len(&mut buf, row.len() as u32)?;
771 for v in row {
772 write_data_type(&mut buf, v)?;
773 }
774 }
775 Ok(buf)
776}
777
778pub fn decode_query_response(data: &[u8]) -> Result<QueryResponse> {
779 let mut map = map_from_value(read_body_value(data)?)?;
780 let cols_arr = match take_field(&mut map, "columns")? {
781 Value::Array(arr) => arr,
782 other => bail!("field \"columns\" expected array, got {other:?}"),
783 };
784 let mut columns = Vec::with_capacity(cols_arr.len());
785 for item in cols_arr {
786 columns.push(column_description_from_value(item)?);
787 }
788 let rows_arr = match take_field(&mut map, "rows")? {
789 Value::Array(arr) => arr,
790 other => bail!("field \"rows\" expected array, got {other:?}"),
791 };
792 let mut rows = Vec::with_capacity(rows_arr.len());
793 for row in rows_arr {
794 let row_arr = match row {
795 Value::Array(arr) => arr,
796 other => bail!("row expected array, got {other:?}"),
797 };
798 let mut typed = Vec::with_capacity(row_arr.len());
799 for v in row_arr {
800 typed.push(data_type_from_value(v)?);
801 }
802 rows.push(typed);
803 }
804 Ok(QueryResponse { columns, rows })
805}
806
807fn write_column_description<W: Write>(w: &mut W, col: &ColumnDescription) -> Result<()> {
808 let len = if col.members.is_some() { 3 } else { 2 };
809 rmp::encode::write_map_len(w, len)?;
810 rmp::encode::write_str(w, "name")?;
811 rmp::encode::write_str(w, &col.name)?;
812 rmp::encode::write_str(w, "type")?;
813 rmp::encode::write_uint(w, col.data_type as u64)?;
814 if let Some(members) = &col.members {
815 rmp::encode::write_str(w, "members")?;
816 rmp::encode::write_array_len(w, members.len() as u32)?;
817 for m in members {
818 rmp::encode::write_uint(w, *m as u64)?;
819 }
820 }
821 Ok(())
822}
823
824fn column_description_from_value(v: Value) -> Result<ColumnDescription> {
825 let mut map = map_from_value(v)?;
826 let name = take_string(&mut map, "name")?;
827 let data_type_i64 = take_i64(&mut map, "type")?;
828 if data_type_i64 < 0 || data_type_i64 > u8::MAX as i64 {
829 bail!("column type tag out of u8 range: {data_type_i64}");
830 }
831 let data_type = data_type_i64 as u8;
832 let members = match map.remove("members") {
833 None => None,
834 Some(Value::Nil) => None,
835 Some(Value::Array(arr)) => {
836 let mut tags = Vec::with_capacity(arr.len());
837 for item in arr {
838 let tag = match item {
839 Value::Integer(n) => n
840 .as_u64()
841 .ok_or_else(|| anyhow!("union member tag not unsigned"))?,
842 other => bail!("union member must be integer, got {other:?}"),
843 };
844 if tag > u8::MAX as u64 {
845 bail!("union member tag out of u8 range: {tag}");
846 }
847 tags.push(tag as u8);
848 }
849 Some(tags)
850 }
851 Some(other) => bail!("\"members\" expected array, got {other:?}"),
852 };
853 Ok(ColumnDescription {
854 name,
855 data_type,
856 members,
857 })
858}
859
860pub fn encode_execute_request(req: &ExecuteRequest) -> Result<Vec<u8>> {
862 let mut buf = Vec::new();
863 rmp::encode::write_map_len(&mut buf, 1)?;
864 rmp::encode::write_str(&mut buf, "ops")?;
865 rmp::encode::write_array_len(&mut buf, req.ops.len() as u32)?;
866 for op in &req.ops {
867 write_tx_op(&mut buf, op)?;
868 }
869 Ok(buf)
870}
871
872pub fn decode_execute_request(data: &[u8]) -> Result<ExecuteRequest> {
873 let mut map = map_from_value(read_body_value(data)?)?;
874 let arr = match take_field(&mut map, "ops")? {
875 Value::Array(arr) => arr,
876 other => bail!("field \"ops\" expected array, got {other:?}"),
877 };
878 let mut ops = Vec::with_capacity(arr.len());
879 for item in arr {
880 ops.push(tx_op_from_value(item)?);
881 }
882 Ok(ExecuteRequest { ops })
883}
884
885pub fn encode_tx_result_response(resp: &TxResultResponse) -> Result<Vec<u8>> {
887 let mut buf = Vec::new();
888 rmp::encode::write_map_len(&mut buf, 4)?;
889 rmp::encode::write_str(&mut buf, "status")?;
890 rmp::encode::write_uint(&mut buf, resp.status as u64)?;
891 rmp::encode::write_str(&mut buf, "tx_id")?;
892 rmp::encode::write_sint(&mut buf, resp.tx_id)?;
893 rmp::encode::write_str(&mut buf, "system_time")?;
894 write_timestamp(&mut buf, &resp.system_time)?;
895 rmp::encode::write_str(&mut buf, "error_message")?;
896 write_optional_string(&mut buf, &resp.error_message)?;
897 Ok(buf)
898}
899
900pub fn decode_tx_result_response(data: &[u8]) -> Result<TxResultResponse> {
901 let mut map = map_from_value(read_body_value(data)?)?;
902 let status_i64 = take_i64(&mut map, "status")?;
903 if status_i64 < 0 || status_i64 > u8::MAX as i64 {
904 bail!("status out of u8 range: {status_i64}");
905 }
906 let tx_id = take_i64(&mut map, "tx_id")?;
907 let system_time = take_timestamp(&mut map, "system_time")?;
908 let error_message = take_optional_string(&mut map, "error_message")?;
909 Ok(TxResultResponse {
910 status: status_i64 as u8,
911 tx_id,
912 system_time,
913 error_message,
914 })
915}
916
917pub fn encode_error_body(resp: &ErrorResponseBody) -> Result<Vec<u8>> {
919 let severity_str = match resp.severity {
920 b'E' => "E",
921 b'F' => "F",
922 other => bail!("invalid severity byte: {other:#x}"),
923 };
924 let mut buf = Vec::new();
925 rmp::encode::write_map_len(&mut buf, 5)?;
926 rmp::encode::write_str(&mut buf, "severity")?;
927 rmp::encode::write_str(&mut buf, severity_str)?;
928 rmp::encode::write_str(&mut buf, "code")?;
929 rmp::encode::write_uint(&mut buf, resp.code as u64)?;
930 rmp::encode::write_str(&mut buf, "message")?;
931 rmp::encode::write_str(&mut buf, &resp.message)?;
932 rmp::encode::write_str(&mut buf, "detail")?;
933 write_optional_string(&mut buf, &resp.detail)?;
934 rmp::encode::write_str(&mut buf, "hint")?;
935 write_optional_string(&mut buf, &resp.hint)?;
936 Ok(buf)
937}
938
939pub fn decode_error_body(data: &[u8]) -> Result<ErrorResponseBody> {
940 let mut map = map_from_value(read_body_value(data)?)?;
941 let severity_str = take_string(&mut map, "severity")?;
942 let severity = match severity_str.as_str() {
943 "E" => b'E',
944 "F" => b'F',
945 other => bail!("invalid severity string: {other:?}"),
946 };
947 let code_i64 = take_i64(&mut map, "code")?;
948 if code_i64 < 0 || code_i64 > u16::MAX as i64 {
949 bail!("code out of u16 range: {code_i64}");
950 }
951 let message = take_string(&mut map, "message")?;
952 let detail = take_optional_string(&mut map, "detail")?;
953 let hint = take_optional_string(&mut map, "hint")?;
954 Ok(ErrorResponseBody {
955 severity,
956 code: code_i64 as u16,
957 message,
958 detail,
959 hint,
960 })
961}
962
963#[derive(Debug, Clone, PartialEq)]
968pub struct SubscribeRequest {
969 pub tx_key: Option<TxKey>,
970 pub query: String,
971 pub args: Vec<QueryArg>,
972}
973
974#[derive(Debug, Clone, PartialEq)]
977pub enum SubscriptionFrame {
978 Open {
980 tx_key: TxKey,
981 columns: Vec<ColumnDescription>,
982 },
983 Delta {
985 tx_key: TxKey,
986 rows: Vec<(Vec<DataType>, i64)>,
987 },
988 Error(ErrorResponseBody),
990}
991
992fn write_optional_tx_key<W: Write>(w: &mut W, tx_key: &Option<TxKey>) -> Result<()> {
993 match tx_key {
994 Some(k) => write_tx_key(w, k)?,
995 None => rmp::encode::write_nil(w)?,
996 }
997 Ok(())
998}
999
1000fn take_optional_tx_key(map: &mut BTreeMap<String, Value>, name: &str) -> Result<Option<TxKey>> {
1001 match map.remove(name) {
1002 None | Some(Value::Nil) => Ok(None),
1003 Some(v) => Ok(Some(tx_key_from_value(v)?)),
1004 }
1005}
1006
1007fn delta_row_from_value(v: Value) -> Result<(Vec<DataType>, i64)> {
1008 let entry = match v {
1009 Value::Array(arr) => arr,
1010 other => bail!("delta row expected [values, weight] array, got {other:?}"),
1011 };
1012 if entry.len() != 2 {
1013 bail!("delta row expected 2 elements, got {}", entry.len());
1014 }
1015 let mut it = entry.into_iter();
1016 let values = match it.next().unwrap() {
1017 Value::Array(arr) => {
1018 let mut out = Vec::with_capacity(arr.len());
1019 for v in arr {
1020 out.push(data_type_from_value(v)?);
1021 }
1022 out
1023 }
1024 other => bail!("delta row values expected array, got {other:?}"),
1025 };
1026 let weight = match it.next().unwrap() {
1027 Value::Integer(n) => n
1028 .as_i64()
1029 .ok_or_else(|| anyhow!("delta row weight out of i64 range"))?,
1030 other => bail!("delta row weight expected integer, got {other:?}"),
1031 };
1032 Ok((values, weight))
1033}
1034
1035pub fn encode_subscribe_request(req: &SubscribeRequest) -> Result<Vec<u8>> {
1037 let mut buf = Vec::new();
1038 rmp::encode::write_map_len(&mut buf, 3)?;
1039 rmp::encode::write_str(&mut buf, "tx_key")?;
1040 write_optional_tx_key(&mut buf, &req.tx_key)?;
1041 rmp::encode::write_str(&mut buf, "query")?;
1042 rmp::encode::write_str(&mut buf, &req.query)?;
1043 rmp::encode::write_str(&mut buf, "args")?;
1044 rmp::encode::write_array_len(&mut buf, req.args.len() as u32)?;
1045 for arg in &req.args {
1046 write_query_arg(&mut buf, arg)?;
1047 }
1048 Ok(buf)
1049}
1050
1051pub fn decode_subscribe_request(data: &[u8]) -> Result<SubscribeRequest> {
1052 let mut map = map_from_value(read_body_value(data)?)?;
1053 let tx_key = take_optional_tx_key(&mut map, "tx_key")?;
1054 let query = take_string(&mut map, "query")?;
1055 let arr = match take_field(&mut map, "args")? {
1056 Value::Array(arr) => arr,
1057 other => bail!("field \"args\" expected array, got {other:?}"),
1058 };
1059 let mut args = Vec::with_capacity(arr.len());
1060 for item in arr {
1061 args.push(query_arg_from_value(item)?);
1062 }
1063 Ok(SubscribeRequest {
1064 tx_key,
1065 query,
1066 args,
1067 })
1068}
1069
1070pub fn encode_subscription_frame(frame: &SubscriptionFrame) -> Result<Vec<u8>> {
1072 let mut buf = Vec::new();
1073 match frame {
1074 SubscriptionFrame::Open { tx_key, columns } => {
1075 rmp::encode::write_map_len(&mut buf, 3)?;
1076 write_str_field(&mut buf, "kind", "open")?;
1077 rmp::encode::write_str(&mut buf, "tx_key")?;
1078 write_tx_key(&mut buf, tx_key)?;
1079 rmp::encode::write_str(&mut buf, "columns")?;
1080 rmp::encode::write_array_len(&mut buf, columns.len() as u32)?;
1081 for col in columns {
1082 write_column_description(&mut buf, col)?;
1083 }
1084 }
1085 SubscriptionFrame::Delta { tx_key, rows } => {
1086 rmp::encode::write_map_len(&mut buf, 3)?;
1087 write_str_field(&mut buf, "kind", "delta")?;
1088 rmp::encode::write_str(&mut buf, "tx_key")?;
1089 write_tx_key(&mut buf, tx_key)?;
1090 rmp::encode::write_str(&mut buf, "rows")?;
1091 rmp::encode::write_array_len(&mut buf, rows.len() as u32)?;
1092 for (values, weight) in rows {
1093 rmp::encode::write_array_len(&mut buf, 2)?;
1094 rmp::encode::write_array_len(&mut buf, values.len() as u32)?;
1095 for v in values {
1096 write_data_type(&mut buf, v)?;
1097 }
1098 rmp::encode::write_sint(&mut buf, *weight)?;
1099 }
1100 }
1101 SubscriptionFrame::Error(err) => {
1102 let severity_str = match err.severity {
1103 b'E' => "E",
1104 b'F' => "F",
1105 other => bail!("invalid severity byte: {other:#x}"),
1106 };
1107 rmp::encode::write_map_len(&mut buf, 6)?;
1108 write_str_field(&mut buf, "kind", "error")?;
1109 rmp::encode::write_str(&mut buf, "severity")?;
1110 rmp::encode::write_str(&mut buf, severity_str)?;
1111 rmp::encode::write_str(&mut buf, "code")?;
1112 rmp::encode::write_uint(&mut buf, err.code as u64)?;
1113 rmp::encode::write_str(&mut buf, "message")?;
1114 rmp::encode::write_str(&mut buf, &err.message)?;
1115 rmp::encode::write_str(&mut buf, "detail")?;
1116 write_optional_string(&mut buf, &err.detail)?;
1117 rmp::encode::write_str(&mut buf, "hint")?;
1118 write_optional_string(&mut buf, &err.hint)?;
1119 }
1120 }
1121 Ok(buf)
1122}
1123
1124pub fn decode_subscription_frame(data: &[u8]) -> Result<SubscriptionFrame> {
1126 subscription_frame_from_value(read_body_value(data)?)
1127}
1128
1129pub fn subscription_frame_from_value(v: Value) -> Result<SubscriptionFrame> {
1132 let mut map = map_from_value(v)?;
1133 let kind = take_string(&mut map, "kind")?;
1134 match kind.as_str() {
1135 "open" => {
1136 let tx_key = tx_key_from_value(take_field(&mut map, "tx_key")?)?;
1137 let cols_arr = match take_field(&mut map, "columns")? {
1138 Value::Array(arr) => arr,
1139 other => bail!("field \"columns\" expected array, got {other:?}"),
1140 };
1141 let mut columns = Vec::with_capacity(cols_arr.len());
1142 for item in cols_arr {
1143 columns.push(column_description_from_value(item)?);
1144 }
1145 Ok(SubscriptionFrame::Open { tx_key, columns })
1146 }
1147 "delta" => {
1148 let tx_key = tx_key_from_value(take_field(&mut map, "tx_key")?)?;
1149 let rows_arr = match take_field(&mut map, "rows")? {
1150 Value::Array(arr) => arr,
1151 other => bail!("field \"rows\" expected array, got {other:?}"),
1152 };
1153 let mut rows = Vec::with_capacity(rows_arr.len());
1154 for entry in rows_arr {
1155 rows.push(delta_row_from_value(entry)?);
1156 }
1157 Ok(SubscriptionFrame::Delta { tx_key, rows })
1158 }
1159 "error" => {
1160 let severity_str = take_string(&mut map, "severity")?;
1161 let severity = match severity_str.as_str() {
1162 "E" => b'E',
1163 "F" => b'F',
1164 other => bail!("invalid severity string: {other:?}"),
1165 };
1166 let code_i64 = take_i64(&mut map, "code")?;
1167 if code_i64 < 0 || code_i64 > u16::MAX as i64 {
1168 bail!("code out of u16 range: {code_i64}");
1169 }
1170 let message = take_string(&mut map, "message")?;
1171 let detail = take_optional_string(&mut map, "detail")?;
1172 let hint = take_optional_string(&mut map, "hint")?;
1173 Ok(SubscriptionFrame::Error(ErrorResponseBody {
1174 severity,
1175 code: code_i64 as u16,
1176 message,
1177 detail,
1178 hint,
1179 }))
1180 }
1181 other => bail!("unknown subscription frame kind: {other}"),
1182 }
1183}
1184
1185#[cfg(test)]
1186mod tests {
1187 use super::*;
1188 use chrono::TimeZone;
1189 use edn::kw;
1190
1191 fn sample_tx_key() -> TxKey {
1192 TxKey {
1193 tx_id: 7,
1194 system_time: Utc.timestamp_opt(1_700_000_000, 0).unwrap(),
1195 }
1196 }
1197
1198 #[test]
1199 fn subscribe_request_round_trip() {
1200 for db in [None, Some(sample_tx_key())] {
1201 let req = SubscribeRequest {
1202 tx_key: db,
1203 query: "[:find ?n :where [?e :name ?n]]".to_string(),
1204 args: vec![QueryArg::Scalar(DataType::Long(42))],
1205 };
1206 let bytes = encode_subscribe_request(&req).expect("encode");
1207 assert_eq!(decode_subscribe_request(&bytes).expect("decode"), req);
1208 }
1209 }
1210
1211 #[test]
1212 fn open_frame_round_trip() {
1213 let frame = SubscriptionFrame::Open {
1214 tx_key: sample_tx_key(),
1215 columns: vec![ColumnDescription {
1216 name: "n".to_string(),
1217 data_type: 255,
1218 members: None,
1219 }],
1220 };
1221 let bytes = encode_subscription_frame(&frame).expect("encode");
1222 assert_eq!(decode_subscription_frame(&bytes).expect("decode"), frame);
1223 }
1224
1225 #[test]
1226 fn delta_frame_round_trip() {
1227 let frame = SubscriptionFrame::Delta {
1229 tx_key: sample_tx_key(),
1230 rows: vec![
1231 (vec![DataType::String("Ivan".to_string())], 1),
1232 (vec![DataType::String("Petr".to_string())], -2),
1233 ],
1234 };
1235 let bytes = encode_subscription_frame(&frame).expect("encode");
1236 assert_eq!(decode_subscription_frame(&bytes).expect("decode"), frame);
1237 }
1238
1239 #[test]
1240 fn error_frame_round_trip() {
1241 let frame = SubscriptionFrame::Error(ErrorResponseBody {
1242 severity: b'F',
1243 code: 4000,
1244 message: "boom".to_string(),
1245 detail: Some("detail".to_string()),
1246 hint: None,
1247 });
1248 let bytes = encode_subscription_frame(&frame).expect("encode");
1249 assert_eq!(decode_subscription_frame(&bytes).expect("decode"), frame);
1250 }
1251
1252 #[test]
1253 fn unknown_subscription_frame_kind_errors() {
1254 let mut buf = Vec::new();
1255 rmp::encode::write_map_len(&mut buf, 1).unwrap();
1256 write_str_field(&mut buf, "kind", "heartbeat").unwrap();
1257 let err = decode_subscription_frame(&buf).unwrap_err();
1258 assert!(err
1259 .to_string()
1260 .contains("unknown subscription frame kind: heartbeat"));
1261 }
1262
1263 #[test]
1264 fn delta_frame_decodes_regardless_of_key_order() {
1265 let mut buf = Vec::new();
1266 rmp::encode::write_map_len(&mut buf, 3).unwrap();
1267 rmp::encode::write_str(&mut buf, "rows").unwrap();
1268 rmp::encode::write_array_len(&mut buf, 1).unwrap();
1269 rmp::encode::write_array_len(&mut buf, 2).unwrap();
1270 rmp::encode::write_array_len(&mut buf, 1).unwrap();
1271 write_data_type(&mut buf, &DataType::Long(5)).unwrap();
1272 rmp::encode::write_sint(&mut buf, 1).unwrap();
1273 write_str_field(&mut buf, "kind", "delta").unwrap();
1274 rmp::encode::write_str(&mut buf, "tx_key").unwrap();
1275 write_tx_key(&mut buf, &sample_tx_key()).unwrap();
1276 assert_eq!(
1277 decode_subscription_frame(&buf).expect("decode"),
1278 SubscriptionFrame::Delta {
1279 tx_key: sample_tx_key(),
1280 rows: vec![(vec![DataType::Long(5)], 1)],
1281 }
1282 );
1283 }
1284
1285 fn round_trip(dt: &DataType) -> DataType {
1286 let mut buf = Vec::new();
1287 write_data_type(&mut buf, dt).expect("encode");
1288 let (decoded, rest) = read_data_type(&buf).expect("decode");
1289 assert!(rest.is_empty(), "unexpected trailing bytes");
1290 decoded
1291 }
1292
1293 #[test]
1294 fn round_trip_boolean() {
1295 assert_eq!(
1296 round_trip(&DataType::Boolean(true)),
1297 DataType::Boolean(true)
1298 );
1299 assert_eq!(
1300 round_trip(&DataType::Boolean(false)),
1301 DataType::Boolean(false)
1302 );
1303 }
1304
1305 #[test]
1306 fn round_trip_long() {
1307 for v in [
1308 0i64,
1309 1,
1310 -1,
1311 i64::MAX,
1312 i64::MIN,
1313 127,
1314 128,
1315 -32,
1316 -33,
1317 i32::MAX as i64,
1318 i32::MIN as i64,
1319 ] {
1320 assert_eq!(round_trip(&DataType::Long(v)), DataType::Long(v), "v = {v}");
1321 }
1322 }
1323
1324 #[test]
1325 fn round_trip_float_and_double_distinct() {
1326 let f = DataType::Float(1.5_f32);
1327 let d = DataType::Double(1.5_f64);
1328 assert_eq!(round_trip(&f), f);
1329 assert_eq!(round_trip(&d), d);
1330 let mut fb = Vec::new();
1332 write_data_type(&mut fb, &f).unwrap();
1333 let mut db = Vec::new();
1334 write_data_type(&mut db, &d).unwrap();
1335 assert_eq!(fb.len(), 5);
1336 assert_eq!(db.len(), 9);
1337 }
1338
1339 #[test]
1340 fn round_trip_string_and_bytes() {
1341 for s in ["", "a", "hello", "𝄞 unicode 漢字"] {
1342 assert_eq!(
1343 round_trip(&DataType::String(s.into())),
1344 DataType::String(s.into())
1345 );
1346 }
1347 for b in [vec![], vec![0u8], vec![0, 1, 255], (0u8..=255).collect()] {
1348 assert_eq!(round_trip(&DataType::Bytes(b.clone())), DataType::Bytes(b));
1349 }
1350 }
1351
1352 #[test]
1353 fn round_trip_bigint() {
1354 for v in [0i128, 1, -1, i128::MAX, i128::MIN, i64::MAX as i128 + 1] {
1355 assert_eq!(
1356 round_trip(&DataType::BigInt(v)),
1357 DataType::BigInt(v),
1358 "v = {v}"
1359 );
1360 }
1361 }
1362
1363 #[test]
1364 fn round_trip_uuid() {
1365 let u = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap();
1366 assert_eq!(round_trip(&DataType::Uuid(u)), DataType::Uuid(u));
1367 assert_eq!(
1368 round_trip(&DataType::Uuid(Uuid::nil())),
1369 DataType::Uuid(Uuid::nil())
1370 );
1371 }
1372
1373 #[test]
1374 fn round_trip_keyword() {
1375 let plain = DataType::Keyword(kw!(:foo));
1376 let ns = DataType::Keyword(kw!(:person/name));
1377 assert_eq!(round_trip(&plain), plain);
1378 assert_eq!(round_trip(&ns), ns);
1379 }
1380
1381 #[test]
1382 fn round_trip_instant() {
1383 let a = Utc.timestamp_opt(1_700_000_000, 0).unwrap();
1385 let b = Utc.timestamp_opt(1_700_000_000, 123_456_789).unwrap();
1387 let c = Utc.timestamp_opt(-1_000_000_000, 500_000_000).unwrap();
1389 let d = Utc.timestamp_opt(1u64.wrapping_shl(34) as i64, 0).unwrap();
1391 for instant in [a, b, c, d] {
1392 assert_eq!(
1393 round_trip(&DataType::Instant(instant)),
1394 DataType::Instant(instant),
1395 "instant = {instant}"
1396 );
1397 }
1398 }
1399
1400 #[test]
1401 fn round_trip_vector() {
1402 let empty = DataType::Vector(vec![]);
1403 assert_eq!(round_trip(&empty), empty);
1404 let v = DataType::Vector(vec![
1405 DataType::Long(1),
1406 DataType::String("two".into()),
1407 DataType::Boolean(true),
1408 ]);
1409 assert_eq!(round_trip(&v), v);
1410 let nested = DataType::Vector(vec![DataType::Vector(vec![DataType::Long(42)])]);
1412 assert_eq!(round_trip(&nested), nested);
1413 }
1414
1415 #[test]
1416 fn round_trip_map() {
1417 let mut m = BTreeMap::new();
1418 m.insert("a".to_string(), DataType::Long(1));
1419 m.insert("b".to_string(), DataType::String("x".into()));
1420 let dt = DataType::Map(m);
1421 assert_eq!(round_trip(&dt), dt);
1422 assert_eq!(
1424 round_trip(&DataType::Map(BTreeMap::new())),
1425 DataType::Map(BTreeMap::new())
1426 );
1427 let mut inner = BTreeMap::new();
1429 inner.insert("k".to_string(), DataType::Long(99));
1430 let mut outer = BTreeMap::new();
1431 outer.insert("nested".to_string(), DataType::Map(inner));
1432 let dt = DataType::Map(outer);
1433 assert_eq!(round_trip(&dt), dt);
1434 }
1435
1436 fn read_value_from_buf(buf: &[u8]) -> (Value, &[u8]) {
1437 let mut cursor = buf;
1438 let v = rmpv::decode::read_value(&mut cursor).unwrap();
1439 (v, cursor)
1440 }
1441
1442 fn round_trip_entity_ref(er: &EntityRef) -> EntityRef {
1443 let mut buf = Vec::new();
1444 write_entity_ref(&mut buf, er).unwrap();
1445 let (value, rest) = read_value_from_buf(&buf);
1446 assert!(rest.is_empty());
1447 entity_ref_from_value(value).unwrap()
1448 }
1449
1450 fn round_trip_tx_op(op: &TxOp) -> TxOp {
1451 let mut buf = Vec::new();
1452 write_tx_op(&mut buf, op).unwrap();
1453 let (value, rest) = read_value_from_buf(&buf);
1454 assert!(rest.is_empty());
1455 tx_op_from_value(value).unwrap()
1456 }
1457
1458 fn round_trip_query_arg(arg: &QueryArg) -> QueryArg {
1459 let mut buf = Vec::new();
1460 write_query_arg(&mut buf, arg).unwrap();
1461 let (value, rest) = read_value_from_buf(&buf);
1462 assert!(rest.is_empty());
1463 query_arg_from_value(value).unwrap()
1464 }
1465
1466 #[test]
1467 fn round_trip_entity_ref_variants() {
1468 let cases = vec![
1469 EntityRef::Id(42),
1470 EntityRef::Id(-1),
1471 EntityRef::TempId("tempid-1".into()),
1472 EntityRef::Ident(kw!(:person/name)),
1473 EntityRef::Ident(kw!(:plain)),
1474 EntityRef::LookupRef(kw!(:user/email), DataType::String("a@b.c".into())),
1475 ];
1476 for er in cases {
1477 assert_eq!(round_trip_entity_ref(&er), er);
1478 }
1479 }
1480
1481 #[test]
1482 fn round_trip_tx_op_variants() {
1483 let cases = vec![
1484 TxOp::put(vec![
1485 (kw!(:db/id), DataType::Long(1)),
1486 (kw!(:person/name), DataType::String("alice".into())),
1487 ]),
1488 TxOp::Add {
1489 entity: EntityRef::Id(1),
1490 attribute: kw!(:person/age),
1491 value: DataType::Long(30),
1492 },
1493 TxOp::Retract {
1494 entity: EntityRef::Ident(kw!(:db/ident)),
1495 attribute: kw!(:db/doc),
1496 value: DataType::String("doc".into()),
1497 },
1498 TxOp::Delete(EntityRef::Id(99)),
1499 TxOp::Erase(EntityRef::Id(100)),
1500 ];
1501 for op in cases {
1502 assert_eq!(round_trip_tx_op(&op), op);
1503 }
1504 }
1505
1506 #[test]
1507 fn round_trip_query_arg_variants() {
1508 let cases = vec![
1509 QueryArg::Scalar(DataType::String("alice".into())),
1510 QueryArg::Scalar(DataType::Long(7)),
1511 QueryArg::Collection(vec![
1512 DataType::Long(1),
1513 DataType::Long(2),
1514 DataType::Long(3),
1515 ]),
1516 QueryArg::Tuple(vec![DataType::String("x".into()), DataType::Long(99)]),
1517 QueryArg::Relation(vec![
1518 vec![DataType::Long(1), DataType::String("a".into())],
1519 vec![DataType::Long(2), DataType::String("b".into())],
1520 ]),
1521 ];
1522 for arg in cases {
1523 assert_eq!(round_trip_query_arg(&arg), arg);
1524 }
1525 }
1526
1527 #[test]
1528 fn round_trip_open_db_request_bodies() {
1529 for (tx_id, system_time) in [
1530 (None, None),
1531 (
1532 Some(42i64),
1533 Some(Utc.timestamp_opt(1_700_000_000, 0).unwrap()),
1534 ),
1535 (Some(-1), Some(Utc.timestamp_opt(0, 1).unwrap())),
1536 ] {
1537 let request = OpenDbRequest { tx_id, system_time };
1538 let buf = encode_open_db_request(&request).unwrap();
1539 assert_eq!(decode_open_db_request(&buf).unwrap(), request);
1540 }
1541 }
1542
1543 #[test]
1544 fn round_trip_tx_key_body() {
1545 let tx_key = TxKey {
1546 tx_id: 7,
1547 system_time: Utc.timestamp_opt(1_700_000_001, 0).unwrap(),
1548 };
1549 let buf = encode_tx_key(&tx_key).unwrap();
1550 assert_eq!(decode_tx_key(&buf).unwrap(), tx_key);
1551 }
1552
1553 #[test]
1554 fn round_trip_query_request_body() {
1555 let q = "{:find [?n] :where [[?e :name ?n]]}";
1556 let db = TxKey {
1557 tx_id: 42,
1558 system_time: Utc.timestamp_opt(1_700_000_002, 0).unwrap(),
1559 };
1560 let args = vec![
1561 QueryArg::Scalar(DataType::Long(7)),
1562 QueryArg::Collection(vec![
1563 DataType::String("a".into()),
1564 DataType::String("b".into()),
1565 ]),
1566 ];
1567 let request = QueryRequest {
1568 tx_key: db,
1569 query: q.into(),
1570 args,
1571 };
1572 let buf = encode_query_request(&request).unwrap();
1573 assert_eq!(decode_query_request(&buf).unwrap(), request);
1574 }
1575
1576 #[test]
1577 fn round_trip_query_response_body() {
1578 let columns = vec![
1579 ColumnDescription {
1580 name: "?e".into(),
1581 data_type: 7,
1582 members: None,
1583 },
1584 ColumnDescription {
1585 name: "?val".into(),
1586 data_type: 127,
1587 members: Some(vec![7, 9]),
1588 },
1589 ];
1590 let rows = vec![
1591 vec![DataType::Long(1), DataType::String("x".into())],
1592 vec![DataType::Long(2), DataType::Long(99)],
1593 ];
1594 let response = QueryResponse { columns, rows };
1595 let buf = encode_query_response(&response).unwrap();
1596 assert_eq!(decode_query_response(&buf).unwrap(), response);
1597 }
1598
1599 #[test]
1600 fn round_trip_execute_request_body() {
1601 let ops = vec![
1602 TxOp::put(vec![(kw!(:name), DataType::String("alice".into()))]),
1603 TxOp::Add {
1604 entity: EntityRef::Id(42),
1605 attribute: kw!(:age),
1606 value: DataType::Long(30),
1607 },
1608 ];
1609 let request = ExecuteRequest { ops };
1610 let buf = encode_execute_request(&request).unwrap();
1611 assert_eq!(decode_execute_request(&buf).unwrap(), request);
1612 }
1613
1614 #[test]
1615 fn round_trip_tx_result_response_body() {
1616 let now = Utc.timestamp_opt(1_700_000_000, 0).unwrap();
1617 for (status, err) in [(0u8, None), (1u8, Some("boom".to_string()))] {
1618 let response = TxResultResponse {
1619 status,
1620 tx_id: 7,
1621 system_time: now,
1622 error_message: err,
1623 };
1624 let buf = encode_tx_result_response(&response).unwrap();
1625 assert_eq!(decode_tx_result_response(&buf).unwrap(), response);
1626 }
1627 }
1628
1629 #[test]
1630 fn round_trip_error_body() {
1631 let response = ErrorResponseBody {
1632 severity: b'E',
1633 code: 2001,
1634 message: "parse error".into(),
1635 detail: Some("near token X".into()),
1636 hint: None,
1637 };
1638 let buf = encode_error_body(&response).unwrap();
1639 assert_eq!(decode_error_body(&buf).unwrap(), response);
1640
1641 let response = ErrorResponseBody {
1642 severity: b'F',
1643 code: 1000,
1644 message: "fatal".into(),
1645 detail: None,
1646 hint: None,
1647 };
1648 let buf = encode_error_body(&response).unwrap();
1649 assert_eq!(decode_error_body(&buf).unwrap(), response);
1650 }
1651
1652 #[test]
1653 fn keyword_wire_format_strips_leading_colon() {
1654 let ns = kw!(:person/name);
1655 let mut buf = Vec::new();
1656 write_data_type(&mut buf, &DataType::Keyword(ns)).unwrap();
1657 assert_eq!(buf[0], 0xc7);
1661 assert_eq!(buf[1], 11);
1662 assert_eq!(buf[2] as i8, EXT_KEYWORD);
1663 assert_eq!(&buf[3..], b"person/name");
1664 }
1665
1666 fn pack(v: Value) -> Vec<u8> {
1671 let mut buf = Vec::new();
1672 rmpv::encode::write_value(&mut buf, &v).unwrap();
1673 buf
1674 }
1675
1676 #[test]
1677 fn decode_open_db_rejects_non_map_body() {
1678 assert!(decode_open_db_request(&pack(Value::Integer(1.into()))).is_err());
1681 }
1682
1683 #[test]
1684 fn decode_tx_key_rejects_wrong_type() {
1685 let body = pack(Value::Map(vec![
1687 (Value::String("tx_id".into()), Value::String("nope".into())),
1688 (
1689 Value::String("system_time".into()),
1690 Value::Ext(EXT_TIMESTAMP, vec![0, 0, 0, 0]),
1691 ),
1692 ]));
1693 assert!(decode_tx_key(&body).is_err());
1694 }
1695
1696 #[test]
1697 fn decode_tx_key_rejects_missing_system_time() {
1698 let body = pack(Value::Map(vec![(
1699 Value::String("tx_id".into()),
1700 Value::Integer(1.into()),
1701 )]));
1702 assert!(decode_tx_key(&body).is_err());
1703 }
1704
1705 #[test]
1706 fn decode_tx_result_rejects_status_overflow() {
1707 let body = pack(Value::Map(vec![
1709 (Value::String("status".into()), Value::Integer(256.into())),
1710 (Value::String("tx_id".into()), Value::Integer(1.into())),
1711 (
1712 Value::String("system_time".into()),
1713 Value::Ext(EXT_TIMESTAMP, vec![0, 0, 0, 0]),
1714 ),
1715 (Value::String("error_message".into()), Value::Nil),
1716 ]));
1717 assert!(decode_tx_result_response(&body).is_err());
1718 }
1719
1720 #[test]
1721 fn decode_error_rejects_invalid_severity() {
1722 let body = pack(Value::Map(vec![
1723 (Value::String("severity".into()), Value::String("Q".into())),
1724 (Value::String("code".into()), Value::Integer(1.into())),
1725 (Value::String("message".into()), Value::String("x".into())),
1726 (Value::String("detail".into()), Value::Nil),
1727 (Value::String("hint".into()), Value::Nil),
1728 ]));
1729 assert!(decode_error_body(&body).is_err());
1730 }
1731
1732 #[test]
1733 fn encode_error_rejects_invalid_severity() {
1734 let r = encode_error_body(&ErrorResponseBody {
1735 severity: b'X',
1736 code: 1,
1737 message: "x".into(),
1738 detail: None,
1739 hint: None,
1740 });
1741 assert!(r.is_err(), "expected bail on invalid severity");
1742 }
1743
1744 #[test]
1745 fn decode_unknown_msgpack_ext_fails() {
1746 let body = pack(Value::Ext(99, vec![0; 4]));
1748 assert!(read_data_type(&body).is_err());
1749 }
1750
1751 #[test]
1752 fn decode_data_type_rejects_nil() {
1753 let body = pack(Value::Nil);
1754 assert!(read_data_type(&body).is_err());
1755 }
1756
1757 #[test]
1758 fn decode_bigint_rejects_wrong_payload_length() {
1759 let body = pack(Value::Ext(EXT_BIGINT, vec![0; 8]));
1761 assert!(read_data_type(&body).is_err());
1762 let body = pack(Value::Ext(EXT_UUID, vec![0; 4]));
1763 assert!(read_data_type(&body).is_err());
1764 }
1765
1766 #[test]
1767 fn decode_keyword_rejects_empty_or_partial() {
1768 let body = pack(Value::Ext(EXT_KEYWORD, b"".to_vec()));
1770 assert!(read_data_type(&body).is_err());
1771 let body = pack(Value::Ext(EXT_KEYWORD, b"/name".to_vec()));
1773 assert!(read_data_type(&body).is_err());
1774 }
1775
1776 #[test]
1777 fn decode_tagged_union_rejects_unknown_kind() {
1778 let body = pack(Value::Map(vec![(
1779 Value::String("kind".into()),
1780 Value::String("xyzzy".into()),
1781 )]));
1782 assert!(entity_ref_from_value(rmpv::decode::read_value(&mut &body[..]).unwrap()).is_err());
1783 assert!(tx_op_from_value(rmpv::decode::read_value(&mut &body[..]).unwrap()).is_err());
1784 assert!(query_arg_from_value(rmpv::decode::read_value(&mut &body[..]).unwrap()).is_err());
1785 }
1786
1787 #[test]
1788 fn decode_body_rejects_trailing_bytes() {
1789 let one = encode_open_db_request(&OpenDbRequest {
1791 tx_id: None,
1792 system_time: None,
1793 })
1794 .unwrap();
1795 let mut two = one.clone();
1796 two.extend_from_slice(&one);
1797 assert!(decode_open_db_request(&two).is_err());
1798 }
1799
1800 #[test]
1801 fn decode_tagged_union_accepts_any_key_order() {
1802 let body = pack(Value::Map(vec![
1805 (Value::String("id".into()), Value::Integer(42.into())),
1806 (Value::String("kind".into()), Value::String("id".into())),
1807 ]));
1808 let er = entity_ref_from_value(rmpv::decode::read_value(&mut &body[..]).unwrap()).unwrap();
1809 assert_eq!(er, EntityRef::Id(42));
1810 }
1811}