1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10 SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13 StorageError, SyncType,
14 command::{Address, CmdId, Command as _},
15 storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub const fn new() -> Self {
25 Self { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42
43 let mut retain_head = |request_head: &Address, new_head: Location| {
44 let new_head_seg = storage.get_segment(new_head)?;
45 let req_head_loc = storage
46 .get_location(*request_head)?
47 .assume("location must exist")?;
48 let req_head_seg = storage.get_segment(req_head_loc)?;
49 if request_head.id
50 == new_head_seg
51 .get_command(new_head)
52 .assume("location must exist")?
53 .address()?
54 .id
55 {
56 add_command = false;
57 }
58 if (new_head.same_segment(req_head_loc) && new_head.command <= req_head_loc.command)
60 || storage.is_ancestor(new_head, &req_head_seg)?
61 {
62 add_command = false;
63 }
64 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65 };
66 self.heads
67 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68 if add_command && !self.heads.is_full() {
69 self.heads
70 .push(command)
71 .ok()
72 .assume("command locations should not be full")?;
73 }
74
75 Ok(())
76 }
77}
78
79#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88 SyncResponse {
90 session_id: u128,
92 response_index: u64,
97 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99 },
100
101 SyncEnd {
104 session_id: u128,
107 max_index: u64,
109 remaining: bool,
112 },
113
114 Offer {
119 session_id: u128,
122 head: CmdId,
124 },
125
126 EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132 pub fn session_id(&self) -> u128 {
133 match self {
134 Self::SyncResponse { session_id, .. } => *session_id,
135 Self::SyncEnd { session_id, .. } => *session_id,
136 Self::Offer { session_id, .. } => *session_id,
137 Self::EndSession { session_id, .. } => *session_id,
138 }
139 }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144 #[default]
145 New,
146 Start,
147 Send,
148 Idle,
149 Reset,
150 Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder {
155 session_id: Option<u128>,
156 graph_id: Option<GraphId>,
157 state: SyncResponderState,
158 bytes_sent: u64,
159 next_send: usize,
160 message_index: usize,
161 has: Vec<Address, COMMAND_SAMPLE_MAX>,
162 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163}
164
165impl SyncResponder {
166 pub fn new() -> Self {
168 Self {
169 session_id: None,
170 graph_id: None,
171 state: SyncResponderState::New,
172 bytes_sent: 0,
173 next_send: 0,
174 message_index: 0,
175 has: Vec::new(),
176 to_send: Vec::new(),
177 }
178 }
179
180 pub fn ready(&self) -> bool {
182 use SyncResponderState::*;
183 match self.state {
184 Reset | Start | Send => true, New | Idle | Stopped => false,
186 }
187 }
188
189 pub fn poll(
192 &mut self,
193 target: &mut [u8],
194 provider: &mut impl StorageProvider,
195 response_cache: &mut PeerCache,
196 ) -> Result<usize, SyncError> {
197 use SyncResponderState as S;
199 let length = match self.state {
200 S::New | S::Idle | S::Stopped => {
201 return Err(SyncError::NotReady); }
203 S::Start => {
204 let Some(graph_id) = self.graph_id else {
205 self.state = S::Reset;
206 bug!("poll called before graph_id was set");
207 };
208
209 let storage = match provider.get_storage(graph_id) {
210 Ok(s) => s,
211 Err(e) => {
212 self.state = S::Reset;
213 return Err(e.into());
214 }
215 };
216
217 self.state = S::Send;
218 for command in &self.has {
219 if let Some(cmd_loc) = storage.get_location(*command)? {
221 response_cache.add_command(storage, *command, cmd_loc)?;
222 }
223 }
224 self.to_send = Self::find_needed_segments(&self.has, storage)?;
225
226 self.get_next(target, provider)?
227 }
228 S::Send => self.get_next(target, provider)?,
229 S::Reset => {
230 self.state = S::Stopped;
231 let message = SyncResponseMessage::EndSession {
232 session_id: self.session_id()?,
233 };
234 Self::write(target, message)?
235 }
236 };
237
238 Ok(length)
239 }
240
241 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
243 if self.session_id.is_none() {
244 self.session_id = Some(message.session_id());
245 }
246 if self.session_id != Some(message.session_id()) {
247 return Err(SyncError::SessionMismatch);
248 }
249
250 match message {
251 SyncRequestMessage::SyncRequest {
252 graph_id,
253 max_bytes,
254 commands,
255 ..
256 } => {
257 self.state = SyncResponderState::Start;
258 self.graph_id = Some(graph_id);
259 self.bytes_sent = max_bytes;
260 self.to_send = Vec::new();
261 self.has = commands;
262 self.next_send = 0;
263 return Ok(());
264 }
265 SyncRequestMessage::RequestMissing { .. } => {
266 todo!()
267 }
268 SyncRequestMessage::SyncResume { .. } => {
269 todo!()
270 }
271 SyncRequestMessage::EndSession { .. } => {
272 self.state = SyncResponderState::Stopped;
273 }
274 }
275
276 Ok(())
277 }
278
279 fn write_sync_type(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
280 Ok(postcard::to_slice(&msg, target)?.len())
281 }
282
283 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
284 Ok(postcard::to_slice(&msg, target)?.len())
285 }
286
287 fn find_needed_segments(
291 commands: &[Address],
292 storage: &impl Storage,
293 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
294 let mut have_locations = vec::Vec::new(); for &addr in commands {
296 if let Some(location) = storage.get_location(addr)? {
299 have_locations.push(location);
300 }
301 }
302
303 for i in (0..have_locations.len()).rev() {
308 let location_a = have_locations[i];
309 let mut is_ancestor_of_other = false;
310 for &location_b in &have_locations {
311 if location_a != location_b {
312 let segment_b = storage.get_segment(location_b)?;
313 if location_a.same_segment(location_b)
314 && location_a.command <= location_b.command
315 || storage.is_ancestor(location_a, &segment_b)?
316 {
317 is_ancestor_of_other = true;
318 break;
319 }
320 }
321 }
322 if is_ancestor_of_other {
323 have_locations.remove(i);
324 }
325 }
326
327 let mut heads = vec::Vec::new();
328 heads.push(storage.get_head()?);
329
330 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
331
332 while !heads.is_empty() {
333 let current = mem::take(&mut heads);
334 'heads: for head in current {
335 let segment = storage.get_segment(head)?;
336 if segment.contains_any(&result) {
338 continue 'heads;
339 }
340
341 for &have_location in &have_locations {
345 let have_segment = storage.get_segment(have_location)?;
346 if storage.is_ancestor(head, &have_segment)? {
347 continue 'heads;
348 }
349 }
350
351 if let Some(latest_loc) = have_locations
353 .iter()
354 .filter(|&&location| segment.contains(location))
355 .max_by_key(|&&location| location.command)
356 {
357 let next_command = latest_loc
358 .command
359 .checked_add(1)
360 .assume("command + 1 mustn't overflow")?;
361 let next_location = Location {
362 segment: head.segment,
363 command: next_command,
364 };
365
366 let head_loc = segment.head_location();
367 if next_location.command > head_loc.command {
368 continue 'heads;
369 }
370 if result.is_full() {
371 result.pop_back();
372 }
373 result
374 .push_front(next_location)
375 .ok()
376 .assume("too many segments")?;
377 continue 'heads;
378 }
379
380 heads.extend(segment.prior());
381
382 if result.is_full() {
383 result.pop_back();
384 }
385
386 let location = segment.first_location();
387 result
388 .push_front(location)
389 .ok()
390 .assume("too many segments")?;
391 }
392 }
393 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
394 for l in result {
395 r.push(l).ok().assume("too many segments")?;
396 }
397 r.sort();
400 Ok(r)
401 }
402
403 fn get_next(
404 &mut self,
405 target: &mut [u8],
406 provider: &mut impl StorageProvider,
407 ) -> Result<usize, SyncError> {
408 if self.next_send >= self.to_send.len() {
409 self.state = SyncResponderState::Idle;
410 let message = SyncResponseMessage::SyncEnd {
411 session_id: self.session_id()?,
412 max_index: self.message_index as u64,
413 remaining: false,
414 };
415 let length = Self::write(target, message)?;
416 return Ok(length);
417 }
418
419 let (commands, command_data, next_send) = self.get_commands(provider)?;
420
421 let message = SyncResponseMessage::SyncResponse {
422 session_id: self.session_id()?,
423 response_index: self.message_index as u64,
424 commands,
425 };
426 self.message_index = self
427 .message_index
428 .checked_add(1)
429 .assume("message_index overflow")?;
430 self.next_send = next_send;
431
432 let length = Self::write(target, message)?;
433 let total_length = length
434 .checked_add(command_data.len())
435 .assume("length + command_data_length mustn't overflow")?;
436 target
437 .get_mut(length..total_length)
438 .assume("sync message fits in target")?
439 .copy_from_slice(&command_data);
440 Ok(total_length)
441 }
442
443 pub fn push(
446 &mut self,
447 target: &mut [u8],
448 provider: &mut impl StorageProvider,
449 ) -> Result<usize, SyncError> {
450 use SyncResponderState as S;
451 let Some(graph_id) = self.graph_id else {
452 self.state = S::Reset;
453 bug!("poll called before graph_id was set");
454 };
455
456 let storage = match provider.get_storage(graph_id) {
457 Ok(s) => s,
458 Err(e) => {
459 self.state = S::Reset;
460 return Err(e.into());
461 }
462 };
463 self.to_send = Self::find_needed_segments(&self.has, storage)?;
464 let (commands, command_data, next_send) = self.get_commands(provider)?;
465 let mut length = 0;
466 if !commands.is_empty() {
467 let message = SyncType::Push {
468 message: SyncResponseMessage::SyncResponse {
469 session_id: self.session_id()?,
470 response_index: self.message_index as u64,
471 commands,
472 },
473 graph_id: self.graph_id.assume("graph id must exist")?,
474 };
475 self.message_index = self
476 .message_index
477 .checked_add(1)
478 .assume("message_index increment overflow")?;
479 self.next_send = next_send;
480
481 length = Self::write_sync_type(target, message)?;
482 let total_length = length
483 .checked_add(command_data.len())
484 .assume("length + command_data_length mustn't overflow")?;
485 target
486 .get_mut(length..total_length)
487 .assume("sync message fits in target")?
488 .copy_from_slice(&command_data);
489 length = total_length;
490 }
491 Ok(length)
492 }
493
494 fn get_commands(
495 &mut self,
496 provider: &mut impl StorageProvider,
497 ) -> Result<
498 (
499 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
500 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
501 usize,
502 ),
503 SyncError,
504 > {
505 let Some(graph_id) = self.graph_id.as_ref() else {
506 self.state = SyncResponderState::Reset;
507 bug!("get_next called before graph_id was set");
508 };
509 let storage = match provider.get_storage(*graph_id) {
510 Ok(s) => s,
511 Err(e) => {
512 self.state = SyncResponderState::Reset;
513 return Err(e.into());
514 }
515 };
516 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
517 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
518 let mut index = self.next_send;
519 for i in self.next_send..self.to_send.len() {
520 if commands.is_full() {
521 break;
522 }
523 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
524 let Some(&location) = self.to_send.get(i) else {
525 self.state = SyncResponderState::Reset;
526 bug!("send index OOB");
527 };
528
529 let segment = storage
530 .get_segment(location)
531 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
532
533 let found = segment.get_from(location);
534
535 for command in &found {
536 let mut policy_length = 0;
537
538 if let Some(policy) = command.policy() {
539 policy_length = policy.len();
540 command_data
541 .extend_from_slice(policy)
542 .ok()
543 .assume("command_data is too large")?;
544 }
545
546 let bytes = command.bytes();
547 command_data
548 .extend_from_slice(bytes)
549 .ok()
550 .assume("command_data is too large")?;
551
552 let max_cut = command.max_cut()?;
553 let meta = CommandMeta {
554 id: command.id(),
555 priority: command.priority(),
556 parent: command.parent(),
557 policy_length: policy_length as u32,
558 length: bytes.len() as u32,
559 max_cut,
560 };
561
562 commands
564 .push(meta)
565 .ok()
566 .assume("too many commands in segment")?;
567 if commands.is_full() {
568 break;
569 }
570 }
571 }
572 Ok((commands, command_data, index))
573 }
574
575 fn session_id(&self) -> Result<u128, SyncError> {
576 Ok(self.session_id.assume("session id is set")?)
577 }
578}