use std::collections::{HashMap, HashSet};
use std::fmt::{Debug, Formatter};
use std::sync::atomic::AtomicU64;
use std::sync::{atomic, Arc};
use std::time::SystemTime;
use crate::error::SignatureError;
use crate::protocol::{
MavSha256, MavTimestamp, MessageId, SecretKey, Sign, SignedLinkId, Signer, SigningConf,
};
use crate::prelude::*;
pub use builder::FrameSignerBuilder;
use builder::{NoLinkId, NoSecretKey};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "specta", derive(specta::Type))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FrameSigner {
link_id: SignedLinkId,
incoming: SignStrategy,
outgoing: SignStrategy,
unknown_links: SignStrategy,
#[cfg_attr(feature = "serde", serde(skip_serializing))]
links: HashMap<SignedLinkId, SecretKey>,
last_timestamp: UniqueMavTimestamp,
exclude: HashSet<MessageId>,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "specta", derive(specta::Type))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum SignStrategy {
#[default]
Sign,
ReSign,
Strict,
Proxy,
Strip,
}
pub trait IntoFrameSigner {
fn into_message_signer(self) -> FrameSigner;
}
#[derive(Clone)]
#[cfg_attr(feature = "specta", derive(specta::Type))]
pub struct UniqueMavTimestamp(Arc<AtomicU64>);
impl FrameSigner {
pub fn new<K: Into<SecretKey>>(link_id: SignedLinkId, key: K) -> Self {
Self::builder().link_id(link_id).key(key.into()).build()
}
pub fn builder() -> FrameSignerBuilder<NoLinkId, NoSecretKey> {
FrameSignerBuilder::new()
}
pub fn link_id(&self) -> SignedLinkId {
self.link_id
}
pub fn key(&self) -> &SecretKey {
self.links.get(&self.link_id).unwrap()
}
pub fn incoming(&self) -> SignStrategy {
self.incoming
}
pub fn outgoing(&self) -> SignStrategy {
self.outgoing
}
pub fn unknown_links(&self) -> SignStrategy {
self.unknown_links
}
pub fn links(&self) -> impl Iterator<Item = (SignedLinkId, &SecretKey)> {
self.links.iter().map(|(&link_id, key)| (link_id, key))
}
pub fn exclude(&self) -> impl Iterator<Item = MessageId> {
self.exclude.clone().into_iter()
}
#[inline(always)]
pub fn process_incoming<V: MaybeVersioned>(
&self,
frame: &mut Frame<V>,
) -> core::result::Result<(), SignatureError> {
self.process_for_strategy(frame, self.incoming)
}
#[inline(always)]
pub fn process_outgoing<V: MaybeVersioned>(
&self,
frame: &mut Frame<V>,
) -> core::result::Result<(), SignatureError> {
self.process_for_strategy(frame, self.outgoing)
}
#[allow(unused_variables)]
pub fn process_new<V: MaybeVersioned>(&self, frame: &mut Frame<V>) {
if let SignStrategy::Strict = self.outgoing {
self.sign_frame(frame);
}
}
pub fn process_for_strategy<V: MaybeVersioned>(
&self,
frame: &mut Frame<V>,
strategy: SignStrategy,
) -> core::result::Result<(), SignatureError> {
if self.exclude.contains(&frame.message_id()) {
return Ok(());
}
self.validate_for_strategy(frame, strategy)?;
self.sign_for_strategy(frame, strategy);
Ok(())
}
pub fn validate_for_strategy<V: MaybeVersioned>(
&self,
frame: &Frame<V>,
strategy: SignStrategy,
) -> core::result::Result<(), SignatureError> {
if let SignStrategy::Proxy = strategy {
return Ok(());
}
if let SignStrategy::Strict = strategy {
if !frame.is_signed() {
return Err(SignatureError);
}
}
match strategy {
SignStrategy::Sign | SignStrategy::ReSign | SignStrategy::Strict => {
if frame.is_signed() && !self.has_valid_signature(frame) {
return Err(SignatureError);
}
}
SignStrategy::Proxy | SignStrategy::Strip => {}
}
Ok(())
}
pub fn sign_frame<V: MaybeVersioned>(&self, frame: &mut Frame<V>) {
let signature_conf = self.to_signature_conf();
signature_conf.apply(frame, &mut self.signer());
}
pub fn has_valid_signature<V: MaybeVersioned>(&self, frame: &Frame<V>) -> bool {
let signature = if let Some(signature) = frame.signature() {
signature
} else {
return false;
};
if let Some(key) = self.links.get(&signature.link_id) {
let mut _signer = self.signer();
let mut signer = Signer::new(&mut _signer);
signer.validate(frame, signature, key)
} else {
match self.unknown_links {
SignStrategy::Sign | SignStrategy::ReSign => {
let mut _signer = self.signer();
let mut signer = Signer::new(&mut _signer);
signer.validate(frame, signature, self.key())
}
SignStrategy::Strict => false,
SignStrategy::Proxy | SignStrategy::Strip => true,
}
}
}
pub fn signer(&self) -> impl Sign {
MavSha256::default()
}
pub fn to_signature_conf(&self) -> SigningConf {
SigningConf {
link_id: self.link_id,
timestamp: self.next_timestamp(),
secret: self.key().clone(),
}
}
pub fn next_timestamp(&self) -> MavTimestamp {
self.last_timestamp.next()
}
fn sign_for_strategy<V: MaybeVersioned>(&self, frame: &mut Frame<V>, strategy: SignStrategy) {
match strategy {
SignStrategy::Sign => {
if self.should_sign(frame) {
self.sign_frame(frame);
}
}
SignStrategy::ReSign => {
if self.should_re_sign(frame) {
self.sign_frame(frame);
}
}
SignStrategy::Strip => {
frame.remove_signature();
}
SignStrategy::Strict => {}
SignStrategy::Proxy => {}
}
}
fn should_sign<V: MaybeVersioned>(&self, frame: &Frame<V>) -> bool {
if let Some(signature) = frame.signature() {
self.links.get(&signature.link_id).is_none()
&& (self.unknown_links == SignStrategy::Sign
|| self.unknown_links == SignStrategy::ReSign)
} else {
true
}
}
fn should_re_sign<V: MaybeVersioned>(&self, frame: &Frame<V>) -> bool {
if let Some(signature) = frame.signature() {
if self.links.get(&signature.link_id).is_none() {
self.unknown_links == SignStrategy::ReSign
} else {
true
}
} else {
true
}
}
}
impl UniqueMavTimestamp {
pub fn new() -> Self {
Self(Arc::new(AtomicU64::new(
MavTimestamp::from(SystemTime::now()).as_raw_u64() - 1,
)))
}
pub fn init(timestamp: MavTimestamp) -> Self {
Self(Arc::new(AtomicU64::new(timestamp.as_raw_u64())))
}
pub fn last(&self) -> MavTimestamp {
MavTimestamp::from_raw_u64(self.0.load(atomic::Ordering::Acquire))
}
pub fn next(&self) -> MavTimestamp {
let last_timestamp = self.0.fetch_add(1, atomic::Ordering::Acquire);
let mut timestamp = MavTimestamp::from(SystemTime::now());
if timestamp.as_raw_u64() <= last_timestamp {
timestamp = MavTimestamp::from_raw_u64(last_timestamp + 1);
} else {
self.0
.store(timestamp.as_raw_u64(), atomic::Ordering::Release);
}
timestamp
}
}
impl Default for UniqueMavTimestamp {
fn default() -> Self {
Self::new()
}
}
impl Debug for UniqueMavTimestamp {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("UniqueMavTimestamp")
.field(&self.last())
.finish()
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for UniqueMavTimestamp {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u64(self.last().as_raw_u64())
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for UniqueMavTimestamp {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
let value = u64::deserialize(d)?;
Ok(UniqueMavTimestamp::init(MavTimestamp::from_raw_u64(value)))
}
}
impl IntoFrameSigner for FrameSigner {
fn into_message_signer(self) -> FrameSigner {
self
}
}
pub mod builder {
use super::*;
pub trait MaybeSecretKey: Clone + Debug {}
#[derive(Clone, Debug)]
pub struct NoSecretKey;
impl MaybeSecretKey for NoSecretKey {}
#[derive(Clone, Debug)]
pub struct HasSecretKey(SecretKey);
impl MaybeSecretKey for HasSecretKey {}
pub trait MaybeLinkId: Copy + Clone + Debug {}
#[derive(Copy, Clone, Debug)]
pub struct NoLinkId;
impl MaybeLinkId for NoLinkId {}
#[derive(Copy, Clone, Debug)]
pub struct HasLinkId(SignedLinkId);
impl MaybeLinkId for HasLinkId {}
#[derive(Clone, Debug)]
pub struct FrameSignerBuilder<L: MaybeLinkId, K: MaybeSecretKey> {
link_id: L,
key: K,
incoming: Option<SignStrategy>,
outgoing: Option<SignStrategy>,
unknown_links: Option<SignStrategy>,
links: HashMap<SignedLinkId, SecretKey>,
exclude: HashSet<MessageId>,
}
impl FrameSignerBuilder<NoLinkId, NoSecretKey> {
pub fn new() -> Self {
Self {
link_id: NoLinkId,
key: NoSecretKey,
incoming: None,
outgoing: None,
unknown_links: None,
links: Default::default(),
exclude: Default::default(),
}
}
}
impl<K: MaybeSecretKey> FrameSignerBuilder<NoLinkId, K> {
pub fn link_id(self, link_id: SignedLinkId) -> FrameSignerBuilder<HasLinkId, K> {
FrameSignerBuilder {
link_id: HasLinkId(link_id),
key: self.key,
incoming: self.incoming,
outgoing: self.outgoing,
unknown_links: self.unknown_links,
links: self.links,
exclude: self.exclude,
}
}
}
impl<L: MaybeLinkId> FrameSignerBuilder<L, NoSecretKey> {
pub fn key<K: Into<SecretKey>>(self, key: K) -> FrameSignerBuilder<L, HasSecretKey> {
FrameSignerBuilder {
link_id: self.link_id,
key: HasSecretKey(key.into()),
incoming: self.incoming,
outgoing: self.outgoing,
unknown_links: self.unknown_links,
links: self.links,
exclude: self.exclude,
}
}
}
impl<L: MaybeLinkId, K: MaybeSecretKey> FrameSignerBuilder<L, K> {
pub fn incoming(self, strategy: SignStrategy) -> Self {
Self {
incoming: Some(strategy),
..self
}
}
pub fn outgoing(self, strategy: SignStrategy) -> Self {
Self {
outgoing: Some(strategy),
..self
}
}
#[cfg(feature = "unstable")]
pub fn unknown_links(self, strategy: SignStrategy) -> Self {
Self {
unknown_links: Some(strategy),
..self
}
}
pub fn exclude(self, message_ids: &[MessageId]) -> Self {
Self {
exclude: HashSet::from_iter(message_ids.iter().copied()),
..self
}
}
}
impl FrameSignerBuilder<HasLinkId, HasSecretKey> {
pub fn add_link<K: Into<SecretKey>>(mut self, link_id: SignedLinkId, key: K) -> Self {
let key = key.into();
if self.link_id.0 == link_id {
self.key.0 = key.clone();
}
self.links.insert(link_id, key);
self.link_id = HasLinkId(link_id);
self
}
}
impl FrameSignerBuilder<HasLinkId, HasSecretKey> {
pub fn build(mut self) -> FrameSigner {
self.links.insert(self.link_id.0, self.key.0.clone());
FrameSigner {
link_id: self.link_id.0,
incoming: self.incoming.unwrap_or_default(),
outgoing: self.outgoing.unwrap_or_default(),
unknown_links: self.unknown_links.unwrap_or(SignStrategy::Strict),
links: self.links,
last_timestamp: Default::default(),
exclude: self.exclude,
}
}
}
impl Default for FrameSignerBuilder<NoLinkId, NoSecretKey> {
fn default() -> Self {
Self::new()
}
}
impl From<FrameSignerBuilder<HasLinkId, HasSecretKey>> for FrameSigner {
#[inline]
fn from(value: FrameSignerBuilder<HasLinkId, HasSecretKey>) -> Self {
value.build()
}
}
impl IntoFrameSigner for FrameSignerBuilder<HasLinkId, HasSecretKey> {
fn into_message_signer(self) -> FrameSigner {
self.build()
}
}
}