1use core::{
2 mem::{align_of, size_of},
3 slice,
4};
5use std::{cmp::Ordering, ops::Range};
6
7use pinocchio::error::ProgramError;
8use static_assertions::const_assert;
9
10use crate::error::DlpError;
11
12#[derive(Debug, Clone, Copy)]
13pub enum SizeChanged {
14 Expanded(usize),
15 Shrunk(usize),
16}
17
18#[repr(C)]
19#[derive(Debug, Clone, Copy)]
20pub struct OffsetPair {
21 pub offset_in_diff: u32,
22 pub offset_in_data: u32,
23}
24
25const_assert!(align_of::<OffsetPair>() == align_of::<u32>());
26const_assert!(size_of::<OffsetPair>() == 8);
27
28pub type OffsetInData = Range<usize>;
30
31pub const SIZE_OF_CHANGED_LEN: usize = size_of::<u32>();
32pub const SIZE_OF_NUM_OFFSET_PAIRS: usize = size_of::<u32>();
33pub const SIZE_OF_SINGLE_OFFSET_PAIR: usize = size_of::<OffsetPair>();
34
35pub struct DiffSet<'a> {
36 buf: *const u8,
37 buflen: usize,
38 changed_len: usize,
39 segments_count: usize,
40 offset_pairs: &'a [OffsetPair],
41 concat_diff: &'a [u8],
42}
43
44impl<'a> DiffSet<'a> {
45 pub fn try_new(diff: &'a [u8]) -> Result<Self, ProgramError> {
46 if diff.len() < (SIZE_OF_CHANGED_LEN + SIZE_OF_NUM_OFFSET_PAIRS) {
51 return Err(DlpError::InvalidDiff.into());
52 } else if diff.as_ptr().align_offset(align_of::<u32>()) != 0 {
53 return Err(DlpError::InvalidDiffAlignment.into());
54 }
55
56 let buf = diff.as_ptr();
60 let buflen = diff.len();
61 let changed_len = unsafe { *(buf as *const u32) as usize };
62 let segments_count = unsafe { *(buf.add(4) as *const u32) as usize };
63
64 let mut this = Self {
65 buf,
66 buflen,
67 changed_len,
68 segments_count,
69 offset_pairs: &[],
70 concat_diff: b"",
71 };
72
73 let header_len = SIZE_OF_CHANGED_LEN
74 + SIZE_OF_NUM_OFFSET_PAIRS
75 + segments_count * SIZE_OF_SINGLE_OFFSET_PAIR;
76
77 match diff.len().cmp(&header_len) {
78 Ordering::Equal => {
79 if this.segments_count() != 0 {
82 return Err(DlpError::InvalidDiff.into());
83 }
84 }
85 Ordering::Less => {
86 return Err(DlpError::InvalidDiff.into());
88 }
89 Ordering::Greater => {
90 this.offset_pairs = unsafe {
95 let raw_pairs = buf
96 .add(SIZE_OF_CHANGED_LEN + SIZE_OF_NUM_OFFSET_PAIRS)
97 as *const OffsetPair;
98 slice::from_raw_parts(raw_pairs, segments_count)
99 };
100 this.concat_diff = &diff[header_len..];
101 }
102 }
103
104 Ok(this)
105 }
106
107 pub fn try_new_from_borsh_vec(
108 vec_buffer: &'a [u8],
109 ) -> Result<Self, ProgramError> {
110 if vec_buffer.len() < 4 {
111 return Err(ProgramError::InvalidInstructionData);
112 }
113 Self::try_new(&vec_buffer[4..])
114 }
115
116 pub fn raw_diff(&self) -> &'a [u8] {
117 unsafe { slice::from_raw_parts(self.buf, self.buflen) }
120 }
121
122 pub fn changed_len(&self) -> usize {
125 self.changed_len
126 }
127
128 pub fn segments_count(&self) -> usize {
130 self.segments_count
131 }
132
133 pub fn offset_pairs(&self) -> &'a [OffsetPair] {
135 self.offset_pairs
136 }
137
138 pub fn diff_segment_at(
143 &self,
144 index: usize,
145 ) -> Result<Option<(&'a [u8], OffsetInData)>, ProgramError> {
146 let offsets = self.offset_pairs();
147 if index >= offsets.len() {
148 return Ok(None);
149 }
150
151 let OffsetPair {
152 offset_in_diff: segment_begin,
153 offset_in_data,
154 } = offsets[index];
155
156 let segment_end = if index + 1 < offsets.len() {
157 offsets[index + 1].offset_in_diff
158 } else {
159 self.concat_diff.len() as u32
160 };
161
162 if segment_end > self.concat_diff.len() as u32
164 || segment_begin >= segment_end
165 || offset_in_data >= self.changed_len() as u32
166 {
167 return Err(DlpError::InvalidDiff.into());
168 }
169
170 let segment =
171 &self.concat_diff[segment_begin as usize..segment_end as usize];
172 let range = offset_in_data as usize
173 ..(offset_in_data + segment_end - segment_begin) as usize;
174
175 if range.end > self.changed_len() {
176 return Err(DlpError::InvalidDiff.into());
177 }
178
179 Ok(Some((segment, range)))
180 }
181
182 pub fn iter(
184 &self,
185 ) -> impl Iterator<Item = Result<(&'a [u8], OffsetInData), ProgramError>> + '_
186 {
187 (0..self.segments_count).map(|index| {
188 self.diff_segment_at(index)
189 .map(|val| val.expect("impossible: index can never be greater than segments_count"))
190 })
191 }
192}