1use alloc::vec;
2
3use aranya_buggy::BugExt;
4use aranya_crypto::Csprng;
5use heapless::Vec;
6use serde::{Deserialize, Serialize};
7
8use super::{
9 responder::SyncResponseMessage, PeerCache, SyncCommand, SyncError, COMMAND_RESPONSE_MAX,
10 COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, REQUEST_MISSING_MAX,
11};
12use crate::{
13 storage::{Segment, Storage, StorageError, StorageProvider},
14 Address, Command, GraphId, Location,
15};
16
17#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26 SyncRequest {
28 session_id: u128,
30 storage_id: GraphId,
32 max_bytes: u64,
35 commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40 },
41
42 RequestMissing {
45 session_id: u128,
48 indexes: Vec<u64, REQUEST_MISSING_MAX>,
50 },
51
52 SyncResume {
56 session_id: u128,
59 response_index: u64,
61 max_bytes: u64,
64 },
65
66 EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72 pub fn session_id(&self) -> u128 {
73 match self {
74 Self::SyncRequest { session_id, .. } => *session_id,
75 Self::RequestMissing { session_id, .. } => *session_id,
76 Self::SyncResume { session_id, .. } => *session_id,
77 Self::EndSession { session_id, .. } => *session_id,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83pub enum SyncRequesterState {
84 New,
85 Start,
86 Waiting,
87 Idle,
88 Closed,
89 Resync,
90 PartialSync,
91 Reset,
92}
93
94const OOO_LEN: usize = 4;
96pub struct SyncRequester<'a> {
97 session_id: u128,
98 storage_id: GraphId,
99 state: SyncRequesterState,
100 max_bytes: u64,
101 next_index: u64,
102 #[allow(unused)] ooo_buffer: [Option<&'a [u8]>; OOO_LEN],
104}
105
106impl SyncRequester<'_> {
107 pub fn new<R: Csprng>(storage_id: GraphId, rng: &mut R) -> Self {
109 let mut dst = [0u8; 16];
111 rng.fill_bytes(&mut dst);
112 let session_id = u128::from_le_bytes(dst);
113
114 SyncRequester {
115 session_id,
116 storage_id,
117 state: SyncRequesterState::New,
118 max_bytes: 0,
119 next_index: 0,
120 ooo_buffer: core::array::from_fn(|_| None),
121 }
122 }
123
124 pub fn ready(&self) -> bool {
126 use SyncRequesterState as S;
127 match self.state {
128 S::New | S::Resync | S::Reset => true,
129 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
130 }
131 }
132
133 pub fn poll(
136 &mut self,
137 target: &mut [u8],
138 provider: &mut impl StorageProvider,
139 heads: &mut PeerCache,
140 ) -> Result<(usize, usize), SyncError> {
141 use SyncRequesterState as S;
142 let result = match self.state {
143 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
144 return Err(SyncError::NotReady)
145 }
146 S::New => {
147 self.state = S::Start;
148 self.start(self.max_bytes, target, provider, heads)?
149 }
150 S::Resync => self.resume(self.max_bytes, target)?,
151 S::Reset => {
152 self.state = S::Closed;
153 self.end_session(target)?
154 }
155 };
156
157 Ok(result)
158 }
159
160 pub fn receive<'a>(
162 &mut self,
163 data: &'a [u8],
164 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
165 let (message, remaining): (SyncResponseMessage, &'a [u8]) =
166 postcard::take_from_bytes(data)?;
167
168 if message.session_id() != self.session_id {
169 return Err(SyncError::SessionMismatch);
170 }
171
172 let result = match message {
173 SyncResponseMessage::SyncResponse {
174 index, commands, ..
175 } => {
176 if !matches!(
177 self.state,
178 SyncRequesterState::Start | SyncRequesterState::Waiting
179 ) {
180 return Err(SyncError::SessionState);
181 }
182
183 if index != self.next_index {
184 self.state = SyncRequesterState::Resync;
185 return Err(SyncError::MissingSyncResponse);
186 }
187 self.next_index = self
188 .next_index
189 .checked_add(1)
190 .assume("next_index + 1 mustn't overflow")?;
191 self.state = SyncRequesterState::Waiting;
192
193 let mut result = Vec::new();
194 let mut start: usize = 0;
195 for meta in commands {
196 let policy_len = meta.policy_length as usize;
197
198 let policy = match policy_len == 0 {
199 true => None,
200 false => {
201 let end = start
202 .checked_add(policy_len)
203 .assume("start + policy_len mustn't overflow")?;
204 let policy = &remaining[start..end];
205 start = end;
206 Some(policy)
207 }
208 };
209
210 let len = meta.length as usize;
211 let end = start
212 .checked_add(len)
213 .assume("start + len mustn't overflow")?;
214 let payload = &remaining[start..end];
215 start = end;
216
217 let command = SyncCommand {
218 id: meta.id,
219 priority: meta.priority,
220 parent: meta.parent,
221 policy,
222 data: payload,
223 max_cut: meta.max_cut,
224 };
225
226 result
227 .push(command)
228 .ok()
229 .assume("commands is not larger than result")?;
230 }
231
232 Some(result)
233 }
234
235 SyncResponseMessage::SyncEnd { max_index, .. } => {
236 if !matches!(
237 self.state,
238 SyncRequesterState::Start | SyncRequesterState::Waiting
239 ) {
240 return Err(SyncError::SessionState);
241 }
242
243 if max_index
244 != self
245 .next_index
246 .checked_sub(1)
247 .assume("next_index must be positive")?
248 {
249 self.state = SyncRequesterState::Resync;
250 return Err(SyncError::MissingSyncResponse);
251 }
252
253 self.state = SyncRequesterState::PartialSync;
254
255 None
256 }
257
258 SyncResponseMessage::Offer { .. } => {
259 if self.state != SyncRequesterState::Idle {
260 return Err(SyncError::SessionState);
261 }
262 self.state = SyncRequesterState::Resync;
263
264 None
265 }
266
267 SyncResponseMessage::EndSession { .. } => {
268 self.state = SyncRequesterState::Closed;
269 None
270 }
271 };
272
273 Ok(result)
274 }
275
276 fn write(target: &mut [u8], msg: SyncRequestMessage) -> Result<usize, SyncError> {
277 Ok(postcard::to_slice(&msg, target)?.len())
278 }
279
280 fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
281 Ok((
282 Self::write(
283 target,
284 SyncRequestMessage::EndSession {
285 session_id: self.session_id,
286 },
287 )?,
288 0,
289 ))
290 }
291
292 fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
293 if !matches!(
294 self.state,
295 SyncRequesterState::Resync | SyncRequesterState::Idle
296 ) {
297 return Err(SyncError::SessionState);
298 }
299
300 self.state = SyncRequesterState::Waiting;
301 let message = SyncRequestMessage::SyncResume {
302 session_id: self.session_id,
303 response_index: self
304 .next_index
305 .checked_sub(1)
306 .assume("next_index must be positive")?,
307 max_bytes,
308 };
309
310 Ok((Self::write(target, message)?, 0))
311 }
312
313 fn start(
314 &mut self,
315 max_bytes: u64,
316 target: &mut [u8],
317 provider: &mut impl StorageProvider,
318 heads: &mut PeerCache,
319 ) -> Result<(usize, usize), SyncError> {
320 if !matches!(
321 self.state,
322 SyncRequesterState::Start | SyncRequesterState::New
323 ) {
324 self.state = SyncRequesterState::Reset;
325 return Err(SyncError::SessionState);
326 }
327
328 self.state = SyncRequesterState::Start;
329 self.max_bytes = max_bytes;
330
331 let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
332
333 match provider.get_storage(self.storage_id) {
334 Err(StorageError::NoSuchStorage) => (),
335 Err(err) => {
336 return Err(SyncError::Storage(err));
337 }
338 Ok(storage) => {
339 let mut command_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
340 for address in heads.heads() {
341 command_locations
342 .push(
343 storage
344 .get_location(*address)?
345 .assume("location must exist")?,
346 )
347 .ok()
348 .assume("command locations should not be full")?;
349 if commands.len() < COMMAND_SAMPLE_MAX {
350 commands
351 .push(*address)
352 .map_err(|_| SyncError::CommandOverflow)?;
353 }
354 }
355 let head = storage.get_head()?;
356
357 let mut current = vec![head];
358
359 while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
364 let mut next = vec::Vec::new(); 'current: for &location in ¤t {
367 let segment = storage.get_segment(location)?;
368
369 let head = segment.head()?;
370 let head_address = head.address()?;
371 for loc in &command_locations {
372 if loc.segment == location.segment {
373 continue 'current;
374 }
375 }
376 commands
377 .push(head_address)
378 .map_err(|_| SyncError::CommandOverflow)?;
379 next.extend(segment.prior());
380 if commands.len() >= COMMAND_SAMPLE_MAX {
381 break 'current;
382 }
383 }
384
385 current = next.to_vec();
386 }
387 }
388 }
389
390 let sent = commands.len();
391 let message = SyncRequestMessage::SyncRequest {
392 session_id: self.session_id,
393 storage_id: self.storage_id,
394 max_bytes,
395 commands,
396 };
397
398 Ok((Self::write(target, message)?, sent))
399 }
400}