use {
super::Error,
crate::{
NetworkId,
StreamId,
groups::GroupId,
network::PeerId,
primitives::*,
tickets::{Expiration, InvalidTicket, Ticket, TicketValidator},
},
chrono::{DateTime, Utc},
core::fmt,
derive_more::{AsRef, Debug, Deref, Into},
iroh::{EndpointAddr, SecretKey, Signature},
semver::Version,
serde::{Deserialize, Deserializer, Serialize, de},
std::collections::{BTreeMap, BTreeSet},
};
#[derive(
Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
pub struct PeerEntryVersion(pub(crate) i64, pub(crate) i64);
impl Default for PeerEntryVersion {
fn default() -> Self {
Self(Utc::now().timestamp_millis(), Utc::now().timestamp_millis())
}
}
impl core::fmt::Debug for PeerEntryVersion {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}.{}", self.0, self.1)
}
}
impl core::fmt::Display for PeerEntryVersion {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}.{}", self.0, self.1)
}
}
impl core::fmt::Display for Short<PeerEntryVersion> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
const MILLIS_PER_WEEK: i64 = 7 * 86_400_000;
write!(f, "{}", self.0.1 % MILLIS_PER_WEEK)
}
}
impl PeerEntryVersion {
#[must_use]
pub(crate) fn increment(self) -> Self {
let last_version = self.1.max(Utc::now().timestamp_millis());
Self(self.0, last_version.saturating_add(1))
}
pub fn updated_at(&self) -> DateTime<Utc> {
DateTime::<Utc>::from_timestamp_millis(self.1).unwrap_or_default()
}
pub fn started_at(&self) -> DateTime<Utc> {
DateTime::<Utc>::from_timestamp_millis(self.0).unwrap_or_default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct PeerEntry {
protocol: Version,
network: NetworkId,
version: PeerEntryVersion,
address: EndpointAddr,
tags: BTreeSet<Tag>,
streams: BTreeSet<StreamId>,
groups: BTreeSet<GroupId>,
tickets: BTreeMap<UniqueId, Ticket>,
}
impl PeerEntry {
pub const fn id(&self) -> &PeerId {
&self.address.id
}
pub const fn network_id(&self) -> &NetworkId {
&self.network
}
pub const fn protocol_version(&self) -> &Version {
&self.protocol
}
pub const fn address(&self) -> &EndpointAddr {
&self.address
}
pub const fn tags(&self) -> &BTreeSet<Tag> {
&self.tags
}
pub const fn streams(&self) -> &BTreeSet<StreamId> {
&self.streams
}
pub const fn groups(&self) -> &BTreeSet<GroupId> {
&self.groups
}
pub const fn tickets(&self) -> &BTreeMap<UniqueId, Ticket> {
&self.tickets
}
pub fn tickets_of(&self, class: UniqueId) -> impl Iterator<Item = &Ticket> {
self
.tickets
.iter()
.filter_map(move |(_, v)| (v.class == class).then_some(v))
}
pub fn valid_tickets(
&self,
class: UniqueId,
validator: impl Fn(&[u8]) -> bool,
) -> impl Iterator<Item = &Ticket> {
self.tickets_of(class).filter(move |t| validator(&t.data))
}
pub fn has_valid_ticket(
&self,
class: UniqueId,
validator: impl Fn(&[u8]) -> bool,
) -> bool {
self.tickets_of(class).any(move |t| validator(&t.data))
}
pub fn validate_ticket(
&self,
validator: &dyn TicketValidator,
) -> Result<Expiration, InvalidTicket> {
self
.tickets_of(validator.class())
.filter_map(|t| validator.validate(&t.data, self).ok())
.max()
.ok_or(InvalidTicket)
}
pub fn validate_tickets(
&self,
validators: &[impl AsRef<dyn TicketValidator>],
) -> Result<Option<Expiration>, InvalidTicket> {
if validators.is_empty() {
return Ok(None);
}
let mut earliest = Expiration::Never;
for v in validators {
let expiration = self.validate_ticket(v.as_ref())?;
earliest = earliest.min(expiration);
}
Ok(Some(earliest))
}
pub const fn update_version(&self) -> PeerEntryVersion {
self.version
}
pub fn updated_at(&self) -> DateTime<Utc> {
self.version.updated_at()
}
pub fn started_at(&self) -> DateTime<Utc> {
self.version.started_at()
}
pub fn uptime(&self) -> core::time::Duration {
(Utc::now() - self.version.started_at())
.to_std()
.unwrap_or_default()
}
pub fn digest(&self) -> blake3::Hash {
let mut hasher = blake3::Hasher::new();
serialize_to_writer(self, &mut hasher);
hasher.finalize()
}
pub fn is_newer_than(&self, other: &Self) -> bool {
self.version > other.version
}
}
impl PeerEntry {
pub(crate) fn new(network: NetworkId, address: EndpointAddr) -> Self {
Self {
network,
address,
tags: BTreeSet::new(),
streams: BTreeSet::new(),
groups: BTreeSet::new(),
tickets: BTreeMap::new(),
version: PeerEntryVersion::default(),
protocol: env!("CARGO_PKG_VERSION")
.parse()
.expect("Invalid CARGO_PKG_VERSION for mosaik"),
}
}
pub fn update_address(
mut self,
address: EndpointAddr,
) -> Result<Self, Error> {
if address.id != *self.id() {
return Err(Error::PeerIdChanged(*self.id(), address.id));
}
self.address = address;
self.version = self.version.increment();
Ok(self)
}
#[must_use]
pub fn add_streams<V>(
mut self,
streams: impl IntoIterOrSingle<StreamId, V>,
) -> Self {
let count = self.streams.len();
self.streams.extend(streams.iterator());
if count != self.streams.len() {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn remove_streams<V>(
mut self,
streams: impl IntoIterOrSingle<StreamId, V>,
) -> Self {
let mut was_present = false;
for stream in streams.iterator() {
was_present |= self.streams.remove(&stream);
}
if was_present {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn add_groups<V>(
mut self,
groups: impl IntoIterOrSingle<GroupId, V>,
) -> Self {
let count = self.groups.len();
self.groups.extend(groups.iterator());
if count != self.groups.len() {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn remove_groups<V>(
mut self,
groups: impl IntoIterOrSingle<GroupId, V>,
) -> Self {
let mut was_present = false;
for group in groups.iterator() {
was_present |= self.groups.remove(&group);
}
if was_present {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn add_tags<V>(mut self, tags: impl IntoIterOrSingle<Tag, V>) -> Self {
let count = self.tags.len();
self.tags.extend(tags.iterator());
if count != self.tags.len() {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn remove_tags<V>(mut self, tags: impl IntoIterOrSingle<Tag, V>) -> Self {
let mut was_present = false;
for tag in tags.iterator() {
was_present |= self.tags.remove(&tag);
}
if was_present {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn add_ticket(mut self, ticket: Ticket) -> Self {
let id = ticket.id();
if self.tickets.insert(id, ticket).is_none() {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn remove_ticket(mut self, ticket_id: UniqueId) -> Self {
if self.tickets.remove(&ticket_id).is_some() {
self.version = self.version.increment();
}
self
}
#[must_use]
pub fn remove_tickets_of(mut self, class: UniqueId) -> Self {
let mut was_present = false;
self.tickets.retain(|_, v| {
if v.class == class {
was_present = true;
false
} else {
true
}
});
if was_present {
self.version = self.version.increment();
}
self
}
#[must_use]
pub(crate) fn increment_version(mut self) -> Self {
self.version = self.version.increment();
self
}
pub fn sign(self, secret: &SecretKey) -> Result<SignedPeerEntry, Error> {
let actual_id: PeerId = *self.id();
let expected_id: PeerId = secret.public();
if actual_id != expected_id {
return Err(Error::InvalidSecretKey(expected_id, actual_id));
}
let digest = self.digest();
let signature = secret.sign(digest.as_bytes());
Ok(SignedPeerEntry(self, signature))
}
}
impl From<&PeerEntry> for PeerId {
fn from(entry: &PeerEntry) -> Self {
*entry.id()
}
}
impl fmt::Display for Short<&PeerEntry> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"PeerEntry[#{}]({}, tags: {}, streams: {}, tickets: {}, groups: {})",
Short(self.0.update_version()),
Short(self.0.id()),
self.0.tags.len(),
self.0.streams.len(),
self.0.tickets.len(),
self.0.groups.len(),
)
}
}
impl fmt::Debug for Pretty<'_, PeerEntry> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "PeerEntry:")?;
writeln!(f, " id: {}", self.id())?;
writeln!(f, " network: {}", self.network_id())?;
writeln!(
f,
" ips: {:?}",
&self.address.ip_addrs().collect::<Vec<_>>()
)?;
writeln!(
f,
" relays: {:?}",
&self
.address
.relay_urls()
.map(|r| r.as_str())
.collect::<Vec<_>>()
)?;
writeln!(f, " tags: {}", FmtIter::<Short<_>, _>::new(&self.tags))?;
writeln!(f, " groups: {}", FmtIter::<Short<_>, _>::new(&self.groups))?;
writeln!(
f,
" streams: {}",
FmtIter::<Short<_>, _>::new(&self.streams)
)?;
writeln!(
f,
" tickets: {}",
FmtIter::<Short<_>, _>::new(self.tickets.values().map(|v| v.class))
)?;
writeln!(f, " update: {}", self.update_version())?;
writeln!(f, " protocol: {}", self.protocol)
}
}
#[derive(Debug, Clone, Serialize, Deref, AsRef, Into, PartialEq, Eq)]
pub struct SignedPeerEntry(
#[deref] PeerEntry,
#[debug("signature: {}", Abbreviated::<16, _>(_1.to_bytes()))] Signature,
);
impl SignedPeerEntry {
pub fn into_unsigned(self) -> PeerEntry {
self.0
}
}
impl SignedPeerEntry {
fn is_signature_valid(&self) -> bool {
let digest = self.0.digest();
self.0.id().verify(digest.as_bytes(), &self.1).is_ok()
}
fn verify_signature(&self) -> Result<(), Error> {
self
.is_signature_valid()
.then_some(())
.ok_or(Error::InvalidSignature)
}
}
impl From<&SignedPeerEntry> for PeerId {
fn from(entry: &SignedPeerEntry) -> Self {
*entry.id()
}
}
impl<'de> Deserialize<'de> for SignedPeerEntry {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let (entry, signature) =
<(PeerEntry, Signature)>::deserialize(deserializer)?;
let signed = Self(entry, signature);
signed.verify_signature().map_err(de::Error::custom)?;
Ok(signed)
}
}
impl From<SignedPeerEntry> for PeerEntry {
fn from(signed: SignedPeerEntry) -> Self {
signed.0
}
}
impl From<&SignedPeerEntry> for PeerEntry {
fn from(signed: &SignedPeerEntry) -> Self {
signed.clone().0
}
}
impl fmt::Debug for Pretty<'_, SignedPeerEntry> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Signed{:?}", Pretty(&self.0.0))?;
writeln!(
f,
" signature: {}",
Abbreviated::<16, _>(self.1.to_bytes())
)
}
}
impl fmt::Display for Short<&SignedPeerEntry> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"SignedPeerEntry[#{}]({}, tags: {}, streams: {}, groups: {})",
Short(self.0.update_version()),
Short(self.0.id()),
FmtIter::<Short<_>, _>::new(&self.0.tags),
FmtIter::<Short<_>, _>::new(&self.0.streams),
FmtIter::<Short<_>, _>::new(&self.0.groups),
)
}
}
impl fmt::Display for Pretty<'_, EndpointAddr> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.addrs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signed_peer_entry_with_invalid_signature_fails_to_deserialize() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address);
let signed = entry.sign(&secret).unwrap();
let serialized = serialize(&signed).to_vec();
let mut tampered = serialized;
let len = tampered.len();
tampered[len - 1] ^= 0xFF;
let result: Result<SignedPeerEntry, _> = deserialize(&tampered);
assert!(
result.is_err(),
"Expected deserialization to fail with invalid signature"
);
}
#[test]
fn signed_peer_entry_with_modified_entry_fails_to_deserialize() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address.clone())
.add_tags(Tag::from("original"));
let signed = entry.sign(&secret).unwrap();
let signature = signed.1;
let modified_entry =
PeerEntry::new(network_id, address).add_tags(Tag::from("modified"));
let invalid_signed = SignedPeerEntry(modified_entry, signature);
let serialized = serialize(&invalid_signed);
let result: Result<SignedPeerEntry, _> = deserialize(&serialized);
assert!(
result.is_err(),
"Expected deserialization to fail with modified entry"
);
}
#[test]
fn valid_signed_peer_entry_deserializes_successfully() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address);
let signed = entry.sign(&secret).unwrap();
let serialized = serialize(&signed);
let deserialized: SignedPeerEntry = deserialize(&serialized)
.expect("Failed to deserialize valid SignedPeerEntry");
assert_eq!(signed, deserialized);
}
#[test]
fn version_increments_on_add_tags() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address);
let initial_version = entry.update_version();
let updated = entry.add_tags(Tag::from("test"));
assert!(
updated.update_version() > initial_version,
"Version should increment after add_tags"
);
}
#[test]
fn version_increments_on_remove_tags() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address).add_tags(Tag::from("test"));
let initial_version = entry.update_version();
let updated = entry.remove_tags("test");
assert!(
updated.update_version() > initial_version,
"Version should increment after remove_tags"
);
}
#[test]
fn version_increments_on_add_streams() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address);
let initial_version = entry.update_version();
let updated = entry.add_streams(StreamId::from("test-stream"));
assert!(
updated.update_version() > initial_version,
"Version should increment after add_streams"
);
}
#[test]
fn version_increments_on_remove_streams() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address)
.add_streams(StreamId::from("test-stream"));
let initial_version = entry.update_version();
let updated = entry.remove_streams(StreamId::from("test-stream"));
assert!(
updated.update_version() > initial_version,
"Version should increment after remove_streams"
);
}
#[test]
fn version_increments_on_update_address() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address.clone());
let initial_version = entry.update_version();
let updated = entry.update_address(address).unwrap();
assert!(
updated.update_version() > initial_version,
"Version should increment after update_address"
);
}
#[test]
fn version_increments_monotonically_on_multiple_changes() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry = PeerEntry::new(network_id, address.clone());
let v0 = entry.update_version();
let entry = entry.add_tags(Tag::from("tag1"));
let v1 = entry.update_version();
assert!(v1 > v0, "Version should increment after first change");
let entry = entry.add_streams(StreamId::from("stream1"));
let v2 = entry.update_version();
assert!(v2 > v1, "Version should increment after second change");
let entry = entry.remove_tags(Tag::from("tag1"));
let v3 = entry.update_version();
assert!(v3 > v2, "Version should increment after third change");
let entry = entry.update_address(address).unwrap();
let v4 = entry.update_version();
assert!(v4 > v3, "Version should increment after fourth change");
}
#[test]
fn is_newer_than_returns_correct_result() {
let network_id = NetworkId::random();
let secret = SecretKey::generate(&mut rand::rng());
let address = EndpointAddr::from(secret.public());
let entry1 = PeerEntry::new(network_id, address);
let entry2 = entry1.clone().add_tags(Tag::from("test"));
assert!(
entry2.is_newer_than(&entry1),
"Updated entry should be newer than original"
);
assert!(
!entry1.is_newer_than(&entry2),
"Original entry should not be newer than updated"
);
assert!(
!entry1.is_newer_than(&entry1),
"Entry should not be newer than itself"
);
}
}