modelvault_core/record/
scalar.rs1use crate::error::{DbError, FormatError};
4use crate::schema::Type;
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum ScalarValue {
8 Bool(bool),
9 Int64(i64),
10 Uint64(u64),
11 Float64(f64),
12 String(String),
13 Bytes(Vec<u8>),
14 Uuid([u8; 16]),
15 Timestamp(i64),
17}
18
19impl ScalarValue {
20 pub fn canonical_key_bytes(&self) -> Vec<u8> {
22 match self {
23 ScalarValue::Bool(b) => vec![0, if *b { 1 } else { 0 }],
24 ScalarValue::Int64(v) => v.to_le_bytes().to_vec(),
25 ScalarValue::Uint64(v) => v.to_le_bytes().to_vec(),
26 ScalarValue::Float64(v) => v.to_le_bytes().to_vec(),
27 ScalarValue::String(s) => s.as_bytes().to_vec(),
28 ScalarValue::Bytes(b) => b.clone(),
29 ScalarValue::Uuid(u) => u.to_vec(),
30 ScalarValue::Timestamp(t) => t.to_le_bytes().to_vec(),
31 }
32 }
33
34 pub fn ty_matches(&self, ty: &Type) -> bool {
35 matches!(
36 (self, ty),
37 (ScalarValue::Bool(_), Type::Bool)
38 | (ScalarValue::Int64(_), Type::Int64)
39 | (ScalarValue::Uint64(_), Type::Uint64)
40 | (ScalarValue::Float64(_), Type::Float64)
41 | (ScalarValue::String(_), Type::String)
42 | (ScalarValue::Bytes(_), Type::Bytes)
43 | (ScalarValue::Uuid(_), Type::Uuid)
44 | (ScalarValue::Timestamp(_), Type::Timestamp)
45 )
46 }
47}
48
49pub fn encode_tagged_scalar(out: &mut Vec<u8>, v: &ScalarValue, ty: &Type) -> Result<(), DbError> {
51 match (v, ty) {
52 (ScalarValue::Bool(b), Type::Bool) => {
53 out.push(0);
54 out.push(if *b { 1 } else { 0 });
55 }
56 (ScalarValue::Int64(n), Type::Int64) => {
57 out.push(1);
58 out.extend_from_slice(&n.to_le_bytes());
59 }
60 (ScalarValue::Uint64(n), Type::Uint64) => {
61 out.push(2);
62 out.extend_from_slice(&n.to_le_bytes());
63 }
64 (ScalarValue::Float64(n), Type::Float64) => {
65 out.push(3);
66 out.extend_from_slice(&n.to_le_bytes());
67 }
68 (ScalarValue::String(s), Type::String) => {
69 out.push(4);
70 let b = s.as_bytes();
71 out.extend_from_slice(&(b.len() as u32).to_le_bytes());
72 out.extend_from_slice(b);
73 }
74 (ScalarValue::Bytes(b), Type::Bytes) => {
75 out.push(5);
76 out.extend_from_slice(&(b.len() as u32).to_le_bytes());
77 out.extend_from_slice(b);
78 }
79 (ScalarValue::Uuid(u), Type::Uuid) => {
80 out.push(6);
81 out.extend_from_slice(u);
82 }
83 (ScalarValue::Timestamp(t), Type::Timestamp) => {
84 out.push(7);
85 out.extend_from_slice(&t.to_le_bytes());
86 }
87 _ => return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch)),
88 }
89 Ok(())
90}
91
92pub struct Cursor<'a> {
93 pub bytes: &'a [u8],
94 pub pos: usize,
95}
96
97impl<'a> Cursor<'a> {
98 pub fn new(bytes: &'a [u8]) -> Self {
99 Self { bytes, pos: 0 }
100 }
101
102 pub fn remaining(&self) -> usize {
103 self.bytes.len().saturating_sub(self.pos)
104 }
105
106 pub fn take_u16(&mut self) -> Result<u16, DbError> {
107 if self.remaining() < 2 {
108 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
109 }
110 let v = u16::from_le_bytes([self.bytes[self.pos], self.bytes[self.pos + 1]]);
111 self.pos += 2;
112 Ok(v)
113 }
114
115 pub fn take_u8(&mut self) -> Result<u8, DbError> {
116 if self.pos >= self.bytes.len() {
117 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
118 }
119 let b = self.bytes[self.pos];
120 self.pos += 1;
121 Ok(b)
122 }
123
124 pub fn take_u32(&mut self) -> Result<u32, DbError> {
125 if self.remaining() < 4 {
126 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
127 }
128 let v = u32::from_le_bytes([
129 self.bytes[self.pos],
130 self.bytes[self.pos + 1],
131 self.bytes[self.pos + 2],
132 self.bytes[self.pos + 3],
133 ]);
134 self.pos += 4;
135 Ok(v)
136 }
137
138 pub fn take_i64(&mut self) -> Result<i64, DbError> {
139 if self.remaining() < 8 {
140 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
141 }
142 let v = i64::from_le_bytes([
143 self.bytes[self.pos],
144 self.bytes[self.pos + 1],
145 self.bytes[self.pos + 2],
146 self.bytes[self.pos + 3],
147 self.bytes[self.pos + 4],
148 self.bytes[self.pos + 5],
149 self.bytes[self.pos + 6],
150 self.bytes[self.pos + 7],
151 ]);
152 self.pos += 8;
153 Ok(v)
154 }
155
156 pub fn take_u64(&mut self) -> Result<u64, DbError> {
157 if self.remaining() < 8 {
158 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
159 }
160 let v = u64::from_le_bytes([
161 self.bytes[self.pos],
162 self.bytes[self.pos + 1],
163 self.bytes[self.pos + 2],
164 self.bytes[self.pos + 3],
165 self.bytes[self.pos + 4],
166 self.bytes[self.pos + 5],
167 self.bytes[self.pos + 6],
168 self.bytes[self.pos + 7],
169 ]);
170 self.pos += 8;
171 Ok(v)
172 }
173
174 pub fn take_f64(&mut self) -> Result<f64, DbError> {
175 Ok(f64::from_bits(self.take_u64()?))
176 }
177
178 pub fn take_bytes(&mut self, n: usize) -> Result<Vec<u8>, DbError> {
179 if self.remaining() < n {
180 return Err(DbError::Format(FormatError::TruncatedRecordPayload));
181 }
182 let slice = &self.bytes[self.pos..self.pos + n];
183 self.pos += n;
184 Ok(slice.to_vec())
185 }
186}
187
188pub fn decode_tagged_scalar(cur: &mut Cursor<'_>, ty: &Type) -> Result<ScalarValue, DbError> {
189 let tag = cur.take_u8()?;
190 Ok(match ty {
191 Type::Bool => {
192 if tag != 0 {
193 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
194 }
195 let b = cur.take_u8()?;
196 ScalarValue::Bool(b != 0)
197 }
198 Type::Int64 => {
199 if tag != 1 {
200 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
201 }
202 ScalarValue::Int64(cur.take_i64()?)
203 }
204 Type::Uint64 => {
205 if tag != 2 {
206 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
207 }
208 ScalarValue::Uint64(cur.take_u64()?)
209 }
210 Type::Float64 => {
211 if tag != 3 {
212 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
213 }
214 ScalarValue::Float64(cur.take_f64()?)
215 }
216 Type::String => {
217 if tag != 4 {
218 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
219 }
220 let n = cur.take_u32()? as usize;
221 let b = cur.take_bytes(n)?;
222 ScalarValue::String(
223 String::from_utf8(b)
224 .map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))?,
225 )
226 }
227 Type::Bytes => {
228 if tag != 5 {
229 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
230 }
231 let n = cur.take_u32()? as usize;
232 ScalarValue::Bytes(cur.take_bytes(n)?)
233 }
234 Type::Uuid => {
235 if tag != 6 {
236 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
237 }
238 let b = cur.take_bytes(16)?;
239 let mut a = [0u8; 16];
240 a.copy_from_slice(&b);
241 ScalarValue::Uuid(a)
242 }
243 Type::Timestamp => {
244 if tag != 7 {
245 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
246 }
247 ScalarValue::Timestamp(cur.take_i64()?)
248 }
249 _ => return Err(DbError::Format(FormatError::RecordPayloadUnsupportedType)),
250 })
251}
252
253pub fn decode_tagged_string(cur: &mut Cursor<'_>) -> Result<String, DbError> {
254 let tag = cur.take_u8()?;
255 if tag != 4 {
256 return Err(DbError::Format(FormatError::RecordPayloadTypeMismatch));
257 }
258 let n = cur.take_u32()? as usize;
259 let b = cur.take_bytes(n)?;
260 String::from_utf8(b).map_err(|_| DbError::Format(FormatError::InvalidRecordUtf8))
261}
262
263#[cfg(test)]
264mod tests {
265 include!(concat!(
266 env!("CARGO_MANIFEST_DIR"),
267 "/tests/unit/src_record_scalar_tests.rs"
268 ));
269}