modelvault_core/record/
scalar.rs1use crate::error::{DbError, FormatError};
4use crate::file_format::check_field_bytes_len;
5use crate::schema::Type;
6
7#[derive(Debug, Clone, PartialEq)]
8pub enum ScalarValue {
9 Bool(bool),
10 Int64(i64),
11 Uint64(u64),
12 Float64(f64),
13 String(String),
14 Bytes(Vec<u8>),
15 Uuid([u8; 16]),
16 Timestamp(i64),
18}
19
20impl ScalarValue {
21 pub fn canonical_key_bytes(&self) -> Vec<u8> {
23 match self {
24 ScalarValue::Bool(b) => vec![0, if *b { 1 } else { 0 }],
25 ScalarValue::Int64(v) => v.to_le_bytes().to_vec(),
26 ScalarValue::Uint64(v) => v.to_le_bytes().to_vec(),
27 ScalarValue::Float64(v) => {
28 let n = if *v == 0.0 { 0.0f64 } else { *v };
29 n.to_le_bytes().to_vec()
30 }
31 ScalarValue::String(s) => s.as_bytes().to_vec(),
32 ScalarValue::Bytes(b) => b.clone(),
33 ScalarValue::Uuid(u) => u.to_vec(),
34 ScalarValue::Timestamp(t) => t.to_le_bytes().to_vec(),
35 }
36 }
37
38 pub fn ty_matches(&self, ty: &Type) -> bool {
39 matches!(
40 (self, ty),
41 (ScalarValue::Bool(_), Type::Bool)
42 | (ScalarValue::Int64(_), Type::Int64)
43 | (ScalarValue::Uint64(_), Type::Uint64)
44 | (ScalarValue::Float64(_), Type::Float64)
45 | (ScalarValue::String(_), Type::String)
46 | (ScalarValue::Bytes(_), Type::Bytes)
47 | (ScalarValue::Uuid(_), Type::Uuid)
48 | (ScalarValue::Timestamp(_), Type::Timestamp)
49 )
50 }
51}
52
53pub fn encode_tagged_scalar(out: &mut Vec<u8>, v: &ScalarValue, ty: &Type) -> Result<(), DbError> {
55 match (v, ty) {
56 (ScalarValue::Bool(b), Type::Bool) => {
57 out.push(0);
58 out.push(if *b { 1 } else { 0 });
59 }
60 (ScalarValue::Int64(n), Type::Int64) => {
61 out.push(1);
62 out.extend_from_slice(&n.to_le_bytes());
63 }
64 (ScalarValue::Uint64(n), Type::Uint64) => {
65 out.push(2);
66 out.extend_from_slice(&n.to_le_bytes());
67 }
68 (ScalarValue::Float64(n), Type::Float64) => {
69 out.push(3);
70 out.extend_from_slice(&n.to_le_bytes());
71 }
72 (ScalarValue::String(s), Type::String) => {
73 out.push(4);
74 let b = s.as_bytes();
75 check_field_bytes_len(b.len())?;
76 out.extend_from_slice(&(b.len() as u32).to_le_bytes());
77 out.extend_from_slice(b);
78 }
79 (ScalarValue::Bytes(b), Type::Bytes) => {
80 out.push(5);
81 check_field_bytes_len(b.len())?;
82 out.extend_from_slice(&(b.len() as u32).to_le_bytes());
83 out.extend_from_slice(b);
84 }
85 (ScalarValue::Uuid(u), Type::Uuid) => {
86 out.push(6);
87 out.extend_from_slice(u);
88 }
89 (ScalarValue::Timestamp(t), Type::Timestamp) => {
90 out.push(7);
91 out.extend_from_slice(&t.to_le_bytes());
92 }
93 _ => return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch)),
94 }
95 Ok(())
96}
97
98pub struct Cursor<'a> {
99 pub bytes: &'a [u8],
100 pub pos: usize,
101}
102
103impl<'a> Cursor<'a> {
104 pub fn new(bytes: &'a [u8]) -> Self {
105 Self { bytes, pos: 0 }
106 }
107
108 pub fn remaining(&self) -> usize {
109 self.bytes.len().saturating_sub(self.pos)
110 }
111
112 pub fn take_u16(&mut self) -> Result<u16, DbError> {
113 if self.remaining() < 2 {
114 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
115 }
116 let v = u16::from_le_bytes([self.bytes[self.pos], self.bytes[self.pos + 1]]);
117 self.pos += 2;
118 Ok(v)
119 }
120
121 pub fn take_u8(&mut self) -> Result<u8, DbError> {
122 if self.pos >= self.bytes.len() {
123 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
124 }
125 let b = self.bytes[self.pos];
126 self.pos += 1;
127 Ok(b)
128 }
129
130 pub fn take_u32(&mut self) -> Result<u32, DbError> {
131 if self.remaining() < 4 {
132 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
133 }
134 let v = u32::from_le_bytes([
135 self.bytes[self.pos],
136 self.bytes[self.pos + 1],
137 self.bytes[self.pos + 2],
138 self.bytes[self.pos + 3],
139 ]);
140 self.pos += 4;
141 Ok(v)
142 }
143
144 pub fn take_i64(&mut self) -> Result<i64, DbError> {
145 if self.remaining() < 8 {
146 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
147 }
148 let v = i64::from_le_bytes([
149 self.bytes[self.pos],
150 self.bytes[self.pos + 1],
151 self.bytes[self.pos + 2],
152 self.bytes[self.pos + 3],
153 self.bytes[self.pos + 4],
154 self.bytes[self.pos + 5],
155 self.bytes[self.pos + 6],
156 self.bytes[self.pos + 7],
157 ]);
158 self.pos += 8;
159 Ok(v)
160 }
161
162 pub fn take_u64(&mut self) -> Result<u64, DbError> {
163 if self.remaining() < 8 {
164 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
165 }
166 let v = u64::from_le_bytes([
167 self.bytes[self.pos],
168 self.bytes[self.pos + 1],
169 self.bytes[self.pos + 2],
170 self.bytes[self.pos + 3],
171 self.bytes[self.pos + 4],
172 self.bytes[self.pos + 5],
173 self.bytes[self.pos + 6],
174 self.bytes[self.pos + 7],
175 ]);
176 self.pos += 8;
177 Ok(v)
178 }
179
180 pub fn take_f64(&mut self) -> Result<f64, DbError> {
181 Ok(f64::from_bits(self.take_u64()?))
182 }
183
184 pub fn take_bytes(&mut self, n: usize) -> Result<Vec<u8>, DbError> {
185 check_field_bytes_len(n)?;
186 if self.remaining() < n {
187 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
188 }
189 let slice = &self.bytes[self.pos..self.pos + n];
190 self.pos += n;
191 Ok(slice.to_vec())
192 }
193}
194
195pub fn decode_tagged_scalar(cur: &mut Cursor<'_>, ty: &Type) -> Result<ScalarValue, DbError> {
196 let tag = cur.take_u8()?;
197 Ok(match ty {
198 Type::Bool => {
199 if tag != 0 {
200 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
201 }
202 let b = cur.take_u8()?;
203 ScalarValue::Bool(b != 0)
204 }
205 Type::Int64 => {
206 if tag != 1 {
207 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
208 }
209 ScalarValue::Int64(cur.take_i64()?)
210 }
211 Type::Uint64 => {
212 if tag != 2 {
213 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
214 }
215 ScalarValue::Uint64(cur.take_u64()?)
216 }
217 Type::Float64 => {
218 if tag != 3 {
219 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
220 }
221 ScalarValue::Float64(cur.take_f64()?)
222 }
223 Type::String => {
224 if tag != 4 {
225 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
226 }
227 let n = cur.take_u32()? as usize;
228 let b = cur.take_bytes(n)?;
229 ScalarValue::String(
230 String::from_utf8(b)
231 .map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))?,
232 )
233 }
234 Type::Bytes => {
235 if tag != 5 {
236 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
237 }
238 let n = cur.take_u32()? as usize;
239 ScalarValue::Bytes(cur.take_bytes(n)?)
240 }
241 Type::Uuid => {
242 if tag != 6 {
243 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
244 }
245 let b = cur.take_bytes(16)?;
246 let mut a = [0u8; 16];
247 a.copy_from_slice(&b);
248 ScalarValue::Uuid(a)
249 }
250 Type::Timestamp => {
251 if tag != 7 {
252 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
253 }
254 ScalarValue::Timestamp(cur.take_i64()?)
255 }
256 _ => return Err(DbError::Format(FormatError::RecordPayloadUnsupportedType)),
257 })
258}
259
260pub fn decode_tagged_string(cur: &mut Cursor<'_>) -> Result<String, DbError> {
261 let tag = cur.take_u8()?;
262 if tag != 4 {
263 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
264 }
265 let n = cur.take_u32()? as usize;
266 let b = cur.take_bytes(n)?;
267 String::from_utf8(b).map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))
268}
269
270#[cfg(test)]
271mod tests {
272 include!(concat!(
273 env!("CARGO_MANIFEST_DIR"),
274 "/tests/unit/src_record_scalar_tests.rs"
275 ));
276}