1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt as _;
5use heapless::Vec;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, PeerCache, REQUEST_MISSING_MAX,
10 SyncCommand, SyncError, dispatcher::SyncType, responder::SyncResponseMessage,
11};
12use crate::{
13 Address, Command as _, GraphId, Location,
14 storage::{Segment as _, Storage as _, StorageError, StorageProvider},
15};
16
17#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26 SyncRequest {
28 session_id: u128,
30 storage_id: GraphId,
32 max_bytes: u64,
35 commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40 },
41
42 RequestMissing {
45 session_id: u128,
48 indexes: Vec<u64, REQUEST_MISSING_MAX>,
50 },
51
52 SyncResume {
56 session_id: u128,
59 response_index: u64,
61 max_bytes: u64,
64 },
65
66 EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72 pub fn session_id(&self) -> u128 {
73 match self {
74 Self::SyncRequest { session_id, .. } => *session_id,
75 Self::RequestMissing { session_id, .. } => *session_id,
76 Self::SyncResume { session_id, .. } => *session_id,
77 Self::EndSession { session_id, .. } => *session_id,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83enum SyncRequesterState {
84 New,
86 Start,
88 Waiting,
90 Idle,
92 Closed,
94 Resync,
96 PartialSync,
98 Reset,
100}
101
102pub struct SyncRequester<A> {
103 session_id: u128,
104 storage_id: GraphId,
105 state: SyncRequesterState,
106 max_bytes: u64,
107 next_message_index: u64,
108 server_address: A,
109}
110
111impl<A: DeserializeOwned + Serialize + Clone> SyncRequester<A> {
112 pub fn new<R: Csprng>(storage_id: GraphId, rng: &mut R, server_address: A) -> Self {
114 let mut dst = [0u8; 16];
116 rng.fill_bytes(&mut dst);
117 let session_id = u128::from_le_bytes(dst);
118
119 Self {
120 session_id,
121 storage_id,
122 state: SyncRequesterState::New,
123 max_bytes: 0,
124 next_message_index: 0,
125 server_address,
126 }
127 }
128
129 pub fn new_session_id(storage_id: GraphId, session_id: u128, server_address: A) -> Self {
131 Self {
132 session_id,
133 storage_id,
134 state: SyncRequesterState::Waiting,
135 max_bytes: 0,
136 next_message_index: 0,
137 server_address,
138 }
139 }
140
141 pub fn server_addr(&self) -> A {
143 self.server_address.clone()
144 }
145
146 pub fn ready(&self) -> bool {
148 use SyncRequesterState as S;
149 match self.state {
150 S::New | S::Resync | S::Reset => true,
151 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
152 }
153 }
154
155 pub fn poll(
158 &mut self,
159 target: &mut [u8],
160 provider: &mut impl StorageProvider,
161 heads: &mut PeerCache,
162 ) -> Result<(usize, usize), SyncError> {
163 use SyncRequesterState as S;
164 let result = match self.state {
165 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
166 return Err(SyncError::NotReady);
167 }
168 S::New => {
169 self.state = S::Start;
170 self.start(self.max_bytes, target, provider, heads)?
171 }
172 S::Resync => self.resume(self.max_bytes, target)?,
173 S::Reset => {
174 self.state = S::Closed;
175 self.end_session(target)?
176 }
177 };
178
179 Ok(result)
180 }
181
182 pub fn receive<'a>(
184 &mut self,
185 data: &'a [u8],
186 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
187 let (message, remaining): (SyncResponseMessage, &'a [u8]) =
188 postcard::take_from_bytes(data)?;
189
190 self.get_sync_commands(message, remaining)
191 }
192
193 pub fn get_sync_commands<'a>(
195 &mut self,
196 message: SyncResponseMessage,
197 remaining: &'a [u8],
198 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
199 if message.session_id() != self.session_id {
200 return Err(SyncError::SessionMismatch);
201 }
202
203 let result = match message {
204 SyncResponseMessage::SyncResponse {
205 response_index,
206 commands,
207 ..
208 } => {
209 if !matches!(
210 self.state,
211 SyncRequesterState::Start | SyncRequesterState::Waiting
212 ) {
213 return Err(SyncError::SessionState);
214 }
215
216 if response_index != self.next_message_index {
217 self.state = SyncRequesterState::Resync;
218 return Err(SyncError::MissingSyncResponse);
219 }
220 self.next_message_index = self
221 .next_message_index
222 .checked_add(1)
223 .assume("next_message_index + 1 mustn't overflow")?;
224 self.state = SyncRequesterState::Waiting;
225
226 let mut result = Vec::new();
227 let mut start: usize = 0;
228 for meta in commands {
229 let policy_len = meta.policy_length as usize;
230
231 let policy = match policy_len == 0 {
232 true => None,
233 false => {
234 let end = start
235 .checked_add(policy_len)
236 .assume("start + policy_len mustn't overflow")?;
237 let policy = &remaining[start..end];
238 start = end;
239 Some(policy)
240 }
241 };
242
243 let len = meta.length as usize;
244 let end = start
245 .checked_add(len)
246 .assume("start + len mustn't overflow")?;
247 let payload = &remaining[start..end];
248 start = end;
249
250 let command = SyncCommand {
251 id: meta.id,
252 priority: meta.priority,
253 parent: meta.parent,
254 policy,
255 data: payload,
256 max_cut: meta.max_cut,
257 };
258
259 result
260 .push(command)
261 .ok()
262 .assume("commands is not larger than result")?;
263 }
264
265 Some(result)
266 }
267
268 SyncResponseMessage::SyncEnd { max_index, .. } => {
269 if !matches!(
270 self.state,
271 SyncRequesterState::Start | SyncRequesterState::Waiting
272 ) {
273 return Err(SyncError::SessionState);
274 }
275
276 if max_index != self.next_message_index {
277 self.state = SyncRequesterState::Resync;
278 return Err(SyncError::MissingSyncResponse);
279 }
280
281 self.state = SyncRequesterState::PartialSync;
282
283 None
284 }
285
286 SyncResponseMessage::Offer { .. } => {
287 if self.state != SyncRequesterState::Idle {
288 return Err(SyncError::SessionState);
289 }
290 self.state = SyncRequesterState::Resync;
291
292 None
293 }
294
295 SyncResponseMessage::EndSession { .. } => {
296 self.state = SyncRequesterState::Closed;
297 None
298 }
299 };
300
301 Ok(result)
302 }
303
304 fn write(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
305 Ok(postcard::to_slice(&msg, target)?.len())
306 }
307
308 fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
309 Ok((
310 Self::write(
311 target,
312 SyncType::Poll {
313 request: SyncRequestMessage::EndSession {
314 session_id: self.session_id,
315 },
316 address: self.server_address.clone(),
317 },
318 )?,
319 0,
320 ))
321 }
322
323 fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
324 if !matches!(
325 self.state,
326 SyncRequesterState::Resync | SyncRequesterState::Idle
327 ) {
328 return Err(SyncError::SessionState);
329 }
330
331 self.state = SyncRequesterState::Waiting;
332 let message = SyncType::Poll {
333 request: SyncRequestMessage::SyncResume {
334 session_id: self.session_id,
335 response_index: self
336 .next_message_index
337 .checked_sub(1)
338 .assume("next_index must be positive")?,
339 max_bytes,
340 },
341 address: self.server_address.clone(),
342 };
343
344 Ok((Self::write(target, message)?, 0))
345 }
346
347 fn get_commands(
348 &self,
349 provider: &mut impl StorageProvider,
350 peer_cache: &mut PeerCache,
351 ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
352 let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
353
354 match provider.get_storage(self.storage_id) {
355 Err(StorageError::NoSuchStorage) => (),
356 Err(err) => {
357 return Err(SyncError::Storage(err));
358 }
359 Ok(storage) => {
360 let mut command_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
361 for address in peer_cache.heads() {
362 command_locations
363 .push(
364 storage
365 .get_location(*address)?
366 .assume("location must exist")?,
367 )
368 .ok()
369 .assume("command locations should not be full")?;
370 if commands.len() < COMMAND_SAMPLE_MAX {
371 commands
372 .push(*address)
373 .map_err(|_| SyncError::CommandOverflow)?;
374 }
375 }
376 let head = storage.get_head()?;
377
378 let mut current = vec![head];
379
380 while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
385 let mut next = vec::Vec::new(); 'current: for &location in ¤t {
388 let segment = storage.get_segment(location)?;
389
390 let head = segment.head()?;
391 let head_address = head.address()?;
392 for loc in &command_locations {
393 if loc.segment == location.segment {
394 continue 'current;
395 }
396 }
397 commands
399 .push(head_address)
400 .map_err(|_| SyncError::CommandOverflow)?;
401 next.extend(segment.prior());
402 if commands.len() >= COMMAND_SAMPLE_MAX {
403 break 'current;
404 }
405 }
406
407 current.clone_from(&next);
408 }
409 }
410 }
411 Ok(commands)
412 }
413
414 pub fn subscribe(
416 &mut self,
417 target: &mut [u8],
418 provider: &mut impl StorageProvider,
419 heads: &mut PeerCache,
420 remain_open: u64,
421 max_bytes: u64,
422 ) -> Result<usize, SyncError> {
423 let commands = self.get_commands(provider, heads)?;
424 let message = SyncType::Subscribe {
425 remain_open,
426 max_bytes,
427 commands,
428 address: self.server_address.clone(),
429 storage_id: self.storage_id,
430 };
431
432 Self::write(target, message)
433 }
434
435 pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
437 let message = SyncType::Unsubscribe {
438 address: self.server_address.clone(),
439 };
440
441 Self::write(target, message)
442 }
443
444 fn start(
445 &mut self,
446 max_bytes: u64,
447 target: &mut [u8],
448 provider: &mut impl StorageProvider,
449 heads: &mut PeerCache,
450 ) -> Result<(usize, usize), SyncError> {
451 if !matches!(
452 self.state,
453 SyncRequesterState::Start | SyncRequesterState::New
454 ) {
455 self.state = SyncRequesterState::Reset;
456 return Err(SyncError::SessionState);
457 }
458
459 self.state = SyncRequesterState::Start;
460 self.max_bytes = max_bytes;
461
462 let command_sample = self.get_commands(provider, heads)?;
463
464 let sent = command_sample.len();
465 let message = SyncType::Poll {
466 request: SyncRequestMessage::SyncRequest {
467 session_id: self.session_id,
468 storage_id: self.storage_id,
469 max_bytes,
470 commands: command_sample,
471 },
472 address: self.server_address.clone(),
473 };
474
475 Ok((Self::write(target, message)?, sent))
476 }
477}