1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10 SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13 StorageError, SyncType,
14 command::{Address, CmdId, Command as _},
15 storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub const fn new() -> Self {
25 Self { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42
43 let mut retain_head = |request_head: &Address, new_head: Location| {
44 let new_head_seg = storage.get_segment(new_head)?;
45 let req_head_loc = storage
46 .get_location(*request_head)?
47 .assume("location must exist")?;
48 let req_head_seg = storage.get_segment(req_head_loc)?;
49 if request_head.id
50 == new_head_seg
51 .get_command(new_head)
52 .assume("location must exist")?
53 .address()?
54 .id
55 {
56 add_command = false;
57 }
58 if (new_head.same_segment(req_head_loc) && new_head.max_cut <= req_head_loc.max_cut)
60 || storage.is_ancestor(new_head, &req_head_seg)?
61 {
62 add_command = false;
63 }
64 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65 };
66 self.heads
67 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68 if add_command && !self.heads.is_full() {
69 self.heads
70 .push(command)
71 .ok()
72 .assume("command locations should not be full")?;
73 }
74
75 Ok(())
76 }
77}
78
79#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88 SyncResponse {
90 session_id: u128,
92 response_index: u64,
97 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99 },
100
101 SyncEnd {
104 session_id: u128,
107 max_index: u64,
109 remaining: bool,
112 },
113
114 Offer {
119 session_id: u128,
122 head: CmdId,
124 },
125
126 EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132 pub fn session_id(&self) -> u128 {
133 match self {
134 Self::SyncResponse { session_id, .. } => *session_id,
135 Self::SyncEnd { session_id, .. } => *session_id,
136 Self::Offer { session_id, .. } => *session_id,
137 Self::EndSession { session_id, .. } => *session_id,
138 }
139 }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144 #[default]
145 New,
146 Start,
147 Send,
148 Idle,
149 Reset,
150 Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder {
155 session_id: Option<u128>,
156 graph_id: Option<GraphId>,
157 state: SyncResponderState,
158 bytes_sent: u64,
159 next_send: usize,
160 message_index: usize,
161 has: Vec<Address, COMMAND_SAMPLE_MAX>,
162 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163}
164
165impl SyncResponder {
166 pub fn new() -> Self {
168 Self {
169 session_id: None,
170 graph_id: None,
171 state: SyncResponderState::New,
172 bytes_sent: 0,
173 next_send: 0,
174 message_index: 0,
175 has: Vec::new(),
176 to_send: Vec::new(),
177 }
178 }
179
180 pub fn ready(&self) -> bool {
182 use SyncResponderState::*;
183 match self.state {
184 Reset | Start | Send => true, New | Idle | Stopped => false,
186 }
187 }
188
189 pub fn poll(
192 &mut self,
193 target: &mut [u8],
194 provider: &mut impl StorageProvider,
195 response_cache: &mut PeerCache,
196 ) -> Result<usize, SyncError> {
197 use SyncResponderState as S;
199 let length = match self.state {
200 S::New | S::Idle | S::Stopped => {
201 return Err(SyncError::NotReady); }
203 S::Start => {
204 let Some(graph_id) = self.graph_id else {
205 self.state = S::Reset;
206 bug!("poll called before graph_id was set");
207 };
208
209 let storage = match provider.get_storage(graph_id) {
210 Ok(s) => s,
211 Err(e) => {
212 self.state = S::Reset;
213 return Err(e.into());
214 }
215 };
216
217 self.state = S::Send;
218 for command in &self.has {
219 if let Some(cmd_loc) = storage.get_location(*command)? {
221 response_cache.add_command(storage, *command, cmd_loc)?;
222 }
223 }
224 self.to_send = Self::find_needed_segments(&self.has, storage)?;
225
226 self.get_next(target, provider)?
227 }
228 S::Send => self.get_next(target, provider)?,
229 S::Reset => {
230 self.state = S::Stopped;
231 let message = SyncResponseMessage::EndSession {
232 session_id: self.session_id()?,
233 };
234 Self::write(target, message)?
235 }
236 };
237
238 Ok(length)
239 }
240
241 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
243 if self.session_id.is_none() {
244 self.session_id = Some(message.session_id());
245 }
246 if self.session_id != Some(message.session_id()) {
247 return Err(SyncError::SessionMismatch);
248 }
249
250 match message {
251 SyncRequestMessage::SyncRequest {
252 graph_id,
253 max_bytes,
254 commands,
255 ..
256 } => {
257 self.state = SyncResponderState::Start;
258 self.graph_id = Some(graph_id);
259 self.bytes_sent = max_bytes;
260 self.to_send = Vec::new();
261 self.has = commands;
262 self.next_send = 0;
263 return Ok(());
264 }
265 SyncRequestMessage::RequestMissing { .. } => {
266 todo!()
267 }
268 SyncRequestMessage::SyncResume { .. } => {
269 todo!()
270 }
271 SyncRequestMessage::EndSession { .. } => {
272 self.state = SyncResponderState::Stopped;
273 }
274 }
275
276 Ok(())
277 }
278
279 fn write_sync_type(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
280 Ok(postcard::to_slice(&msg, target)?.len())
281 }
282
283 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
284 Ok(postcard::to_slice(&msg, target)?.len())
285 }
286
287 fn find_needed_segments(
291 commands: &[Address],
292 storage: &impl Storage,
293 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
294 let mut have_locations = vec::Vec::new(); for &addr in commands {
296 if let Some(location) = storage.get_location(addr)? {
299 have_locations.push(location);
300 }
301 }
302
303 for i in (0..have_locations.len()).rev() {
308 let location_a = have_locations[i];
309 let mut is_ancestor_of_other = false;
310 for &location_b in &have_locations {
311 if location_a != location_b {
312 let segment_b = storage.get_segment(location_b)?;
313 if location_a.same_segment(location_b)
314 && location_a.max_cut <= location_b.max_cut
315 || storage.is_ancestor(location_a, &segment_b)?
316 {
317 is_ancestor_of_other = true;
318 break;
319 }
320 }
321 }
322 if is_ancestor_of_other {
323 have_locations.remove(i);
324 }
325 }
326
327 let mut heads = vec::Vec::new();
328 heads.push(storage.get_head()?);
329
330 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
331
332 while !heads.is_empty() {
333 let current = mem::take(&mut heads);
334 'heads: for head in current {
335 if result.iter().any(|loc| loc.same_segment(head)) {
337 continue 'heads;
338 }
339
340 for &have_location in &have_locations {
344 let have_segment = storage.get_segment(have_location)?;
345 if storage.is_ancestor(head, &have_segment)? {
346 continue 'heads;
347 }
348 }
349
350 let segment = storage.get_segment(head)?;
351
352 if let Some(latest_loc) = have_locations
354 .iter()
355 .filter(|loc| loc.same_segment(head))
356 .max_by_key(|&&location| location.max_cut)
357 {
358 let next_max_cut = latest_loc
359 .max_cut
360 .checked_add(1)
361 .assume("command + 1 mustn't overflow")?;
362 let next_location = Location {
363 segment: head.segment,
364 max_cut: next_max_cut,
365 };
366
367 let head_loc = segment.head_location()?;
368 if next_location.max_cut > head_loc.max_cut {
369 continue 'heads;
370 }
371 if result.is_full() {
372 result.pop_back();
373 }
374 result
375 .push_front(next_location)
376 .ok()
377 .assume("too many segments")?;
378 continue 'heads;
379 }
380
381 heads.extend(segment.prior());
382
383 if result.is_full() {
384 result.pop_back();
385 }
386
387 let location = segment.first_location();
388 result
389 .push_front(location)
390 .ok()
391 .assume("too many segments")?;
392 }
393 }
394 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
395 for l in result {
396 r.push(l).ok().assume("too many segments")?;
397 }
398 r.sort();
401 Ok(r)
402 }
403
404 fn get_next(
405 &mut self,
406 target: &mut [u8],
407 provider: &mut impl StorageProvider,
408 ) -> Result<usize, SyncError> {
409 if self.next_send >= self.to_send.len() {
410 self.state = SyncResponderState::Idle;
411 let message = SyncResponseMessage::SyncEnd {
412 session_id: self.session_id()?,
413 max_index: self.message_index as u64,
414 remaining: false,
415 };
416 let length = Self::write(target, message)?;
417 return Ok(length);
418 }
419
420 let (commands, command_data, next_send) = self.get_commands(provider)?;
421
422 let message = SyncResponseMessage::SyncResponse {
423 session_id: self.session_id()?,
424 response_index: self.message_index as u64,
425 commands,
426 };
427 self.message_index = self
428 .message_index
429 .checked_add(1)
430 .assume("message_index overflow")?;
431 self.next_send = next_send;
432
433 let length = Self::write(target, message)?;
434 let total_length = length
435 .checked_add(command_data.len())
436 .assume("length + command_data_length mustn't overflow")?;
437 target
438 .get_mut(length..total_length)
439 .assume("sync message fits in target")?
440 .copy_from_slice(&command_data);
441 Ok(total_length)
442 }
443
444 pub fn push(
447 &mut self,
448 target: &mut [u8],
449 provider: &mut impl StorageProvider,
450 ) -> Result<usize, SyncError> {
451 use SyncResponderState as S;
452 let Some(graph_id) = self.graph_id else {
453 self.state = S::Reset;
454 bug!("poll called before graph_id was set");
455 };
456
457 let storage = match provider.get_storage(graph_id) {
458 Ok(s) => s,
459 Err(e) => {
460 self.state = S::Reset;
461 return Err(e.into());
462 }
463 };
464 self.to_send = Self::find_needed_segments(&self.has, storage)?;
465 let (commands, command_data, next_send) = self.get_commands(provider)?;
466 let mut length = 0;
467 if !commands.is_empty() {
468 let message = SyncType::Push {
469 message: SyncResponseMessage::SyncResponse {
470 session_id: self.session_id()?,
471 response_index: self.message_index as u64,
472 commands,
473 },
474 graph_id: self.graph_id.assume("graph id must exist")?,
475 };
476 self.message_index = self
477 .message_index
478 .checked_add(1)
479 .assume("message_index increment overflow")?;
480 self.next_send = next_send;
481
482 length = Self::write_sync_type(target, message)?;
483 let total_length = length
484 .checked_add(command_data.len())
485 .assume("length + command_data_length mustn't overflow")?;
486 target
487 .get_mut(length..total_length)
488 .assume("sync message fits in target")?
489 .copy_from_slice(&command_data);
490 length = total_length;
491 }
492 Ok(length)
493 }
494
495 fn get_commands(
496 &mut self,
497 provider: &mut impl StorageProvider,
498 ) -> Result<
499 (
500 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
501 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
502 usize,
503 ),
504 SyncError,
505 > {
506 let Some(graph_id) = self.graph_id.as_ref() else {
507 self.state = SyncResponderState::Reset;
508 bug!("get_next called before graph_id was set");
509 };
510 let storage = match provider.get_storage(*graph_id) {
511 Ok(s) => s,
512 Err(e) => {
513 self.state = SyncResponderState::Reset;
514 return Err(e.into());
515 }
516 };
517 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
518 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
519 let mut index = self.next_send;
520 for i in self.next_send..self.to_send.len() {
521 if commands.is_full() {
522 break;
523 }
524 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
525 let Some(&location) = self.to_send.get(i) else {
526 self.state = SyncResponderState::Reset;
527 bug!("send index OOB");
528 };
529
530 let segment = storage
531 .get_segment(location)
532 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
533
534 let found = segment.get_from(location);
535
536 for command in &found {
537 let mut policy_length = 0;
538
539 if let Some(policy) = command.policy() {
540 policy_length = policy.len();
541 command_data
542 .extend_from_slice(policy)
543 .ok()
544 .assume("command_data is too large")?;
545 }
546
547 let bytes = command.bytes();
548 command_data
549 .extend_from_slice(bytes)
550 .ok()
551 .assume("command_data is too large")?;
552
553 let max_cut = command.max_cut()?;
554 let meta = CommandMeta {
555 id: command.id(),
556 priority: command.priority(),
557 parent: command.parent(),
558 policy_length: policy_length as u32,
559 length: bytes.len() as u32,
560 max_cut,
561 };
562
563 commands
565 .push(meta)
566 .ok()
567 .assume("too many commands in segment")?;
568 if commands.is_full() {
569 break;
570 }
571 }
572 }
573 Ok((commands, command_data, index))
574 }
575
576 fn session_id(&self) -> Result<u128, SyncError> {
577 Ok(self.session_id.assume("session id is set")?)
578 }
579}