1use prikk_error::{PrikkError, Result};
4
5use crate::canonical::{WireType, is_strictly_sorted};
6use crate::{CanonicalEncode, CanonicalWriter, ObjectId};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10#[repr(u16)]
11pub enum RefKind {
12 Branch = 1,
14 Tag = 2,
16}
17
18impl RefKind {
19 #[must_use]
21 pub const fn code(self) -> u16 {
22 self as u16
23 }
24
25 pub fn from_code(code: u32) -> Result<Self> {
27 match code {
28 1 => Ok(Self::Branch),
29 2 => Ok(Self::Tag),
30 other => Err(PrikkError::MalformedData(format!(
31 "unknown ref kind code: {other}"
32 ))),
33 }
34 }
35}
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RefStatePayload {
40 pub ref_name: String,
42 pub kind: RefKind,
44 pub target_object_id: ObjectId,
46 pub update_seq: u64,
48 pub previous_ref_state_id: Option<ObjectId>,
50 pub required_attestation_ids: Vec<ObjectId>,
52}
53
54impl RefStatePayload {
55 pub fn decode_canonical(bytes: &[u8]) -> Result<Self> {
57 let mut cursor = CanonicalCursor::new(bytes);
58 let mut ref_name = None;
59 let mut kind = None;
60 let mut target_object_id = None;
61 let mut update_seq = None;
62 let mut previous_ref_state_id = None;
63 let mut required_attestation_ids = Vec::new();
64 while let Some(field) = cursor.next_field()? {
65 match field.tag {
66 1 => ref_name = Some(field.read_string()?),
67 2 => target_object_id = Some(field.read_object_id()?),
68 3 => update_seq = Some(field.read_u64()?),
69 4 => previous_ref_state_id = Some(field.read_object_id()?),
70 5 => required_attestation_ids.push(field.read_object_id()?),
71 6 => kind = Some(RefKind::from_code(u32::from(field.read_enum_u16()?))?),
72 other => {
73 return Err(PrikkError::MalformedData(format!(
74 "unknown RefState field tag: {other}"
75 )));
76 }
77 }
78 }
79 let payload = Self {
80 ref_name: ref_name.ok_or_else(|| {
81 PrikkError::MalformedData("RefState missing ref_name".to_string())
82 })?,
83 kind: kind
84 .ok_or_else(|| PrikkError::MalformedData("RefState missing kind".to_string()))?,
85 target_object_id: target_object_id.ok_or_else(|| {
86 PrikkError::MalformedData("RefState missing target_object_id".to_string())
87 })?,
88 update_seq: update_seq.ok_or_else(|| {
89 PrikkError::MalformedData("RefState missing update_seq".to_string())
90 })?,
91 previous_ref_state_id,
92 required_attestation_ids,
93 };
94 if !is_strictly_sorted(&payload.required_attestation_ids) {
95 return Err(PrikkError::MalformedData(
96 "RefState attestation IDs are not sorted and unique".to_string(),
97 ));
98 }
99 Ok(payload)
100 }
101}
102
103impl CanonicalEncode for RefStatePayload {
104 fn encode_canonical(&self, writer: &mut CanonicalWriter) -> Result<()> {
105 if !is_strictly_sorted(&self.required_attestation_ids) {
106 return Err(PrikkError::CanonicalEncoding(
107 "required_attestation_ids must be sorted and unique".to_string(),
108 ));
109 }
110 writer.field_string(1, &self.ref_name)?;
111 writer.field_object_id(2, &self.target_object_id)?;
112 writer.field_u64(3, self.update_seq)?;
113 if let Some(previous) = self.previous_ref_state_id {
114 writer.field_object_id(4, &previous)?;
115 }
116 writer.repeated_object_id(5, &self.required_attestation_ids)?;
117 writer.field_enum_u16(6, self.kind.code())?;
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct RefUpdatePayload {
125 pub ref_name: String,
127 pub old_ref_state_id: Option<ObjectId>,
129 pub new_ref_state_id: ObjectId,
131 pub new_target_object_id: ObjectId,
133 pub update_seq: u64,
135 pub created_at: u64,
137 pub author_key_id: String,
139}
140
141impl CanonicalEncode for RefUpdatePayload {
142 fn encode_canonical(&self, writer: &mut CanonicalWriter) -> Result<()> {
143 writer.field_string(1, &self.ref_name)?;
144 if let Some(old) = self.old_ref_state_id {
145 writer.field_object_id(2, &old)?;
146 }
147 writer.field_object_id(3, &self.new_ref_state_id)?;
148 writer.field_object_id(4, &self.new_target_object_id)?;
149 writer.field_u64(5, self.update_seq)?;
150 writer.field_u64(6, self.created_at)?;
151 writer.field_string(7, &self.author_key_id)?;
152 Ok(())
153 }
154}
155
156impl RefUpdatePayload {
157 pub fn decode_canonical(bytes: &[u8]) -> Result<Self> {
159 let mut cursor = CanonicalCursor::new(bytes);
160 let mut ref_name = None;
161 let mut old_ref_state_id = None;
162 let mut new_ref_state_id = None;
163 let mut new_target_object_id = None;
164 let mut update_seq = None;
165 let mut created_at = None;
166 let mut author_key_id = None;
167 while let Some(field) = cursor.next_field()? {
168 match field.tag {
169 1 => ref_name = Some(field.read_string()?),
170 2 => old_ref_state_id = Some(field.read_object_id()?),
171 3 => new_ref_state_id = Some(field.read_object_id()?),
172 4 => new_target_object_id = Some(field.read_object_id()?),
173 5 => update_seq = Some(field.read_u64()?),
174 6 => created_at = Some(field.read_u64()?),
175 7 => author_key_id = Some(field.read_string()?),
176 other => {
177 return Err(PrikkError::MalformedData(format!(
178 "unknown RefUpdate field tag: {other}"
179 )));
180 }
181 }
182 }
183 Ok(Self {
184 ref_name: ref_name.ok_or_else(|| {
185 PrikkError::MalformedData("RefUpdate missing ref_name".to_string())
186 })?,
187 old_ref_state_id,
188 new_ref_state_id: new_ref_state_id.ok_or_else(|| {
189 PrikkError::MalformedData("RefUpdate missing new_ref_state_id".to_string())
190 })?,
191 new_target_object_id: new_target_object_id.ok_or_else(|| {
192 PrikkError::MalformedData("RefUpdate missing new_target_object_id".to_string())
193 })?,
194 update_seq: update_seq.ok_or_else(|| {
195 PrikkError::MalformedData("RefUpdate missing update_seq".to_string())
196 })?,
197 created_at: created_at.ok_or_else(|| {
198 PrikkError::MalformedData("RefUpdate missing created_at".to_string())
199 })?,
200 author_key_id: author_key_id.ok_or_else(|| {
201 PrikkError::MalformedData("RefUpdate missing author_key_id".to_string())
202 })?,
203 })
204 }
205}
206
207struct CanonicalCursor<'a> {
208 bytes: &'a [u8],
209 pos: usize,
210 last_tag: Option<u16>,
211}
212
213impl<'a> CanonicalCursor<'a> {
214 const fn new(bytes: &'a [u8]) -> Self {
215 Self {
216 bytes,
217 pos: 0,
218 last_tag: None,
219 }
220 }
221
222 fn next_field(&mut self) -> Result<Option<CanonicalField<'a>>> {
223 if self.pos == self.bytes.len() {
224 return Ok(None);
225 }
226 let tag = u16::from_be_bytes(self.read_array::<2>()?);
227 if tag == 0 {
228 return Err(PrikkError::MalformedData(
229 "field tag 0 is reserved".to_string(),
230 ));
231 }
232 if let Some(last) = self.last_tag {
233 if tag < last {
234 return Err(PrikkError::MalformedData(format!(
235 "field tag order violation: {tag} after {last}"
236 )));
237 }
238 }
239 self.last_tag = Some(tag);
240 let wire_type = self.read_u8()?;
241 let len = usize::try_from(u64::from_be_bytes(self.read_array::<8>()?)).map_err(|_| {
242 PrikkError::MalformedData("canonical field length does not fit usize".to_string())
243 })?;
244 let value = self.read_exact(len)?;
245 Ok(Some(CanonicalField {
246 tag,
247 wire_type,
248 value,
249 }))
250 }
251
252 fn read_u8(&mut self) -> Result<u8> {
253 let value = self.read_exact(1)?;
254 let Some(byte) = value.first() else {
255 return Err(PrikkError::MalformedData(
256 "unexpected empty byte".to_string(),
257 ));
258 };
259 Ok(*byte)
260 }
261
262 fn read_array<const N: usize>(&mut self) -> Result<[u8; N]> {
263 let bytes = self.read_exact(N)?;
264 let mut out = [0_u8; N];
265 out.copy_from_slice(bytes);
266 Ok(out)
267 }
268
269 fn read_exact(&mut self, len: usize) -> Result<&'a [u8]> {
270 let end = self
271 .pos
272 .checked_add(len)
273 .ok_or_else(|| PrikkError::MalformedData("canonical range overflow".to_string()))?;
274 let Some(slice) = self.bytes.get(self.pos..end) else {
275 return Err(PrikkError::MalformedData(
276 "unexpected end of canonical payload".to_string(),
277 ));
278 };
279 self.pos = end;
280 Ok(slice)
281 }
282}
283
284struct CanonicalField<'a> {
285 tag: u16,
286 wire_type: u8,
287 value: &'a [u8],
288}
289
290impl<'a> CanonicalField<'a> {
291 fn read_string(&self) -> Result<String> {
292 self.require_wire(WireType::String)?;
293 String::from_utf8(self.value.to_vec())
294 .map_err(|err| PrikkError::MalformedData(format!("invalid UTF-8 string: {err}")))
295 }
296
297 fn read_u64(&self) -> Result<u64> {
298 self.require_wire(WireType::U64)?;
299 Ok(u64::from_be_bytes(self.read_array::<8>()?))
300 }
301
302 fn read_object_id(&self) -> Result<ObjectId> {
303 self.require_wire(WireType::ObjectId)?;
304 Ok(ObjectId::from_bytes(self.read_array::<32>()?))
305 }
306
307 fn read_enum_u16(&self) -> Result<u16> {
308 self.require_wire(WireType::EnumU16)?;
309 Ok(u16::from_be_bytes(self.read_array::<2>()?))
310 }
311
312 fn require_wire(&self, expected: WireType) -> Result<()> {
313 if self.wire_type == expected as u8 {
314 return Ok(());
315 }
316 Err(PrikkError::MalformedData(format!(
317 "field {} has wrong wire type: expected {}, got {}",
318 self.tag, expected as u8, self.wire_type
319 )))
320 }
321
322 fn read_array<const N: usize>(&self) -> Result<[u8; N]> {
323 if self.value.len() != N {
324 return Err(PrikkError::MalformedData(format!(
325 "field {} expected {N} bytes, got {}",
326 self.tag,
327 self.value.len()
328 )));
329 }
330 let mut out = [0_u8; N];
331 out.copy_from_slice(self.value);
332 Ok(out)
333 }
334}