use std::io::{Read, Write};
use crate::error::Result;
pub fn hash(s: &[u8]) -> u32 {
let mut h: u32 = 2166136261;
for &byte in s {
h ^= byte as i8 as i32 as u32;
h = h.wrapping_mul(16777619);
}
h
}
pub fn read_i32<R: Read>(reader: &mut R) -> Result<i32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(i32::from_le_bytes(buf))
}
pub fn read_i64<R: Read>(reader: &mut R) -> Result<i64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(i64::from_le_bytes(buf))
}
pub fn read_f32<R: Read>(reader: &mut R) -> Result<f32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(f32::from_le_bytes(buf))
}
pub fn read_f64<R: Read>(reader: &mut R) -> Result<f64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf)?;
Ok(f64::from_le_bytes(buf))
}
pub fn write_i32<W: Write>(writer: &mut W, value: i32) -> Result<()> {
writer.write_all(&value.to_le_bytes())?;
Ok(())
}
pub fn write_i64<W: Write>(writer: &mut W, value: i64) -> Result<()> {
writer.write_all(&value.to_le_bytes())?;
Ok(())
}
pub fn write_f32<W: Write>(writer: &mut W, value: f32) -> Result<()> {
writer.write_all(&value.to_le_bytes())?;
Ok(())
}
pub fn write_f64<W: Write>(writer: &mut W, value: f64) -> Result<()> {
writer.write_all(&value.to_le_bytes())?;
Ok(())
}
pub fn read_bool<R: Read>(reader: &mut R) -> Result<bool> {
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
Ok(buf[0] != 0)
}
pub fn write_bool<W: Write>(writer: &mut W, value: bool) -> Result<()> {
writer.write_all(&[value as u8])?;
Ok(())
}
pub fn softmax_in_place(data: &mut [f32]) {
let max = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut z = 0.0f32;
for v in data.iter_mut() {
*v = (*v - max).exp();
z += *v;
}
if z > 0.0 {
let inv_z = 1.0 / z;
for v in data.iter_mut() {
*v *= inv_z;
}
}
}
pub fn l2_normalize(v: &mut [f32]) {
let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm < 1e-8 {
return;
}
let inv = 1.0 / norm;
for x in v.iter_mut() {
*x *= inv;
}
}
pub fn cpp_default_format(val: f64, sig_digits: usize) -> String {
let sci = format!("{:.prec$e}", val, prec = sig_digits.saturating_sub(1));
let e_pos = sci.find('e').unwrap();
let exp: i32 = sci[e_pos + 1..].parse().unwrap();
if exp >= -4 && exp < sig_digits as i32 {
let decimal_places = if exp >= 0 {
sig_digits.saturating_sub(1 + exp as usize)
} else {
sig_digits + (-exp as usize) - 1
};
let fixed = format!("{:.prec$}", val, prec = decimal_places);
if fixed.contains('.') {
fixed
.trim_end_matches('0')
.trim_end_matches('.')
.to_string()
} else {
fixed
}
} else {
let mantissa = sci[..e_pos].trim_end_matches('0').trim_end_matches('.');
let exp_sign = if exp < 0 { "-" } else { "+" };
let exp_abs = exp.unsigned_abs();
if exp_abs < 10 {
format!("{}e{}{:02}", mantissa, exp_sign, exp_abs)
} else {
format!("{}e{}{}", mantissa, exp_sign, exp_abs)
}
}
}
#[derive(Clone, Copy, PartialEq)]
pub struct OrdF32(pub f32);
impl Eq for OrdF32 {}
impl PartialOrd for OrdF32 {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrdF32 {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn roundtrip_i32(val: i32) -> i32 {
let mut buf = Vec::new();
write_i32(&mut buf, val).unwrap();
let mut cursor = Cursor::new(&buf);
read_i32(&mut cursor).unwrap()
}
fn roundtrip_i64(val: i64) -> i64 {
let mut buf = Vec::new();
write_i64(&mut buf, val).unwrap();
let mut cursor = Cursor::new(&buf);
read_i64(&mut cursor).unwrap()
}
fn roundtrip_f32(val: f32) -> f32 {
let mut buf = Vec::new();
write_f32(&mut buf, val).unwrap();
let mut cursor = Cursor::new(&buf);
read_f32(&mut cursor).unwrap()
}
fn roundtrip_f64(val: f64) -> f64 {
let mut buf = Vec::new();
write_f64(&mut buf, val).unwrap();
let mut cursor = Cursor::new(&buf);
read_f64(&mut cursor).unwrap()
}
#[test]
fn test_binary_io_roundtrip_i32() {
assert_eq!(roundtrip_i32(0), 0);
assert_eq!(roundtrip_i32(-1), -1);
assert_eq!(roundtrip_i32(1), 1);
assert_eq!(roundtrip_i32(i32::MAX), i32::MAX);
assert_eq!(roundtrip_i32(i32::MIN), i32::MIN);
assert_eq!(roundtrip_i32(42), 42);
assert_eq!(roundtrip_i32(-42), -42);
assert_eq!(roundtrip_i32(793712314), 793712314); }
#[test]
fn test_binary_io_roundtrip_i64() {
assert_eq!(roundtrip_i64(0), 0);
assert_eq!(roundtrip_i64(-1), -1);
assert_eq!(roundtrip_i64(1), 1);
assert_eq!(roundtrip_i64(i64::MAX), i64::MAX);
assert_eq!(roundtrip_i64(i64::MIN), i64::MIN);
assert_eq!(roundtrip_i64(42), 42);
assert_eq!(roundtrip_i64(-42), -42);
}
#[test]
fn test_binary_io_roundtrip_f32() {
assert_eq!(roundtrip_f32(0.0), 0.0);
assert_eq!(roundtrip_f32(-0.0).to_bits(), (-0.0_f32).to_bits());
assert_eq!(roundtrip_f32(1.0), 1.0);
assert_eq!(roundtrip_f32(-1.0), -1.0);
assert_eq!(roundtrip_f32(f32::MAX), f32::MAX);
assert_eq!(roundtrip_f32(f32::MIN), f32::MIN);
assert_eq!(roundtrip_f32(f32::INFINITY), f32::INFINITY);
assert_eq!(roundtrip_f32(f32::NEG_INFINITY), f32::NEG_INFINITY);
assert_eq!(roundtrip_f32(f32::MIN_POSITIVE), f32::MIN_POSITIVE);
assert!(roundtrip_f32(f32::NAN).is_nan());
let subnormal: f32 = 1.0e-40;
assert_eq!(roundtrip_f32(subnormal), subnormal);
}
#[test]
fn test_binary_io_roundtrip_f64() {
assert_eq!(roundtrip_f64(0.0), 0.0);
assert_eq!(roundtrip_f64(-0.0).to_bits(), (-0.0_f64).to_bits());
assert_eq!(roundtrip_f64(1.0), 1.0);
assert_eq!(roundtrip_f64(-1.0), -1.0);
assert_eq!(roundtrip_f64(f64::MAX), f64::MAX);
assert_eq!(roundtrip_f64(f64::MIN), f64::MIN);
assert_eq!(roundtrip_f64(f64::INFINITY), f64::INFINITY);
assert_eq!(roundtrip_f64(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert_eq!(roundtrip_f64(f64::MIN_POSITIVE), f64::MIN_POSITIVE);
assert!(roundtrip_f64(f64::NAN).is_nan());
let subnormal: f64 = 5.0e-324;
assert_eq!(roundtrip_f64(subnormal), subnormal);
assert_eq!(roundtrip_f64(1e-4), 1e-4);
}
#[test]
fn test_binary_io_little_endian_layout() {
let mut buf = Vec::new();
write_i32(&mut buf, 0x04030201).unwrap();
assert_eq!(buf, vec![0x01, 0x02, 0x03, 0x04]);
let mut buf = Vec::new();
write_i64(&mut buf, 0x0807060504030201).unwrap();
assert_eq!(buf, vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
}
#[test]
#[allow(clippy::approx_constant)]
fn test_binary_io_write_read_sequence() {
let mut buf = Vec::new();
write_i32(&mut buf, 100).unwrap();
write_i64(&mut buf, 200).unwrap();
write_f32(&mut buf, 3.14).unwrap();
write_f64(&mut buf, 2.71828).unwrap();
let mut cursor = Cursor::new(&buf);
assert_eq!(read_i32(&mut cursor).unwrap(), 100);
assert_eq!(read_i64(&mut cursor).unwrap(), 200);
assert!((read_f32(&mut cursor).unwrap() - 3.14).abs() < 1e-6);
assert!((read_f64(&mut cursor).unwrap() - 2.71828).abs() < 1e-10);
}
#[test]
fn test_binary_io_read_truncated() {
let buf = vec![0u8; 2];
let mut cursor = Cursor::new(&buf);
assert!(read_i32(&mut cursor).is_err());
let buf = vec![0u8; 4];
let mut cursor = Cursor::new(&buf);
assert!(read_i64(&mut cursor).is_err());
let buf = vec![0u8; 2];
let mut cursor = Cursor::new(&buf);
assert!(read_f32(&mut cursor).is_err());
let buf = vec![0u8; 4];
let mut cursor = Cursor::new(&buf);
assert!(read_f64(&mut cursor).is_err());
}
#[test]
fn test_binary_io_roundtrip_primitives() {
for &val in &[0i32, -1, 1, i32::MAX, i32::MIN, 42, -42, 793712314] {
assert_eq!(roundtrip_i32(val), val, "i32 roundtrip failed for {}", val);
}
for &val in &[0i64, -1, 1, i64::MAX, i64::MIN, 42, -42] {
assert_eq!(roundtrip_i64(val), val, "i64 roundtrip failed for {}", val);
}
for &val in &[0.0f32, -1.0, 1.0, f32::MAX, f32::MIN, f32::MIN_POSITIVE] {
assert_eq!(roundtrip_f32(val), val, "f32 roundtrip failed for {}", val);
}
assert_eq!(roundtrip_f32(f32::INFINITY), f32::INFINITY);
assert_eq!(roundtrip_f32(f32::NEG_INFINITY), f32::NEG_INFINITY);
assert!(roundtrip_f32(f32::NAN).is_nan());
for &val in &[0.0f64, -1.0, 1.0, f64::MAX, f64::MIN, f64::MIN_POSITIVE] {
assert_eq!(roundtrip_f64(val), val, "f64 roundtrip failed for {}", val);
}
assert_eq!(roundtrip_f64(f64::INFINITY), f64::INFINITY);
assert_eq!(roundtrip_f64(f64::NEG_INFINITY), f64::NEG_INFINITY);
assert!(roundtrip_f64(f64::NAN).is_nan());
let subnormal_f32: f32 = 1.0e-40;
assert_eq!(roundtrip_f32(subnormal_f32), subnormal_f32);
let subnormal_f64: f64 = 5.0e-324;
assert_eq!(roundtrip_f64(subnormal_f64), subnormal_f64);
}
#[test]
fn test_fnv1a_reference_vectors() {
assert_eq!(hash(b""), 2166136261, "hash('') should be FNV offset basis");
assert_eq!(hash(b"hello"), 1335831723);
assert_eq!(hash(b"a"), 3826002220);
assert_eq!(hash(b"</s>"), 3617362777);
}
#[test]
fn test_fnv1a_signed_byte_extension() {
fn hash_unsigned(s: &[u8]) -> u32 {
let mut h: u32 = 2166136261;
for &byte in s {
h ^= byte as u32; h = h.wrapping_mul(16777619);
}
h
}
let h_signed = hash(&[0x80]);
let h_unsigned = hash_unsigned(&[0x80]);
assert_ne!(
h_signed, h_unsigned,
"0x80 should differ between signed and unsigned paths"
);
assert_eq!(h_signed, 83079839, "hash([0x80]) with signed extension");
let h_signed = hash(&[0xFF]);
let h_unsigned = hash_unsigned(&[0xFF]);
assert_ne!(
h_signed, h_unsigned,
"0xFF should differ between signed and unsigned paths"
);
assert_eq!(h_signed, 4193493326, "hash([0xFF]) with signed extension");
let h_signed = hash(&[0xFE]);
let h_unsigned = hash_unsigned(&[0xFE]);
assert_ne!(
h_signed, h_unsigned,
"0xFE should differ between signed and unsigned paths"
);
assert_eq!(h_signed, 4210270945, "hash([0xFE]) with signed extension");
let h_signed = hash(&[0xC0]);
let h_unsigned = hash_unsigned(&[0xC0]);
assert_ne!(
h_signed, h_unsigned,
"0xC0 should differ between signed and unsigned paths"
);
assert_eq!(h_signed, 3304279519, "hash([0xC0]) with signed extension");
}
#[test]
fn test_fnv1a_utf8_multibyte() {
assert_eq!(hash("日本語".as_bytes()), 308035559);
assert_eq!(hash("café".as_bytes()), 1970454601);
}
#[test]
fn test_fnv1a_empty_string() {
assert_eq!(hash(b""), 2166136261);
}
#[test]
fn test_fnv1a_single_byte_values() {
fn hash_unsigned(s: &[u8]) -> u32 {
let mut h: u32 = 2166136261;
for &byte in s {
h ^= byte as u32;
h = h.wrapping_mul(16777619);
}
h
}
for b in 0u8..=127 {
assert_eq!(
hash(&[b]),
hash_unsigned(&[b]),
"byte {} (0x{:02X}) should match between signed and unsigned for values 0-127",
b,
b
);
}
for b in 128u8..=255 {
assert_ne!(
hash(&[b]),
hash_unsigned(&[b]),
"byte {} (0x{:02X}) should differ between signed and unsigned for values 128-255",
b,
b
);
}
}
#[test]
fn test_fnv1a_longer_strings() {
let h1 = hash(b"test");
let h2 = hash(b"test");
assert_eq!(h1, h2, "Same input must produce same hash");
let h3 = hash(b"tset");
assert_ne!(h1, h3, "Different inputs should produce different hashes");
let h4 = hash(b"the quick brown fox");
let h5 = hash(b"the quick brown fox");
assert_eq!(h4, h5, "Same longer input must produce same hash");
}
#[test]
fn test_fnv1a_hash_str_convenience() {
let s = "hello";
assert_eq!(hash(s.as_bytes()), 1335831723);
}
}