use prikk_error::{PrikkError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum WireType {
Bool = 0x01,
U16 = 0x02,
U32 = 0x03,
U64 = 0x04,
EnumU16 = 0x05,
String = 0x10,
Bytes = 0x11,
ObjectId = 0x12,
RepoPath = 0x13,
Record = 0x20,
RecordListItem = 0x21,
}
pub trait CanonicalEncode {
fn encode_canonical(&self, writer: &mut CanonicalWriter) -> Result<()>;
fn to_canonical_bytes(&self) -> Result<Vec<u8>> {
let mut writer = CanonicalWriter::new();
self.encode_canonical(&mut writer)?;
Ok(writer.finish())
}
}
#[derive(Debug, Default)]
pub struct CanonicalWriter {
bytes: Vec<u8>,
last_tag: Option<u16>,
}
impl CanonicalWriter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn finish(self) -> Vec<u8> {
self.bytes
}
pub fn field_string(&mut self, tag: u16, value: &str) -> Result<()> {
self.field_raw(tag, WireType::String, value.as_bytes())
}
pub fn field_string_opt(&mut self, tag: u16, value: Option<&str>) -> Result<()> {
if let Some(value) = value {
self.field_string(tag, value)?;
}
Ok(())
}
pub fn field_bytes(&mut self, tag: u16, value: &[u8]) -> Result<()> {
self.field_raw(tag, WireType::Bytes, value)
}
pub fn field_u32(&mut self, tag: u16, value: u32) -> Result<()> {
self.field_raw(tag, WireType::U32, &value.to_be_bytes())
}
pub fn field_u64(&mut self, tag: u16, value: u64) -> Result<()> {
self.field_raw(tag, WireType::U64, &value.to_be_bytes())
}
pub fn field_bool(&mut self, tag: u16, value: bool) -> Result<()> {
let encoded = if value { [1_u8] } else { [0_u8] };
self.field_raw(tag, WireType::Bool, &encoded)
}
pub fn field_u16(&mut self, tag: u16, value: u16) -> Result<()> {
self.field_raw(tag, WireType::U16, &value.to_be_bytes())
}
pub fn field_enum_u16(&mut self, tag: u16, value: u16) -> Result<()> {
self.field_raw(tag, WireType::EnumU16, &value.to_be_bytes())
}
pub fn field_object_id(&mut self, tag: u16, value: &crate::ObjectId) -> Result<()> {
self.field_raw(tag, WireType::ObjectId, value.as_bytes())
}
pub fn field_repo_path(&mut self, tag: u16, value: &str) -> Result<()> {
self.field_raw(tag, WireType::RepoPath, value.as_bytes())
}
pub fn field_record<T: CanonicalEncode>(&mut self, tag: u16, value: &T) -> Result<()> {
let mut nested = CanonicalWriter::new();
value.encode_canonical(&mut nested)?;
self.field_raw(tag, WireType::Record, &nested.finish())
}
pub fn repeated_record<T: CanonicalEncode>(&mut self, tag: u16, values: &[T]) -> Result<()> {
for value in values {
self.field_record(tag, value)?;
}
Ok(())
}
pub fn field_record_list_item<T: CanonicalEncode>(
&mut self,
tag: u16,
value: &T,
) -> Result<()> {
let mut nested = CanonicalWriter::new();
value.encode_canonical(&mut nested)?;
self.field_raw(tag, WireType::RecordListItem, &nested.finish())
}
pub fn repeated_record_list<T: CanonicalEncode>(
&mut self,
tag: u16,
values: &[T],
) -> Result<()> {
for value in values {
self.field_record_list_item(tag, value)?;
}
Ok(())
}
pub fn repeated_string(&mut self, tag: u16, values: &[String]) -> Result<()> {
for value in values {
self.field_string(tag, value)?;
}
Ok(())
}
pub fn repeated_object_id(&mut self, tag: u16, values: &[crate::ObjectId]) -> Result<()> {
for value in values {
self.field_object_id(tag, value)?;
}
Ok(())
}
pub fn field_raw(&mut self, tag: u16, wire_type: WireType, value: &[u8]) -> Result<()> {
if tag == 0 {
return Err(PrikkError::CanonicalEncoding(
"field tag 0 is reserved".to_string(),
));
}
if let Some(last) = self.last_tag {
if tag < last {
return Err(PrikkError::CanonicalEncoding(format!(
"field tag order violation: {tag} after {last}"
)));
}
}
self.last_tag = Some(tag);
self.bytes.extend_from_slice(&tag.to_be_bytes());
self.bytes.push(wire_type as u8);
self.bytes
.extend_from_slice(&(value.len() as u64).to_be_bytes());
self.bytes.extend_from_slice(value);
Ok(())
}
}
#[must_use]
pub fn is_strictly_sorted<T: Ord>(values: &[T]) -> bool {
values.windows(2).all(|pair| {
let mut items = pair.iter();
match (items.next(), items.next()) {
(Some(left), Some(right)) => left < right,
_ => true,
}
})
}
#[must_use]
pub fn is_contiguous_op_seq(values: &[u32]) -> bool {
values
.iter()
.enumerate()
.all(|(idx, value)| *value as usize == idx + 1)
}
#[cfg(test)]
mod tests {
use super::CanonicalWriter;
#[test]
fn rejects_decreasing_tags() {
let mut writer = CanonicalWriter::new();
assert!(writer.field_u32(2, 1).is_ok());
assert!(writer.field_u32(1, 1).is_err());
}
}