1use core::{
2 mem::{align_of, size_of},
3 slice,
4};
5use std::{cmp::Ordering, ops::Range};
6
7use pinocchio::error::ProgramError;
8use static_assertions::const_assert;
9
10use crate::{error::DlpError, require_eq, require_ge, require_le, require_lt};
11
12#[derive(Debug, Clone, Copy)]
13pub enum SizeChanged {
14 Expanded(usize),
15 Shrunk(usize),
16}
17
18#[repr(C)]
19#[derive(Debug, Clone, Copy)]
20pub struct OffsetPair {
21 pub offset_in_diff: u32,
22 pub offset_in_data: u32,
23}
24
25const_assert!(align_of::<OffsetPair>() == align_of::<u32>());
26const_assert!(size_of::<OffsetPair>() == 8);
27
28pub type OffsetInData = Range<usize>;
30
31pub const SIZE_OF_CHANGED_LEN: usize = size_of::<u32>();
32pub const SIZE_OF_NUM_OFFSET_PAIRS: usize = size_of::<u32>();
33pub const SIZE_OF_SINGLE_OFFSET_PAIR: usize = size_of::<OffsetPair>();
34
35pub struct DiffSet<'a> {
36 buf: *const u8,
37 buflen: usize,
38 changed_len: usize,
39 segments_count: usize,
40 offset_pairs: &'a [OffsetPair],
41 concat_diff: &'a [u8],
42}
43
44impl<'a> DiffSet<'a> {
45 pub fn try_new(diff: &'a [u8]) -> Result<Self, ProgramError> {
46 require_ge!(
51 diff.len(),
52 SIZE_OF_CHANGED_LEN + SIZE_OF_NUM_OFFSET_PAIRS,
53 DlpError::InvalidDiff
54 );
55 require_eq!(
56 diff.as_ptr().align_offset(align_of::<u32>()),
57 0,
58 DlpError::InvalidDiffAlignment
59 );
60
61 let buf = diff.as_ptr();
65 let buflen = diff.len();
66 let changed_len = unsafe { *(buf as *const u32) as usize };
67 let segments_count = unsafe { *(buf.add(4) as *const u32) as usize };
68
69 let mut this = Self {
70 buf,
71 buflen,
72 changed_len,
73 segments_count,
74 offset_pairs: &[],
75 concat_diff: b"",
76 };
77
78 let header_len = SIZE_OF_CHANGED_LEN
79 + SIZE_OF_NUM_OFFSET_PAIRS
80 + segments_count * SIZE_OF_SINGLE_OFFSET_PAIR;
81
82 match diff.len().cmp(&header_len) {
83 Ordering::Equal => {
84 require_eq!(this.segments_count(), 0, DlpError::InvalidDiff);
87 }
88 Ordering::Less => {
89 pinocchio_log::log!(
90 "segments_count {} is invalid, or diff {} is truncated",
91 this.segments_count(),
92 diff.len()
93 );
94 return Err(DlpError::InvalidDiff.into());
95 }
96 Ordering::Greater => {
97 this.offset_pairs = unsafe {
102 let raw_pairs = buf
103 .add(SIZE_OF_CHANGED_LEN + SIZE_OF_NUM_OFFSET_PAIRS)
104 as *const OffsetPair;
105 slice::from_raw_parts(raw_pairs, segments_count)
106 };
107 this.concat_diff = &diff[header_len..];
108 }
109 }
110
111 Ok(this)
112 }
113
114 pub fn try_new_from_borsh_vec(
115 vec_buffer: &'a [u8],
116 ) -> Result<Self, ProgramError> {
117 if vec_buffer.len() < 4 {
118 return Err(ProgramError::InvalidInstructionData);
119 }
120 Self::try_new(&vec_buffer[4..])
121 }
122
123 pub fn raw_diff(&self) -> &'a [u8] {
124 unsafe { slice::from_raw_parts(self.buf, self.buflen) }
127 }
128
129 pub fn changed_len(&self) -> usize {
132 self.changed_len
133 }
134
135 pub fn segments_count(&self) -> usize {
137 self.segments_count
138 }
139
140 pub fn offset_pairs(&self) -> &'a [OffsetPair] {
142 self.offset_pairs
143 }
144
145 pub fn diff_segment_at(
150 &self,
151 index: usize,
152 ) -> Result<Option<(&'a [u8], OffsetInData)>, ProgramError> {
153 let offsets = self.offset_pairs();
154 if index >= offsets.len() {
155 return Ok(None);
156 }
157
158 let OffsetPair {
159 offset_in_diff: segment_begin,
160 offset_in_data,
161 } = offsets[index];
162
163 let segment_end = if index + 1 < offsets.len() {
164 offsets[index + 1].offset_in_diff
165 } else {
166 self.concat_diff.len() as u32
167 };
168
169 require_lt!(segment_begin, segment_end, DlpError::InvalidDiff);
171 require_le!(
172 segment_end as usize,
173 self.concat_diff.len(),
174 DlpError::InvalidDiff
175 );
176 require_lt!(
177 offset_in_data as usize,
178 self.changed_len(),
179 DlpError::InvalidDiff
180 );
181
182 let segment =
183 &self.concat_diff[segment_begin as usize..segment_end as usize];
184 let range = offset_in_data as usize
185 ..(offset_in_data + segment_end - segment_begin) as usize;
186
187 require_le!(range.end, self.changed_len(), DlpError::InvalidDiff);
188
189 Ok(Some((segment, range)))
190 }
191
192 pub fn iter(
194 &self,
195 ) -> impl Iterator<Item = Result<(&'a [u8], OffsetInData), ProgramError>> + '_
196 {
197 (0..self.segments_count).map(|index| {
198 self.diff_segment_at(index).and_then(|maybe_value| {
199 maybe_value.ok_or_else(|| {
200 pinocchio_log::log!(
201 "index can never be greater than segments_count"
202 );
203 DlpError::InfallibleError.into()
204 })
205 })
206 })
207 }
208}