use common_types::{encoded, header::Header};
use vapory_types::H256;
use light::request::{HashOrNumber, CompleteHeadersRequest as HeadersRequest};
use tetsy_rlp::DecoderError;
use std::fmt;
#[derive(Debug, PartialEq)]
pub enum BasicError {
WrongSkip(u64, Option<u64>),
WrongStartNumber(u64, u64),
WrongStartHash(H256, H256),
TooManyHeaders(usize, usize),
Decoder(DecoderError),
}
impl From<DecoderError> for BasicError {
fn from(err: DecoderError) -> Self {
BasicError::Decoder(err)
}
}
impl fmt::Display for BasicError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Header response verification error: ")?;
match *self {
BasicError::WrongSkip(ref exp, ref got)
=> write!(f, "wrong skip (expected {}, got {:?})", exp, got),
BasicError::WrongStartNumber(ref exp, ref got)
=> write!(f, "wrong start number (expected {}, got {})", exp, got),
BasicError::WrongStartHash(ref exp, ref got)
=> write!(f, "wrong start hash (expected {}, got {})", exp, got),
BasicError::TooManyHeaders(ref max, ref got)
=> write!(f, "too many headers (max {}, got {})", max, got),
BasicError::Decoder(ref err)
=> write!(f, "{}", err),
}
}
}
pub trait Constraint {
type Error;
fn verify(&self, headers: &[Header], reverse: bool) -> Result<(), Self::Error>;
}
pub fn verify(headers: &[encoded::Header], request: &HeadersRequest) -> Result<Vec<Header>, BasicError> {
let headers: Result<Vec<_>, _> = headers.iter().map(|h| h.decode() ).collect();
match headers {
Ok(headers) => {
let reverse = request.reverse;
Max(request.max as usize).verify(&headers, reverse)?;
match request.start {
HashOrNumber::Number(ref num) => StartsAtNumber(*num).verify(&headers, reverse)?,
HashOrNumber::Hash(ref hash) => StartsAtHash(*hash).verify(&headers, reverse)?,
}
SkipsBetween(request.skip).verify(&headers, reverse)?;
Ok(headers)
},
Err(e) => Err(e.into())
}
}
struct StartsAtNumber(u64);
struct StartsAtHash(H256);
struct SkipsBetween(u64);
struct Max(usize);
impl Constraint for StartsAtNumber {
type Error = BasicError;
fn verify(&self, headers: &[Header], _reverse: bool) -> Result<(), BasicError> {
headers.first().map_or(Ok(()), |h| {
if h.number() == self.0 {
Ok(())
} else {
Err(BasicError::WrongStartNumber(self.0, h.number()))
}
})
}
}
impl Constraint for StartsAtHash {
type Error = BasicError;
fn verify(&self, headers: &[Header], _reverse: bool) -> Result<(), BasicError> {
headers.first().map_or(Ok(()), |h| {
if h.hash() == self.0 {
Ok(())
} else {
Err(BasicError::WrongStartHash(self.0, h.hash()))
}
})
}
}
impl Constraint for SkipsBetween {
type Error = BasicError;
fn verify(&self, headers: &[Header], reverse: bool) -> Result<(), BasicError> {
for pair in headers.windows(2) {
let (low, high) = if reverse { (&pair[1], &pair[0]) } else { (&pair[0], &pair[1]) };
if low.number() >= high.number() { return Err(BasicError::WrongSkip(self.0, None)) }
let skip = (high.number() - low.number()) - 1;
if skip != self.0 { return Err(BasicError::WrongSkip(self.0, Some(skip))) }
}
Ok(())
}
}
impl Constraint for Max {
type Error = BasicError;
fn verify(&self, headers: &[Header], _reverse: bool) -> Result<(), BasicError> {
match headers.len() > self.0 {
true => Err(BasicError::TooManyHeaders(self.0, headers.len())),
false => Ok(())
}
}
}
#[cfg(test)]
mod tests {
use common_types::encoded;
use common_types::header::Header;
use light::request::CompleteHeadersRequest as HeadersRequest;
use super::*;
#[test]
fn sequential_forward() {
let request = HeadersRequest {
start: 10.into(),
max: 30,
skip: 0,
reverse: false,
};
let mut parent_hash = None;
let headers: Vec<_> = (0..25).map(|x| x + 10).map(|x| {
let mut header = Header::default();
header.set_number(x);
if let Some(parent_hash) = parent_hash {
header.set_parent_hash(parent_hash);
}
parent_hash = Some(header.hash());
encoded::Header::new(::tetsy_rlp::encode(&header))
}).collect();
assert!(verify(&headers, &request).is_ok());
}
#[test]
fn sequential_backward() {
let request = HeadersRequest {
start: 34.into(),
max: 30,
skip: 0,
reverse: true,
};
let mut parent_hash = None;
let headers: Vec<_> = (0..25).map(|x| x + 10).rev().map(|x| {
let mut header = Header::default();
header.set_number(x);
if let Some(parent_hash) = parent_hash {
header.set_parent_hash(parent_hash);
}
parent_hash = Some(header.hash());
encoded::Header::new(::tetsy_rlp::encode(&header))
}).collect();
assert!(verify(&headers, &request).is_ok());
}
#[test]
fn too_many() {
let request = HeadersRequest {
start: 10.into(),
max: 20,
skip: 0,
reverse: false,
};
let mut parent_hash = None;
let headers: Vec<_> = (0..25).map(|x| x + 10).map(|x| {
let mut header = Header::default();
header.set_number(x);
if let Some(parent_hash) = parent_hash {
header.set_parent_hash(parent_hash);
}
parent_hash = Some(header.hash());
encoded::Header::new(::tetsy_rlp::encode(&header))
}).collect();
assert_eq!(verify(&headers, &request), Err(BasicError::TooManyHeaders(20, 25)));
}
#[test]
fn wrong_skip() {
let request = HeadersRequest {
start: 10.into(),
max: 30,
skip: 5,
reverse: false,
};
let headers: Vec<_> = (0..25).map(|x| x * 3).map(|x| x + 10).map(|x| {
let mut header = Header::default();
header.set_number(x);
encoded::Header::new(::tetsy_rlp::encode(&header))
}).collect();
assert_eq!(verify(&headers, &request), Err(BasicError::WrongSkip(5, Some(2))));
}
}