use alloc::vec::Vec;
use crypto_bigint::U256;
#[must_use]
pub fn encode_sig(r: &U256, s: &U256) -> Vec<u8> {
let r_der = encode_integer(&r.to_be_bytes());
let s_der = encode_integer(&s.to_be_bytes());
let body_len = r_der.len() + s_der.len();
let mut out = Vec::with_capacity(body_len + 8);
out.push(0x30); push_length(&mut out, body_len);
out.extend_from_slice(&r_der);
out.extend_from_slice(&s_der);
out
}
#[must_use]
pub fn decode_sig(input: &[u8]) -> Option<(U256, U256)> {
let (tag, rest) = input.split_first()?;
if *tag != 0x30 {
return None;
}
let (body_len, rest) = read_length(rest)?;
if rest.len() != body_len {
return None;
}
let (r, rest) = read_integer(rest)?;
let (s, rest) = read_integer(rest)?;
if !rest.is_empty() {
return None;
}
Some((r, s))
}
fn encode_integer(value_be: &[u8]) -> Vec<u8> {
let mut start = 0;
while start < value_be.len() - 1 && value_be[start] == 0 {
start += 1;
}
let trimmed = &value_be[start..];
let needs_pad = (trimmed[0] & 0x80) != 0;
let int_len = trimmed.len() + usize::from(needs_pad);
let mut out = Vec::with_capacity(int_len + 4);
out.push(0x02); push_length(&mut out, int_len);
if needs_pad {
out.push(0x00);
}
out.extend_from_slice(trimmed);
out
}
fn read_integer(input: &[u8]) -> Option<(U256, &[u8])> {
let (tag, rest) = input.split_first()?;
if *tag != 0x02 {
return None;
}
let (int_len, rest) = read_length(rest)?;
if rest.len() < int_len {
return None;
}
let (int_bytes, rest_after) = rest.split_at(int_len);
if int_bytes.is_empty() {
return None;
}
if int_bytes[0] & 0x80 != 0 {
return None;
}
let bytes = if int_bytes[0] == 0x00 {
if int_bytes.len() < 2 || int_bytes[1] & 0x80 == 0 {
return None;
}
&int_bytes[1..]
} else {
int_bytes
};
if bytes.len() > 32 {
return None;
}
let mut padded = [0u8; 32];
padded[32 - bytes.len()..].copy_from_slice(bytes);
Some((U256::from_be_slice(&padded), rest_after))
}
fn push_length(out: &mut Vec<u8>, len: usize) {
if len < 128 {
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
} else if len < 256 {
out.push(0x81);
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
} else if len < 65_536 {
#[allow(clippy::cast_possible_truncation)]
{
out.push(0x82);
out.push((len >> 8) as u8);
out.push(len as u8);
}
} else {
panic!("signature DER length overflow");
}
}
fn read_length(input: &[u8]) -> Option<(usize, &[u8])> {
let (first, rest) = input.split_first()?;
if *first < 0x80 {
Some((*first as usize, rest))
} else if *first == 0x81 {
let (b, rest) = rest.split_first()?;
if *b < 0x80 {
return None;
} Some((*b as usize, rest))
} else if *first == 0x82 {
let (hi, rest) = rest.split_first()?;
let (lo, rest) = rest.split_first()?;
let len = ((*hi as usize) << 8) | (*lo as usize);
if len < 256 {
return None;
} Some((len, rest))
} else {
None }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_small() {
let r = U256::from_u64(0x1234);
let s = U256::from_u64(0x5678);
let der = encode_sig(&r, &s);
let (r2, s2) = decode_sig(&der).expect("round-trip");
assert_eq!(r2, r);
assert_eq!(s2, s);
}
#[test]
fn round_trip_large_with_high_bit() {
let r =
U256::from_be_hex("FF00000000000000000000000000000000000000000000000000000000000001");
let s =
U256::from_be_hex("8000000000000000000000000000000000000000000000000000000000000002");
let der = encode_sig(&r, &s);
let (r2, s2) = decode_sig(&der).expect("round-trip");
assert_eq!(r2, r);
assert_eq!(s2, s);
}
#[test]
fn malformed_returns_none() {
assert!(decode_sig(&[]).is_none());
assert!(decode_sig(&[0x30]).is_none()); assert!(decode_sig(&[0x31, 0x00]).is_none()); assert!(decode_sig(&[0x30, 0x05, 0x02, 0x01, 0x01]).is_none()); }
#[test]
fn rejects_non_canonical_leading_zero() {
let bad = [0x30, 0x07, 0x02, 0x02, 0x00, 0x01, 0x02, 0x01, 0x01];
assert!(
decode_sig(&bad).is_none(),
"non-canonical 00-pad on small int must be rejected"
);
}
#[test]
fn rejects_negative_integer_encoding() {
let bad = [0x30, 0x06, 0x02, 0x01, 0x80, 0x02, 0x01, 0x01];
assert!(
decode_sig(&bad).is_none(),
"high-bit-set first byte without 00 pad must be rejected"
);
}
#[test]
fn rejects_empty_integer() {
let bad = [0x30, 0x05, 0x02, 0x00, 0x02, 0x01, 0x01];
assert!(
decode_sig(&bad).is_none(),
"empty INTEGER content must be rejected"
);
}
}