1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt as _;
5use heapless::Vec;
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, PeerCache, REQUEST_MISSING_MAX,
10 SyncCommand, SyncError, dispatcher::SyncType, responder::SyncResponseMessage,
11};
12use crate::{
13 Address, GraphId, Location,
14 storage::{Segment as _, Storage as _, StorageError, StorageProvider},
15};
16
17#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26 SyncRequest {
28 session_id: u128,
30 graph_id: GraphId,
32 max_bytes: u64,
35 commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40 },
41
42 RequestMissing {
45 session_id: u128,
48 indexes: Vec<u64, REQUEST_MISSING_MAX>,
50 },
51
52 SyncResume {
56 session_id: u128,
59 response_index: u64,
61 max_bytes: u64,
64 },
65
66 EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72 pub fn session_id(&self) -> u128 {
73 match self {
74 Self::SyncRequest { session_id, .. } => *session_id,
75 Self::RequestMissing { session_id, .. } => *session_id,
76 Self::SyncResume { session_id, .. } => *session_id,
77 Self::EndSession { session_id, .. } => *session_id,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83enum SyncRequesterState {
84 New,
86 Start,
88 Waiting,
90 Idle,
92 Closed,
94 Resync,
96 PartialSync,
98 Reset,
100}
101
102pub struct SyncRequester {
103 session_id: u128,
104 graph_id: GraphId,
105 state: SyncRequesterState,
106 max_bytes: u64,
107 next_message_index: u64,
108}
109
110impl SyncRequester {
111 pub fn new<R: Csprng>(graph_id: GraphId, rng: R) -> Self {
113 let mut dst = [0u8; 16];
115 rng.fill_bytes(&mut dst);
116 let session_id = u128::from_le_bytes(dst);
117
118 Self {
119 session_id,
120 graph_id,
121 state: SyncRequesterState::New,
122 max_bytes: 0,
123 next_message_index: 0,
124 }
125 }
126
127 pub fn new_session_id(graph_id: GraphId, session_id: u128) -> Self {
129 Self {
130 session_id,
131 graph_id,
132 state: SyncRequesterState::Waiting,
133 max_bytes: 0,
134 next_message_index: 0,
135 }
136 }
137
138 pub fn ready(&self) -> bool {
140 use SyncRequesterState as S;
141 match self.state {
142 S::New | S::Resync | S::Reset => true,
143 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
144 }
145 }
146
147 pub fn poll(
150 &mut self,
151 target: &mut [u8],
152 provider: &mut impl StorageProvider,
153 heads: &mut PeerCache,
154 ) -> Result<(usize, usize), SyncError> {
155 use SyncRequesterState as S;
156 let result = match self.state {
157 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
158 return Err(SyncError::NotReady);
159 }
160 S::New => {
161 self.state = S::Start;
162 self.start(self.max_bytes, target, provider, heads)?
163 }
164 S::Resync => self.resume(self.max_bytes, target)?,
165 S::Reset => {
166 self.state = S::Closed;
167 self.end_session(target)?
168 }
169 };
170
171 Ok(result)
172 }
173
174 pub fn receive<'a>(
176 &mut self,
177 data: &'a [u8],
178 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
179 let (message, remaining): (SyncResponseMessage, &'a [u8]) =
180 postcard::take_from_bytes(data)?;
181
182 self.get_sync_commands(message, remaining)
183 }
184
185 pub fn get_sync_commands<'a>(
187 &mut self,
188 message: SyncResponseMessage,
189 remaining: &'a [u8],
190 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
191 if message.session_id() != self.session_id {
192 return Err(SyncError::SessionMismatch);
193 }
194
195 let result = match message {
196 SyncResponseMessage::SyncResponse {
197 response_index,
198 commands,
199 ..
200 } => {
201 if !matches!(
202 self.state,
203 SyncRequesterState::Start | SyncRequesterState::Waiting
204 ) {
205 return Err(SyncError::SessionState);
206 }
207
208 if response_index != self.next_message_index {
209 self.state = SyncRequesterState::Resync;
210 return Err(SyncError::MissingSyncResponse);
211 }
212 self.next_message_index = self
213 .next_message_index
214 .checked_add(1)
215 .assume("next_message_index + 1 mustn't overflow")?;
216 self.state = SyncRequesterState::Waiting;
217
218 let mut result = Vec::new();
219 let mut start: usize = 0;
220 for meta in commands {
221 let policy_len = meta.policy_length as usize;
222
223 let policy = match policy_len == 0 {
224 true => None,
225 false => {
226 let end = start
227 .checked_add(policy_len)
228 .assume("start + policy_len mustn't overflow")?;
229 let policy = &remaining[start..end];
230 start = end;
231 Some(policy)
232 }
233 };
234
235 let len = meta.length as usize;
236 let end = start
237 .checked_add(len)
238 .assume("start + len mustn't overflow")?;
239 let payload = &remaining[start..end];
240 start = end;
241
242 let command = SyncCommand {
243 id: meta.id,
244 priority: meta.priority,
245 parent: meta.parent,
246 policy,
247 data: payload,
248 max_cut: meta.max_cut,
249 };
250
251 result
252 .push(command)
253 .ok()
254 .assume("commands is not larger than result")?;
255 }
256
257 Some(result)
258 }
259
260 SyncResponseMessage::SyncEnd { max_index, .. } => {
261 if !matches!(
262 self.state,
263 SyncRequesterState::Start | SyncRequesterState::Waiting
264 ) {
265 return Err(SyncError::SessionState);
266 }
267
268 if max_index != self.next_message_index {
269 self.state = SyncRequesterState::Resync;
270 return Err(SyncError::MissingSyncResponse);
271 }
272
273 self.state = SyncRequesterState::PartialSync;
274
275 None
276 }
277
278 SyncResponseMessage::Offer { .. } => {
279 if self.state != SyncRequesterState::Idle {
280 return Err(SyncError::SessionState);
281 }
282 self.state = SyncRequesterState::Resync;
283
284 None
285 }
286
287 SyncResponseMessage::EndSession { .. } => {
288 self.state = SyncRequesterState::Closed;
289 None
290 }
291 };
292
293 Ok(result)
294 }
295
296 fn write(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
297 Ok(postcard::to_slice(&msg, target)?.len())
298 }
299
300 fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
301 Ok((
302 Self::write(
303 target,
304 SyncType::Poll {
305 request: SyncRequestMessage::EndSession {
306 session_id: self.session_id,
307 },
308 },
309 )?,
310 0,
311 ))
312 }
313
314 fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
315 if !matches!(
316 self.state,
317 SyncRequesterState::Resync | SyncRequesterState::Idle
318 ) {
319 return Err(SyncError::SessionState);
320 }
321
322 self.state = SyncRequesterState::Waiting;
323 let message = SyncType::Poll {
324 request: SyncRequestMessage::SyncResume {
325 session_id: self.session_id,
326 response_index: self
327 .next_message_index
328 .checked_sub(1)
329 .assume("next_index must be positive")?,
330 max_bytes,
331 },
332 };
333
334 Ok((Self::write(target, message)?, 0))
335 }
336
337 fn get_commands(
338 &self,
339 provider: &mut impl StorageProvider,
340 peer_cache: &mut PeerCache,
341 ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
342 let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
343
344 match provider.get_storage(self.graph_id) {
345 Err(StorageError::NoSuchStorage) => (),
346 Err(err) => {
347 return Err(SyncError::Storage(err));
348 }
349 Ok(storage) => {
350 let mut cache_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
351 for address in peer_cache.heads() {
352 let loc = storage
353 .get_location(*address)?
354 .assume("location must exist")?;
355 cache_locations
356 .push(loc)
357 .ok()
358 .assume("command locations should not be full")?;
359 if commands.len() < COMMAND_SAMPLE_MAX {
360 commands
361 .push(*address)
362 .map_err(|_| SyncError::CommandOverflow)?;
363 }
364 }
365 let head = storage.get_head()?;
367 let mut current = vec![head];
368
369 while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
374 let mut next = vec::Vec::new(); 'current: for &location in ¤t {
377 for &peer_cache_loc in &cache_locations {
379 if location == peer_cache_loc {
381 continue 'current;
382 }
383 let peer_cache_segment = storage.get_segment(peer_cache_loc)?;
386 if (peer_cache_loc.same_segment(location)
387 && location.max_cut <= peer_cache_loc.max_cut)
388 || storage.is_ancestor(location, &peer_cache_segment)?
389 {
390 continue 'current;
391 }
392 }
393
394 let segment = storage.get_segment(location)?;
395 commands
396 .push(segment.head_address()?)
397 .map_err(|_| SyncError::CommandOverflow)?;
398 next.extend(segment.prior());
399 if commands.len() >= COMMAND_SAMPLE_MAX {
400 break 'current;
401 }
402 }
403
404 current.clone_from(&next);
405 }
406 }
407 }
408
409 Ok(commands)
410 }
411
412 pub fn subscribe(
414 &mut self,
415 target: &mut [u8],
416 provider: &mut impl StorageProvider,
417 heads: &mut PeerCache,
418 remain_open: u64,
419 max_bytes: u64,
420 ) -> Result<usize, SyncError> {
421 let commands = self.get_commands(provider, heads)?;
422 let message = SyncType::Subscribe {
423 remain_open,
424 max_bytes,
425 commands,
426 graph_id: self.graph_id,
427 };
428
429 Self::write(target, message)
430 }
431
432 pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
434 let message = SyncType::Unsubscribe {
435 graph_id: self.graph_id,
436 };
437
438 Self::write(target, message)
439 }
440
441 fn start(
442 &mut self,
443 max_bytes: u64,
444 target: &mut [u8],
445 provider: &mut impl StorageProvider,
446 heads: &mut PeerCache,
447 ) -> Result<(usize, usize), SyncError> {
448 if !matches!(
449 self.state,
450 SyncRequesterState::Start | SyncRequesterState::New
451 ) {
452 self.state = SyncRequesterState::Reset;
453 return Err(SyncError::SessionState);
454 }
455
456 self.state = SyncRequesterState::Start;
457 self.max_bytes = max_bytes;
458
459 let command_sample = self.get_commands(provider, heads)?;
460
461 let sent = command_sample.len();
462 let message = SyncType::Poll {
463 request: SyncRequestMessage::SyncRequest {
464 session_id: self.session_id,
465 graph_id: self.graph_id,
466 max_bytes,
467 commands: command_sample,
468 },
469 };
470
471 Ok((Self::write(target, message)?, sent))
472 }
473}