1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt as _;
5use heapless::Vec;
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, PeerCache, REQUEST_MISSING_MAX,
10 SyncCommand, SyncError, dispatcher::SyncType, responder::SyncResponseMessage,
11};
12use crate::{
13 Address, Command as _, GraphId, Location,
14 storage::{Segment as _, Storage as _, StorageError, StorageProvider},
15};
16
17#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26 SyncRequest {
28 session_id: u128,
30 graph_id: GraphId,
32 max_bytes: u64,
35 commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40 },
41
42 RequestMissing {
45 session_id: u128,
48 indexes: Vec<u64, REQUEST_MISSING_MAX>,
50 },
51
52 SyncResume {
56 session_id: u128,
59 response_index: u64,
61 max_bytes: u64,
64 },
65
66 EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72 pub fn session_id(&self) -> u128 {
73 match self {
74 Self::SyncRequest { session_id, .. } => *session_id,
75 Self::RequestMissing { session_id, .. } => *session_id,
76 Self::SyncResume { session_id, .. } => *session_id,
77 Self::EndSession { session_id, .. } => *session_id,
78 }
79 }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83enum SyncRequesterState {
84 New,
86 Start,
88 Waiting,
90 Idle,
92 Closed,
94 Resync,
96 PartialSync,
98 Reset,
100}
101
102pub struct SyncRequester {
103 session_id: u128,
104 graph_id: GraphId,
105 state: SyncRequesterState,
106 max_bytes: u64,
107 next_message_index: u64,
108}
109
110impl SyncRequester {
111 pub fn new<R: Csprng>(graph_id: GraphId, rng: &mut R) -> Self {
113 let mut dst = [0u8; 16];
115 rng.fill_bytes(&mut dst);
116 let session_id = u128::from_le_bytes(dst);
117
118 Self {
119 session_id,
120 graph_id,
121 state: SyncRequesterState::New,
122 max_bytes: 0,
123 next_message_index: 0,
124 }
125 }
126
127 pub fn new_session_id(graph_id: GraphId, session_id: u128) -> Self {
129 Self {
130 session_id,
131 graph_id,
132 state: SyncRequesterState::Waiting,
133 max_bytes: 0,
134 next_message_index: 0,
135 }
136 }
137
138 pub fn ready(&self) -> bool {
140 use SyncRequesterState as S;
141 match self.state {
142 S::New | S::Resync | S::Reset => true,
143 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
144 }
145 }
146
147 pub fn poll(
150 &mut self,
151 target: &mut [u8],
152 provider: &mut impl StorageProvider,
153 heads: &mut PeerCache,
154 ) -> Result<(usize, usize), SyncError> {
155 use SyncRequesterState as S;
156 let result = match self.state {
157 S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
158 return Err(SyncError::NotReady);
159 }
160 S::New => {
161 self.state = S::Start;
162 self.start(self.max_bytes, target, provider, heads)?
163 }
164 S::Resync => self.resume(self.max_bytes, target)?,
165 S::Reset => {
166 self.state = S::Closed;
167 self.end_session(target)?
168 }
169 };
170
171 Ok(result)
172 }
173
174 pub fn receive<'a>(
176 &mut self,
177 data: &'a [u8],
178 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
179 let (message, remaining): (SyncResponseMessage, &'a [u8]) =
180 postcard::take_from_bytes(data)?;
181
182 self.get_sync_commands(message, remaining)
183 }
184
185 pub fn get_sync_commands<'a>(
187 &mut self,
188 message: SyncResponseMessage,
189 remaining: &'a [u8],
190 ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
191 if message.session_id() != self.session_id {
192 return Err(SyncError::SessionMismatch);
193 }
194
195 let result = match message {
196 SyncResponseMessage::SyncResponse {
197 response_index,
198 commands,
199 ..
200 } => {
201 if !matches!(
202 self.state,
203 SyncRequesterState::Start | SyncRequesterState::Waiting
204 ) {
205 return Err(SyncError::SessionState);
206 }
207
208 if response_index != self.next_message_index {
209 self.state = SyncRequesterState::Resync;
210 return Err(SyncError::MissingSyncResponse);
211 }
212 self.next_message_index = self
213 .next_message_index
214 .checked_add(1)
215 .assume("next_message_index + 1 mustn't overflow")?;
216 self.state = SyncRequesterState::Waiting;
217
218 let mut result = Vec::new();
219 let mut start: usize = 0;
220 for meta in commands {
221 let policy_len = meta.policy_length as usize;
222
223 let policy = match policy_len == 0 {
224 true => None,
225 false => {
226 let end = start
227 .checked_add(policy_len)
228 .assume("start + policy_len mustn't overflow")?;
229 let policy = &remaining[start..end];
230 start = end;
231 Some(policy)
232 }
233 };
234
235 let len = meta.length as usize;
236 let end = start
237 .checked_add(len)
238 .assume("start + len mustn't overflow")?;
239 let payload = &remaining[start..end];
240 start = end;
241
242 let command = SyncCommand {
243 id: meta.id,
244 priority: meta.priority,
245 parent: meta.parent,
246 policy,
247 data: payload,
248 max_cut: meta.max_cut,
249 };
250
251 result
252 .push(command)
253 .ok()
254 .assume("commands is not larger than result")?;
255 }
256
257 Some(result)
258 }
259
260 SyncResponseMessage::SyncEnd { max_index, .. } => {
261 if !matches!(
262 self.state,
263 SyncRequesterState::Start | SyncRequesterState::Waiting
264 ) {
265 return Err(SyncError::SessionState);
266 }
267
268 if max_index != self.next_message_index {
269 self.state = SyncRequesterState::Resync;
270 return Err(SyncError::MissingSyncResponse);
271 }
272
273 self.state = SyncRequesterState::PartialSync;
274
275 None
276 }
277
278 SyncResponseMessage::Offer { .. } => {
279 if self.state != SyncRequesterState::Idle {
280 return Err(SyncError::SessionState);
281 }
282 self.state = SyncRequesterState::Resync;
283
284 None
285 }
286
287 SyncResponseMessage::EndSession { .. } => {
288 self.state = SyncRequesterState::Closed;
289 None
290 }
291 };
292
293 Ok(result)
294 }
295
296 fn write(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
297 Ok(postcard::to_slice(&msg, target)?.len())
298 }
299
300 fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
301 Ok((
302 Self::write(
303 target,
304 SyncType::Poll {
305 request: SyncRequestMessage::EndSession {
306 session_id: self.session_id,
307 },
308 },
309 )?,
310 0,
311 ))
312 }
313
314 fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
315 if !matches!(
316 self.state,
317 SyncRequesterState::Resync | SyncRequesterState::Idle
318 ) {
319 return Err(SyncError::SessionState);
320 }
321
322 self.state = SyncRequesterState::Waiting;
323 let message = SyncType::Poll {
324 request: SyncRequestMessage::SyncResume {
325 session_id: self.session_id,
326 response_index: self
327 .next_message_index
328 .checked_sub(1)
329 .assume("next_index must be positive")?,
330 max_bytes,
331 },
332 };
333
334 Ok((Self::write(target, message)?, 0))
335 }
336
337 fn get_commands(
338 &self,
339 provider: &mut impl StorageProvider,
340 peer_cache: &mut PeerCache,
341 ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
342 let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
343
344 match provider.get_storage(self.graph_id) {
345 Err(StorageError::NoSuchStorage) => (),
346 Err(err) => {
347 return Err(SyncError::Storage(err));
348 }
349 Ok(storage) => {
350 let mut cache_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
351 for address in peer_cache.heads() {
352 let loc = storage
353 .get_location(*address)?
354 .assume("location must exist")?;
355 cache_locations
356 .push(loc)
357 .ok()
358 .assume("command locations should not be full")?;
359 if commands.len() < COMMAND_SAMPLE_MAX {
360 commands
361 .push(*address)
362 .map_err(|_| SyncError::CommandOverflow)?;
363 }
364 }
365 let head = storage.get_head()?;
367 let mut current = vec![head];
368
369 while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
374 let mut next = vec::Vec::new(); 'current: for &location in ¤t {
377 let segment = storage.get_segment(location)?;
378
379 let head = segment.head()?;
380 let head_address = head.address()?;
381
382 for &peer_cache_loc in &cache_locations {
384 if location == peer_cache_loc {
386 continue 'current;
387 }
388 let peer_cache_segment = storage.get_segment(peer_cache_loc)?;
391 if (peer_cache_loc.same_segment(location)
392 && location.command <= peer_cache_loc.command)
393 || storage.is_ancestor(location, &peer_cache_segment)?
394 {
395 continue 'current;
396 }
397 }
398
399 commands
400 .push(head_address)
401 .map_err(|_| SyncError::CommandOverflow)?;
402 next.extend(segment.prior());
403 if commands.len() >= COMMAND_SAMPLE_MAX {
404 break 'current;
405 }
406 }
407
408 current.clone_from(&next);
409 }
410 }
411 }
412
413 Ok(commands)
414 }
415
416 pub fn subscribe(
418 &mut self,
419 target: &mut [u8],
420 provider: &mut impl StorageProvider,
421 heads: &mut PeerCache,
422 remain_open: u64,
423 max_bytes: u64,
424 ) -> Result<usize, SyncError> {
425 let commands = self.get_commands(provider, heads)?;
426 let message = SyncType::Subscribe {
427 remain_open,
428 max_bytes,
429 commands,
430 graph_id: self.graph_id,
431 };
432
433 Self::write(target, message)
434 }
435
436 pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
438 let message = SyncType::Unsubscribe {
439 graph_id: self.graph_id,
440 };
441
442 Self::write(target, message)
443 }
444
445 fn start(
446 &mut self,
447 max_bytes: u64,
448 target: &mut [u8],
449 provider: &mut impl StorageProvider,
450 heads: &mut PeerCache,
451 ) -> Result<(usize, usize), SyncError> {
452 if !matches!(
453 self.state,
454 SyncRequesterState::Start | SyncRequesterState::New
455 ) {
456 self.state = SyncRequesterState::Reset;
457 return Err(SyncError::SessionState);
458 }
459
460 self.state = SyncRequesterState::Start;
461 self.max_bytes = max_bytes;
462
463 let command_sample = self.get_commands(provider, heads)?;
464
465 let sent = command_sample.len();
466 let message = SyncType::Poll {
467 request: SyncRequestMessage::SyncRequest {
468 session_id: self.session_id,
469 graph_id: self.graph_id,
470 max_bytes,
471 commands: command_sample,
472 },
473 };
474
475 Ok((Self::write(target, message)?, sent))
476 }
477}