use std::fmt;
use alloy_primitives::{U256, keccak256};
use crate::error::CowError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OrderRefund {
pub order_uid: String,
pub refund_type: RefundType,
}
impl OrderRefund {
#[must_use]
pub const fn new(order_uid: String, refund_type: RefundType) -> Self {
Self { order_uid, refund_type }
}
}
impl fmt::Display for OrderRefund {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Refund({}, {:?})", self.order_uid, self.refund_type)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RefundType {
Settlement,
EthFlow,
}
impl RefundType {
#[must_use]
pub const fn is_settlement(&self) -> bool {
matches!(self, Self::Settlement)
}
#[must_use]
pub const fn is_eth_flow(&self) -> bool {
matches!(self, Self::EthFlow)
}
}
impl fmt::Display for RefundType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Settlement => write!(f, "Settlement"),
Self::EthFlow => write!(f, "EthFlow"),
}
}
}
pub fn settlement_refund_calldata(order_uid: &str) -> Result<Vec<u8>, CowError> {
let uid_bytes = decode_uid(order_uid)?;
let sel = selector("freeFilledAmountStorage(bytes)");
let padded_len = padded32(uid_bytes.len());
let mut buf = Vec::with_capacity(4 + 32 + 32 + padded_len);
buf.extend_from_slice(&sel);
buf.extend_from_slice(&u256_be(32));
buf.extend_from_slice(&u256_be(uid_bytes.len() as u64));
buf.extend_from_slice(&uid_bytes);
pad_to(&mut buf, uid_bytes.len());
Ok(buf)
}
pub fn ethflow_refund_calldata(order_uid: &str) -> Result<Vec<u8>, CowError> {
let uid_bytes = decode_uid(order_uid)?;
let sel = selector("invalidateOrder(bytes)");
let padded_len = padded32(uid_bytes.len());
let mut buf = Vec::with_capacity(4 + 32 + 32 + padded_len);
buf.extend_from_slice(&sel);
buf.extend_from_slice(&u256_be(32));
buf.extend_from_slice(&u256_be(uid_bytes.len() as u64));
buf.extend_from_slice(&uid_bytes);
pad_to(&mut buf, uid_bytes.len());
Ok(buf)
}
#[must_use]
pub fn is_refundable(filled_amount: U256, total_amount: U256) -> bool {
filled_amount < total_amount
}
#[must_use]
pub const fn refund_amount(filled_amount: U256, total_amount: U256) -> U256 {
total_amount.saturating_sub(filled_amount)
}
fn selector(sig: &str) -> [u8; 4] {
let h = keccak256(sig.as_bytes());
[h[0], h[1], h[2], h[3]]
}
fn u256_be(v: u64) -> [u8; 32] {
let mut out = [0u8; 32];
out[24..].copy_from_slice(&v.to_be_bytes());
out
}
fn pad_to(buf: &mut Vec<u8>, written: usize) {
let rem = written % 32;
if rem != 0 {
buf.resize(buf.len() + (32 - rem), 0);
}
}
const fn padded32(n: usize) -> usize {
if n.is_multiple_of(32) { n } else { n + (32 - n % 32) }
}
fn decode_uid(uid: &str) -> Result<Vec<u8>, CowError> {
let stripped = uid.trim_start_matches("0x");
alloy_primitives::hex::decode(stripped)
.map_err(|_e| CowError::Api { status: 0, body: format!("invalid orderUid: {uid}") })
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_uid_56() -> String {
"0x".to_owned() + &"ab".repeat(56)
}
#[test]
fn order_refund_new() {
let refund = OrderRefund::new("0xdead".to_owned(), RefundType::Settlement);
assert_eq!(refund.order_uid, "0xdead");
assert_eq!(refund.refund_type, RefundType::Settlement);
}
#[test]
fn order_refund_display() {
let refund = OrderRefund::new("0xbeef".to_owned(), RefundType::EthFlow);
let s = format!("{refund}");
assert!(s.contains("0xbeef"));
assert!(s.contains("EthFlow"));
}
#[test]
fn order_refund_clone_eq() {
let a = OrderRefund::new("0xaa".to_owned(), RefundType::Settlement);
let b = a.clone();
assert_eq!(a, b);
}
#[test]
fn refund_type_is_settlement() {
assert!(RefundType::Settlement.is_settlement());
assert!(!RefundType::Settlement.is_eth_flow());
}
#[test]
fn refund_type_is_eth_flow() {
assert!(RefundType::EthFlow.is_eth_flow());
assert!(!RefundType::EthFlow.is_settlement());
}
#[test]
fn refund_type_display() {
assert_eq!(format!("{}", RefundType::Settlement), "Settlement");
assert_eq!(format!("{}", RefundType::EthFlow), "EthFlow");
}
#[test]
fn refund_type_copy() {
let a = RefundType::Settlement;
let b = a;
assert_eq!(a, b);
}
#[test]
fn settlement_refund_calldata_valid() {
let uid = dummy_uid_56();
let data = settlement_refund_calldata(&uid).unwrap();
assert_eq!(data.len(), 132);
assert_eq!(&data[..4], &selector("freeFilledAmountStorage(bytes)"));
}
#[test]
fn settlement_refund_calldata_invalid_hex() {
assert!(settlement_refund_calldata("0xZZZZ").is_err());
}
#[test]
fn settlement_refund_calldata_without_prefix() {
let uid = "ab".repeat(56);
let data = settlement_refund_calldata(&uid).unwrap();
assert_eq!(data.len(), 132);
}
#[test]
fn settlement_refund_calldata_empty_uid() {
let data = settlement_refund_calldata("0x").unwrap();
assert_eq!(data.len(), 68);
}
#[test]
fn ethflow_refund_calldata_valid() {
let uid = dummy_uid_56();
let data = ethflow_refund_calldata(&uid).unwrap();
assert_eq!(data.len(), 132);
assert_eq!(&data[..4], &selector("invalidateOrder(bytes)"));
}
#[test]
fn ethflow_refund_calldata_invalid_hex() {
assert!(ethflow_refund_calldata("not_hex_gg").is_err());
}
#[test]
fn ethflow_refund_calldata_without_prefix() {
let uid = "cd".repeat(56);
let data = ethflow_refund_calldata(&uid).unwrap();
assert_eq!(data.len(), 132);
}
#[test]
fn is_refundable_zero_filled() {
assert!(is_refundable(U256::ZERO, U256::from(1000)));
}
#[test]
fn is_refundable_partial_filled() {
assert!(is_refundable(U256::from(500), U256::from(1000)));
}
#[test]
fn is_refundable_fully_filled() {
assert!(!is_refundable(U256::from(1000), U256::from(1000)));
}
#[test]
fn is_refundable_zero_total() {
assert!(!is_refundable(U256::ZERO, U256::ZERO));
}
#[test]
fn refund_amount_partial() {
assert_eq!(refund_amount(U256::from(300), U256::from(1000)), U256::from(700));
}
#[test]
fn refund_amount_fully_filled() {
assert_eq!(refund_amount(U256::from(1000), U256::from(1000)), U256::ZERO);
}
#[test]
fn refund_amount_zero_filled() {
assert_eq!(refund_amount(U256::ZERO, U256::from(500)), U256::from(500));
}
#[test]
fn refund_amount_overfilled_saturates() {
assert_eq!(refund_amount(U256::from(2000), U256::from(1000)), U256::ZERO);
}
#[test]
fn padded32_rounds_up() {
assert_eq!(padded32(0), 0);
assert_eq!(padded32(1), 32);
assert_eq!(padded32(31), 32);
assert_eq!(padded32(32), 32);
assert_eq!(padded32(33), 64);
}
}