1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt;
5use heapless::Vec;
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8use super::{
9 dispatcher::SyncType, responder::SyncResponseMessage, PeerCache, SyncCommand, SyncError,
10 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, REQUEST_MISSING_MAX,
11};
12use crate::{
13 storage::{Segment, Storage, StorageError, StorageProvider},
14 Address, Command, GraphId, Location,
15};
16
17#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26 SyncRequest {
28 session_id: u128,
30 storage_id: GraphId,
32 max_bytes: u64,
35 commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40 },
41
42 RequestMissing {
45 session_id: u128,
48 indexes: Vec<u64, REQUEST_MISSING_MAX>,
50 },
51
52 SyncResume {
56 session_id: u128,
59 response_index: u64,
61 max_bytes: u64,
64 },
65
66 EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72 pub fn session_id(&self) -> u128 {
73 match self {
74 Self::SyncRequest { session_id, .. } => *session_id,
75 Self::RequestMissing { session_id, .. } => *session_id,
76 Self::SyncResume { session_id, .. } => *session_id,
77 Self::EndSession { session_id, .. } => *session_id,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83pub enum SyncRequesterState {
84 New,
85 Start,
86 Waiting,
87 Idle,
88 Closed,
89 Resync,
90 PartialSync,
91 Reset,
92}
93
94const OOO_LEN: usize = 4;
96pub struct SyncRequester<'a, A> {
97 session_id: u128,
98 storage_id: GraphId,
99 state: SyncRequesterState,
100 max_bytes: u64,
101 next_index: u64,
102 #[allow(unused)] ooo_buffer: [Option<&'a [u8]>; OOO_LEN],
104 server_address: A,
105}
106
107impl<A: DeserializeOwned + Serialize + Clone> SyncRequester<'_, A> {
108 pub fn new<R: Csprng>(storage_id: GraphId, rng: &mut R, server_address: A) -> Self {
110 let mut dst = [0u8; 16];
112 rng.fill_bytes(&mut dst);
113 let session_id = u128::from_le_bytes(dst);
114
115 SyncRequester {
116 session_id,
117 storage_id,
118 state: SyncRequesterState::New,
119 max_bytes: 0,
120 next_index: 0,
121 ooo_buffer: core::array::from_fn(|_| None),
122 server_address,
123 }
124 }
125
126 pub fn new_session_id(storage_id: GraphId, session_id: u128, server_address: A) -> Self {
128 SyncRequester {
129 session_id,
130 storage_id,
131 state: SyncRequesterState::Waiting,
132 max_bytes: 0,
133 next_index: 0,
134 ooo_buffer: core::array::from_fn(|_| None),
135 server_address,
136 }
137 }
138
139 pub fn server_addr(&self) -> A {
141 self.server_address.clone()
142 }
143
144 pub fn ready(&self) -> bool {
146 use SyncRequesterState as S;
147 match self.state {
148 S::New | S::Resync | S::Reset => true,
149 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
150 }
151 }
152
153 pub fn poll(
156 &mut self,
157 target: &mut [u8],
158 provider: &mut impl StorageProvider,
159 heads: &mut PeerCache,
160 ) -> Result<(usize, usize), SyncError> {
161 use SyncRequesterState as S;
162 let result = match self.state {
163 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
164 return Err(SyncError::NotReady)
165 }
166 S::New => {
167 self.state = S::Start;
168 self.start(self.max_bytes, target, provider, heads)?
169 }
170 S::Resync => self.resume(self.max_bytes, target)?,
171 S::Reset => {
172 self.state = S::Closed;
173 self.end_session(target)?
174 }
175 };
176
177 Ok(result)
178 }
179
180 pub fn receive<'a>(
182 &mut self,
183 data: &'a [u8],
184 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
185 let (message, remaining): (SyncResponseMessage, &'a [u8]) =
186 postcard::take_from_bytes(data)?;
187
188 self.get_sync_commands(message, remaining)
189 }
190
191 pub fn get_sync_commands<'a>(
193 &mut self,
194 message: SyncResponseMessage,
195 remaining: &'a [u8],
196 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_SAMPLE_MAX>>, SyncError> {
197 if message.session_id() != self.session_id {
198 return Err(SyncError::SessionMismatch);
199 }
200
201 let result = match message {
202 SyncResponseMessage::SyncResponse {
203 index, commands, ..
204 } => {
205 if !matches!(
206 self.state,
207 SyncRequesterState::Start | SyncRequesterState::Waiting
208 ) {
209 return Err(SyncError::SessionState);
210 }
211
212 if index != self.next_index {
213 self.state = SyncRequesterState::Resync;
214 return Err(SyncError::MissingSyncResponse);
215 }
216 self.next_index = self
217 .next_index
218 .checked_add(1)
219 .assume("next_index + 1 mustn't overflow")?;
220 self.state = SyncRequesterState::Waiting;
221
222 let mut result = Vec::new();
223 let mut start: usize = 0;
224 for meta in commands {
225 let policy_len = meta.policy_length as usize;
226
227 let policy = match policy_len == 0 {
228 true => None,
229 false => {
230 let end = start
231 .checked_add(policy_len)
232 .assume("start + policy_len mustn't overflow")?;
233 let policy = &remaining[start..end];
234 start = end;
235 Some(policy)
236 }
237 };
238
239 let len = meta.length as usize;
240 let end = start
241 .checked_add(len)
242 .assume("start + len mustn't overflow")?;
243 let payload = &remaining[start..end];
244 start = end;
245
246 let command = SyncCommand {
247 id: meta.id,
248 priority: meta.priority,
249 parent: meta.parent,
250 policy,
251 data: payload,
252 max_cut: meta.max_cut,
253 };
254
255 result
256 .push(command)
257 .ok()
258 .assume("commands is not larger than result")?;
259 }
260
261 Some(result)
262 }
263
264 SyncResponseMessage::SyncEnd { max_index, .. } => {
265 if !matches!(
266 self.state,
267 SyncRequesterState::Start | SyncRequesterState::Waiting
268 ) {
269 return Err(SyncError::SessionState);
270 }
271
272 if max_index
273 != self
274 .next_index
275 .checked_sub(1)
276 .assume("next_index must be positive")?
277 {
278 self.state = SyncRequesterState::Resync;
279 return Err(SyncError::MissingSyncResponse);
280 }
281
282 self.state = SyncRequesterState::PartialSync;
283
284 None
285 }
286
287 SyncResponseMessage::Offer { .. } => {
288 if self.state != SyncRequesterState::Idle {
289 return Err(SyncError::SessionState);
290 }
291 self.state = SyncRequesterState::Resync;
292
293 None
294 }
295
296 SyncResponseMessage::EndSession { .. } => {
297 self.state = SyncRequesterState::Closed;
298 None
299 }
300 };
301
302 Ok(result)
303 }
304
305 fn write(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
306 Ok(postcard::to_slice(&msg, target)?.len())
307 }
308
309 fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
310 Ok((
311 Self::write(
312 target,
313 SyncType::Poll {
314 request: SyncRequestMessage::EndSession {
315 session_id: self.session_id,
316 },
317 address: self.server_address.clone(),
318 },
319 )?,
320 0,
321 ))
322 }
323
324 fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
325 if !matches!(
326 self.state,
327 SyncRequesterState::Resync | SyncRequesterState::Idle
328 ) {
329 return Err(SyncError::SessionState);
330 }
331
332 self.state = SyncRequesterState::Waiting;
333 let message = SyncType::Poll {
334 request: SyncRequestMessage::SyncResume {
335 session_id: self.session_id,
336 response_index: self
337 .next_index
338 .checked_sub(1)
339 .assume("next_index must be positive")?,
340 max_bytes,
341 },
342 address: self.server_address.clone(),
343 };
344
345 Ok((Self::write(target, message)?, 0))
346 }
347
348 fn get_commands(
349 &self,
350 provider: &mut impl StorageProvider,
351 heads: &mut PeerCache,
352 ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
353 let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
354
355 match provider.get_storage(self.storage_id) {
356 Err(StorageError::NoSuchStorage) => (),
357 Err(err) => {
358 return Err(SyncError::Storage(err));
359 }
360 Ok(storage) => {
361 let mut command_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
362 for address in heads.heads() {
363 command_locations
364 .push(
365 storage
366 .get_location(*address)?
367 .assume("location must exist")?,
368 )
369 .ok()
370 .assume("command locations should not be full")?;
371 if commands.len() < COMMAND_SAMPLE_MAX {
372 commands
373 .push(*address)
374 .map_err(|_| SyncError::CommandOverflow)?;
375 }
376 }
377 let head = storage.get_head()?;
378
379 let mut current = vec![head];
380
381 while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
386 let mut next = vec::Vec::new(); 'current: for &location in ¤t {
389 let segment = storage.get_segment(location)?;
390
391 let head = segment.head()?;
392 let head_address = head.address()?;
393 for loc in &command_locations {
394 if loc.segment == location.segment {
395 continue 'current;
396 }
397 }
398 commands
399 .push(head_address)
400 .map_err(|_| SyncError::CommandOverflow)?;
401 next.extend(segment.prior());
402 if commands.len() >= COMMAND_SAMPLE_MAX {
403 break 'current;
404 }
405 }
406
407 current = next.to_vec();
408 }
409 }
410 }
411 Ok(commands)
412 }
413
414 pub fn subscribe(
416 &mut self,
417 target: &mut [u8],
418 provider: &mut impl StorageProvider,
419 heads: &mut PeerCache,
420 remain_open: u64,
421 max_bytes: u64,
422 ) -> Result<usize, SyncError> {
423 let commands = self.get_commands(provider, heads)?;
424 let message = SyncType::Subscribe {
425 remain_open,
426 max_bytes,
427 commands,
428 address: self.server_address.clone(),
429 storage_id: self.storage_id,
430 };
431
432 Self::write(target, message)
433 }
434
435 pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
437 let message = SyncType::Unsubscribe {
438 address: self.server_address.clone(),
439 };
440
441 Self::write(target, message)
442 }
443
444 fn start(
445 &mut self,
446 max_bytes: u64,
447 target: &mut [u8],
448 provider: &mut impl StorageProvider,
449 heads: &mut PeerCache,
450 ) -> Result<(usize, usize), SyncError> {
451 if !matches!(
452 self.state,
453 SyncRequesterState::Start | SyncRequesterState::New
454 ) {
455 self.state = SyncRequesterState::Reset;
456 return Err(SyncError::SessionState);
457 }
458
459 self.state = SyncRequesterState::Start;
460 self.max_bytes = max_bytes;
461
462 let commands = self.get_commands(provider, heads)?;
463
464 let sent = commands.len();
465 let message = SyncType::Poll {
466 request: SyncRequestMessage::SyncRequest {
467 session_id: self.session_id,
468 storage_id: self.storage_id,
469 max_bytes,
470 commands,
471 },
472 address: self.server_address.clone(),
473 };
474
475 Ok((Self::write(target, message)?, sent))
476 }
477}