use crate::shared::AssociationId;
use bytes::Bytes;
use crc::{Crc, CRC_32_ISCSI};
use std::time::Duration;
#[inline(never)]
fn constant_time_ne(a: &[u8], b: &[u8]) -> u8 {
assert!(a.len() == b.len());
let len = a.len();
let a = &a[..len];
let b = &b[..len];
let mut tmp = 0;
for i in 0..len {
tmp |= a[i] ^ b[i];
}
tmp }
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
a.len() == b.len() && constant_time_ne(a, b) == 0
}
pub trait AssociationIdGenerator: Send + Sync {
fn generate_aid(&mut self) -> AssociationId;
fn aid_lifetime(&self) -> Option<Duration>;
}
#[derive(Default, Debug, Clone, Copy)]
pub struct RandomAssociationIdGenerator {
lifetime: Option<Duration>,
}
impl RandomAssociationIdGenerator {
pub fn new() -> Self {
RandomAssociationIdGenerator::default()
}
pub fn set_lifetime(&mut self, d: Duration) -> &mut Self {
self.lifetime = Some(d);
self
}
}
impl AssociationIdGenerator for RandomAssociationIdGenerator {
fn generate_aid(&mut self) -> AssociationId {
rand::random::<u32>()
}
fn aid_lifetime(&self) -> Option<Duration> {
self.lifetime
}
}
const PADDING_MULTIPLE: usize = 4;
pub(crate) fn get_padding_size(len: usize) -> usize {
(PADDING_MULTIPLE - (len % PADDING_MULTIPLE)) % PADDING_MULTIPLE
}
pub(crate) static FOUR_ZEROES: Bytes = Bytes::from_static(&[0, 0, 0, 0]);
pub(crate) fn generate_packet_checksum(raw: &Bytes) -> u32 {
let hasher = Crc::<u32>::new(&CRC_32_ISCSI);
let mut digest = hasher.digest();
digest.update(&raw[0..8]);
digest.update(&FOUR_ZEROES[..]);
digest.update(&raw[12..]);
digest.finalize()
}
pub struct BytesArray<'a> {
chunks: &'a mut [Bytes],
consumed: usize,
length: usize,
}
impl<'a> BytesArray<'a> {
pub fn from_chunks(chunks: &'a mut [Bytes]) -> Self {
let mut length = 0;
for chunk in chunks.iter() {
length += chunk.len();
}
Self {
chunks,
consumed: 0,
length,
}
}
}
impl<'a> BytesSource for BytesArray<'a> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let mut chunks_consumed = 0;
while self.consumed < self.chunks.len() {
let chunk = &mut self.chunks[self.consumed];
if chunk.len() <= limit {
let chunk = std::mem::take(chunk);
self.consumed += 1;
chunks_consumed += 1;
if chunk.is_empty() {
continue;
}
return (chunk, chunks_consumed);
} else if limit > 0 {
let chunk = chunk.split_to(limit);
return (chunk, chunks_consumed);
} else {
break;
}
}
(Bytes::new(), chunks_consumed)
}
fn has_remaining(&self) -> bool {
self.consumed < self.length
}
fn remaining(&self) -> usize {
self.length - self.consumed
}
}
pub struct ByteSlice<'a> {
data: &'a [u8],
}
impl<'a> ByteSlice<'a> {
pub fn from_slice(data: &'a [u8]) -> Self {
Self { data }
}
}
impl<'a> BytesSource for ByteSlice<'a> {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize) {
let limit = limit.min(self.data.len());
if limit == 0 {
return (Bytes::new(), 0);
}
let chunk = Bytes::from(self.data[..limit].to_owned());
self.data = &self.data[chunk.len()..];
let chunks_consumed = if self.data.is_empty() { 1 } else { 0 };
(chunk, chunks_consumed)
}
fn has_remaining(&self) -> bool {
!self.data.is_empty()
}
fn remaining(&self) -> usize {
self.data.len()
}
}
pub trait BytesSource {
fn pop_chunk(&mut self, limit: usize) -> (Bytes, usize);
fn has_remaining(&self) -> bool;
fn remaining(&self) -> usize;
}
#[inline]
pub(crate) fn sna32lt(i1: u32, i2: u32) -> bool {
(i1 < i2 && i2 - i1 < 1 << 31) || (i1 > i2 && i1 - i2 > 1 << 31)
}
#[inline]
pub(crate) fn sna32lte(i1: u32, i2: u32) -> bool {
i1 == i2 || sna32lt(i1, i2)
}
#[inline]
pub(crate) fn sna32gt(i1: u32, i2: u32) -> bool {
(i1 < i2 && (i2 - i1) >= 1 << 31) || (i1 > i2 && (i1 - i2) <= 1 << 31)
}
#[inline]
pub(crate) fn sna32gte(i1: u32, i2: u32) -> bool {
i1 == i2 || sna32gt(i1, i2)
}
#[inline]
pub(crate) fn sna32eq(i1: u32, i2: u32) -> bool {
i1 == i2
}
#[inline]
pub(crate) fn sna16lt(i1: u16, i2: u16) -> bool {
(i1 < i2 && (i2 - i1) < 1 << 15) || (i1 > i2 && (i1 - i2) > 1 << 15)
}
#[inline]
pub(crate) fn sna16lte(i1: u16, i2: u16) -> bool {
i1 == i2 || sna16lt(i1, i2)
}
#[inline]
pub(crate) fn sna16gt(i1: u16, i2: u16) -> bool {
(i1 < i2 && (i2 - i1) >= 1 << 15) || (i1 > i2 && (i1 - i2) <= 1 << 15)
}
#[inline]
pub(crate) fn sna16gte(i1: u16, i2: u16) -> bool {
i1 == i2 || sna16gt(i1, i2)
}
#[inline]
pub(crate) fn sna16eq(i1: u16, i2: u16) -> bool {
i1 == i2
}
#[cfg(test)]
mod test {
use crate::error::Result;
use super::*;
const DIV: isize = 16;
#[test]
fn test_serial_number_arithmetic32bit() -> Result<()> {
const SERIAL_BITS: u32 = 32;
const INTERVAL: u32 = ((1u64 << (SERIAL_BITS as u64)) / (DIV as u64)) as u32;
const MAX_FORWARD_DISTANCE: u32 = 1 << ((SERIAL_BITS - 1) - 1);
const MAX_BACKWARD_DISTANCE: u32 = 1 << (SERIAL_BITS - 1);
for i in 0..DIV as u32 {
let s1 = i * INTERVAL;
let s2f = s1.checked_add(MAX_FORWARD_DISTANCE);
let s2b = s1.checked_add(MAX_BACKWARD_DISTANCE);
if let (Some(s2f), Some(s2b)) = (s2f, s2b) {
assert!(
sna32lt(s1, s2f),
"s1 < s2 should be true: s1={} s2={}",
s1,
s2f
);
assert!(
!sna32lt(s1, s2b),
"s1 < s2 should be false: s1={} s2={}",
s1,
s2b
);
assert!(
!sna32gt(s1, s2f),
"s1 > s2 should be false: s1={} s2={}",
s1,
s2f
);
assert!(
sna32gt(s1, s2b),
"s1 > s2 should be true: s1={} s2={}",
s1,
s2b
);
assert!(
sna32lte(s1, s2f),
"s1 <= s2 should be true: s1={} s2={}",
s1,
s2f
);
assert!(
!sna32lte(s1, s2b),
"s1 <= s2 should be false: s1={} s2={}",
s1,
s2b
);
assert!(
!sna32gte(s1, s2f),
"s1 >= s2 should be fales: s1={} s2={}",
s1,
s2f
);
assert!(
sna32gte(s1, s2b),
"s1 >= s2 should be true: s1={} s2={}",
s1,
s2b
);
assert!(
sna32eq(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
assert!(
sna32lte(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
assert!(
sna32gte(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
}
if let Some(s1add1) = s1.checked_add(1) {
assert!(
!sna32eq(s1, s1add1),
"s1 == s1+1 should be false: s1={} s1+1={}",
s1,
s1add1
);
}
if let Some(s1sub1) = s1.checked_sub(1) {
assert!(
!sna32eq(s1, s1sub1),
"s1 == s1-1 hould be false: s1={} s1-1={}",
s1,
s1sub1
);
}
assert!(
sna32eq(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
assert!(
sna32lte(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
assert!(
sna32gte(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
}
Ok(())
}
#[test]
fn test_serial_number_arithmetic16bit() -> Result<()> {
const SERIAL_BITS: u16 = 16;
const INTERVAL: u16 = ((1u64 << (SERIAL_BITS as u64)) / (DIV as u64)) as u16;
const MAX_FORWARD_DISTANCE: u16 = 1 << ((SERIAL_BITS - 1) - 1);
const MAX_BACKWARD_DISTANCE: u16 = 1 << (SERIAL_BITS - 1);
for i in 0..DIV as u16 {
let s1 = i * INTERVAL;
let s2f = s1.checked_add(MAX_FORWARD_DISTANCE);
let s2b = s1.checked_add(MAX_BACKWARD_DISTANCE);
if let (Some(s2f), Some(s2b)) = (s2f, s2b) {
assert!(
sna16lt(s1, s2f),
"s1 < s2 should be true: s1={} s2={}",
s1,
s2f
);
assert!(
!sna16lt(s1, s2b),
"s1 < s2 should be false: s1={} s2={}",
s1,
s2b
);
assert!(
!sna16gt(s1, s2f),
"s1 > s2 should be fales: s1={} s2={}",
s1,
s2f
);
assert!(
sna16gt(s1, s2b),
"s1 > s2 should be true: s1={} s2={}",
s1,
s2b
);
assert!(
sna16lte(s1, s2f),
"s1 <= s2 should be true: s1={} s2={}",
s1,
s2f
);
assert!(
!sna16lte(s1, s2b),
"s1 <= s2 should be false: s1={} s2={}",
s1,
s2b
);
assert!(
!sna16gte(s1, s2f),
"s1 >= s2 should be fales: s1={} s2={}",
s1,
s2f
);
assert!(
sna16gte(s1, s2b),
"s1 >= s2 should be true: s1={} s2={}",
s1,
s2b
);
assert!(
sna16eq(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
assert!(
sna16lte(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
assert!(
sna16gte(s2b, s2b),
"s2 == s2 should be true: s2={} s2={}",
s2b,
s2b
);
}
assert!(
sna16eq(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
if let Some(s1add1) = s1.checked_add(1) {
assert!(
!sna16eq(s1, s1add1),
"s1 == s1+1 should be false: s1={} s1+1={}",
s1,
s1add1
);
}
if let Some(s1sub1) = s1.checked_sub(1) {
assert!(
!sna16eq(s1, s1sub1),
"s1 == s1-1 hould be false: s1={} s1-1={}",
s1,
s1sub1
);
}
assert!(
sna16lte(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
assert!(
sna16gte(s1, s1),
"s1 == s1 should be true: s1={} s2={}",
s1,
s1
);
}
Ok(())
}
}