mx-core 0.1.0

Core utilities for MultiversX Rust services.
Documentation
//! Shard computation utilities.
//!
//! This module determines which shard an address belongs to using the same
//! algorithm as mx-chain-go's multiShardCoordinator.

use crate::constants::METACHAIN_SHARD_ID;
use crate::error::CoreError;

/// Determines the shard ID for a given bech32 address.
///
/// This implementation matches the Go reference: mx-chain-go/sharding/multiShardCoordinator.go
///
/// # Arguments
/// * `address` - A bech32-encoded `MultiversX` address
/// * `num_shards` - The total number of shards in the network
///
/// # Returns
/// The shard ID (0 to num_shards-1) or `METACHAIN_SHARD_ID` for metachain
///
/// # Errors
/// Returns an error if the address is not a valid bech32 string
pub fn shard_of(address: &str, num_shards: u32) -> Result<u32, CoreError> {
    let (_hrp, raw) =
        bech32::decode(address).map_err(|e| CoreError::InvalidBech32(e.to_string()))?;
    shard_of_address_bytes(&raw, num_shards)
}

/// Determines shard ID from raw public key bytes (no bech32 decoding).
///
/// This implementation matches the Go reference: mx-chain-go/sharding/multiShardCoordinator.go
/// Uses the same mask-based algorithm as `ComputeShardID`.
///
/// # Arguments
/// * `address` - Raw 32-byte public key
/// * `num_shards` - The total number of shards in the network
///
/// # Returns
/// The shard ID (0 to num_shards-1) or `METACHAIN_SHARD_ID` for metachain
pub fn shard_of_address_bytes(address: &[u8], num_shards: u32) -> Result<u32, CoreError> {
    if address.is_empty() {
        return Err(CoreError::EmptyAddress);
    }

    // Calculate masks (same as Go's calculateMasks)
    let (mask_high, mask_low) = calculate_masks(num_shards);

    // Compute shard ID (same as Go's computeIdBasedOfNrOfShardAndMasks)
    Ok(compute_shard_id(address, num_shards, mask_high, mask_low))
}

/// Calculates bit masks for shard computation using logarithmic distribution.
/// Returns (`mask_high`, `mask_low`) used in the two-stage masking algorithm.
fn calculate_masks(num_shards: u32) -> (u32, u32) {
    let n = f64::from(num_shards).log2().ceil() as u32;
    let mask_high = (1u32 << n) - 1;
    let mask_low = if n > 0 { (1u32 << (n - 1)) - 1 } else { 0 };
    (mask_high, mask_low)
}

/// Computes shard ID from address bytes using two-stage masking algorithm.
/// Returns `METACHAIN_SHARD_ID` for smart contracts on metachain, otherwise computed shard ID.
fn compute_shard_id(address: &[u8], num_shards: u32, mask_high: u32, mask_low: u32) -> u32 {
    // Determine how many bytes we need based on number of shards
    let bytes_needed = if num_shards <= 256 {
        1
    } else if num_shards <= 65536 {
        2
    } else if num_shards <= 16_777_216 {
        3
    } else {
        4
    };

    // Get the trailing bytes needed for shard calculation
    let starting_index = if address.len() > bytes_needed {
        address.len() - bytes_needed
    } else {
        0
    };

    let buff_needed = &address[starting_index..];

    // Check for smart contract on metachain
    if is_smart_contract_on_metachain(buff_needed, address) {
        return METACHAIN_SHARD_ID;
    }

    // Convert bytes to u32 (big-endian)
    let mut addr: u32 = 0;
    for &byte in buff_needed {
        addr = (addr << 8) + u32::from(byte);
    }

    // Apply masks to get shard ID
    let mut shard = addr & mask_high;
    if shard > num_shards - 1 {
        shard = addr & mask_low;
    }

    shard
}

/// Check if address is a smart contract on metachain.
/// Mirrors Go: `core.IsSmartContractOnMetachain(buffNeeded, address)`
///
/// A smart contract is on metachain if:
/// - First 8 bytes of address are zeros (it's a SC)
/// - The shard byte (byte at len-2 from `buff_needed` perspective, or byte 30 in full address) is 0xFF
fn is_smart_contract_on_metachain(_buff_needed: &[u8], address: &[u8]) -> bool {
    // Check if it's a smart contract (first 8 bytes are zeros)
    if address.len() < 8 || address[..8] != [0u8; 8] {
        return false;
    }

    // For smart contracts, check if the shard identifier byte is metachain (0xFF)
    // The shard byte for SCs is at position len-2 of the full address (byte 30 for 32-byte addresses)
    if address.len() >= 2 {
        let shard_byte_index = address.len() - 2;
        if address[shard_byte_index] == 0xFF {
            return true;
        }
    }

    false
}

/// Computes the shard ID from the last byte of an address.
///
/// Uses optimized bit masking for the common case of 3 shards (instead of modulo).
///
/// # Arguments
/// * `last_byte` - The last byte of the address
/// * `num_shards` - The total number of shards in the network
///
/// # Returns
/// The shard ID (0 to num_shards-1)
#[inline]
pub fn select_shard(last_byte: u8, num_shards: u32) -> u32 {
    if num_shards == 3 {
        u32::from(last_byte & 0x03)
    } else {
        u32::from(last_byte) % num_shards
    }
}

/// Determines the shard ID from raw address bytes (simpler version).
///
/// This is a simpler alternative to `shard_of_address_bytes` that uses
/// a direct last-byte approach instead of the full mask-based algorithm.
///
/// # Arguments
/// * `raw` - Raw address bytes
/// * `num_shards` - The total number of shards in the network
///
/// # Returns
/// - `Some(0xFF)` for metachain addresses (last byte == 0xFF)
/// - `Some(shard_id)` for regular shards
/// - `None` if the address is empty
#[inline]
pub fn shard_of_bytes(raw: &[u8], num_shards: u32) -> Option<u32> {
    let last = *raw.last()?;
    if last == 0xFF {
        Some(0xFF) // metachain
    } else {
        Some(select_shard(last, num_shards))
    }
}

/// Extracts the embedded receiver address from ESDT/NFT transfer data payloads.
///
/// Handles three built-in transfer functions:
/// - `ESDTNFTTransfer@token@nonce@qty@dest` → dest at position 3
/// - `MultiESDTNFTTransfer@dest@numTokens@...` → dest at position 0
/// - `ESDTTransfer@token@amount@dest` → dest at position 2
///
/// # Arguments
/// * `data` - The transaction data payload bytes
///
/// # Returns
/// The decoded receiver address bytes if found, None otherwise
pub fn decode_embedded_receiver(data: &[u8]) -> Option<Vec<u8>> {
    /// Parses the transfer function call format to extract destination address.
    fn parse(txt: &str) -> Option<Vec<u8>> {
        let mut parts = txt.split('@');
        let func = parts.next()?.to_ascii_lowercase();
        match func.as_str() {
            // ESDTNFTTransfer@token@nonce@qty@dest
            "esdtnfttransfer" => parts.nth(3).and_then(|h| hex::decode(h).ok()),
            // MultiESDTNFTTransfer@dest@numTokens@...
            "multiesdtnfttransfer" => parts.next().and_then(|h| hex::decode(h).ok()),
            // ESDTTransfer@token@amount@dest
            "esdttransfer" => parts.nth(2).and_then(|h| hex::decode(h).ok()),
            _ => None,
        }
    }

    // First try raw bytes as UTF-8 (on-chain data is plain ASCII for built-ins).
    if let Ok(txt) = std::str::from_utf8(data)
        && let Some(res) = parse(txt)
    {
        return Some(res);
    }

    // Some gateways expose base64 for data; fall back to decoding then parsing.
    if let Ok(decoded) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, data)
        && let Ok(txt) = std::str::from_utf8(&decoded)
    {
        return parse(txt);
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_shard_calculation() {
        // Known addresses and their shards for 3-shard network
        let cases = [(
            "erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th",
            1,
        )];

        for (addr, expected_shard) in cases {
            let shard = shard_of(addr, 3).unwrap();
            assert_eq!(
                shard, expected_shard,
                "address {addr} expected shard {expected_shard}"
            );
        }
    }

    #[test]
    fn test_mask_calculation() {
        // For 3 shards: ceil(log2(3)) = 2, so mask_high = 0b11, mask_low = 0b01
        let (mask_high, mask_low) = calculate_masks(3);
        assert_eq!(mask_high, 0b11);
        assert_eq!(mask_low, 0b01);

        // For 4 shards: ceil(log2(4)) = 2, so mask_high = 0b11, mask_low = 0b01
        let (mask_high, mask_low) = calculate_masks(4);
        assert_eq!(mask_high, 0b11);
        assert_eq!(mask_low, 0b01);
    }

    #[test]
    fn test_empty_address() {
        let result = shard_of_address_bytes(&[], 3);
        assert!(result.is_err());
    }

    #[test]
    fn test_select_shard() {
        // Test 3-shard optimization (bit masking)
        assert_eq!(select_shard(0b0000_0000, 3), 0); // last 2 bits: 00 -> shard 0
        assert_eq!(select_shard(0b0000_0001, 3), 1); // last 2 bits: 01 -> shard 1
        assert_eq!(select_shard(0b0000_0010, 3), 2); // last 2 bits: 10 -> shard 2
        assert_eq!(select_shard(0b0000_0011, 3), 3); // last 2 bits: 11 -> shard 3
        assert_eq!(select_shard(0b1111_1100, 3), 0); // last 2 bits: 00 -> shard 0

        // Test other shard counts (modulo)
        assert_eq!(select_shard(5, 4), 1); // 5 % 4 = 1
        assert_eq!(select_shard(10, 7), 3); // 10 % 7 = 3
    }

    #[test]
    fn test_shard_of_bytes() {
        // Regular shard addresses (3 shards)
        assert_eq!(shard_of_bytes(&[0x00], 3), Some(0));
        assert_eq!(shard_of_bytes(&[0x01], 3), Some(1));
        assert_eq!(shard_of_bytes(&[0x02], 3), Some(2));
        assert_eq!(shard_of_bytes(&[0x03], 3), Some(3));

        // Metachain address
        assert_eq!(shard_of_bytes(&[0xFF], 3), Some(0xFF));

        // Empty address
        assert_eq!(shard_of_bytes(&[], 3), None);

        // Multi-byte address (uses last byte)
        assert_eq!(shard_of_bytes(&[0xAA, 0xBB, 0x02], 3), Some(2));
    }

    #[test]
    fn test_decode_embedded_receiver_esdt_nft_transfer() {
        // ESDTNFTTransfer@token@nonce@qty@dest
        let data = b"ESDTNFTTransfer@544f4b454e@01@0a@aabbccdd";
        let result = decode_embedded_receiver(data);
        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
    }

    #[test]
    fn test_decode_embedded_receiver_multi_esdt_nft_transfer() {
        // MultiESDTNFTTransfer@dest@numTokens@...
        let data = b"MultiESDTNFTTransfer@aabbccdd@02";
        let result = decode_embedded_receiver(data);
        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
    }

    #[test]
    fn test_decode_embedded_receiver_esdt_transfer() {
        // ESDTTransfer@token@amount@dest
        let data = b"ESDTTransfer@544f4b454e@0a@aabbccdd";
        let result = decode_embedded_receiver(data);
        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
    }

    #[test]
    fn test_decode_embedded_receiver_case_insensitive() {
        // Test lowercase and uppercase variations
        let data1 = b"esdttransfer@544f4b454e@0a@aabbccdd";
        let data2 = b"ESDTTRANSFER@544f4b454e@0a@aabbccdd";
        assert_eq!(
            decode_embedded_receiver(data1),
            decode_embedded_receiver(data2)
        );
    }

    #[test]
    fn test_decode_embedded_receiver_base64() {
        // Base64 encoded: "ESDTTransfer@544f4b454e@0a@aabbccdd"
        let data = b"RVNEVFRyYW5zZmVyQDU0NGY0YjQ1NGVAMGFAYWFiYmNjZGQ=";
        let result = decode_embedded_receiver(data);
        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
    }

    #[test]
    fn test_decode_embedded_receiver_invalid() {
        // Invalid function name
        assert_eq!(decode_embedded_receiver(b"SomeOtherFunction@arg"), None);

        // Not enough arguments
        assert_eq!(decode_embedded_receiver(b"ESDTTransfer@token"), None);

        // Invalid hex
        assert_eq!(
            decode_embedded_receiver(b"ESDTTransfer@token@amount@zzzz"),
            None
        );

        // Empty data
        assert_eq!(decode_embedded_receiver(b""), None);
    }
}