1use ipfrs_core::Cid;
9use serde::{Deserialize, Deserializer, Serialize, Serializer};
10use std::ops::Range;
11use thiserror::Error;
12
13fn serialize_cid<S>(cid: &Cid, serializer: S) -> Result<S::Ok, S::Error>
15where
16 S: Serializer,
17{
18 serializer.serialize_str(&cid.to_string())
19}
20
21fn deserialize_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
23where
24 D: Deserializer<'de>,
25{
26 let s = String::deserialize(deserializer)?;
27 s.parse().map_err(serde::de::Error::custom)
28}
29
30#[derive(Error, Debug)]
32pub enum RangeError {
33 #[error("Invalid range: {0}")]
34 InvalidRange(String),
35 #[error("Range out of bounds: requested {requested}, available {available}")]
36 OutOfBounds { requested: u64, available: u64 },
37 #[error("Block not found: {0}")]
38 BlockNotFound(Cid),
39 #[error("Unsatisfiable range")]
40 Unsatisfiable,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45pub enum ByteRange {
46 FromTo { start: u64, end: u64 },
48 From(u64),
50 Suffix(u64),
52 All,
54}
55
56impl ByteRange {
57 pub fn from_to(start: u64, end: u64) -> Result<Self, RangeError> {
59 if start > end {
60 return Err(RangeError::InvalidRange(format!(
61 "start ({}) > end ({})",
62 start, end
63 )));
64 }
65 Ok(ByteRange::FromTo { start, end })
66 }
67
68 pub fn from(start: u64) -> Self {
70 ByteRange::From(start)
71 }
72
73 pub fn suffix(count: u64) -> Self {
75 ByteRange::Suffix(count)
76 }
77
78 pub fn to_range(&self, total_size: u64) -> Result<Range<u64>, RangeError> {
80 match self {
81 ByteRange::FromTo { start, end } => {
82 if *end >= total_size {
83 return Err(RangeError::OutOfBounds {
84 requested: *end,
85 available: total_size,
86 });
87 }
88 Ok(*start..*end + 1)
89 }
90 ByteRange::From(start) => {
91 if *start >= total_size {
92 return Err(RangeError::OutOfBounds {
93 requested: *start,
94 available: total_size,
95 });
96 }
97 Ok(*start..total_size)
98 }
99 ByteRange::Suffix(count) => {
100 if *count > total_size {
101 Ok(0..total_size)
102 } else {
103 Ok(total_size - count..total_size)
104 }
105 }
106 ByteRange::All => Ok(0..total_size),
107 }
108 }
109
110 pub fn overlaps(&self, other: &ByteRange, total_size: u64) -> bool {
112 if let (Ok(r1), Ok(r2)) = (self.to_range(total_size), other.to_range(total_size)) {
113 r1.start < r2.end && r2.start < r1.end
114 } else {
115 false
116 }
117 }
118
119 pub fn merge(&self, other: &ByteRange, total_size: u64) -> Option<ByteRange> {
121 if let (Ok(r1), Ok(r2)) = (self.to_range(total_size), other.to_range(total_size)) {
122 if r1.start <= r2.end && r2.start <= r1.end {
123 let start = r1.start.min(r2.start);
124 let end = (r1.end - 1).max(r2.end - 1);
125 Some(ByteRange::FromTo { start, end })
126 } else {
127 None
128 }
129 } else {
130 None
131 }
132 }
133
134 pub fn size(&self, total_size: u64) -> u64 {
136 self.to_range(total_size)
137 .map(|r| r.end - r.start)
138 .unwrap_or(0)
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct RangeRequest {
145 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
147 pub cid: Cid,
148 pub range: ByteRange,
150 pub priority: i32,
152}
153
154impl RangeRequest {
155 pub fn new(cid: Cid, range: ByteRange) -> Self {
157 Self {
158 cid,
159 range,
160 priority: 0,
161 }
162 }
163
164 pub fn with_priority(cid: Cid, range: ByteRange, priority: i32) -> Self {
166 Self {
167 cid,
168 range,
169 priority,
170 }
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct RangeResponse {
177 #[serde(serialize_with = "serialize_cid", deserialize_with = "deserialize_cid")]
179 pub cid: Cid,
180 pub range: Range<u64>,
182 pub data: Vec<u8>,
184 pub total_size: u64,
186}
187
188impl RangeResponse {
189 pub fn new(cid: Cid, range: Range<u64>, data: Vec<u8>, total_size: u64) -> Self {
191 Self {
192 cid,
193 range,
194 data,
195 total_size,
196 }
197 }
198
199 pub fn satisfies(&self, request: &RangeRequest) -> bool {
201 if self.cid != request.cid {
202 return false;
203 }
204 if let Ok(req_range) = request.range.to_range(self.total_size) {
205 self.range.start <= req_range.start && self.range.end >= req_range.end
206 } else {
207 false
208 }
209 }
210
211 pub fn extract_range(&self, range: &Range<u64>) -> Result<Vec<u8>, RangeError> {
213 if range.start < self.range.start || range.end > self.range.end {
214 return Err(RangeError::OutOfBounds {
215 requested: range.end,
216 available: self.range.end,
217 });
218 }
219
220 let offset = (range.start - self.range.start) as usize;
221 let len = (range.end - range.start) as usize;
222
223 if offset + len > self.data.len() {
224 return Err(RangeError::OutOfBounds {
225 requested: (offset + len) as u64,
226 available: self.data.len() as u64,
227 });
228 }
229
230 Ok(self.data[offset..offset + len].to_vec())
231 }
232}
233
234pub struct RangeAssembler {
236 cid: Cid,
238 total_size: u64,
240 received: Vec<(Range<u64>, Vec<u8>)>,
242}
243
244impl RangeAssembler {
245 pub fn new(cid: Cid, total_size: u64) -> Self {
247 Self {
248 cid,
249 total_size,
250 received: Vec::new(),
251 }
252 }
253
254 pub fn add_range(&mut self, response: RangeResponse) -> Result<(), RangeError> {
256 if response.cid != self.cid {
257 return Err(RangeError::InvalidRange("CID mismatch".to_string()));
258 }
259
260 if response.total_size != self.total_size {
261 return Err(RangeError::InvalidRange("Total size mismatch".to_string()));
262 }
263
264 self.received.push((response.range, response.data));
265 Ok(())
266 }
267
268 pub fn is_complete(&self) -> bool {
270 let mut covered = vec![false; self.total_size as usize];
271
272 for (range, _) in &self.received {
273 for i in range.start..range.end {
274 if (i as usize) < covered.len() {
275 covered[i as usize] = true;
276 }
277 }
278 }
279
280 covered.iter().all(|&x| x)
281 }
282
283 pub fn missing_ranges(&self) -> Vec<Range<u64>> {
285 let mut covered = vec![false; self.total_size as usize];
286
287 for (range, _) in &self.received {
288 for i in range.start..range.end {
289 if (i as usize) < covered.len() {
290 covered[i as usize] = true;
291 }
292 }
293 }
294
295 let mut missing = Vec::new();
296 let mut start = None;
297
298 for (i, &is_covered) in covered.iter().enumerate() {
299 if !is_covered && start.is_none() {
300 start = Some(i as u64);
301 } else if is_covered && start.is_some() {
302 missing.push(start.unwrap()..i as u64);
303 start = None;
304 }
305 }
306
307 if let Some(s) = start {
308 missing.push(s..self.total_size);
309 }
310
311 missing
312 }
313
314 pub fn assemble(&self) -> Result<Vec<u8>, RangeError> {
316 if !self.is_complete() {
317 return Err(RangeError::InvalidRange("Block incomplete".to_string()));
318 }
319
320 let mut data = vec![0u8; self.total_size as usize];
321
322 for (range, chunk) in &self.received {
323 let start = range.start as usize;
324 let end = range.end as usize;
325 let len = end - start;
326
327 if chunk.len() != len {
328 return Err(RangeError::InvalidRange("Chunk size mismatch".to_string()));
329 }
330
331 data[start..end].copy_from_slice(chunk);
332 }
333
334 Ok(data)
335 }
336
337 pub fn completion_percentage(&self) -> f64 {
339 let mut covered = vec![false; self.total_size as usize];
340
341 for (range, _) in &self.received {
342 for i in range.start..range.end {
343 if (i as usize) < covered.len() {
344 covered[i as usize] = true;
345 }
346 }
347 }
348
349 let covered_count = covered.iter().filter(|&&x| x).count();
350 (covered_count as f64 / self.total_size as f64) * 100.0
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 fn test_cid() -> Cid {
359 "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
360 .parse()
361 .unwrap()
362 }
363
364 #[test]
365 fn test_byte_range_from_to() {
366 let range = ByteRange::from_to(0, 99).unwrap();
367 assert_eq!(range.to_range(1000).unwrap(), 0..100);
368 }
369
370 #[test]
371 fn test_byte_range_from() {
372 let range = ByteRange::from(500);
373 assert_eq!(range.to_range(1000).unwrap(), 500..1000);
374 }
375
376 #[test]
377 fn test_byte_range_suffix() {
378 let range = ByteRange::suffix(100);
379 assert_eq!(range.to_range(1000).unwrap(), 900..1000);
380 }
381
382 #[test]
383 fn test_byte_range_all() {
384 let range = ByteRange::All;
385 assert_eq!(range.to_range(1000).unwrap(), 0..1000);
386 }
387
388 #[test]
389 fn test_byte_range_out_of_bounds() {
390 let range = ByteRange::from_to(0, 1500).unwrap();
391 assert!(range.to_range(1000).is_err());
392 }
393
394 #[test]
395 fn test_byte_range_invalid() {
396 assert!(ByteRange::from_to(100, 50).is_err());
397 }
398
399 #[test]
400 fn test_byte_range_overlaps() {
401 let range1 = ByteRange::from_to(0, 99).unwrap();
402 let range2 = ByteRange::from_to(50, 149).unwrap();
403 assert!(range1.overlaps(&range2, 1000));
404
405 let range3 = ByteRange::from_to(200, 299).unwrap();
406 assert!(!range1.overlaps(&range3, 1000));
407 }
408
409 #[test]
410 fn test_byte_range_merge() {
411 let range1 = ByteRange::from_to(0, 99).unwrap();
412 let range2 = ByteRange::from_to(50, 149).unwrap();
413 let merged = range1.merge(&range2, 1000).unwrap();
414 assert_eq!(merged.to_range(1000).unwrap(), 0..150);
415 }
416
417 #[test]
418 fn test_byte_range_size() {
419 let range = ByteRange::from_to(100, 199).unwrap();
420 assert_eq!(range.size(1000), 100);
421 }
422
423 #[test]
424 fn test_range_request() {
425 let cid = test_cid();
426 let range = ByteRange::from_to(0, 99).unwrap();
427 let req = RangeRequest::new(cid, range);
428 assert_eq!(req.priority, 0);
429
430 let req2 = RangeRequest::with_priority(cid, range, 10);
431 assert_eq!(req2.priority, 10);
432 }
433
434 #[test]
435 fn test_range_response_satisfies() {
436 let cid = test_cid();
437 let range = ByteRange::from_to(0, 99).unwrap();
438 let req = RangeRequest::new(cid, range);
439
440 let response = RangeResponse::new(cid, 0..100, vec![0u8; 100], 1000);
441 assert!(response.satisfies(&req));
442
443 let response2 = RangeResponse::new(cid, 50..150, vec![0u8; 100], 1000);
444 assert!(!response2.satisfies(&req));
445 }
446
447 #[test]
448 fn test_range_response_extract() {
449 let cid = test_cid();
450 let data = (0..100).collect::<Vec<u8>>();
451 let response = RangeResponse::new(cid, 0..100, data.clone(), 1000);
452
453 let extracted = response.extract_range(&(10..20)).unwrap();
454 assert_eq!(extracted, &data[10..20]);
455 }
456
457 #[test]
458 fn test_range_assembler() {
459 let cid = test_cid();
460 let mut assembler = RangeAssembler::new(cid, 100);
461
462 assert!(!assembler.is_complete());
463 assert_eq!(assembler.completion_percentage(), 0.0);
464
465 let resp1 = RangeResponse::new(cid, 0..50, vec![1u8; 50], 100);
466 assembler.add_range(resp1).unwrap();
467 assert_eq!(assembler.completion_percentage(), 50.0);
468
469 let resp2 = RangeResponse::new(cid, 50..100, vec![2u8; 50], 100);
470 assembler.add_range(resp2).unwrap();
471 assert!(assembler.is_complete());
472 assert_eq!(assembler.completion_percentage(), 100.0);
473
474 let data = assembler.assemble().unwrap();
475 assert_eq!(data.len(), 100);
476 assert_eq!(&data[0..50], &vec![1u8; 50][..]);
477 assert_eq!(&data[50..100], &vec![2u8; 50][..]);
478 }
479
480 #[test]
481 fn test_range_assembler_missing_ranges() {
482 let cid = test_cid();
483 let mut assembler = RangeAssembler::new(cid, 100);
484
485 let resp1 = RangeResponse::new(cid, 0..25, vec![0u8; 25], 100);
486 assembler.add_range(resp1).unwrap();
487
488 let resp2 = RangeResponse::new(cid, 75..100, vec![0u8; 25], 100);
489 assembler.add_range(resp2).unwrap();
490
491 let missing = assembler.missing_ranges();
492 assert_eq!(missing, vec![25..75]);
493 }
494
495 #[test]
496 fn test_range_assembler_overlapping() {
497 let cid = test_cid();
498 let mut assembler = RangeAssembler::new(cid, 100);
499
500 let resp1 = RangeResponse::new(cid, 0..60, vec![1u8; 60], 100);
501 assembler.add_range(resp1).unwrap();
502
503 let resp2 = RangeResponse::new(cid, 40..100, vec![2u8; 60], 100);
504 assembler.add_range(resp2).unwrap();
505
506 assert!(assembler.is_complete());
507 }
508
509 #[test]
510 fn test_range_assembler_incomplete() {
511 let cid = test_cid();
512 let mut assembler = RangeAssembler::new(cid, 100);
513
514 let resp = RangeResponse::new(cid, 0..50, vec![0u8; 50], 100);
515 assembler.add_range(resp).unwrap();
516
517 assert!(assembler.assemble().is_err());
518 }
519}