cozo_ce/data/
value.rs

1/*
2 * Copyright 2022, The Cozo Project Authors.
3 *
4 * This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
5 * If a copy of the MPL was not distributed with this file,
6 * You can obtain one at https://mozilla.org/MPL/2.0/.
7 */
8
9use base64::engine::general_purpose::STANDARD;
10use base64::Engine;
11use ndarray::Array1;
12use std::cmp::{Ordering, Reverse};
13use std::collections::BTreeSet;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::ops::Deref;
17
18use crate::data::json::JsonValue;
19use crate::data::relation::VecElementType;
20use ordered_float::OrderedFloat;
21use regex::Regex;
22use serde::de::{SeqAccess, Visitor};
23use serde::ser::SerializeTuple;
24use serde::{Deserialize, Deserializer, Serialize, Serializer};
25use sha2::digest::FixedOutput;
26use sha2::{Digest, Sha256};
27use smartstring::{LazyCompact, SmartString};
28use uuid::Uuid;
29
30/// UUID value in the database
31#[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)]
32pub struct UuidWrapper(pub Uuid);
33
34impl PartialOrd<Self> for UuidWrapper {
35    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
36        Some(self.cmp(other))
37    }
38}
39
40impl Ord for UuidWrapper {
41    fn cmp(&self, other: &Self) -> Ordering {
42        let (s_l, s_m, s_h, s_rest) = self.0.as_fields();
43        let (o_l, o_m, o_h, o_rest) = other.0.as_fields();
44        s_h.cmp(&o_h)
45            .then_with(|| s_m.cmp(&o_m))
46            .then_with(|| s_l.cmp(&o_l))
47            .then_with(|| s_rest.cmp(o_rest))
48    }
49}
50
51/// A Regex in the database. Used internally in functions.
52#[derive(Clone)]
53pub struct RegexWrapper(pub Regex);
54
55impl Hash for RegexWrapper {
56    fn hash<H: Hasher>(&self, state: &mut H) {
57        self.0.as_str().hash(state)
58    }
59}
60
61impl Serialize for RegexWrapper {
62    fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
63    where
64        S: serde::Serializer,
65    {
66        panic!("serializing regex");
67    }
68}
69
70impl<'de> Deserialize<'de> for RegexWrapper {
71    fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
72    where
73        D: Deserializer<'de>,
74    {
75        panic!("deserializing regex");
76    }
77}
78
79impl PartialEq for RegexWrapper {
80    fn eq(&self, other: &Self) -> bool {
81        self.0.as_str() == other.0.as_str()
82    }
83}
84
85impl Eq for RegexWrapper {}
86
87impl Ord for RegexWrapper {
88    fn cmp(&self, other: &Self) -> Ordering {
89        self.0.as_str().cmp(other.0.as_str())
90    }
91}
92
93impl PartialOrd for RegexWrapper {
94    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
95        self.0.as_str().partial_cmp(other.0.as_str())
96    }
97}
98
99/// Timestamp part of validity
100#[derive(
101    Copy,
102    Clone,
103    Eq,
104    PartialEq,
105    Ord,
106    PartialOrd,
107    serde_derive::Deserialize,
108    serde_derive::Serialize,
109    Hash,
110    Debug,
111)]
112pub struct ValidityTs(pub Reverse<i64>);
113
114/// Validity for time travel
115#[derive(
116    Copy,
117    Clone,
118    Eq,
119    PartialEq,
120    Ord,
121    PartialOrd,
122    serde_derive::Deserialize,
123    serde_derive::Serialize,
124    Hash,
125)]
126pub struct Validity {
127    /// Timestamp, sorted descendingly
128    pub timestamp: ValidityTs,
129    /// Whether this validity is an assertion, sorted descendingly
130    pub is_assert: Reverse<bool>,
131}
132
133impl From<(i64, bool)> for Validity {
134    fn from(value: (i64, bool)) -> Self {
135        Self {
136            timestamp: ValidityTs(Reverse(value.0)),
137            is_assert: Reverse(value.1),
138        }
139    }
140}
141
142/// A Value in the database
143#[derive(
144    Clone, PartialEq, Eq, PartialOrd, Ord, serde_derive::Deserialize, serde_derive::Serialize, Hash,
145)]
146pub enum DataValue {
147    /// null
148    Null,
149    /// boolean
150    Bool(bool),
151    /// number, may be int or float
152    Num(Num),
153    /// string
154    Str(SmartString<LazyCompact>),
155    /// bytes
156    #[serde(with = "serde_bytes")]
157    Bytes(Vec<u8>),
158    /// UUID
159    Uuid(UuidWrapper),
160    /// Regex, used internally only
161    Regex(RegexWrapper),
162    /// list
163    List(Vec<DataValue>),
164    /// set, used internally only
165    Set(BTreeSet<DataValue>),
166    /// Array, mainly for proximity search
167    Vec(Vector),
168    /// Json
169    Json(JsonData),
170    /// validity,
171    Validity(Validity),
172    /// bottom type, used internally only
173    Bot,
174}
175
176/// Wrapper for JsonValue
177#[derive(Clone, PartialEq, Eq, serde_derive::Deserialize, serde_derive::Serialize)]
178pub struct JsonData(pub JsonValue);
179
180impl PartialOrd<Self> for JsonData {
181    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
182        Some(self.cmp(other))
183    }
184}
185
186impl Ord for JsonData {
187    fn cmp(&self, other: &Self) -> Ordering {
188        self.0.to_string().cmp(&other.0.to_string())
189    }
190}
191
192impl Deref for JsonData {
193    type Target = JsonValue;
194
195    fn deref(&self) -> &Self::Target {
196        &self.0
197    }
198}
199
200impl Hash for JsonData {
201    fn hash<H: Hasher>(&self, state: &mut H) {
202        self.0.to_string().hash(state)
203    }
204}
205
206/// Vector of floating numbers
207#[derive(Debug, Clone)]
208pub enum Vector {
209    /// 32-bit float array
210    F32(Array1<f32>),
211    /// 64-bit float array
212    F64(Array1<f64>),
213}
214
215struct VecBytes<'a>(&'a [u8]);
216
217impl serde::Serialize for VecBytes<'_> {
218    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
219    where
220        S: Serializer,
221    {
222        serializer.serialize_bytes(self.0)
223    }
224}
225
226impl serde::Serialize for Vector {
227    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
228    where
229        S: Serializer,
230    {
231        let mut state = serializer.serialize_tuple(2)?;
232        match self {
233            Vector::F32(a) => {
234                state.serialize_element(&0u8)?;
235                let arr = a.as_slice().unwrap();
236                let len = std::mem::size_of_val(arr);
237                let ptr = arr.as_ptr() as *const u8;
238                let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
239                state.serialize_element(&VecBytes(bytes))?;
240            }
241            Vector::F64(a) => {
242                state.serialize_element(&1u8)?;
243                let arr = a.as_slice().unwrap();
244                let len = std::mem::size_of_val(arr);
245                let ptr = arr.as_ptr() as *const u8;
246                let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
247                state.serialize_element(&VecBytes(bytes))?;
248            }
249        }
250        state.end()
251    }
252}
253
254impl<'de> serde::Deserialize<'de> for Vector {
255    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
256    where
257        D: Deserializer<'de>,
258    {
259        deserializer.deserialize_tuple(2, VectorVisitor)
260    }
261}
262
263struct VectorVisitor;
264
265impl<'de> Visitor<'de> for VectorVisitor {
266    type Value = Vector;
267
268    fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
269        formatter.write_str("vector representation")
270    }
271    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
272    where
273        A: SeqAccess<'de>,
274    {
275        let tag: u8 = seq
276            .next_element()?
277            .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
278        let bytes: &[u8] = seq
279            .next_element()?
280            .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
281        match tag {
282            0u8 => {
283                let len = bytes.len() / std::mem::size_of::<f32>();
284                let mut v = vec![];
285                v.reserve_exact(len);
286                let ptr = v.as_mut_ptr() as *mut u8;
287                unsafe {
288                    std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len());
289                    v.set_len(len);
290                }
291                Ok(Vector::F32(Array1::from(v)))
292            }
293            1u8 => {
294                let len = bytes.len() / std::mem::size_of::<f64>();
295                let mut v = vec![];
296                v.reserve_exact(len);
297                let ptr = v.as_mut_ptr() as *mut u8;
298                unsafe {
299                    std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len());
300                    v.set_len(len);
301                }
302                Ok(Vector::F64(Array1::from(v)))
303            }
304            _ => Err(serde::de::Error::invalid_value(
305                serde::de::Unexpected::Unsigned(tag as u64),
306                &self,
307            )),
308        }
309    }
310}
311
312impl Vector {
313    /// Get the length of the vector
314    pub fn len(&self) -> usize {
315        match self {
316            Vector::F32(v) => v.len(),
317            Vector::F64(v) => v.len(),
318        }
319    }
320    /// Check if the vector is empty
321    pub fn is_empty(&self) -> bool {
322        match self {
323            Vector::F32(v) => v.is_empty(),
324            Vector::F64(v) => v.is_empty(),
325        }
326    }
327    pub(crate) fn el_type(&self) -> VecElementType {
328        match self {
329            Vector::F32(_) => VecElementType::F32,
330            Vector::F64(_) => VecElementType::F64,
331        }
332    }
333    pub(crate) fn get_hash(&self) -> impl AsRef<[u8]> {
334        let mut hasher = Sha256::new();
335        match self {
336            Vector::F32(v) => {
337                for e in v.iter() {
338                    hasher.update(e.to_le_bytes());
339                }
340            }
341            Vector::F64(v) => {
342                for e in v.iter() {
343                    hasher.update(e.to_le_bytes());
344                }
345            }
346        }
347        hasher.finalize_fixed()
348    }
349}
350
351impl PartialEq<Self> for Vector {
352    fn eq(&self, other: &Self) -> bool {
353        match (self, other) {
354            (Vector::F32(l), Vector::F32(r)) => {
355                if l.len() != r.len() {
356                    return false;
357                }
358                for (le, re) in l.iter().zip(r) {
359                    if !OrderedFloat(*le).eq(&OrderedFloat(*re)) {
360                        return false;
361                    }
362                }
363                true
364            }
365            (Vector::F64(l), Vector::F64(r)) => {
366                if l.len() != r.len() {
367                    return false;
368                }
369                for (le, re) in l.iter().zip(r) {
370                    if !OrderedFloat(*le).eq(&OrderedFloat(*re)) {
371                        return false;
372                    }
373                }
374                true
375            }
376            _ => false,
377        }
378    }
379}
380
381impl Eq for Vector {}
382
383impl PartialOrd for Vector {
384    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
385        Some(self.cmp(other))
386    }
387}
388
389impl Ord for Vector {
390    fn cmp(&self, other: &Self) -> Ordering {
391        match (self, other) {
392            (Vector::F32(l), Vector::F32(r)) => {
393                match l.len().cmp(&r.len()) {
394                    Ordering::Equal => (),
395                    o => return o,
396                }
397                for (le, re) in l.iter().zip(r) {
398                    match OrderedFloat(*le).cmp(&OrderedFloat(*re)) {
399                        Ordering::Equal => continue,
400                        o => return o,
401                    }
402                }
403                Ordering::Equal
404            }
405            (Vector::F32(_), Vector::F64(_)) => Ordering::Less,
406            (Vector::F64(l), Vector::F64(r)) => {
407                match l.len().cmp(&r.len()) {
408                    Ordering::Equal => (),
409                    o => return o,
410                }
411                for (le, re) in l.iter().zip(r) {
412                    match OrderedFloat(*le).cmp(&OrderedFloat(*re)) {
413                        Ordering::Equal => continue,
414                        o => return o,
415                    }
416                }
417                Ordering::Equal
418            }
419            (Vector::F64(_), Vector::F32(_)) => Ordering::Greater,
420        }
421    }
422}
423
424impl Hash for Vector {
425    fn hash<H: Hasher>(&self, state: &mut H) {
426        match self {
427            Vector::F32(a) => {
428                for el in a {
429                    OrderedFloat(*el).hash(state)
430                }
431            }
432            Vector::F64(a) => {
433                for el in a {
434                    OrderedFloat(*el).hash(state)
435                }
436            }
437        }
438    }
439}
440
441impl From<bool> for DataValue {
442    fn from(value: bool) -> Self {
443        DataValue::Bool(value)
444    }
445}
446
447impl From<i64> for DataValue {
448    fn from(v: i64) -> Self {
449        DataValue::Num(Num::Int(v))
450    }
451}
452
453impl From<i32> for DataValue {
454    fn from(v: i32) -> Self {
455        DataValue::Num(Num::Int(v as i64))
456    }
457}
458
459impl From<f64> for DataValue {
460    fn from(v: f64) -> Self {
461        DataValue::Num(Num::Float(v))
462    }
463}
464
465impl From<&str> for DataValue {
466    fn from(v: &str) -> Self {
467        DataValue::Str(SmartString::from(v))
468    }
469}
470
471impl From<String> for DataValue {
472    fn from(v: String) -> Self {
473        DataValue::Str(SmartString::from(v))
474    }
475}
476
477impl From<Vec<u8>> for DataValue {
478    fn from(v: Vec<u8>) -> Self {
479        DataValue::Bytes(v)
480    }
481}
482
483impl<T: Into<DataValue>> From<Vec<T>> for DataValue {
484    fn from(v: Vec<T>) -> Self
485    where
486        T: Into<DataValue>,
487    {
488        DataValue::List(v.into_iter().map(Into::into).collect())
489    }
490}
491
492/// Representing a number
493#[derive(Copy, Clone, serde_derive::Deserialize, serde_derive::Serialize)]
494pub enum Num {
495    /// intger number
496    Int(i64),
497    /// float number
498    Float(f64),
499}
500
501impl Hash for Num {
502    fn hash<H: Hasher>(&self, state: &mut H) {
503        match self {
504            Num::Int(i) => i.hash(state),
505            Num::Float(f) => OrderedFloat(*f).hash(state),
506        }
507    }
508}
509
510impl Num {
511    pub(crate) fn get_int(&self) -> Option<i64> {
512        match self {
513            Num::Int(i) => Some(*i),
514            Num::Float(f) => {
515                if f.round() == *f {
516                    Some(*f as i64)
517                } else {
518                    None
519                }
520            }
521        }
522    }
523    pub(crate) fn get_float(&self) -> f64 {
524        match self {
525            Num::Int(i) => *i as f64,
526            Num::Float(f) => *f,
527        }
528    }
529}
530
531impl PartialEq for Num {
532    fn eq(&self, other: &Self) -> bool {
533        self.cmp(other) == Ordering::Equal
534    }
535}
536
537impl Eq for Num {}
538
539impl Display for Num {
540    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
541        match self {
542            Num::Int(i) => write!(f, "{i}"),
543            Num::Float(n) => {
544                if n.is_nan() {
545                    write!(f, r#"to_float("NAN")"#)
546                } else if n.is_infinite() {
547                    if n.is_sign_negative() {
548                        write!(f, r#"to_float("NEG_INF")"#)
549                    } else {
550                        write!(f, r#"to_float("INF")"#)
551                    }
552                } else {
553                    write!(f, "{n}")
554                }
555            }
556        }
557    }
558}
559
560impl Debug for Num {
561    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
562        match self {
563            Num::Int(i) => write!(f, "{i}"),
564            Num::Float(n) => write!(f, "{n}"),
565        }
566    }
567}
568
569impl PartialOrd for Num {
570    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
571        Some(self.cmp(other))
572    }
573}
574
575impl Ord for Num {
576    fn cmp(&self, other: &Self) -> Ordering {
577        match (self, other) {
578            (Num::Int(i), Num::Float(r)) => {
579                let l = *i as f64;
580                match l.total_cmp(r) {
581                    Ordering::Less => Ordering::Less,
582                    Ordering::Equal => Ordering::Less,
583                    Ordering::Greater => Ordering::Greater,
584                }
585            }
586            (Num::Float(l), Num::Int(i)) => {
587                let r = *i as f64;
588                match l.total_cmp(&r) {
589                    Ordering::Less => Ordering::Less,
590                    Ordering::Equal => Ordering::Greater,
591                    Ordering::Greater => Ordering::Greater,
592                }
593            }
594            (Num::Int(l), Num::Int(r)) => l.cmp(r),
595            (Num::Float(l), Num::Float(r)) => l.total_cmp(r),
596        }
597    }
598}
599
600impl Debug for DataValue {
601    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
602        write!(f, "{self}")
603    }
604}
605
606impl Display for DataValue {
607    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
608        match self {
609            DataValue::Null => f.write_str("null"),
610            DataValue::Bool(b) => write!(f, "{b}"),
611            DataValue::Num(n) => write!(f, "{n}"),
612            DataValue::Str(s) => write!(f, "{s:?}"),
613            DataValue::Bytes(b) => {
614                let bs = STANDARD.encode(b);
615                write!(f, "decode_base64({bs:?})")
616            }
617            DataValue::Uuid(u) => {
618                let us = u.0.to_string();
619                write!(f, "to_uuid({us:?})")
620            }
621            DataValue::Regex(rx) => {
622                write!(f, "regex({:?})", rx.0.as_str())
623            }
624            DataValue::List(ls) => f.debug_list().entries(ls).finish(),
625            DataValue::Set(s) => f.debug_list().entries(s).finish(),
626            DataValue::Bot => write!(f, "null"),
627            DataValue::Validity(v) => f
628                .debug_struct("Validity")
629                .field("timestamp", &v.timestamp.0)
630                .field("retracted", &v.is_assert)
631                .finish(),
632            DataValue::Vec(a) => match a {
633                Vector::F32(a) => {
634                    write!(f, "vec({:?})", a.to_vec())
635                }
636                Vector::F64(a) => {
637                    write!(f, "vec({:?}, \"F64\")", a.to_vec())
638                }
639            },
640            DataValue::Json(j) => {
641                if j.is_object() {
642                    write!(f, "{}", j.0)
643                } else {
644                    write!(f, "json({})", j.0)
645                }
646            }
647        }
648    }
649}
650
651impl DataValue {
652    /// Returns a slice of bytes if this one is a Bytes
653    pub fn get_bytes(&self) -> Option<&[u8]> {
654        match self {
655            DataValue::Bytes(b) => Some(b),
656            _ => None,
657        }
658    }
659    /// Returns a slice of DataValues if this one is a List
660    pub fn get_slice(&self) -> Option<&[DataValue]> {
661        match self {
662            DataValue::List(l) => Some(l),
663            _ => None,
664        }
665    }
666    /// Returns the raw str if this one is a Str
667    pub fn get_str(&self) -> Option<&str> {
668        match self {
669            DataValue::Str(s) => Some(s),
670            _ => None,
671        }
672    }
673    /// Returns int if this one is an int
674    pub fn get_int(&self) -> Option<i64> {
675        match self {
676            DataValue::Num(n) => n.get_int(),
677            _ => None,
678        }
679    }
680    pub(crate) fn get_non_neg_int(&self) -> Option<u64> {
681        match self {
682            DataValue::Num(n) => n
683                .get_int()
684                .and_then(|i| if i < 0 { None } else { Some(i as u64) }),
685            _ => None,
686        }
687    }
688    /// Returns float if this one is.
689    pub fn get_float(&self) -> Option<f64> {
690        match self {
691            DataValue::Num(n) => Some(n.get_float()),
692            _ => None,
693        }
694    }
695    /// Returns bool if this one is.
696    pub fn get_bool(&self) -> Option<bool> {
697        match self {
698            DataValue::Bool(b) => Some(*b),
699            _ => None,
700        }
701    }
702    pub(crate) fn uuid(uuid: Uuid) -> Self {
703        Self::Uuid(UuidWrapper(uuid))
704    }
705    pub(crate) fn get_uuid(&self) -> Option<Uuid> {
706        match self {
707            DataValue::Uuid(UuidWrapper(uuid)) => Some(*uuid),
708            DataValue::Str(s) => uuid::Uuid::try_parse(s).ok(),
709            _ => None,
710        }
711    }
712}
713
714pub(crate) const LARGEST_UTF_CHAR: char = '\u{10ffff}';