Skip to main content

rns_core/resource/
advertisement.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use crate::msgpack::{self, Value};
5use crate::constants::{RESOURCE_HASHMAP_MAX_LEN, RESOURCE_MAPHASH_LEN};
6use super::types::{AdvFlags, ResourceError};
7
8/// Resource advertisement data, corresponding to Python's ResourceAdvertisement.
9#[derive(Debug, Clone)]
10pub struct ResourceAdvertisement {
11    /// Transfer size (encrypted data size)
12    pub transfer_size: u64,
13    /// Total uncompressed data size (including metadata overhead)
14    pub data_size: u64,
15    /// Number of parts
16    pub num_parts: u64,
17    /// Resource hash (full 32 bytes)
18    pub resource_hash: Vec<u8>,
19    /// Random hash (4 bytes)
20    pub random_hash: Vec<u8>,
21    /// Original hash (first segment, 32 bytes)
22    pub original_hash: Vec<u8>,
23    /// Hashmap segment (concatenated 4-byte part hashes)
24    pub hashmap: Vec<u8>,
25    /// Flags byte
26    pub flags: AdvFlags,
27    /// Segment index (1-based)
28    pub segment_index: u64,
29    /// Total segments
30    pub total_segments: u64,
31    /// Request ID (optional)
32    pub request_id: Option<Vec<u8>>,
33}
34
35impl ResourceAdvertisement {
36    /// Pack the advertisement to msgpack bytes.
37    /// `segment` controls which hashmap segment to include (0-based).
38    pub fn pack(&self, segment: usize) -> Vec<u8> {
39        let hashmap_start = segment * RESOURCE_HASHMAP_MAX_LEN * RESOURCE_MAPHASH_LEN;
40        let max_end = (segment + 1) * RESOURCE_HASHMAP_MAX_LEN * RESOURCE_MAPHASH_LEN;
41        let hashmap_end = core::cmp::min(max_end, self.hashmap.len());
42        let hashmap_segment = if hashmap_start < self.hashmap.len() {
43            &self.hashmap[hashmap_start..hashmap_end]
44        } else {
45            &[]
46        };
47
48        let q_value = match &self.request_id {
49            Some(id) => Value::Bin(id.clone()),
50            None => Value::Nil,
51        };
52
53        // Match Python's key order: t, d, n, h, r, o, i, l, q, f, m
54        let entries: Vec<(&str, Value)> = vec![
55            ("t", Value::UInt(self.transfer_size)),
56            ("d", Value::UInt(self.data_size)),
57            ("n", Value::UInt(self.num_parts)),
58            ("h", Value::Bin(self.resource_hash.clone())),
59            ("r", Value::Bin(self.random_hash.clone())),
60            ("o", Value::Bin(self.original_hash.clone())),
61            ("i", Value::UInt(self.segment_index)),
62            ("l", Value::UInt(self.total_segments)),
63            ("q", q_value),
64            ("f", Value::UInt(self.flags.to_byte() as u64)),
65            ("m", Value::Bin(hashmap_segment.to_vec())),
66        ];
67
68        msgpack::pack_str_map(&entries)
69    }
70
71    /// Unpack an advertisement from msgpack bytes.
72    pub fn unpack(data: &[u8]) -> Result<Self, ResourceError> {
73        let value = msgpack::unpack_exact(data).map_err(|_| ResourceError::InvalidAdvertisement)?;
74
75        let t = value.map_get("t")
76            .and_then(|v| v.as_uint())
77            .ok_or(ResourceError::InvalidAdvertisement)?;
78        let d = value.map_get("d")
79            .and_then(|v| v.as_uint())
80            .ok_or(ResourceError::InvalidAdvertisement)?;
81        let n = value.map_get("n")
82            .and_then(|v| v.as_uint())
83            .ok_or(ResourceError::InvalidAdvertisement)?;
84        let h = value.map_get("h")
85            .and_then(|v| v.as_bin())
86            .ok_or(ResourceError::InvalidAdvertisement)?
87            .to_vec();
88        let r = value.map_get("r")
89            .and_then(|v| v.as_bin())
90            .ok_or(ResourceError::InvalidAdvertisement)?
91            .to_vec();
92        let o = value.map_get("o")
93            .and_then(|v| v.as_bin())
94            .ok_or(ResourceError::InvalidAdvertisement)?
95            .to_vec();
96        let m = value.map_get("m")
97            .and_then(|v| v.as_bin())
98            .ok_or(ResourceError::InvalidAdvertisement)?
99            .to_vec();
100        let f = value.map_get("f")
101            .and_then(|v| v.as_uint())
102            .ok_or(ResourceError::InvalidAdvertisement)? as u8;
103        let i = value.map_get("i")
104            .and_then(|v| v.as_uint())
105            .ok_or(ResourceError::InvalidAdvertisement)?;
106        let l = value.map_get("l")
107            .and_then(|v| v.as_uint())
108            .ok_or(ResourceError::InvalidAdvertisement)?;
109
110        let q_val = value.map_get("q").ok_or(ResourceError::InvalidAdvertisement)?;
111        let request_id = if q_val.is_nil() {
112            None
113        } else {
114            Some(q_val.as_bin().ok_or(ResourceError::InvalidAdvertisement)?.to_vec())
115        };
116
117        Ok(ResourceAdvertisement {
118            transfer_size: t,
119            data_size: d,
120            num_parts: n,
121            resource_hash: h,
122            random_hash: r,
123            original_hash: o,
124            hashmap: m,
125            flags: AdvFlags::from_byte(f),
126            segment_index: i,
127            total_segments: l,
128            request_id,
129        })
130    }
131
132    /// Check if this advertisement is a request.
133    pub fn is_request(&self) -> bool {
134        self.request_id.is_some() && self.flags.is_request
135    }
136
137    /// Check if this advertisement is a response.
138    pub fn is_response(&self) -> bool {
139        self.request_id.is_some() && self.flags.is_response
140    }
141
142    /// Get the number of hashmap segments needed.
143    pub fn hashmap_segments(&self) -> usize {
144        let total_hashes = self.num_parts as usize;
145        if total_hashes == 0 {
146            return 1;
147        }
148        (total_hashes + RESOURCE_HASHMAP_MAX_LEN - 1) / RESOURCE_HASHMAP_MAX_LEN
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    fn make_adv(flags: AdvFlags) -> ResourceAdvertisement {
157        ResourceAdvertisement {
158            transfer_size: 1000,
159            data_size: 950,
160            num_parts: 3,
161            resource_hash: vec![0x11; 32],
162            random_hash: vec![0xAA, 0xBB, 0xCC, 0xDD],
163            original_hash: vec![0x22; 32],
164            hashmap: vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C],
165            flags,
166            segment_index: 1,
167            total_segments: 1,
168            request_id: None,
169        }
170    }
171
172    #[test]
173    fn test_pack_unpack_roundtrip() {
174        let flags = AdvFlags {
175            encrypted: true,
176            compressed: false,
177            split: false,
178            is_request: false,
179            is_response: false,
180            has_metadata: false,
181        };
182        let adv = make_adv(flags);
183        let packed = adv.pack(0);
184        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
185
186        assert_eq!(unpacked.transfer_size, 1000);
187        assert_eq!(unpacked.data_size, 950);
188        assert_eq!(unpacked.num_parts, 3);
189        assert_eq!(unpacked.resource_hash, vec![0x11; 32]);
190        assert_eq!(unpacked.random_hash, vec![0xAA, 0xBB, 0xCC, 0xDD]);
191        assert_eq!(unpacked.original_hash, vec![0x22; 32]);
192        assert_eq!(unpacked.flags, flags);
193        assert_eq!(unpacked.segment_index, 1);
194        assert_eq!(unpacked.total_segments, 1);
195        assert!(unpacked.request_id.is_none());
196    }
197
198    #[test]
199    fn test_flags_encrypted_compressed() {
200        let flags = AdvFlags {
201            encrypted: true,
202            compressed: true,
203            split: false,
204            is_request: false,
205            is_response: false,
206            has_metadata: false,
207        };
208        let adv = make_adv(flags);
209        let packed = adv.pack(0);
210        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
211        assert!(unpacked.flags.encrypted);
212        assert!(unpacked.flags.compressed);
213        assert!(!unpacked.flags.split);
214    }
215
216    #[test]
217    fn test_flags_with_metadata() {
218        let flags = AdvFlags {
219            encrypted: true,
220            compressed: false,
221            split: false,
222            is_request: false,
223            is_response: false,
224            has_metadata: true,
225        };
226        let adv = make_adv(flags);
227        let packed = adv.pack(0);
228        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
229        assert!(unpacked.flags.has_metadata);
230    }
231
232    #[test]
233    fn test_multi_segment() {
234        let flags = AdvFlags {
235            encrypted: true,
236            compressed: false,
237            split: true,
238            is_request: false,
239            is_response: false,
240            has_metadata: false,
241        };
242        let mut adv = make_adv(flags);
243        adv.segment_index = 2;
244        adv.total_segments = 5;
245        let packed = adv.pack(0);
246        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
247        assert!(unpacked.flags.split);
248        assert_eq!(unpacked.segment_index, 2);
249        assert_eq!(unpacked.total_segments, 5);
250    }
251
252    #[test]
253    fn test_with_request_id() {
254        let flags = AdvFlags {
255            encrypted: true,
256            compressed: false,
257            split: false,
258            is_request: true,
259            is_response: false,
260            has_metadata: false,
261        };
262        let mut adv = make_adv(flags);
263        adv.request_id = Some(vec![0xDE, 0xAD, 0xBE, 0xEF]);
264        let packed = adv.pack(0);
265        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
266        assert!(unpacked.is_request());
267        assert!(!unpacked.is_response());
268        assert_eq!(unpacked.request_id, Some(vec![0xDE, 0xAD, 0xBE, 0xEF]));
269    }
270
271    #[test]
272    fn test_is_response() {
273        let flags = AdvFlags {
274            encrypted: true,
275            compressed: false,
276            split: false,
277            is_request: false,
278            is_response: true,
279            has_metadata: false,
280        };
281        let mut adv = make_adv(flags);
282        adv.request_id = Some(vec![0x42; 16]);
283        assert!(adv.is_response());
284        assert!(!adv.is_request());
285    }
286
287    #[test]
288    fn test_nil_request_id() {
289        let flags = AdvFlags {
290            encrypted: true,
291            compressed: false,
292            split: false,
293            is_request: false,
294            is_response: false,
295            has_metadata: false,
296        };
297        let adv = make_adv(flags);
298        let packed = adv.pack(0);
299        let unpacked = ResourceAdvertisement::unpack(&packed).unwrap();
300        assert!(unpacked.request_id.is_none());
301        assert!(!unpacked.is_request());
302        assert!(!unpacked.is_response());
303    }
304
305    #[test]
306    fn test_hashmap_segmentation() {
307        // Create a large hashmap with > HASHMAP_MAX_LEN(74) hashes
308        let num_hashes = 100;
309        let hashmap: Vec<u8> = (0..num_hashes).flat_map(|i| vec![i as u8; 4]).collect();
310
311        let flags = AdvFlags {
312            encrypted: true,
313            compressed: false,
314            split: false,
315            is_request: false,
316            is_response: false,
317            has_metadata: false,
318        };
319        let adv = ResourceAdvertisement {
320            transfer_size: 50000,
321            data_size: 48000,
322            num_parts: num_hashes,
323            resource_hash: vec![0x11; 32],
324            random_hash: vec![0xAA; 4],
325            original_hash: vec![0x22; 32],
326            hashmap: hashmap.clone(),
327            flags,
328            segment_index: 1,
329            total_segments: 1,
330            request_id: None,
331        };
332
333        // Segment 0: first 74 hashes = 296 bytes
334        let packed0 = adv.pack(0);
335        let unpacked0 = ResourceAdvertisement::unpack(&packed0).unwrap();
336        assert_eq!(unpacked0.hashmap.len(), 74 * 4);
337
338        // Segment 1: remaining 26 hashes = 104 bytes
339        let packed1 = adv.pack(1);
340        let unpacked1 = ResourceAdvertisement::unpack(&packed1).unwrap();
341        assert_eq!(unpacked1.hashmap.len(), 26 * 4);
342    }
343
344    #[test]
345    fn test_hashmap_segments_count() {
346        let flags = AdvFlags {
347            encrypted: true, compressed: false, split: false,
348            is_request: false, is_response: false, has_metadata: false,
349        };
350        let mut adv = make_adv(flags);
351
352        adv.num_parts = 74; // exactly HASHMAP_MAX_LEN
353        assert_eq!(adv.hashmap_segments(), 1);
354
355        adv.num_parts = 75;
356        assert_eq!(adv.hashmap_segments(), 2);
357
358        adv.num_parts = 148;
359        assert_eq!(adv.hashmap_segments(), 2);
360
361        adv.num_parts = 149;
362        assert_eq!(adv.hashmap_segments(), 3);
363    }
364
365    #[test]
366    fn test_unpack_invalid_data() {
367        assert!(ResourceAdvertisement::unpack(&[]).is_err());
368        assert!(ResourceAdvertisement::unpack(&[0xc0]).is_err()); // nil
369        assert!(ResourceAdvertisement::unpack(&[0x01, 0x02]).is_err()); // not a map
370    }
371}