use std::{
fmt::{Debug, Display, Error, Formatter},
time::Duration,
};
use derive_more::derive::Display;
use primitives::{
algebra::elliptic_curve::curve::PointAtInfinityError,
errors::PrimitiveError,
types::identifiers::PeerId,
};
use serde::{Deserialize, Serialize};
use tokio::{sync::mpsc::error::SendError, task::JoinError};
use wincode::{SchemaRead, SchemaWrite, WriteError};
use crate::{circuit::GateIndex, key_recovery::KeyRecoveryError};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, SchemaRead, SchemaWrite)]
pub struct AbortErrorInner {
reason: AbortReason,
faulty_peer: FaultyPeer,
}
#[derive(Serialize, Deserialize, SchemaRead, SchemaWrite)]
#[serde(transparent)]
pub struct AbortError(Box<AbortErrorInner>);
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, SchemaRead, SchemaWrite, Display,
)]
pub enum FaultyPeer {
Local,
Foreign(PeerId),
}
#[derive(
Debug, Clone, PartialEq, Serialize, Deserialize, SchemaRead, SchemaWrite, thiserror::Error,
)]
pub enum AbortReason {
#[error("Invalid MAC: got {0}")]
InvalidMAC(String),
#[error("Expected sent share, got {0:?}")]
ExpectedSentShare(Vec<u8>),
#[error("Expected field element, got {0:?}")]
ExpectedFieldElement(Vec<u8>),
#[error("Expected abort, got {0:?}")]
ExpectedAbort(Vec<u8>),
#[error("Malformed data: {0:?}")]
MalformedData(Vec<u8>),
#[error("Computation failed.")]
ComputationFailed,
#[error("Internal error: {0}")]
InternalError(String),
#[error("Preprocessing stream error: {0}")]
PreprocessingStreamError(String),
#[error("Division by 0 for label {0}")]
DivisionByZero(GateIndex),
#[error("No signature")]
NoSignature,
#[error("Invalid signature: {0:?}")]
InvalidSignature(Vec<u8>),
#[error("Primitive error: {0}")]
PrimitiveError(#[from] PrimitiveError),
#[error("Invalid batch length. Expected {expected}, got {received}")]
InvalidBatchLength { expected: usize, received: usize },
#[error("Quadratic non-residue for label {0}")]
QuadraticNonResidue(GateIndex),
#[error(
"Bit conversion error: expected field element to be 0 or 1 for label {0}, got {1:?} instead."
)]
BitConversionError(GateIndex, Vec<u8>),
#[error("Channel closed.")]
ChannelClosed,
#[error("Timeout elapsed while listening for data after {0} seconds.")]
TimeoutElapsed(f64),
#[error("Key recovery error: {0}")]
KeyRecoveryError(#[from] KeyRecoveryError),
}
impl std::ops::Deref for AbortError {
type Target = AbortErrorInner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for AbortError {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl Clone for AbortError {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl PartialEq for AbortError {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for AbortError {}
impl Debug for AbortError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl AbortError {
pub fn new(reason: AbortReason, faulty_peer: FaultyPeer) -> Self {
Self(Box::new(AbortErrorInner {
reason,
faulty_peer,
}))
}
pub fn invalid_mac(peer: PeerId, received: String) -> Self {
Self::new(AbortReason::InvalidMAC(received), FaultyPeer::Foreign(peer))
}
pub fn expected_sent_share(data: Vec<u8>, peer: PeerId) -> Self {
Self::new(
AbortReason::ExpectedSentShare(data),
FaultyPeer::Foreign(peer),
)
}
pub fn expected_field_element(data: Vec<u8>, peer: PeerId) -> Self {
Self::new(
AbortReason::ExpectedFieldElement(data),
FaultyPeer::Foreign(peer),
)
}
pub fn expected_abort(data: Vec<u8>, peer: PeerId) -> Self {
Self::new(AbortReason::ExpectedAbort(data), FaultyPeer::Foreign(peer))
}
pub fn malformed_data(data: Vec<u8>, faulty_peer: PeerId) -> Self {
Self::new(
AbortReason::MalformedData(data),
FaultyPeer::Foreign(faulty_peer),
)
}
pub fn internal_error(message: &str) -> Self {
Self::new(
AbortReason::InternalError(message.to_string()),
FaultyPeer::Local,
)
}
pub fn division_by_zero(label: GateIndex, faulty_peer: FaultyPeer) -> Self {
Self::new(AbortReason::DivisionByZero(label), faulty_peer)
}
pub fn no_signature(peer: PeerId) -> Self {
Self::new(AbortReason::NoSignature, FaultyPeer::Foreign(peer))
}
pub fn invalid_signature(signature: Vec<u8>, peer: PeerId) -> Self {
Self::new(
AbortReason::InvalidSignature(signature),
FaultyPeer::Foreign(peer),
)
}
pub fn invalid_batch_length(expected: usize, received: usize, faulty_peer: PeerId) -> Self {
Self::new(
AbortReason::InvalidBatchLength { expected, received },
FaultyPeer::Foreign(faulty_peer),
)
}
pub fn preprocessing_stream_error(message: &str) -> Self {
Self::new(
AbortReason::PreprocessingStreamError(message.to_string()),
FaultyPeer::Local,
)
}
pub fn quadratic_non_residue(label: GateIndex, faulty_peer: FaultyPeer) -> Self {
Self::new(AbortReason::QuadraticNonResidue(label), faulty_peer)
}
pub fn bit_conversion_error(label: GateIndex, data: Vec<u8>) -> Self {
Self::new(
AbortReason::BitConversionError(label, data),
FaultyPeer::Local,
)
}
pub fn channel_closed(faulty_peer: PeerId) -> Self {
Self::new(AbortReason::ChannelClosed, FaultyPeer::Foreign(faulty_peer))
}
pub fn timeout_elapsed(timeout: Duration, faulty_peer: PeerId) -> Self {
Self::new(
AbortReason::TimeoutElapsed(timeout.as_secs_f64()),
FaultyPeer::Foreign(faulty_peer),
)
}
pub fn get_reason(&self) -> &AbortReason {
&self.reason
}
pub fn get_faulty_peer(&self) -> FaultyPeer {
self.faulty_peer
}
}
impl Display for AbortError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(
f,
"AbortError: {} caused by {}",
self.reason, self.faulty_peer
)
}
}
impl std::error::Error for AbortError {}
impl From<JoinError> for AbortError {
fn from(value: JoinError) -> Self {
AbortError::internal_error(&format!("JoinError: {value}"))
}
}
impl<T> From<SendError<T>> for AbortError {
fn from(value: SendError<T>) -> Self {
AbortError::internal_error(&format!("Channel error. Unable to send request: {value}"))
}
}
impl From<Box<bincode::ErrorKind>> for AbortError {
fn from(value: Box<bincode::ErrorKind>) -> Self {
AbortError::internal_error(&format!("Unable to serialize data: {value}"))
}
}
impl From<WriteError> for AbortError {
fn from(value: WriteError) -> Self {
AbortError::internal_error(&format!("Wincode write error: {value}"))
}
}
impl From<std::io::Error> for AbortError {
fn from(value: std::io::Error) -> Self {
AbortError::internal_error(&format!("IO error: {value}"))
}
}
impl From<PrimitiveError> for AbortError {
fn from(value: PrimitiveError) -> Self {
AbortError::new(AbortReason::PrimitiveError(value), FaultyPeer::Local)
}
}
impl From<PointAtInfinityError> for AbortError {
fn from(value: PointAtInfinityError) -> Self {
AbortError::internal_error(&format!("PointAtInfinity: {value}"))
}
}
impl From<KeyRecoveryError> for AbortError {
fn from(value: KeyRecoveryError) -> Self {
AbortError::new(AbortReason::KeyRecoveryError(value), FaultyPeer::Local)
}
}