use std::collections::BTreeMap;
use ipld_core::ipld::Ipld;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::id::{Cid, NodeId};
use crate::objects::tombstone::Tombstone;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
pub enum RefTarget {
Normal {
target: Cid,
},
Conflicted {
adds: Vec<Cid>,
removes: Vec<Cid>,
},
}
impl RefTarget {
#[must_use]
pub const fn normal(target: Cid) -> Self {
Self::Normal { target }
}
#[must_use]
pub fn conflicted(mut adds: Vec<Cid>, mut removes: Vec<Cid>) -> Self {
adds.sort();
adds.dedup();
removes.sort();
removes.dedup();
Self::Conflicted { adds, removes }
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct View {
pub heads: Vec<Cid>,
pub refs: BTreeMap<String, RefTarget>,
pub remote_refs: Option<BTreeMap<String, BTreeMap<String, RefTarget>>>,
pub wc_commit: Option<Cid>,
pub tombstones: BTreeMap<NodeId, Tombstone>,
pub extra: BTreeMap<String, Ipld>,
}
impl Default for View {
fn default() -> Self {
Self::new()
}
}
impl View {
pub const KIND: &'static str = "view";
#[must_use]
pub const fn new() -> Self {
Self {
heads: Vec::new(),
refs: BTreeMap::new(),
remote_refs: None,
wc_commit: None,
tombstones: BTreeMap::new(),
extra: BTreeMap::new(),
}
}
#[must_use]
pub fn with_head(mut self, head: Cid) -> Self {
self.heads.push(head);
self
}
#[must_use]
pub fn with_ref(mut self, name: impl Into<String>, target: RefTarget) -> Self {
self.refs.insert(name.into(), target);
self
}
#[must_use]
pub fn with_tracking_ref(
mut self,
remote: impl Into<String>,
ref_name: impl Into<String>,
target: Cid,
) -> Self {
let remote = remote.into();
let ref_name = ref_name.into();
let rt = RefTarget::normal(target);
let map = self.remote_refs.get_or_insert_with(BTreeMap::new);
map.entry(remote).or_default().insert(ref_name, rt);
self
}
#[must_use]
pub fn tracking_ref(&self, remote: &str, ref_name: &str) -> Option<&RefTarget> {
self.remote_refs.as_ref()?.get(remote)?.get(ref_name)
}
}
#[derive(Serialize, Deserialize)]
struct TombstoneEntry {
node_id: NodeId,
#[serde(flatten)]
tombstone: Tombstone,
}
#[derive(Serialize, Deserialize)]
struct ViewWire {
#[serde(rename = "_kind")]
kind: String,
heads: Vec<Cid>,
refs: BTreeMap<String, RefTarget>,
#[serde(default, skip_serializing_if = "Option::is_none")]
remote_refs: Option<BTreeMap<String, BTreeMap<String, RefTarget>>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
wc_commit: Option<Cid>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tombstones: Vec<TombstoneEntry>,
#[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
extra: BTreeMap<String, Ipld>,
}
impl Serialize for View {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let tombstones: Vec<TombstoneEntry> = self
.tombstones
.iter()
.map(|(id, ts)| TombstoneEntry {
node_id: *id,
tombstone: ts.clone(),
})
.collect();
ViewWire {
kind: Self::KIND.into(),
heads: self.heads.clone(),
refs: self.refs.clone(),
remote_refs: self.remote_refs.clone(),
wc_commit: self.wc_commit.clone(),
tombstones,
extra: self.extra.clone(),
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for View {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let w = ViewWire::deserialize(deserializer)?;
if w.kind != Self::KIND {
return Err(serde::de::Error::custom(format!(
"expected _kind='{}', got '{}'",
Self::KIND,
w.kind
)));
}
let mut tombstones = BTreeMap::new();
for entry in w.tombstones {
tombstones.insert(entry.node_id, entry.tombstone);
}
Ok(Self {
heads: w.heads,
refs: w.refs,
remote_refs: w.remote_refs,
wc_commit: w.wc_commit,
tombstones,
extra: w.extra,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{from_canonical_bytes, to_canonical_bytes};
use crate::id::{CODEC_RAW, Multihash};
fn raw(n: u32) -> Cid {
Cid::new(CODEC_RAW, Multihash::sha2_256(&n.to_be_bytes()))
}
#[test]
fn empty_view_round_trip() {
let original = View::new();
let bytes = to_canonical_bytes(&original).unwrap();
let decoded: View = from_canonical_bytes(&bytes).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn view_with_heads_and_refs_round_trip() {
let v = View::new()
.with_head(raw(1))
.with_ref("refs/heads/main", RefTarget::normal(raw(1)))
.with_ref(
"refs/heads/feature",
RefTarget::conflicted(vec![raw(2), raw(3)], vec![raw(1)]),
);
let bytes = to_canonical_bytes(&v).unwrap();
let decoded: View = from_canonical_bytes(&bytes).unwrap();
assert_eq!(v, decoded);
}
#[test]
fn conflicted_ref_sorts_adds_and_removes() {
let r = RefTarget::conflicted(vec![raw(3), raw(1), raw(2)], vec![raw(5), raw(4)]);
match r {
RefTarget::Conflicted { adds, removes } => {
assert!(adds.windows(2).all(|w| w[0] < w[1]));
assert!(removes.windows(2).all(|w| w[0] < w[1]));
}
_ => panic!(),
}
}
#[test]
fn ref_target_normal_round_trip() {
let r = RefTarget::normal(raw(42));
let bytes = to_canonical_bytes(&r).unwrap();
let decoded: RefTarget = from_canonical_bytes(&bytes).unwrap();
assert_eq!(r, decoded);
}
#[test]
fn view_with_tracking_refs_round_trip() {
let v = View::new()
.with_head(raw(1))
.with_ref("refs/heads/main", RefTarget::normal(raw(1)))
.with_tracking_ref("origin", "refs/heads/main", raw(10))
.with_tracking_ref("origin", "refs/heads/feature", raw(11))
.with_tracking_ref("backup", "refs/heads/main", raw(20));
let bytes = to_canonical_bytes(&v).unwrap();
let decoded: View = from_canonical_bytes(&bytes).unwrap();
assert_eq!(v, decoded);
assert_eq!(
decoded.tracking_ref("origin", "refs/heads/main"),
Some(&RefTarget::normal(raw(10))),
);
assert_eq!(
decoded.tracking_ref("backup", "refs/heads/main"),
Some(&RefTarget::normal(raw(20))),
);
assert!(decoded.tracking_ref("unknown", "refs/heads/main").is_none());
assert!(
decoded
.tracking_ref("origin", "refs/heads/missing")
.is_none()
);
}
#[test]
fn view_without_tracking_refs_stays_backward_compatible() {
let v_without = View::new()
.with_head(raw(1))
.with_ref("refs/heads/main", RefTarget::normal(raw(1)));
let v_with_empty = View::new()
.with_head(raw(1))
.with_ref("refs/heads/main", RefTarget::normal(raw(1)));
let a = to_canonical_bytes(&v_without).unwrap();
let b = to_canonical_bytes(&v_with_empty).unwrap();
assert_eq!(a, b, "empty remote_refs must not change bytes");
}
#[test]
fn view_kind_rejection() {
let w = ViewWire {
kind: "commit".into(),
heads: Vec::new(),
refs: BTreeMap::new(),
remote_refs: None,
wc_commit: None,
tombstones: Vec::new(),
extra: BTreeMap::new(),
};
let bytes = serde_ipld_dagcbor::to_vec(&w).unwrap();
let err = serde_ipld_dagcbor::from_slice::<View>(&bytes).unwrap_err();
assert!(err.to_string().contains("_kind"));
}
}