1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10 SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13 StorageError, SyncType,
14 command::{Address, CmdId, Command as _},
15 storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub const fn new() -> Self {
25 Self { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42
43 let mut retain_head = |request_head: &Address, new_head: Location| {
44 let new_head_seg = storage.get_segment(new_head)?;
45 let req_head_loc = storage
46 .get_location(*request_head)?
47 .assume("location must exist")?;
48 let req_head_seg = storage.get_segment(req_head_loc)?;
49 if request_head.id
50 == new_head_seg
51 .get_command(new_head)
52 .assume("location must exist")?
53 .address()?
54 .id
55 {
56 add_command = false;
57 }
58 if (new_head.same_segment(req_head_loc) && new_head.command <= req_head_loc.command)
60 || storage.is_ancestor(new_head, &req_head_seg)?
61 {
62 add_command = false;
63 }
64 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65 };
66 self.heads
67 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68 if add_command && !self.heads.is_full() {
69 self.heads
70 .push(command)
71 .ok()
72 .assume("command locations should not be full")?;
73 }
74
75 Ok(())
76 }
77}
78
79#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88 SyncResponse {
90 session_id: u128,
92 response_index: u64,
97 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99 },
100
101 SyncEnd {
104 session_id: u128,
107 max_index: u64,
109 remaining: bool,
112 },
113
114 Offer {
119 session_id: u128,
122 head: CmdId,
124 },
125
126 EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132 pub fn session_id(&self) -> u128 {
133 match self {
134 Self::SyncResponse { session_id, .. } => *session_id,
135 Self::SyncEnd { session_id, .. } => *session_id,
136 Self::Offer { session_id, .. } => *session_id,
137 Self::EndSession { session_id, .. } => *session_id,
138 }
139 }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144 #[default]
145 New,
146 Start,
147 Send,
148 Idle,
149 Reset,
150 Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder<A> {
155 session_id: Option<u128>,
156 storage_id: Option<GraphId>,
157 state: SyncResponderState,
158 bytes_sent: u64,
159 next_send: usize,
160 message_index: usize,
161 has: Vec<Address, COMMAND_SAMPLE_MAX>,
162 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163 server_address: A,
164}
165
166impl<A: Serialize + Clone> SyncResponder<A> {
167 pub fn new(server_address: A) -> Self {
169 Self {
170 session_id: None,
171 storage_id: None,
172 state: SyncResponderState::New,
173 bytes_sent: 0,
174 next_send: 0,
175 message_index: 0,
176 has: Vec::new(),
177 to_send: Vec::new(),
178 server_address,
179 }
180 }
181
182 pub fn ready(&self) -> bool {
184 use SyncResponderState::*;
185 match self.state {
186 Reset | Start | Send => true, New | Idle | Stopped => false,
188 }
189 }
190
191 pub fn poll(
194 &mut self,
195 target: &mut [u8],
196 provider: &mut impl StorageProvider,
197 response_cache: &mut PeerCache,
198 ) -> Result<usize, SyncError> {
199 use SyncResponderState as S;
201 let length = match self.state {
202 S::New | S::Idle | S::Stopped => {
203 return Err(SyncError::NotReady); }
205 S::Start => {
206 let Some(storage_id) = self.storage_id else {
207 self.state = S::Reset;
208 bug!("poll called before storage_id was set");
209 };
210
211 let storage = match provider.get_storage(storage_id) {
212 Ok(s) => s,
213 Err(e) => {
214 self.state = S::Reset;
215 return Err(e.into());
216 }
217 };
218
219 self.state = S::Send;
220 for command in &self.has {
221 if let Some(cmd_loc) = storage.get_location(*command)? {
223 response_cache.add_command(storage, *command, cmd_loc)?;
224 }
225 }
226 self.to_send = Self::find_needed_segments(&self.has, storage)?;
227
228 self.get_next(target, provider)?
229 }
230 S::Send => self.get_next(target, provider)?,
231 S::Reset => {
232 self.state = S::Stopped;
233 let message = SyncResponseMessage::EndSession {
234 session_id: self.session_id()?,
235 };
236 Self::write(target, message)?
237 }
238 };
239
240 Ok(length)
241 }
242
243 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
245 if self.session_id.is_none() {
246 self.session_id = Some(message.session_id());
247 }
248 if self.session_id != Some(message.session_id()) {
249 return Err(SyncError::SessionMismatch);
250 }
251
252 match message {
253 SyncRequestMessage::SyncRequest {
254 storage_id,
255 max_bytes,
256 commands,
257 ..
258 } => {
259 self.state = SyncResponderState::Start;
260 self.storage_id = Some(storage_id);
261 self.bytes_sent = max_bytes;
262 self.to_send = Vec::new();
263 self.has = commands;
264 self.next_send = 0;
265 return Ok(());
266 }
267 SyncRequestMessage::RequestMissing { .. } => {
268 todo!()
269 }
270 SyncRequestMessage::SyncResume { .. } => {
271 todo!()
272 }
273 SyncRequestMessage::EndSession { .. } => {
274 self.state = SyncResponderState::Stopped;
275 }
276 }
277
278 Ok(())
279 }
280
281 fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
282 Ok(postcard::to_slice(&msg, target)?.len())
283 }
284
285 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
286 Ok(postcard::to_slice(&msg, target)?.len())
287 }
288
289 fn find_needed_segments(
293 commands: &[Address],
294 storage: &impl Storage,
295 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
296 let mut have_locations = vec::Vec::new(); for &addr in commands {
298 if let Some(location) = storage.get_location(addr)? {
301 have_locations.push(location);
302 }
303 }
304
305 for i in (0..have_locations.len()).rev() {
310 let location_a = have_locations[i];
311 let mut is_ancestor_of_other = false;
312 for &location_b in &have_locations {
313 if location_a != location_b {
314 let segment_b = storage.get_segment(location_b)?;
315 if location_a.same_segment(location_b)
316 && location_a.command <= location_b.command
317 || storage.is_ancestor(location_a, &segment_b)?
318 {
319 is_ancestor_of_other = true;
320 break;
321 }
322 }
323 }
324 if is_ancestor_of_other {
325 have_locations.remove(i);
326 }
327 }
328
329 let mut heads = vec::Vec::new();
330 heads.push(storage.get_head()?);
331
332 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
333
334 while !heads.is_empty() {
335 let current = mem::take(&mut heads);
336 'heads: for head in current {
337 let segment = storage.get_segment(head)?;
338 if segment.contains_any(&result) {
340 continue 'heads;
341 }
342
343 for &have_location in &have_locations {
347 let have_segment = storage.get_segment(have_location)?;
348 if storage.is_ancestor(head, &have_segment)? {
349 continue 'heads;
350 }
351 }
352
353 if let Some(latest_loc) = have_locations
355 .iter()
356 .filter(|&&location| segment.contains(location))
357 .max_by_key(|&&location| location.command)
358 {
359 let next_command = latest_loc
360 .command
361 .checked_add(1)
362 .assume("command + 1 mustn't overflow")?;
363 let next_location = Location {
364 segment: head.segment,
365 command: next_command,
366 };
367
368 let head_loc = segment.head_location();
369 if next_location.command > head_loc.command {
370 continue 'heads;
371 }
372 if result.is_full() {
373 result.pop_back();
374 }
375 result
376 .push_front(next_location)
377 .ok()
378 .assume("too many segments")?;
379 continue 'heads;
380 }
381
382 heads.extend(segment.prior());
383
384 if result.is_full() {
385 result.pop_back();
386 }
387
388 let location = segment.first_location();
389 result
390 .push_front(location)
391 .ok()
392 .assume("too many segments")?;
393 }
394 }
395 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
396 for l in result {
397 r.push(l).ok().assume("too many segments")?;
398 }
399 r.sort();
402 Ok(r)
403 }
404
405 fn get_next(
406 &mut self,
407 target: &mut [u8],
408 provider: &mut impl StorageProvider,
409 ) -> Result<usize, SyncError> {
410 if self.next_send >= self.to_send.len() {
411 self.state = SyncResponderState::Idle;
412 let message = SyncResponseMessage::SyncEnd {
413 session_id: self.session_id()?,
414 max_index: self.message_index as u64,
415 remaining: false,
416 };
417 let length = Self::write(target, message)?;
418 return Ok(length);
419 }
420
421 let (commands, command_data, next_send) = self.get_commands(provider)?;
422
423 let message = SyncResponseMessage::SyncResponse {
424 session_id: self.session_id()?,
425 response_index: self.message_index as u64,
426 commands,
427 };
428 self.message_index = self
429 .message_index
430 .checked_add(1)
431 .assume("message_index overflow")?;
432 self.next_send = next_send;
433
434 let length = Self::write(target, message)?;
435 let total_length = length
436 .checked_add(command_data.len())
437 .assume("length + command_data_length mustn't overflow")?;
438 target
439 .get_mut(length..total_length)
440 .assume("sync message fits in target")?
441 .copy_from_slice(&command_data);
442 Ok(total_length)
443 }
444
445 pub fn push(
448 &mut self,
449 target: &mut [u8],
450 provider: &mut impl StorageProvider,
451 ) -> Result<usize, SyncError> {
452 use SyncResponderState as S;
453 let Some(storage_id) = self.storage_id else {
454 self.state = S::Reset;
455 bug!("poll called before storage_id was set");
456 };
457
458 let storage = match provider.get_storage(storage_id) {
459 Ok(s) => s,
460 Err(e) => {
461 self.state = S::Reset;
462 return Err(e.into());
463 }
464 };
465 self.to_send = Self::find_needed_segments(&self.has, storage)?;
466 let (commands, command_data, next_send) = self.get_commands(provider)?;
467 let mut length = 0;
468 if !commands.is_empty() {
469 let message = SyncType::Push {
470 message: SyncResponseMessage::SyncResponse {
471 session_id: self.session_id()?,
472 response_index: self.message_index as u64,
473 commands,
474 },
475 storage_id: self.storage_id.assume("storage id must exist")?,
476 address: self.server_address.clone(),
477 };
478 self.message_index = self
479 .message_index
480 .checked_add(1)
481 .assume("message_index increment overflow")?;
482 self.next_send = next_send;
483
484 length = Self::write_sync_type(target, message)?;
485 let total_length = length
486 .checked_add(command_data.len())
487 .assume("length + command_data_length mustn't overflow")?;
488 target
489 .get_mut(length..total_length)
490 .assume("sync message fits in target")?
491 .copy_from_slice(&command_data);
492 length = total_length;
493 }
494 Ok(length)
495 }
496
497 fn get_commands(
498 &mut self,
499 provider: &mut impl StorageProvider,
500 ) -> Result<
501 (
502 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
503 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
504 usize,
505 ),
506 SyncError,
507 > {
508 let Some(storage_id) = self.storage_id.as_ref() else {
509 self.state = SyncResponderState::Reset;
510 bug!("get_next called before storage_id was set");
511 };
512 let storage = match provider.get_storage(*storage_id) {
513 Ok(s) => s,
514 Err(e) => {
515 self.state = SyncResponderState::Reset;
516 return Err(e.into());
517 }
518 };
519 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
520 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
521 let mut index = self.next_send;
522 for i in self.next_send..self.to_send.len() {
523 if commands.is_full() {
524 break;
525 }
526 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
527 let Some(&location) = self.to_send.get(i) else {
528 self.state = SyncResponderState::Reset;
529 bug!("send index OOB");
530 };
531
532 let segment = storage
533 .get_segment(location)
534 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
535
536 let found = segment.get_from(location);
537
538 for command in &found {
539 let mut policy_length = 0;
540
541 if let Some(policy) = command.policy() {
542 policy_length = policy.len();
543 command_data
544 .extend_from_slice(policy)
545 .ok()
546 .assume("command_data is too large")?;
547 }
548
549 let bytes = command.bytes();
550 command_data
551 .extend_from_slice(bytes)
552 .ok()
553 .assume("command_data is too large")?;
554
555 let max_cut = command.max_cut()?;
556 let meta = CommandMeta {
557 id: command.id(),
558 priority: command.priority(),
559 parent: command.parent(),
560 policy_length: policy_length as u32,
561 length: bytes.len() as u32,
562 max_cut,
563 };
564
565 commands
567 .push(meta)
568 .ok()
569 .assume("too many commands in segment")?;
570 if commands.is_full() {
571 break;
572 }
573 }
574 }
575 Ok((commands, command_data, index))
576 }
577
578 fn session_id(&self) -> Result<u128, SyncError> {
579 Ok(self.session_id.assume("session id is set")?)
580 }
581}