1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10 SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13 StorageError, SyncType,
14 command::{Address, CmdId, Command},
15 storage::{GraphId, Location, Segment, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub const fn new() -> Self {
25 PeerCache { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &mut S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42 let mut retain_head = |request_head: &Address, new_head: Location| {
43 let new_head_seg = storage.get_segment(new_head)?;
44 let req_head_loc = storage
45 .get_location(*request_head)?
46 .assume("location must exist")?;
47 let req_head_seg = storage.get_segment(req_head_loc)?;
48 if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49 if request_head.id == new_head_command.address()?.id {
50 add_command = false;
51 }
52 }
53 if storage.is_ancestor(new_head, &req_head_seg)? {
54 add_command = false;
55 }
56 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57 };
58 self.heads
59 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60 if add_command && !self.heads.is_full() {
61 self.heads
62 .push(command)
63 .ok()
64 .assume("command locations should not be full")?;
65 };
66 Ok(())
67 }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79 SyncResponse {
81 session_id: u128,
83 response_index: u64,
88 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90 },
91
92 SyncEnd {
95 session_id: u128,
98 max_index: u64,
100 remaining: bool,
103 },
104
105 Offer {
110 session_id: u128,
113 head: CmdId,
115 },
116
117 EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123 pub fn session_id(&self) -> u128 {
124 match self {
125 Self::SyncResponse { session_id, .. } => *session_id,
126 Self::SyncEnd { session_id, .. } => *session_id,
127 Self::Offer { session_id, .. } => *session_id,
128 Self::EndSession { session_id, .. } => *session_id,
129 }
130 }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135 New,
136 Start,
137 Send,
138 Idle,
139 Reset,
140 Stopped,
141}
142
143impl Default for SyncResponderState {
144 fn default() -> Self {
145 Self::New
146 }
147}
148
149#[derive(Default)]
150pub struct SyncResponder<A> {
151 session_id: Option<u128>,
152 storage_id: Option<GraphId>,
153 state: SyncResponderState,
154 bytes_sent: u64,
155 next_send: usize,
156 message_index: usize,
157 has: Vec<Address, COMMAND_SAMPLE_MAX>,
158 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
159 server_address: A,
160}
161
162impl<A: Serialize + Clone> SyncResponder<A> {
163 pub fn new(server_address: A) -> Self {
165 SyncResponder {
166 session_id: None,
167 storage_id: None,
168 state: SyncResponderState::New,
169 bytes_sent: 0,
170 next_send: 0,
171 message_index: 0,
172 has: Vec::new(),
173 to_send: Vec::new(),
174 server_address,
175 }
176 }
177
178 pub fn ready(&self) -> bool {
180 use SyncResponderState::*;
181 match self.state {
182 Reset | Start | Send => true, New | Idle | Stopped => false,
184 }
185 }
186
187 pub fn poll(
190 &mut self,
191 target: &mut [u8],
192 provider: &mut impl StorageProvider,
193 response_cache: &mut PeerCache,
194 ) -> Result<usize, SyncError> {
195 use SyncResponderState as S;
197 let length = match self.state {
198 S::New | S::Idle | S::Stopped => {
199 return Err(SyncError::NotReady); }
201 S::Start => {
202 let Some(storage_id) = self.storage_id else {
203 self.state = S::Reset;
204 bug!("poll called before storage_id was set");
205 };
206
207 let storage = match provider.get_storage(storage_id) {
208 Ok(s) => s,
209 Err(e) => {
210 self.state = S::Reset;
211 return Err(e.into());
212 }
213 };
214
215 self.state = S::Send;
216 for command in &self.has {
217 if let Some(cmd_loc) = storage.get_location(*command)? {
219 response_cache.add_command(storage, *command, cmd_loc)?;
220 }
221 }
222 self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
223
224 self.get_next(target, provider)?
225 }
226 S::Send => self.get_next(target, provider)?,
227 S::Reset => {
228 self.state = S::Stopped;
229 let message = SyncResponseMessage::EndSession {
230 session_id: self.session_id()?,
231 };
232 Self::write(target, message)?
233 }
234 };
235
236 Ok(length)
237 }
238
239 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
241 if self.session_id.is_none() {
242 self.session_id = Some(message.session_id());
243 }
244 if self.session_id != Some(message.session_id()) {
245 return Err(SyncError::SessionMismatch);
246 }
247
248 match message {
249 SyncRequestMessage::SyncRequest {
250 storage_id,
251 max_bytes,
252 commands,
253 ..
254 } => {
255 self.state = SyncResponderState::Start;
256 self.storage_id = Some(storage_id);
257 self.bytes_sent = max_bytes;
258 self.to_send = Vec::new();
259 self.has = commands;
260 self.next_send = 0;
261 return Ok(());
262 }
263 SyncRequestMessage::RequestMissing { .. } => {
264 todo!()
265 }
266 SyncRequestMessage::SyncResume { .. } => {
267 todo!()
268 }
269 SyncRequestMessage::EndSession { .. } => {
270 self.state = SyncResponderState::Stopped;
271 }
272 };
273
274 Ok(())
275 }
276
277 fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
278 Ok(postcard::to_slice(&msg, target)?.len())
279 }
280
281 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
282 Ok(postcard::to_slice(&msg, target)?.len())
283 }
284
285 fn find_needed_segments(
289 commands: &[Address],
290 storage: &impl Storage,
291 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
292 let mut have_locations = vec::Vec::new(); for &addr in commands {
294 let Some(location) = storage.get_location(addr)? else {
295 continue;
299 };
300
301 have_locations.push(location);
302 }
303
304 let mut heads = vec::Vec::new();
305 heads.push(storage.get_head()?);
306
307 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
308
309 while !heads.is_empty() {
310 let current = mem::take(&mut heads);
311 'heads: for head in current {
312 let segment = storage.get_segment(head)?;
313 if segment.contains_any(&result) {
314 continue 'heads;
315 }
316
317 for &location in &have_locations {
318 if segment.contains(location) {
319 if location != segment.head_location() {
320 if result.is_full() {
321 result.pop_back();
322 }
323 result
324 .push_front(location)
325 .ok()
326 .assume("too many segments")?;
327 }
328 continue 'heads;
329 }
330 }
331 heads.extend(segment.prior());
332
333 if result.is_full() {
334 result.pop_back();
335 }
336
337 let location = segment.first_location();
338 result
339 .push_front(location)
340 .ok()
341 .assume("too many segments")?;
342 }
343 }
344 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
345 for l in result {
346 r.push(l).ok().assume("too many segments")?;
347 }
348 r.sort();
351 Ok(r)
352 }
353
354 fn get_next(
355 &mut self,
356 target: &mut [u8],
357 provider: &mut impl StorageProvider,
358 ) -> Result<usize, SyncError> {
359 if self.next_send >= self.to_send.len() {
360 self.state = SyncResponderState::Idle;
361 let message = SyncResponseMessage::SyncEnd {
362 session_id: self.session_id()?,
363 max_index: self.message_index as u64,
364 remaining: false,
365 };
366 let length = Self::write(target, message)?;
367 return Ok(length);
368 }
369
370 let (commands, command_data, next_send) = self.get_commands(provider)?;
371
372 let message = SyncResponseMessage::SyncResponse {
373 session_id: self.session_id()?,
374 response_index: self.message_index as u64,
375 commands,
376 };
377 self.message_index = self
378 .message_index
379 .checked_add(1)
380 .assume("message_index overflow")?;
381 self.next_send = next_send;
382
383 let length = Self::write(target, message)?;
384 let total_length = length
385 .checked_add(command_data.len())
386 .assume("length + command_data_length mustn't overflow")?;
387 target
388 .get_mut(length..total_length)
389 .assume("sync message fits in target")?
390 .copy_from_slice(&command_data);
391 Ok(total_length)
392 }
393
394 pub fn push(
397 &mut self,
398 target: &mut [u8],
399 provider: &mut impl StorageProvider,
400 ) -> Result<usize, SyncError> {
401 use SyncResponderState as S;
402 let Some(storage_id) = self.storage_id else {
403 self.state = S::Reset;
404 bug!("poll called before storage_id was set");
405 };
406
407 let storage = match provider.get_storage(storage_id) {
408 Ok(s) => s,
409 Err(e) => {
410 self.state = S::Reset;
411 return Err(e.into());
412 }
413 };
414 self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
415 let (commands, command_data, next_send) = self.get_commands(provider)?;
416 let mut length = 0;
417 if !commands.is_empty() {
418 let message = SyncType::Push {
419 message: SyncResponseMessage::SyncResponse {
420 session_id: self.session_id()?,
421 response_index: self.message_index as u64,
422 commands,
423 },
424 storage_id: self.storage_id.assume("storage id must exist")?,
425 address: self.server_address.clone(),
426 };
427 self.message_index = self
428 .message_index
429 .checked_add(1)
430 .assume("message_index increment overflow")?;
431 self.next_send = next_send;
432
433 length = Self::write_sync_type(target, message)?;
434 let total_length = length
435 .checked_add(command_data.len())
436 .assume("length + command_data_length mustn't overflow")?;
437 target
438 .get_mut(length..total_length)
439 .assume("sync message fits in target")?
440 .copy_from_slice(&command_data);
441 length = total_length;
442 }
443 Ok(length)
444 }
445
446 fn get_commands(
447 &mut self,
448 provider: &mut impl StorageProvider,
449 ) -> Result<
450 (
451 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
452 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
453 usize,
454 ),
455 SyncError,
456 > {
457 let Some(storage_id) = self.storage_id.as_ref() else {
458 self.state = SyncResponderState::Reset;
459 bug!("get_next called before storage_id was set");
460 };
461 let storage = match provider.get_storage(*storage_id) {
462 Ok(s) => s,
463 Err(e) => {
464 self.state = SyncResponderState::Reset;
465 return Err(e.into());
466 }
467 };
468 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
469 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
470 let mut index = self.next_send;
471 for i in self.next_send..self.to_send.len() {
472 if commands.is_full() {
473 break;
474 }
475 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
476 let Some(&location) = self.to_send.get(i) else {
477 self.state = SyncResponderState::Reset;
478 bug!("send index OOB");
479 };
480
481 let segment = storage
482 .get_segment(location)
483 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
484
485 let found = segment.get_from(location);
486
487 for command in &found {
488 let mut policy_length = 0;
489
490 if let Some(policy) = command.policy() {
491 policy_length = policy.len();
492 command_data
493 .extend_from_slice(policy)
494 .ok()
495 .assume("command_data is too large")?;
496 }
497
498 let bytes = command.bytes();
499 command_data
500 .extend_from_slice(bytes)
501 .ok()
502 .assume("command_data is too large")?;
503
504 let meta = CommandMeta {
505 id: command.id(),
506 priority: command.priority(),
507 parent: command.parent(),
508 policy_length: policy_length as u32,
509 length: bytes.len() as u32,
510 max_cut: command.max_cut()?,
511 };
512
513 commands
515 .push(meta)
516 .ok()
517 .assume("too many commands in segment")?;
518 if commands.is_full() {
519 break;
520 }
521 }
522 }
523 Ok((commands, command_data, index))
524 }
525
526 fn session_id(&self) -> Result<u128, SyncError> {
527 Ok(self.session_id.assume("session id is set")?)
528 }
529}