1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10 SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13 StorageError, SyncType,
14 command::{Address, CmdId, Command as _},
15 storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub const fn new() -> Self {
25 Self { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &mut S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42 let mut retain_head = |request_head: &Address, new_head: Location| {
43 let new_head_seg = storage.get_segment(new_head)?;
44 let req_head_loc = storage
45 .get_location(*request_head)?
46 .assume("location must exist")?;
47 let req_head_seg = storage.get_segment(req_head_loc)?;
48 if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49 if request_head.id == new_head_command.address()?.id {
50 add_command = false;
51 }
52 }
53 if storage.is_ancestor(new_head, &req_head_seg)? {
54 add_command = false;
55 }
56 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57 };
58 self.heads
59 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60 if add_command && !self.heads.is_full() {
61 self.heads
62 .push(command)
63 .ok()
64 .assume("command locations should not be full")?;
65 }
66 Ok(())
67 }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79 SyncResponse {
81 session_id: u128,
83 response_index: u64,
88 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90 },
91
92 SyncEnd {
95 session_id: u128,
98 max_index: u64,
100 remaining: bool,
103 },
104
105 Offer {
110 session_id: u128,
113 head: CmdId,
115 },
116
117 EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123 pub fn session_id(&self) -> u128 {
124 match self {
125 Self::SyncResponse { session_id, .. } => *session_id,
126 Self::SyncEnd { session_id, .. } => *session_id,
127 Self::Offer { session_id, .. } => *session_id,
128 Self::EndSession { session_id, .. } => *session_id,
129 }
130 }
131}
132
133#[derive(Debug, Default)]
134enum SyncResponderState {
135 #[default]
136 New,
137 Start,
138 Send,
139 Idle,
140 Reset,
141 Stopped,
142}
143
144#[derive(Default)]
145pub struct SyncResponder<A> {
146 session_id: Option<u128>,
147 storage_id: Option<GraphId>,
148 state: SyncResponderState,
149 bytes_sent: u64,
150 next_send: usize,
151 message_index: usize,
152 has: Vec<Address, COMMAND_SAMPLE_MAX>,
153 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
154 server_address: A,
155}
156
157impl<A: Serialize + Clone> SyncResponder<A> {
158 pub fn new(server_address: A) -> Self {
160 Self {
161 session_id: None,
162 storage_id: None,
163 state: SyncResponderState::New,
164 bytes_sent: 0,
165 next_send: 0,
166 message_index: 0,
167 has: Vec::new(),
168 to_send: Vec::new(),
169 server_address,
170 }
171 }
172
173 pub fn ready(&self) -> bool {
175 use SyncResponderState::*;
176 match self.state {
177 Reset | Start | Send => true, New | Idle | Stopped => false,
179 }
180 }
181
182 pub fn poll(
185 &mut self,
186 target: &mut [u8],
187 provider: &mut impl StorageProvider,
188 response_cache: &mut PeerCache,
189 ) -> Result<usize, SyncError> {
190 use SyncResponderState as S;
192 let length = match self.state {
193 S::New | S::Idle | S::Stopped => {
194 return Err(SyncError::NotReady); }
196 S::Start => {
197 let Some(storage_id) = self.storage_id else {
198 self.state = S::Reset;
199 bug!("poll called before storage_id was set");
200 };
201
202 let storage = match provider.get_storage(storage_id) {
203 Ok(s) => s,
204 Err(e) => {
205 self.state = S::Reset;
206 return Err(e.into());
207 }
208 };
209
210 self.state = S::Send;
211 for command in &self.has {
212 if let Some(cmd_loc) = storage.get_location(*command)? {
214 response_cache.add_command(storage, *command, cmd_loc)?;
215 }
216 }
217 self.to_send = Self::find_needed_segments(&self.has, storage)?;
218
219 self.get_next(target, provider)?
220 }
221 S::Send => self.get_next(target, provider)?,
222 S::Reset => {
223 self.state = S::Stopped;
224 let message = SyncResponseMessage::EndSession {
225 session_id: self.session_id()?,
226 };
227 Self::write(target, message)?
228 }
229 };
230
231 Ok(length)
232 }
233
234 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
236 if self.session_id.is_none() {
237 self.session_id = Some(message.session_id());
238 }
239 if self.session_id != Some(message.session_id()) {
240 return Err(SyncError::SessionMismatch);
241 }
242
243 match message {
244 SyncRequestMessage::SyncRequest {
245 storage_id,
246 max_bytes,
247 commands,
248 ..
249 } => {
250 self.state = SyncResponderState::Start;
251 self.storage_id = Some(storage_id);
252 self.bytes_sent = max_bytes;
253 self.to_send = Vec::new();
254 self.has = commands;
255 self.next_send = 0;
256 return Ok(());
257 }
258 SyncRequestMessage::RequestMissing { .. } => {
259 todo!()
260 }
261 SyncRequestMessage::SyncResume { .. } => {
262 todo!()
263 }
264 SyncRequestMessage::EndSession { .. } => {
265 self.state = SyncResponderState::Stopped;
266 }
267 }
268
269 Ok(())
270 }
271
272 fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
273 Ok(postcard::to_slice(&msg, target)?.len())
274 }
275
276 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
277 Ok(postcard::to_slice(&msg, target)?.len())
278 }
279
280 fn find_needed_segments(
284 commands: &[Address],
285 storage: &impl Storage,
286 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
287 let mut have_locations = vec::Vec::new(); for &addr in commands {
289 let Some(location) = storage.get_location(addr)? else {
290 continue;
294 };
295
296 have_locations.push(location);
297 }
298
299 let mut heads = vec::Vec::new();
300 heads.push(storage.get_head()?);
301
302 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
303
304 while !heads.is_empty() {
305 let current = mem::take(&mut heads);
306 'heads: for head in current {
307 let segment = storage.get_segment(head)?;
308 if segment.contains_any(&result) {
309 continue 'heads;
310 }
311
312 for &location in &have_locations {
313 if segment.contains(location) {
314 if location != segment.head_location() {
315 if result.is_full() {
316 result.pop_back();
317 }
318 result
319 .push_front(location)
320 .ok()
321 .assume("too many segments")?;
322 }
323 continue 'heads;
324 }
325 }
326 heads.extend(segment.prior());
327
328 if result.is_full() {
329 result.pop_back();
330 }
331
332 let location = segment.first_location();
333 result
334 .push_front(location)
335 .ok()
336 .assume("too many segments")?;
337 }
338 }
339 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
340 for l in result {
341 r.push(l).ok().assume("too many segments")?;
342 }
343 r.sort();
346 Ok(r)
347 }
348
349 fn get_next(
350 &mut self,
351 target: &mut [u8],
352 provider: &mut impl StorageProvider,
353 ) -> Result<usize, SyncError> {
354 if self.next_send >= self.to_send.len() {
355 self.state = SyncResponderState::Idle;
356 let message = SyncResponseMessage::SyncEnd {
357 session_id: self.session_id()?,
358 max_index: self.message_index as u64,
359 remaining: false,
360 };
361 let length = Self::write(target, message)?;
362 return Ok(length);
363 }
364
365 let (commands, command_data, next_send) = self.get_commands(provider)?;
366
367 let message = SyncResponseMessage::SyncResponse {
368 session_id: self.session_id()?,
369 response_index: self.message_index as u64,
370 commands,
371 };
372 self.message_index = self
373 .message_index
374 .checked_add(1)
375 .assume("message_index overflow")?;
376 self.next_send = next_send;
377
378 let length = Self::write(target, message)?;
379 let total_length = length
380 .checked_add(command_data.len())
381 .assume("length + command_data_length mustn't overflow")?;
382 target
383 .get_mut(length..total_length)
384 .assume("sync message fits in target")?
385 .copy_from_slice(&command_data);
386 Ok(total_length)
387 }
388
389 pub fn push(
392 &mut self,
393 target: &mut [u8],
394 provider: &mut impl StorageProvider,
395 ) -> Result<usize, SyncError> {
396 use SyncResponderState as S;
397 let Some(storage_id) = self.storage_id else {
398 self.state = S::Reset;
399 bug!("poll called before storage_id was set");
400 };
401
402 let storage = match provider.get_storage(storage_id) {
403 Ok(s) => s,
404 Err(e) => {
405 self.state = S::Reset;
406 return Err(e.into());
407 }
408 };
409 self.to_send = Self::find_needed_segments(&self.has, storage)?;
410 let (commands, command_data, next_send) = self.get_commands(provider)?;
411 let mut length = 0;
412 if !commands.is_empty() {
413 let message = SyncType::Push {
414 message: SyncResponseMessage::SyncResponse {
415 session_id: self.session_id()?,
416 response_index: self.message_index as u64,
417 commands,
418 },
419 storage_id: self.storage_id.assume("storage id must exist")?,
420 address: self.server_address.clone(),
421 };
422 self.message_index = self
423 .message_index
424 .checked_add(1)
425 .assume("message_index increment overflow")?;
426 self.next_send = next_send;
427
428 length = Self::write_sync_type(target, message)?;
429 let total_length = length
430 .checked_add(command_data.len())
431 .assume("length + command_data_length mustn't overflow")?;
432 target
433 .get_mut(length..total_length)
434 .assume("sync message fits in target")?
435 .copy_from_slice(&command_data);
436 length = total_length;
437 }
438 Ok(length)
439 }
440
441 fn get_commands(
442 &mut self,
443 provider: &mut impl StorageProvider,
444 ) -> Result<
445 (
446 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
447 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
448 usize,
449 ),
450 SyncError,
451 > {
452 let Some(storage_id) = self.storage_id.as_ref() else {
453 self.state = SyncResponderState::Reset;
454 bug!("get_next called before storage_id was set");
455 };
456 let storage = match provider.get_storage(*storage_id) {
457 Ok(s) => s,
458 Err(e) => {
459 self.state = SyncResponderState::Reset;
460 return Err(e.into());
461 }
462 };
463 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
464 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
465 let mut index = self.next_send;
466 for i in self.next_send..self.to_send.len() {
467 if commands.is_full() {
468 break;
469 }
470 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
471 let Some(&location) = self.to_send.get(i) else {
472 self.state = SyncResponderState::Reset;
473 bug!("send index OOB");
474 };
475
476 let segment = storage
477 .get_segment(location)
478 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
479
480 let found = segment.get_from(location);
481
482 for command in &found {
483 let mut policy_length = 0;
484
485 if let Some(policy) = command.policy() {
486 policy_length = policy.len();
487 command_data
488 .extend_from_slice(policy)
489 .ok()
490 .assume("command_data is too large")?;
491 }
492
493 let bytes = command.bytes();
494 command_data
495 .extend_from_slice(bytes)
496 .ok()
497 .assume("command_data is too large")?;
498
499 let meta = CommandMeta {
500 id: command.id(),
501 priority: command.priority(),
502 parent: command.parent(),
503 policy_length: policy_length as u32,
504 length: bytes.len() as u32,
505 max_cut: command.max_cut()?,
506 };
507
508 commands
510 .push(meta)
511 .ok()
512 .assume("too many commands in segment")?;
513 if commands.is_full() {
514 break;
515 }
516 }
517 }
518 Ok((commands, command_data, index))
519 }
520
521 fn session_id(&self) -> Result<u128, SyncError> {
522 Ok(self.session_id.assume("session id is set")?)
523 }
524}