Skip to main content

rns_core/resource/
advertisement.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use super::types::{AdvFlags, ResourceError};
5use crate::constants::{RESOURCE_HASHMAP_MAX_LEN, RESOURCE_MAPHASH_LEN};
6use crate::msgpack::{self, Value};
7
8/// Resource advertisement data, corresponding to Python's ResourceAdvertisement.
9#[derive(Debug, Clone)]
10pub struct ResourceAdvertisement {
11    /// Transfer size (encrypted data size)
12    pub transfer_size: u64,
13    /// Total uncompressed data size (including metadata overhead)
14    pub data_size: u64,
15    /// Number of parts
16    pub num_parts: u64,
17    /// Resource hash (full 32 bytes)
18    pub resource_hash: Vec<u8>,
19    /// Random hash (4 bytes)
20    pub random_hash: Vec<u8>,
21    /// Original hash (first segment, 32 bytes)
22    pub original_hash: Vec<u8>,
23    /// Hashmap segment (concatenated 4-byte part hashes)
24    pub hashmap: Vec<u8>,
25    /// Flags byte
26    pub flags: AdvFlags,
27    /// Segment index (1-based)
28    pub segment_index: u64,
29    /// Total segments
30    pub total_segments: u64,
31    /// Request ID (optional)
32    pub request_id: Option<Vec<u8>>,
33}
34
35impl ResourceAdvertisement {
36    /// Pack the advertisement to msgpack bytes.
37    /// `segment` controls which hashmap segment to include (0-based).
38    pub fn pack(&self, segment: usize) -> Vec<u8> {
39        let hashmap_start = segment * RESOURCE_HASHMAP_MAX_LEN * RESOURCE_MAPHASH_LEN;
40        let max_end = (segment + 1) * RESOURCE_HASHMAP_MAX_LEN * RESOURCE_MAPHASH_LEN;
41        let hashmap_end = core::cmp::min(max_end, self.hashmap.len());
42        let hashmap_segment = if hashmap_start < self.hashmap.len() {
43            &self.hashmap[hashmap_start..hashmap_end]
44        } else {
45            &[]
46        };
47
48        let q_value = match &self.request_id {
49            Some(id) => Value::Bin(id.clone()),
50            None => Value::Nil,
51        };
52
53        // Match Python's key order: t, d, n, h, r, o, i, l, q, f, m
54        let entries: Vec<(&str, Value)> = vec![
55            ("t", Value::UInt(self.transfer_size)),
56            ("d", Value::UInt(self.data_size)),
57            ("n", Value::UInt(self.num_parts)),
58            ("h", Value::Bin(self.resource_hash.clone())),
59            ("r", Value::Bin(self.random_hash.clone())),
60            ("o", Value::Bin(self.original_hash.clone())),
61            ("i", Value::UInt(self.segment_index)),
62            ("l", Value::UInt(self.total_segments)),
63            ("q", q_value),
64            ("f", Value::UInt(self.flags.to_byte() as u64)),
65            ("m", Value::Bin(hashmap_segment.to_vec())),
66        ];
67
68        msgpack::pack_str_map(&entries)
69    }
70
71    /// Unpack an advertisement from msgpack bytes.
72    pub fn unpack(data: &[u8]) -> Result<Self, ResourceError> {
73        let value = msgpack::unpack_exact(data).map_err(|_| ResourceError::InvalidAdvertisement)?;
74
75        let t = value
76            .map_get("t")
77            .and_then(|v| v.as_uint())
78            .ok_or(ResourceError::InvalidAdvertisement)?;
79        let d = value
80            .map_get("d")
81            .and_then(|v| v.as_uint())
82            .ok_or(ResourceError::InvalidAdvertisement)?;
83        let n = value
84            .map_get("n")
85            .and_then(|v| v.as_uint())
86            .ok_or(ResourceError::InvalidAdvertisement)?;
87        let h = value
88            .map_get("h")
89            .and_then(|v| v.as_bin())
90            .ok_or(ResourceError::InvalidAdvertisement)?
91            .to_vec();
92        let r = value
93            .map_get("r")
94            .and_then(|v| v.as_bin())
95            .ok_or(ResourceError::InvalidAdvertisement)?
96            .to_vec();
97        let o = value
98            .map_get("o")
99            .and_then(|v| v.as_bin())
100            .ok_or(ResourceError::InvalidAdvertisement)?
101            .to_vec();
102        let m = value
103            .map_get("m")
104            .and_then(|v| v.as_bin())
105            .ok_or(ResourceError::InvalidAdvertisement)?
106            .to_vec();
107        let f = value
108            .map_get("f")
109            .and_then(|v| v.as_uint())
110            .ok_or(ResourceError::InvalidAdvertisement)? as u8;
111        let i = value
112            .map_get("i")
113            .and_then(|v| v.as_uint())
114            .ok_or(ResourceError::InvalidAdvertisement)?;
115        let l = value
116            .map_get("l")
117            .and_then(|v| v.as_uint())
118            .ok_or(ResourceError::InvalidAdvertisement)?;
119
120        let q_val = value
121            .map_get("q")
122            .ok_or(ResourceError::InvalidAdvertisement)?;
123        let request_id = if q_val.is_nil() {
124            None
125        } else {
126            Some(
127                q_val
128                    .as_bin()
129                    .ok_or(ResourceError::InvalidAdvertisement)?
130                    .to_vec(),
131            )
132        };
133
134        Ok(ResourceAdvertisement {
135            transfer_size: t,
136            data_size: d,
137            num_parts: n,
138            resource_hash: h,
139            random_hash: r,
140            original_hash: o,
141            hashmap: m,
142            flags: AdvFlags::from_byte(f),
143            segment_index: i,
144            total_segments: l,
145            request_id,
146        })
147    }
148
149    /// Check if this advertisement is a request.
150    pub fn is_request(&self) -> bool {
151        self.request_id.is_some() && self.flags.is_request
152    }
153
154    /// Check if this advertisement is a response.
155    pub fn is_response(&self) -> bool {
156        self.request_id.is_some() && self.flags.is_response
157    }
158
159    /// Get the number of hashmap segments needed.
160    pub fn hashmap_segments(&self) -> usize {
161        let total_hashes = self.num_parts as usize;
162        if total_hashes == 0 {
163            return 1;
164        }
165        total_hashes.div_ceil(RESOURCE_HASHMAP_MAX_LEN)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    fn make_adv(flags: AdvFlags) -> ResourceAdvertisement {
174        ResourceAdvertisement {
175            transfer_size: 1000,
176            data_size: 950,
177            num_parts: 3,
178            resource_hash: vec![0x11; 32],
179            random_hash: vec![0xAA, 0xBB, 0xCC, 0xDD],
180            original_hash: vec![0x22; 32],
181            hashmap: vec![
182                0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
183            ],
184            flags,
185            segment_index: 1,
186            total_segments: 1,
187            request_id: None,
188        }
189    }
190
191    #[test]
192    fn test_pack_unpack_roundtrip() {
193        let flags = AdvFlags {
194            encrypted: true,
195            compressed: false,
196            split: false,
197            is_request: false,
198            is_response: false,
199            has_metadata: false,
200        };
201        let adv = make_adv(flags);
202        let packed = adv.pack(0);
203        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
204
205        assert_eq!(unpacked.transfer_size, 1000);
206        assert_eq!(unpacked.data_size, 950);
207        assert_eq!(unpacked.num_parts, 3);
208        assert_eq!(unpacked.resource_hash, vec![0x11; 32]);
209        assert_eq!(unpacked.random_hash, vec![0xAA, 0xBB, 0xCC, 0xDD]);
210        assert_eq!(unpacked.original_hash, vec![0x22; 32]);
211        assert_eq!(unpacked.flags, flags);
212        assert_eq!(unpacked.segment_index, 1);
213        assert_eq!(unpacked.total_segments, 1);
214        assert!(unpacked.request_id.is_none());
215    }
216
217    #[test]
218    fn test_flags_encrypted_compressed() {
219        let flags = AdvFlags {
220            encrypted: true,
221            compressed: true,
222            split: false,
223            is_request: false,
224            is_response: false,
225            has_metadata: false,
226        };
227        let adv = make_adv(flags);
228        let packed = adv.pack(0);
229        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
230        assert!(unpacked.flags.encrypted);
231        assert!(unpacked.flags.compressed);
232        assert!(!unpacked.flags.split);
233    }
234
235    #[test]
236    fn test_flags_with_metadata() {
237        let flags = AdvFlags {
238            encrypted: true,
239            compressed: false,
240            split: false,
241            is_request: false,
242            is_response: false,
243            has_metadata: true,
244        };
245        let adv = make_adv(flags);
246        let packed = adv.pack(0);
247        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
248        assert!(unpacked.flags.has_metadata);
249    }
250
251    #[test]
252    fn test_multi_segment() {
253        let flags = AdvFlags {
254            encrypted: true,
255            compressed: false,
256            split: true,
257            is_request: false,
258            is_response: false,
259            has_metadata: false,
260        };
261        let mut adv = make_adv(flags);
262        adv.segment_index = 2;
263        adv.total_segments = 5;
264        let packed = adv.pack(0);
265        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
266        assert!(unpacked.flags.split);
267        assert_eq!(unpacked.segment_index, 2);
268        assert_eq!(unpacked.total_segments, 5);
269    }
270
271    #[test]
272    fn test_with_request_id() {
273        let flags = AdvFlags {
274            encrypted: true,
275            compressed: false,
276            split: false,
277            is_request: true,
278            is_response: false,
279            has_metadata: false,
280        };
281        let mut adv = make_adv(flags);
282        adv.request_id = Some(vec![0xDE, 0xAD, 0xBE, 0xEF]);
283        let packed = adv.pack(0);
284        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
285        assert!(unpacked.is_request());
286        assert!(!unpacked.is_response());
287        assert_eq!(unpacked.request_id, Some(vec![0xDE, 0xAD, 0xBE, 0xEF]));
288    }
289
290    #[test]
291    fn test_is_response() {
292        let flags = AdvFlags {
293            encrypted: true,
294            compressed: false,
295            split: false,
296            is_request: false,
297            is_response: true,
298            has_metadata: false,
299        };
300        let mut adv = make_adv(flags);
301        adv.request_id = Some(vec![0x42; 16]);
302        assert!(adv.is_response());
303        assert!(!adv.is_request());
304    }
305
306    #[test]
307    fn test_nil_request_id() {
308        let flags = AdvFlags {
309            encrypted: true,
310            compressed: false,
311            split: false,
312            is_request: false,
313            is_response: false,
314            has_metadata: false,
315        };
316        let adv = make_adv(flags);
317        let packed = adv.pack(0);
318        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
319        assert!(unpacked.request_id.is_none());
320        assert!(!unpacked.is_request());
321        assert!(!unpacked.is_response());
322    }
323
324    #[test]
325    fn test_hashmap_segmentation() {
326        // Create a large hashmap with > HASHMAP_MAX_LEN(74) hashes
327        let num_hashes = 100;
328        let hashmap: Vec<u8> = (0..num_hashes).flat_map(|i| vec![i as u8; 4]).collect();
329
330        let flags = AdvFlags {
331            encrypted: true,
332            compressed: false,
333            split: false,
334            is_request: false,
335            is_response: false,
336            has_metadata: false,
337        };
338        let adv = ResourceAdvertisement {
339            transfer_size: 50000,
340            data_size: 48000,
341            num_parts: num_hashes,
342            resource_hash: vec![0x11; 32],
343            random_hash: vec![0xAA; 4],
344            original_hash: vec![0x22; 32],
345            hashmap: hashmap.clone(),
346            flags,
347            segment_index: 1,
348            total_segments: 1,
349            request_id: None,
350        };
351
352        // Segment 0: first 74 hashes = 296 bytes
353        let packed0 = adv.pack(0);
354        let unpacked0 = ResourceAdvertisement::unpack(&packed0).unwrap();
355        assert_eq!(unpacked0.hashmap.len(), 74 * 4);
356
357        // Segment 1: remaining 26 hashes = 104 bytes
358        let packed1 = adv.pack(1);
359        let unpacked1 = ResourceAdvertisement::unpack(&packed1).unwrap();
360        assert_eq!(unpacked1.hashmap.len(), 26 * 4);
361    }
362
363    #[test]
364    fn test_hashmap_segments_count() {
365        let flags = AdvFlags {
366            encrypted: true,
367            compressed: false,
368            split: false,
369            is_request: false,
370            is_response: false,
371            has_metadata: false,
372        };
373        let mut adv = make_adv(flags);
374
375        adv.num_parts = 74; // exactly HASHMAP_MAX_LEN
376        assert_eq!(adv.hashmap_segments(), 1);
377
378        adv.num_parts = 75;
379        assert_eq!(adv.hashmap_segments(), 2);
380
381        adv.num_parts = 148;
382        assert_eq!(adv.hashmap_segments(), 2);
383
384        adv.num_parts = 149;
385        assert_eq!(adv.hashmap_segments(), 3);
386    }
387
388    #[test]
389    fn test_unpack_invalid_data() {
390        assert!(ResourceAdvertisement::unpack(&[]).is_err());
391        assert!(ResourceAdvertisement::unpack(&[0xc0]).is_err()); // nil
392        assert!(ResourceAdvertisement::unpack(&[0x01, 0x02]).is_err()); // not a map
393    }
394}