use ark_babyjubjub::Fq;
use ark_ff::Zero;
use sha3::{Digest, Sha3_256};
use crate::{FieldElement, PrimitiveError};
const CHUNK_SIZE_BYTES: usize = 31; const RATE_ELEMENTS: usize = 15;
const IO_ABSORB_PREFIX: u32 = 0x8000_0000;
const IO_SQUEEZE_PREFIX: u32 = 0x0000_0000;
const IO_SQUEEZE_LEN_BYTES: u32 = 32;
pub fn hash_bytes_to_field_element(
ds_tag: &[u8],
data: &[u8],
) -> Result<FieldElement, PrimitiveError> {
if data.is_empty() {
return Err(PrimitiveError::InvalidInput {
attribute: "associated_data".to_string(),
reason: "data cannot be empty".to_string(),
});
}
if data.len() > (u32::MAX as usize) {
return Err(PrimitiveError::InvalidInput {
attribute: "associated_data".to_string(),
reason: "data length exceeds supported range (u32::MAX)".to_string(),
});
}
hash_bytes_with_poseidon2_t16_r15(data, ds_tag, "associated_data")
}
#[must_use]
pub fn bytes_to_field_elements(chunk_size: usize, data: &[u8]) -> Vec<Fq> {
data.chunks(chunk_size)
.map(|chunk| *FieldElement::from_be_bytes_mod_order(chunk))
.collect()
}
struct IoPattern<'a> {
expected: Vec<u32>,
idx: usize,
attr: &'a str,
}
impl<'a> IoPattern<'a> {
#[allow(clippy::missing_const_for_fn)]
#[must_use]
fn new(attr: &'a str, expected: Vec<u32>) -> Self {
Self {
expected,
idx: 0,
attr,
}
}
fn record_absorb(&mut self, len_bytes: u32) -> Result<(), PrimitiveError> {
self.check(IO_ABSORB_PREFIX.wrapping_add(len_bytes), "absorb")
}
fn record_squeeze(&mut self, len_bytes: u32) -> Result<(), PrimitiveError> {
self.check(IO_SQUEEZE_PREFIX.wrapping_add(len_bytes), "squeeze")
}
fn finish(self) -> Result<(), PrimitiveError> {
if self.idx != self.expected.len() {
return Err(PrimitiveError::InvalidInput {
attribute: self.attr.to_string(),
reason: "SAFE IO pattern not fully consumed".to_string(),
});
}
Ok(())
}
fn check(&mut self, word: u32, label: &str) -> Result<(), PrimitiveError> {
if self.idx >= self.expected.len() || self.expected[self.idx] != word {
return Err(PrimitiveError::InvalidInput {
attribute: self.attr.to_string(),
reason: format!("SAFE IO pattern violated during {label}"),
});
}
self.idx += 1;
Ok(())
}
}
#[must_use]
fn derive_safe_tag(
absorb_len_bytes: u32,
squeeze_len_bytes: u32,
domain_separator: &[u8],
) -> FieldElement {
let absorb_word = IO_ABSORB_PREFIX
.wrapping_add(absorb_len_bytes)
.to_be_bytes();
let squeeze_word = IO_SQUEEZE_PREFIX
.wrapping_add(squeeze_len_bytes)
.to_be_bytes();
let mut tag_input =
Vec::with_capacity(absorb_word.len() + squeeze_word.len() + domain_separator.len());
tag_input.extend_from_slice(&absorb_word);
tag_input.extend_from_slice(&squeeze_word);
tag_input.extend_from_slice(domain_separator);
let tag_digest = Sha3_256::digest(&tag_input);
FieldElement::from_be_bytes_mod_order(&tag_digest)
}
fn hash_bytes_with_poseidon2_t16_r15(
data: &[u8],
domain_separator: &[u8],
attr: &str,
) -> Result<FieldElement, PrimitiveError> {
if data.is_empty() {
return Err(PrimitiveError::InvalidInput {
attribute: attr.to_string(),
reason: "data cannot be empty".to_string(),
});
}
let data_len_u32 = u32::try_from(data.len()).map_err(|_| PrimitiveError::InvalidInput {
attribute: attr.to_string(),
reason: "data length exceeds supported range (u32::MAX)".to_string(),
})?;
let mut state: [Fq; 16] = [Fq::zero(); 16];
let mut io_pattern = IoPattern::new(
attr,
vec![
IO_ABSORB_PREFIX.wrapping_add(data_len_u32),
IO_SQUEEZE_PREFIX.wrapping_add(IO_SQUEEZE_LEN_BYTES),
],
);
io_pattern.record_absorb(data_len_u32)?;
let tag_fe: Fq = *derive_safe_tag(data_len_u32, IO_SQUEEZE_LEN_BYTES, domain_separator);
state[15] += tag_fe;
let field_elements = bytes_to_field_elements(CHUNK_SIZE_BYTES, data);
for batch in field_elements.chunks(RATE_ELEMENTS) {
for (i, &elem) in batch.iter().enumerate() {
state[i] += elem;
}
poseidon2::bn254::t16::permutation_in_place(&mut state);
}
io_pattern.record_squeeze(IO_SQUEEZE_LEN_BYTES)?;
io_pattern.finish()?;
Ok(FieldElement::from(state[0]))
}
#[cfg(test)]
mod tests {
use crate::{FieldElement, PrimitiveError, sponge::hash_bytes_with_poseidon2_t16_r15};
use super::hash_bytes_to_field_element;
const TEST_DS_TAG: &[u8] = b"TEST_DS_TAG";
#[test]
fn derive_tag_stable() {
let tag = super::derive_safe_tag(10, super::IO_SQUEEZE_LEN_BYTES, b"DS");
let again = super::derive_safe_tag(10, super::IO_SQUEEZE_LEN_BYTES, b"DS");
assert_eq!(tag, again);
}
#[test]
fn hash_bytes_rejects_empty() {
let res = hash_bytes_with_poseidon2_t16_r15(&[], b"DS", "test");
assert!(matches!(
res,
Err(PrimitiveError::InvalidInput { attribute, .. }) if attribute == "test"
));
}
#[test]
fn hash_bytes_deterministic_nonzero() {
let data = vec![1u8, 2, 3, 4];
let h1 = hash_bytes_with_poseidon2_t16_r15(&data, b"DS", "test").unwrap();
let h2 = hash_bytes_with_poseidon2_t16_r15(&data, b"DS", "test").unwrap();
assert_eq!(h1, h2);
assert_ne!(h1, FieldElement::ZERO);
}
#[test]
fn test_hash_bytes_to_field_element_basic() {
let data = vec![1u8, 2, 3, 4, 5];
let result = hash_bytes_to_field_element(TEST_DS_TAG, &data);
assert!(result.is_ok());
let hash = result.unwrap();
assert_ne!(hash, FieldElement::ZERO);
}
#[test]
fn test_hash_bytes_to_field_element_deterministic() {
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let result1 = hash_bytes_to_field_element(TEST_DS_TAG, &data).unwrap();
let result2 = hash_bytes_to_field_element(TEST_DS_TAG, &data).unwrap();
assert_eq!(result1, result2);
assert_ne!(result1, FieldElement::ZERO);
}
#[test]
fn test_hash_bytes_to_field_element_different_inputs() {
let data1 = vec![1u8, 2, 3, 4, 5];
let data2 = vec![5u8, 4, 3, 2, 1];
let data3 = vec![1u8, 2, 3, 4, 5, 6];
let hash1 = hash_bytes_to_field_element(TEST_DS_TAG, &data1).unwrap();
let hash2 = hash_bytes_to_field_element(TEST_DS_TAG, &data2).unwrap();
let hash3 = hash_bytes_to_field_element(TEST_DS_TAG, &data3).unwrap();
assert_ne!(hash1, hash2);
assert_ne!(hash1, hash3);
assert_ne!(hash2, hash3);
}
#[test]
fn test_hash_bytes_to_field_element_empty_error() {
let data: Vec<u8> = vec![];
let result = hash_bytes_to_field_element(TEST_DS_TAG, &data);
assert!(result.is_err());
if let Err(PrimitiveError::InvalidInput { attribute, reason }) = result {
assert_eq!(attribute, "associated_data");
assert!(reason.contains("empty"));
} else {
panic!("Expected InvalidInput error");
}
}
#[test]
fn test_hash_bytes_to_field_element_large_input() {
let data = vec![42u8; 10 * 1024];
let result = hash_bytes_to_field_element(TEST_DS_TAG, &data);
assert!(result.is_ok());
let hash = result.unwrap();
assert_ne!(hash, FieldElement::ZERO);
}
#[test]
fn test_hash_bytes_to_field_element_length_domain_separation() {
let data1 = vec![0u8; 10];
let data2 = vec![0u8; 11];
let hash1 = hash_bytes_to_field_element(TEST_DS_TAG, &data1).unwrap();
let hash2 = hash_bytes_to_field_element(TEST_DS_TAG, &data2).unwrap();
assert_ne!(hash1, hash2);
}
#[test]
fn test_hash_bytes_chunk_boundaries_and_batches() {
let sizes = [
1usize,
31,
32,
33,
15 * 31, 15 * 31 + 1, ];
for size in sizes {
let data = vec![42u8; size];
let h1 = hash_bytes_to_field_element(TEST_DS_TAG, &data).unwrap();
let h2 = hash_bytes_to_field_element(TEST_DS_TAG, &data).unwrap();
assert_ne!(
h1,
FieldElement::ZERO,
"size {size} should not hash to zero"
);
assert_eq!(h1, h2, "hash should be deterministic for size {size}");
}
}
}