use cid::Cid;
use serde::Deserialize;
use sha2::{Digest, Sha256};
#[derive(Debug, Deserialize)]
pub struct Commit {
pub did: String,
pub version: u64,
pub data: Cid,
pub rev: String,
pub prev: Option<Cid>,
#[serde(with = "serde_bytes")]
pub sig: serde_bytes::ByteBuf,
}
use serde::de::{self, Deserializer, MapAccess, Unexpected, Visitor};
use std::fmt;
pub type Depth = u32;
#[inline(always)]
pub fn atproto_mst_depth(key: &str) -> Depth {
u128::from_be_bytes(Sha256::digest(key).split_at(16).0.try_into().unwrap()).leading_zeros() / 2
}
#[derive(Debug)]
pub(crate) struct MstNode {
pub depth: Option<Depth>, pub things: Vec<NodeThing>,
}
#[derive(Debug)]
pub(crate) struct NodeThing {
pub(crate) cid: Cid,
pub(crate) kind: ThingKind,
}
#[derive(Debug)]
pub(crate) enum ThingKind {
Tree,
Value { rkey: String },
}
impl<'de> Deserialize<'de> for MstNode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct NodeVisitor;
impl<'de> Visitor<'de> for NodeVisitor {
type Value = MstNode;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct MstNode")
}
fn visit_map<V>(self, mut map: V) -> Result<MstNode, V::Error>
where
V: MapAccess<'de>,
{
let mut found_left = false;
let mut left = None;
let mut found_entries = false;
let mut things = Vec::new();
let mut depth = None;
while let Some(key) = map.next_key()? {
match key {
"l" => {
if found_left {
return Err(de::Error::duplicate_field("l"));
}
found_left = true;
if let Some(cid) = map.next_value()? {
left = Some(NodeThing {
cid,
kind: ThingKind::Tree,
});
}
}
"e" => {
if found_entries {
return Err(de::Error::duplicate_field("e"));
}
found_entries = true;
let mut prefix: Vec<u8> = vec![];
for entry in map.next_value::<Vec<Entry>>()? {
let mut rkey: Vec<u8> = vec![];
let pre_checked =
prefix.get(..entry.prefix_len).ok_or_else(|| {
de::Error::invalid_value(
Unexpected::Bytes(&prefix),
&"a prefix at least as long as the prefix_len",
)
})?;
rkey.extend_from_slice(pre_checked);
rkey.extend_from_slice(&entry.keysuffix);
let rkey_s = String::from_utf8(rkey.clone()).map_err(|_| {
de::Error::invalid_value(
Unexpected::Bytes(&rkey),
&"a valid utf-8 rkey",
)
})?;
let key_depth = atproto_mst_depth(&rkey_s);
if depth.is_none() {
depth = Some(key_depth);
} else if Some(key_depth) != depth {
return Err(de::Error::invalid_value(
Unexpected::Bytes(&prefix),
&"all rkeys to have equal MST depth",
));
}
things.push(NodeThing {
cid: entry.value,
kind: ThingKind::Value { rkey: rkey_s },
});
if let Some(cid) = entry.tree {
things.push(NodeThing {
cid,
kind: ThingKind::Tree,
});
}
prefix = rkey;
}
}
f => return Err(de::Error::unknown_field(f, NODE_FIELDS)),
}
}
if !found_left {
return Err(de::Error::missing_field("l"));
}
if !found_entries {
return Err(de::Error::missing_field("e"));
}
things.reverse();
if let Some(l) = left {
things.push(l);
}
Ok(MstNode { depth, things })
}
}
const NODE_FIELDS: &[&str] = &["l", "e"];
deserializer.deserialize_struct("MstNode", NODE_FIELDS, NodeVisitor)
}
}
impl MstNode {
pub(crate) fn is_empty(&self) -> bool {
self.things.is_empty()
}
#[inline(always)]
pub(crate) fn could_be(bytes: impl AsRef<[u8]>) -> bool {
const NODE_FINGERPRINT: [u8; 3] = [
0xA2, 0x61, b'e', ];
let bytes = bytes.as_ref();
bytes.starts_with(&NODE_FINGERPRINT)
&& bytes
.get(3)
.map(|b| b & 0b1110_0000 == 0x80)
.unwrap_or(false)
}
}
#[derive(Debug, Deserialize, PartialEq)]
#[serde(deny_unknown_fields)]
pub(crate) struct Entry {
#[serde(rename = "p")]
pub prefix_len: usize,
#[serde(rename = "k")]
pub keysuffix: serde_bytes::ByteBuf,
#[serde(rename = "v")]
pub value: Cid,
#[serde(rename = "t")]
pub tree: Option<Cid>,
}