Skip to main content

mx_core/
shard.rs

1//! Shard computation utilities.
2//!
3//! This module determines which shard an address belongs to using the same
4//! algorithm as mx-chain-go's multiShardCoordinator.
5
6use crate::constants::METACHAIN_SHARD_ID;
7use crate::error::CoreError;
8
9/// Determines the shard ID for a given bech32 address.
10///
11/// This implementation matches the Go reference: mx-chain-go/sharding/multiShardCoordinator.go
12///
13/// # Arguments
14/// * `address` - A bech32-encoded `MultiversX` address
15/// * `num_shards` - The total number of shards in the network
16///
17/// # Returns
18/// The shard ID (0 to num_shards-1) or `METACHAIN_SHARD_ID` for metachain
19///
20/// # Errors
21/// Returns an error if the address is not a valid bech32 string
22pub fn shard_of(address: &str, num_shards: u32) -> Result<u32, CoreError> {
23    let (_hrp, raw) =
24        bech32::decode(address).map_err(|e| CoreError::InvalidBech32(e.to_string()))?;
25    shard_of_address_bytes(&raw, num_shards)
26}
27
28/// Determines shard ID from raw public key bytes (no bech32 decoding).
29///
30/// This implementation matches the Go reference: mx-chain-go/sharding/multiShardCoordinator.go
31/// Uses the same mask-based algorithm as `ComputeShardID`.
32///
33/// # Arguments
34/// * `address` - Raw 32-byte public key
35/// * `num_shards` - The total number of shards in the network
36///
37/// # Returns
38/// The shard ID (0 to num_shards-1) or `METACHAIN_SHARD_ID` for metachain
39pub fn shard_of_address_bytes(address: &[u8], num_shards: u32) -> Result<u32, CoreError> {
40    if address.is_empty() {
41        return Err(CoreError::EmptyAddress);
42    }
43
44    // Calculate masks (same as Go's calculateMasks)
45    let (mask_high, mask_low) = calculate_masks(num_shards);
46
47    // Compute shard ID (same as Go's computeIdBasedOfNrOfShardAndMasks)
48    Ok(compute_shard_id(address, num_shards, mask_high, mask_low))
49}
50
51/// Calculates bit masks for shard computation using logarithmic distribution.
52/// Returns (`mask_high`, `mask_low`) used in the two-stage masking algorithm.
53fn calculate_masks(num_shards: u32) -> (u32, u32) {
54    let n = f64::from(num_shards).log2().ceil() as u32;
55    let mask_high = (1u32 << n) - 1;
56    let mask_low = if n > 0 { (1u32 << (n - 1)) - 1 } else { 0 };
57    (mask_high, mask_low)
58}
59
60/// Computes shard ID from address bytes using two-stage masking algorithm.
61/// Returns `METACHAIN_SHARD_ID` for smart contracts on metachain, otherwise computed shard ID.
62fn compute_shard_id(address: &[u8], num_shards: u32, mask_high: u32, mask_low: u32) -> u32 {
63    // Determine how many bytes we need based on number of shards
64    let bytes_needed = if num_shards <= 256 {
65        1
66    } else if num_shards <= 65536 {
67        2
68    } else if num_shards <= 16_777_216 {
69        3
70    } else {
71        4
72    };
73
74    // Get the trailing bytes needed for shard calculation
75    let starting_index = if address.len() > bytes_needed {
76        address.len() - bytes_needed
77    } else {
78        0
79    };
80
81    let buff_needed = &address[starting_index..];
82
83    // Check for smart contract on metachain
84    if is_smart_contract_on_metachain(buff_needed, address) {
85        return METACHAIN_SHARD_ID;
86    }
87
88    // Convert bytes to u32 (big-endian)
89    let mut addr: u32 = 0;
90    for &byte in buff_needed {
91        addr = (addr << 8) + u32::from(byte);
92    }
93
94    // Apply masks to get shard ID
95    let mut shard = addr & mask_high;
96    if shard > num_shards - 1 {
97        shard = addr & mask_low;
98    }
99
100    shard
101}
102
103/// Check if address is a smart contract on metachain.
104/// Mirrors Go: `core.IsSmartContractOnMetachain(buffNeeded, address)`
105///
106/// A smart contract is on metachain if:
107/// - First 8 bytes of address are zeros (it's a SC)
108/// - The shard byte (byte at len-2 from `buff_needed` perspective, or byte 30 in full address) is 0xFF
109fn is_smart_contract_on_metachain(_buff_needed: &[u8], address: &[u8]) -> bool {
110    // Check if it's a smart contract (first 8 bytes are zeros)
111    if address.len() < 8 || address[..8] != [0u8; 8] {
112        return false;
113    }
114
115    // For smart contracts, check if the shard identifier byte is metachain (0xFF)
116    // The shard byte for SCs is at position len-2 of the full address (byte 30 for 32-byte addresses)
117    if address.len() >= 2 {
118        let shard_byte_index = address.len() - 2;
119        if address[shard_byte_index] == 0xFF {
120            return true;
121        }
122    }
123
124    false
125}
126
127/// Computes the shard ID from the last byte of an address.
128///
129/// Uses optimized bit masking for the common case of 3 shards (instead of modulo).
130///
131/// # Arguments
132/// * `last_byte` - The last byte of the address
133/// * `num_shards` - The total number of shards in the network
134///
135/// # Returns
136/// The shard ID (0 to num_shards-1)
137#[inline]
138pub fn select_shard(last_byte: u8, num_shards: u32) -> u32 {
139    if num_shards == 3 {
140        u32::from(last_byte & 0x03)
141    } else {
142        u32::from(last_byte) % num_shards
143    }
144}
145
146/// Determines the shard ID from raw address bytes (simpler version).
147///
148/// This is a simpler alternative to `shard_of_address_bytes` that uses
149/// a direct last-byte approach instead of the full mask-based algorithm.
150///
151/// # Arguments
152/// * `raw` - Raw address bytes
153/// * `num_shards` - The total number of shards in the network
154///
155/// # Returns
156/// - `Some(0xFF)` for metachain addresses (last byte == 0xFF)
157/// - `Some(shard_id)` for regular shards
158/// - `None` if the address is empty
159#[inline]
160pub fn shard_of_bytes(raw: &[u8], num_shards: u32) -> Option<u32> {
161    let last = *raw.last()?;
162    if last == 0xFF {
163        Some(0xFF) // metachain
164    } else {
165        Some(select_shard(last, num_shards))
166    }
167}
168
169/// Extracts the embedded receiver address from ESDT/NFT transfer data payloads.
170///
171/// Handles three built-in transfer functions:
172/// - `ESDTNFTTransfer@token@nonce@qty@dest` → dest at position 3
173/// - `MultiESDTNFTTransfer@dest@numTokens@...` → dest at position 0
174/// - `ESDTTransfer@token@amount@dest` → dest at position 2
175///
176/// # Arguments
177/// * `data` - The transaction data payload bytes
178///
179/// # Returns
180/// The decoded receiver address bytes if found, None otherwise
181pub fn decode_embedded_receiver(data: &[u8]) -> Option<Vec<u8>> {
182    /// Parses the transfer function call format to extract destination address.
183    fn parse(txt: &str) -> Option<Vec<u8>> {
184        let mut parts = txt.split('@');
185        let func = parts.next()?.to_ascii_lowercase();
186        match func.as_str() {
187            // ESDTNFTTransfer@token@nonce@qty@dest
188            "esdtnfttransfer" => parts.nth(3).and_then(|h| hex::decode(h).ok()),
189            // MultiESDTNFTTransfer@dest@numTokens@...
190            "multiesdtnfttransfer" => parts.next().and_then(|h| hex::decode(h).ok()),
191            // ESDTTransfer@token@amount@dest
192            "esdttransfer" => parts.nth(2).and_then(|h| hex::decode(h).ok()),
193            _ => None,
194        }
195    }
196
197    // First try raw bytes as UTF-8 (on-chain data is plain ASCII for built-ins).
198    if let Ok(txt) = std::str::from_utf8(data)
199        && let Some(res) = parse(txt)
200    {
201        return Some(res);
202    }
203
204    // Some gateways expose base64 for data; fall back to decoding then parsing.
205    if let Ok(decoded) = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, data)
206        && let Ok(txt) = std::str::from_utf8(&decoded)
207    {
208        return parse(txt);
209    }
210    None
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_shard_calculation() {
219        // Known addresses and their shards for 3-shard network
220        let cases = [(
221            "erd1qyu5wthldzr8wx5c9ucg8kjagg0jfs53s8nr3zpz3hypefsdd8ssycr6th",
222            1,
223        )];
224
225        for (addr, expected_shard) in cases {
226            let shard = shard_of(addr, 3).unwrap();
227            assert_eq!(
228                shard, expected_shard,
229                "address {addr} expected shard {expected_shard}"
230            );
231        }
232    }
233
234    #[test]
235    fn test_mask_calculation() {
236        // For 3 shards: ceil(log2(3)) = 2, so mask_high = 0b11, mask_low = 0b01
237        let (mask_high, mask_low) = calculate_masks(3);
238        assert_eq!(mask_high, 0b11);
239        assert_eq!(mask_low, 0b01);
240
241        // For 4 shards: ceil(log2(4)) = 2, so mask_high = 0b11, mask_low = 0b01
242        let (mask_high, mask_low) = calculate_masks(4);
243        assert_eq!(mask_high, 0b11);
244        assert_eq!(mask_low, 0b01);
245    }
246
247    #[test]
248    fn test_empty_address() {
249        let result = shard_of_address_bytes(&[], 3);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn test_select_shard() {
255        // Test 3-shard optimization (bit masking)
256        assert_eq!(select_shard(0b0000_0000, 3), 0); // last 2 bits: 00 -> shard 0
257        assert_eq!(select_shard(0b0000_0001, 3), 1); // last 2 bits: 01 -> shard 1
258        assert_eq!(select_shard(0b0000_0010, 3), 2); // last 2 bits: 10 -> shard 2
259        assert_eq!(select_shard(0b0000_0011, 3), 3); // last 2 bits: 11 -> shard 3
260        assert_eq!(select_shard(0b1111_1100, 3), 0); // last 2 bits: 00 -> shard 0
261
262        // Test other shard counts (modulo)
263        assert_eq!(select_shard(5, 4), 1); // 5 % 4 = 1
264        assert_eq!(select_shard(10, 7), 3); // 10 % 7 = 3
265    }
266
267    #[test]
268    fn test_shard_of_bytes() {
269        // Regular shard addresses (3 shards)
270        assert_eq!(shard_of_bytes(&[0x00], 3), Some(0));
271        assert_eq!(shard_of_bytes(&[0x01], 3), Some(1));
272        assert_eq!(shard_of_bytes(&[0x02], 3), Some(2));
273        assert_eq!(shard_of_bytes(&[0x03], 3), Some(3));
274
275        // Metachain address
276        assert_eq!(shard_of_bytes(&[0xFF], 3), Some(0xFF));
277
278        // Empty address
279        assert_eq!(shard_of_bytes(&[], 3), None);
280
281        // Multi-byte address (uses last byte)
282        assert_eq!(shard_of_bytes(&[0xAA, 0xBB, 0x02], 3), Some(2));
283    }
284
285    #[test]
286    fn test_decode_embedded_receiver_esdt_nft_transfer() {
287        // ESDTNFTTransfer@token@nonce@qty@dest
288        let data = b"ESDTNFTTransfer@544f4b454e@01@0a@aabbccdd";
289        let result = decode_embedded_receiver(data);
290        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
291    }
292
293    #[test]
294    fn test_decode_embedded_receiver_multi_esdt_nft_transfer() {
295        // MultiESDTNFTTransfer@dest@numTokens@...
296        let data = b"MultiESDTNFTTransfer@aabbccdd@02";
297        let result = decode_embedded_receiver(data);
298        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
299    }
300
301    #[test]
302    fn test_decode_embedded_receiver_esdt_transfer() {
303        // ESDTTransfer@token@amount@dest
304        let data = b"ESDTTransfer@544f4b454e@0a@aabbccdd";
305        let result = decode_embedded_receiver(data);
306        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
307    }
308
309    #[test]
310    fn test_decode_embedded_receiver_case_insensitive() {
311        // Test lowercase and uppercase variations
312        let data1 = b"esdttransfer@544f4b454e@0a@aabbccdd";
313        let data2 = b"ESDTTRANSFER@544f4b454e@0a@aabbccdd";
314        assert_eq!(
315            decode_embedded_receiver(data1),
316            decode_embedded_receiver(data2)
317        );
318    }
319
320    #[test]
321    fn test_decode_embedded_receiver_base64() {
322        // Base64 encoded: "ESDTTransfer@544f4b454e@0a@aabbccdd"
323        let data = b"RVNEVFRyYW5zZmVyQDU0NGY0YjQ1NGVAMGFAYWFiYmNjZGQ=";
324        let result = decode_embedded_receiver(data);
325        assert_eq!(result, Some(vec![0xaa, 0xbb, 0xcc, 0xdd]));
326    }
327
328    #[test]
329    fn test_decode_embedded_receiver_invalid() {
330        // Invalid function name
331        assert_eq!(decode_embedded_receiver(b"SomeOtherFunction@arg"), None);
332
333        // Not enough arguments
334        assert_eq!(decode_embedded_receiver(b"ESDTTransfer@token"), None);
335
336        // Invalid hex
337        assert_eq!(
338            decode_embedded_receiver(b"ESDTTransfer@token@amount@zzzz"),
339            None
340        );
341
342        // Empty data
343        assert_eq!(decode_embedded_receiver(b""), None);
344    }
345}