Skip to main content

modelvault_core/record/
scalar.rs

1//! Typed scalar values for record payloads (v1).
2
3use crate::error::{DbError, FormatError};
4use crate::schema::Type;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum ScalarValue {
8    Bool(bool),
9    Int64(i64),
10    Uint64(u64),
11    Float64(f64),
12    String(String),
13    Bytes(Vec<u8>),
14    Uuid([u8; 16]),
15    /// Unix microseconds (same convention as elsewhere in ModelVault).
16    Timestamp(i64),
17}
18
19impl ScalarValue {
20    /// Canonical bytes for indexing (last insert wins per key).
21    pub fn canonical_key_bytes(&self) -> Vec<u8> {
22        match self {
23            ScalarValue::Bool(b) => vec![0, if *b { 1 } else { 0 }],
24            ScalarValue::Int64(v) => v.to_le_bytes().to_vec(),
25            ScalarValue::Uint64(v) => v.to_le_bytes().to_vec(),
26            ScalarValue::Float64(v) => v.to_le_bytes().to_vec(),
27            ScalarValue::String(s) => s.as_bytes().to_vec(),
28            ScalarValue::Bytes(b) => b.clone(),
29            ScalarValue::Uuid(u) => u.to_vec(),
30            ScalarValue::Timestamp(t) => t.to_le_bytes().to_vec(),
31        }
32    }
33
34    pub fn ty_matches(&self, ty: &Type) -> bool {
35        matches!(
36            (self, ty),
37            (ScalarValue::Bool(_), Type::Bool)
38                | (ScalarValue::Int64(_), Type::Int64)
39                | (ScalarValue::Uint64(_), Type::Uint64)
40                | (ScalarValue::Float64(_), Type::Float64)
41                | (ScalarValue::String(_), Type::String)
42                | (ScalarValue::Bytes(_), Type::Bytes)
43                | (ScalarValue::Uuid(_), Type::Uuid)
44                | (ScalarValue::Timestamp(_), Type::Timestamp)
45        )
46    }
47}
48
49/// Encode a scalar with a leading type tag (must match `ty`).
50pub fn encode_tagged_scalar(out: &mut Vec<u8>, v: &ScalarValue, ty: &Type) -> Result<(), DbError> {
51    match (v, ty) {
52        (ScalarValue::Bool(b), Type::Bool) => {
53            out.push(0);
54            out.push(if *b { 1 } else { 0 });
55        }
56        (ScalarValue::Int64(n), Type::Int64) => {
57            out.push(1);
58            out.extend_from_slice(&n.to_le_bytes());
59        }
60        (ScalarValue::Uint64(n), Type::Uint64) => {
61            out.push(2);
62            out.extend_from_slice(&n.to_le_bytes());
63        }
64        (ScalarValue::Float64(n), Type::Float64) => {
65            out.push(3);
66            out.extend_from_slice(&n.to_le_bytes());
67        }
68        (ScalarValue::String(s), Type::String) => {
69            out.push(4);
70            let b = s.as_bytes();
71            out.extend_from_slice(&(b.len() as u32).to_le_bytes());
72            out.extend_from_slice(b);
73        }
74        (ScalarValue::Bytes(b), Type::Bytes) => {
75            out.push(5);
76            out.extend_from_slice(&(b.len() as u32).to_le_bytes());
77            out.extend_from_slice(b);
78        }
79        (ScalarValue::Uuid(u), Type::Uuid) => {
80            out.push(6);
81            out.extend_from_slice(u);
82        }
83        (ScalarValue::Timestamp(t), Type::Timestamp) => {
84            out.push(7);
85            out.extend_from_slice(&t.to_le_bytes());
86        }
87        _ => return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch)),
88    }
89    Ok(())
90}
91
92pub struct Cursor<'a> {
93    pub bytes: &'a [u8],
94    pub pos: usize,
95}
96
97impl<'a> Cursor<'a> {
98    pub fn new(bytes: &'a [u8]) -> Self {
99        Self { bytes, pos: 0 }
100    }
101
102    pub fn remaining(&self) -> usize {
103        self.bytes.len().saturating_sub(self.pos)
104    }
105
106    pub fn take_u16(&mut self) -> Result<u16, DbError> {
107        if self.remaining() < 2 {
108            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
109        }
110        let v = u16::from_le_bytes([self.bytes[self.pos], self.bytes[self.pos + 1]]);
111        self.pos += 2;
112        Ok(v)
113    }
114
115    pub fn take_u8(&mut self) -> Result<u8, DbError> {
116        if self.pos >= self.bytes.len() {
117            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
118        }
119        let b = self.bytes[self.pos];
120        self.pos += 1;
121        Ok(b)
122    }
123
124    pub fn take_u32(&mut self) -> Result<u32, DbError> {
125        if self.remaining() < 4 {
126            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
127        }
128        let v = u32::from_le_bytes([
129            self.bytes[self.pos],
130            self.bytes[self.pos + 1],
131            self.bytes[self.pos + 2],
132            self.bytes[self.pos + 3],
133        ]);
134        self.pos += 4;
135        Ok(v)
136    }
137
138    pub fn take_i64(&mut self) -> Result<i64, DbError> {
139        if self.remaining() < 8 {
140            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
141        }
142        let v = i64::from_le_bytes([
143            self.bytes[self.pos],
144            self.bytes[self.pos + 1],
145            self.bytes[self.pos + 2],
146            self.bytes[self.pos + 3],
147            self.bytes[self.pos + 4],
148            self.bytes[self.pos + 5],
149            self.bytes[self.pos + 6],
150            self.bytes[self.pos + 7],
151        ]);
152        self.pos += 8;
153        Ok(v)
154    }
155
156    pub fn take_u64(&mut self) -> Result<u64, DbError> {
157        if self.remaining() < 8 {
158            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
159        }
160        let v = u64::from_le_bytes([
161            self.bytes[self.pos],
162            self.bytes[self.pos + 1],
163            self.bytes[self.pos + 2],
164            self.bytes[self.pos + 3],
165            self.bytes[self.pos + 4],
166            self.bytes[self.pos + 5],
167            self.bytes[self.pos + 6],
168            self.bytes[self.pos + 7],
169        ]);
170        self.pos += 8;
171        Ok(v)
172    }
173
174    pub fn take_f64(&mut self) -> Result<f64, DbError> {
175        Ok(f64::from_bits(self.take_u64()?))
176    }
177
178    pub fn take_bytes(&mut self, n: usize) -> Result<Vec<u8>, DbError> {
179        if self.remaining() < n {
180            return Err(DbError::Format(FormatError::TruncatedRecordPayload));
181        }
182        let slice = &self.bytes[self.pos..self.pos + n];
183        self.pos += n;
184        Ok(slice.to_vec())
185    }
186}
187
188pub fn decode_tagged_scalar(cur: &mut Cursor<'_>, ty: &Type) -> Result<ScalarValue, DbError> {
189    let tag = cur.take_u8()?;
190    Ok(match ty {
191        Type::Bool => {
192            if tag != 0 {
193                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
194            }
195            let b = cur.take_u8()?;
196            ScalarValue::Bool(b != 0)
197        }
198        Type::Int64 => {
199            if tag != 1 {
200                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
201            }
202            ScalarValue::Int64(cur.take_i64()?)
203        }
204        Type::Uint64 => {
205            if tag != 2 {
206                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
207            }
208            ScalarValue::Uint64(cur.take_u64()?)
209        }
210        Type::Float64 => {
211            if tag != 3 {
212                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
213            }
214            ScalarValue::Float64(cur.take_f64()?)
215        }
216        Type::String => {
217            if tag != 4 {
218                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
219            }
220            let n = cur.take_u32()? as usize;
221            let b = cur.take_bytes(n)?;
222            ScalarValue::String(
223                String::from_utf8(b)
224                    .map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))?,
225            )
226        }
227        Type::Bytes => {
228            if tag != 5 {
229                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
230            }
231            let n = cur.take_u32()? as usize;
232            ScalarValue::Bytes(cur.take_bytes(n)?)
233        }
234        Type::Uuid => {
235            if tag != 6 {
236                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
237            }
238            let b = cur.take_bytes(16)?;
239            let mut a = [0u8; 16];
240            a.copy_from_slice(&b);
241            ScalarValue::Uuid(a)
242        }
243        Type::Timestamp => {
244            if tag != 7 {
245                return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
246            }
247            ScalarValue::Timestamp(cur.take_i64()?)
248        }
249        _ => return Err(DbError::Format(FormatError::RecordPayloadUnsupportedType)),
250    })
251}
252
253pub fn decode_tagged_string(cur: &mut Cursor<'_>) -> Result<String, DbError> {
254    let tag = cur.take_u8()?;
255    if tag != 4 {
256        return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
257    }
258    let n = cur.take_u32()? as usize;
259    let b = cur.take_bytes(n)?;
260    String::from_utf8(b).map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))
261}
262
263#[cfg(test)]
264mod tests {
265    include!(concat!(
266        env!("CARGO_MANIFEST_DIR"),
267        "/tests/unit/src_record_scalar_tests.rs"
268    ));
269}