Skip to main content

modelvault_core/record/
scalar.rs

1//! Typed scalar values for record payloads (v1).
2
3use crate::error::{DbError, FormatError};
4use crate::file_format::check_field_bytes_len;
5use crate::schema::Type;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum ScalarValue {
9    Bool(bool),
10    Int64(i64),
11    Uint64(u64),
12    Float64(f64),
13    String(String),
14    Bytes(Vec<u8>),
15    Uuid([u8; 16]),
16    /// Unix microseconds (same convention as elsewhere in ModelVault).
17    Timestamp(i64),
18}
19
20impl ScalarValue {
21    /// Canonical bytes for indexing (last insert wins per key).
22    pub fn canonical_key_bytes(&self) -> Vec<u8> {
23        match self {
24            ScalarValue::Bool(b) => vec![0, if *b { 1 } else { 0 }],
25            ScalarValue::Int64(v) => v.to_le_bytes().to_vec(),
26            ScalarValue::Uint64(v) => v.to_le_bytes().to_vec(),
27            ScalarValue::Float64(v) => {
28                let n = if *v == 0.0 { 0.0f64 } else { *v };
29                n.to_le_bytes().to_vec()
30            }
31            ScalarValue::String(s) => s.as_bytes().to_vec(),
32            ScalarValue::Bytes(b) => b.clone(),
33            ScalarValue::Uuid(u) => u.to_vec(),
34            ScalarValue::Timestamp(t) => t.to_le_bytes().to_vec(),
35        }
36    }
37
38    pub fn ty_matches(&self, ty: &Type) -> bool {
39        matches!(
40            (self, ty),
41            (ScalarValue::Bool(_), Type::Bool)
42                | (ScalarValue::Int64(_), Type::Int64)
43                | (ScalarValue::Uint64(_), Type::Uint64)
44                | (ScalarValue::Float64(_), Type::Float64)
45                | (ScalarValue::String(_), Type::String)
46                | (ScalarValue::Bytes(_), Type::Bytes)
47                | (ScalarValue::Uuid(_), Type::Uuid)
48                | (ScalarValue::Timestamp(_), Type::Timestamp)
49        )
50    }
51}
52
53/// Encode a scalar with a leading type tag (must match `ty`).
54pub fn encode_tagged_scalar(out: &mut Vec<u8>, v: &ScalarValue, ty: &Type) -> Result<(), DbError> {
55    match (v, ty) {
56        (ScalarValue::Bool(b), Type::Bool) => {
57            out.push(0);
58            out.push(if *b { 1 } else { 0 });
59        }
60        (ScalarValue::Int64(n), Type::Int64) => {
61            out.push(1);
62            out.extend_from_slice(&n.to_le_bytes());
63        }
64        (ScalarValue::Uint64(n), Type::Uint64) => {
65            out.push(2);
66            out.extend_from_slice(&n.to_le_bytes());
67        }
68        (ScalarValue::Float64(n), Type::Float64) => {
69            out.push(3);
70            out.extend_from_slice(&n.to_le_bytes());
71        }
72        (ScalarValue::String(s), Type::String) => {
73            out.push(4);
74            let b = s.as_bytes();
75            check_field_bytes_len(b.len())?;
76            out.extend_from_slice(&(b.len() as u32).to_le_bytes());
77            out.extend_from_slice(b);
78        }
79        (ScalarValue::Bytes(b), Type::Bytes) => {
80            out.push(5);
81            check_field_bytes_len(b.len())?;
82            out.extend_from_slice(&(b.len() as u32).to_le_bytes());
83            out.extend_from_slice(b);
84        }
85        (ScalarValue::Uuid(u), Type::Uuid) => {
86            out.push(6);
87            out.extend_from_slice(u);
88        }
89        (ScalarValue::Timestamp(t), Type::Timestamp) => {
90            out.push(7);
91            out.extend_from_slice(&t.to_le_bytes());
92        }
93        _ => return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch)),
94    }
95    Ok(())
96}
97
98pub struct Cursor<'a> {
99    pub bytes: &'a [u8],
100    pub pos: usize,
101}
102
103impl<'a> Cursor<'a> {
104    pub fn new(bytes: &'a [u8]) -> Self {
105        Self { bytes, pos: 0 }
106    }
107
108    pub fn remaining(&self) -> usize {
109        self.bytes.len().saturating_sub(self.pos)
110    }
111
112    pub fn take_u16(&mut self) -> Result<u16, DbError> {
113        if self.remaining() < 2 {
114            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
115        }
116        let v = u16::from_le_bytes([self.bytes[self.pos], self.bytes[self.pos + 1]]);
117        self.pos += 2;
118        Ok(v)
119    }
120
121    pub fn take_u8(&mut self) -> Result<u8, DbError> {
122        if self.pos >= self.bytes.len() {
123            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
124        }
125        let b = self.bytes[self.pos];
126        self.pos += 1;
127        Ok(b)
128    }
129
130    pub fn take_u32(&mut self) -> Result<u32, DbError> {
131        if self.remaining() < 4 {
132            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
133        }
134        let v = u32::from_le_bytes([
135            self.bytes[self.pos],
136            self.bytes[self.pos + 1],
137            self.bytes[self.pos + 2],
138            self.bytes[self.pos + 3],
139        ]);
140        self.pos += 4;
141        Ok(v)
142    }
143
144    pub fn take_i64(&mut self) -> Result<i64, DbError> {
145        if self.remaining() < 8 {
146            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
147        }
148        let v = i64::from_le_bytes([
149            self.bytes[self.pos],
150            self.bytes[self.pos + 1],
151            self.bytes[self.pos + 2],
152            self.bytes[self.pos + 3],
153            self.bytes[self.pos + 4],
154            self.bytes[self.pos + 5],
155            self.bytes[self.pos + 6],
156            self.bytes[self.pos + 7],
157        ]);
158        self.pos += 8;
159        Ok(v)
160    }
161
162    pub fn take_u64(&mut self) -> Result<u64, DbError> {
163        if self.remaining() < 8 {
164            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
165        }
166        let v = u64::from_le_bytes([
167            self.bytes[self.pos],
168            self.bytes[self.pos + 1],
169            self.bytes[self.pos + 2],
170            self.bytes[self.pos + 3],
171            self.bytes[self.pos + 4],
172            self.bytes[self.pos + 5],
173            self.bytes[self.pos + 6],
174            self.bytes[self.pos + 7],
175        ]);
176        self.pos += 8;
177        Ok(v)
178    }
179
180    pub fn take_f64(&mut self) -> Result<f64, DbError> {
181        Ok(f64::from_bits(self.take_u64()?))
182    }
183
184    pub fn take_bytes(&mut self, n: usize) -> Result<Vec<u8>, DbError> {
185        check_field_bytes_len(n)?;
186        if self.remaining() < n {
187            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
188        }
189        let slice = &self.bytes[self.pos..self.pos + n];
190        self.pos += n;
191        Ok(slice.to_vec())
192    }
193}
194
195pub fn decode_tagged_scalar(cur: &mut Cursor<'_>, ty: &Type) -> Result<ScalarValue, DbError> {
196    let tag = cur.take_u8()?;
197    Ok(match ty {
198        Type::Bool => {
199            if tag != 0 {
200                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
201            }
202            let b = cur.take_u8()?;
203            ScalarValue::Bool(b != 0)
204        }
205        Type::Int64 => {
206            if tag != 1 {
207                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
208            }
209            ScalarValue::Int64(cur.take_i64()?)
210        }
211        Type::Uint64 => {
212            if tag != 2 {
213                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
214            }
215            ScalarValue::Uint64(cur.take_u64()?)
216        }
217        Type::Float64 => {
218            if tag != 3 {
219                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
220            }
221            ScalarValue::Float64(cur.take_f64()?)
222        }
223        Type::String => {
224            if tag != 4 {
225                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
226            }
227            let n = cur.take_u32()? as usize;
228            let b = cur.take_bytes(n)?;
229            ScalarValue::String(
230                String::from_utf8(b)
231                    .map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))?,
232            )
233        }
234        Type::Bytes => {
235            if tag != 5 {
236                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
237            }
238            let n = cur.take_u32()? as usize;
239            ScalarValue::Bytes(cur.take_bytes(n)?)
240        }
241        Type::Uuid => {
242            if tag != 6 {
243                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
244            }
245            let b = cur.take_bytes(16)?;
246            let mut a = [0u8; 16];
247            a.copy_from_slice(&b);
248            ScalarValue::Uuid(a)
249        }
250        Type::Timestamp => {
251            if tag != 7 {
252                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
253            }
254            ScalarValue::Timestamp(cur.take_i64()?)
255        }
256        _ => return Err(DbError::Format(FormatError::RecordPayloadUnsupportedType)),
257    })
258}
259
260pub fn decode_tagged_string(cur: &mut Cursor<'_>) -> Result<String, DbError> {
261    let tag = cur.take_u8()?;
262    if tag != 4 {
263        return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
264    }
265    let n = cur.take_u32()? as usize;
266    let b = cur.take_bytes(n)?;
267    String::from_utf8(b).map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))
268}
269
270#[cfg(test)]
271mod tests {
272    include!(concat!(
273        env!("CARGO_MANIFEST_DIR"),
274        "/tests/unit/src_record_scalar_tests.rs"
275    ));
276}