1use alloc::vec;
2use core::mem;
3
4use buggy::{bug, BugExt};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 requester::SyncRequestMessage, CommandMeta, SyncError, COMMAND_RESPONSE_MAX,
10 COMMAND_SAMPLE_MAX, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX, SEGMENT_BUFFER_MAX,
11};
12use crate::{
13 command::{Address, Command, CommandId},
14 storage::{GraphId, Location, Segment, Storage, StorageProvider},
15 StorageError, SyncType,
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub fn new() -> Self {
25 PeerCache { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &mut S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42 let mut retain_head = |request_head: &Address, new_head: Location| {
43 let new_head_seg = storage.get_segment(new_head)?;
44 let req_head_loc = storage
45 .get_location(*request_head)?
46 .assume("location must exist")?;
47 let req_head_seg = storage.get_segment(req_head_loc)?;
48 if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49 if request_head.id == new_head_command.address()?.id {
50 add_command = false;
51 }
52 }
53 if storage.is_ancestor(new_head, &req_head_seg)? {
54 add_command = false;
55 }
56 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57 };
58 self.heads
59 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60 if add_command && !self.heads.is_full() {
61 self.heads
62 .push(command)
63 .ok()
64 .assume("command locations should not be full")?;
65 };
66 Ok(())
67 }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79 SyncResponse {
81 session_id: u128,
83 index: u64,
88 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90 },
91
92 SyncEnd {
95 session_id: u128,
98 max_index: u64,
100 remaining: bool,
103 },
104
105 Offer {
110 session_id: u128,
113 head: CommandId,
115 },
116
117 EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123 pub fn session_id(&self) -> u128 {
124 match self {
125 Self::SyncResponse { session_id, .. } => *session_id,
126 Self::SyncEnd { session_id, .. } => *session_id,
127 Self::Offer { session_id, .. } => *session_id,
128 Self::EndSession { session_id, .. } => *session_id,
129 }
130 }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135 New,
136 Start,
137 Send,
138 Idle,
139 Reset,
140 Stopped,
141}
142
143impl Default for SyncResponderState {
144 fn default() -> Self {
145 Self::New
146 }
147}
148
149#[derive(Default)]
150pub struct SyncResponder<A> {
151 session_id: Option<u128>,
152 storage_id: Option<GraphId>,
153 state: SyncResponderState,
154 bytes_sent: u64,
155 next_send: usize,
156 has: Vec<Address, COMMAND_SAMPLE_MAX>,
157 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
158 server_address: A,
159}
160
161impl<A: Serialize + Clone> SyncResponder<A> {
162 pub fn new(server_address: A) -> Self {
164 SyncResponder {
165 session_id: None,
166 storage_id: None,
167 state: SyncResponderState::New,
168 bytes_sent: 0,
169 next_send: 0,
170 has: Vec::new(),
171 to_send: Vec::new(),
172 server_address,
173 }
174 }
175
176 pub fn ready(&self) -> bool {
178 use SyncResponderState::*;
179 match self.state {
180 Reset | Start | Send => true,
181 New | Idle | Stopped => false,
182 }
183 }
184
185 pub fn poll(
188 &mut self,
189 target: &mut [u8],
190 provider: &mut impl StorageProvider,
191 response_cache: &mut PeerCache,
192 ) -> Result<usize, SyncError> {
193 use SyncResponderState as S;
194 let length = match self.state {
195 S::New | S::Idle | S::Stopped => {
196 return Err(SyncError::NotReady);
197 }
198 S::Start => {
199 let Some(storage_id) = self.storage_id else {
200 self.state = S::Reset;
201 bug!("poll called before storage_id was set");
202 };
203
204 let storage = match provider.get_storage(storage_id) {
205 Ok(s) => s,
206 Err(e) => {
207 self.state = S::Reset;
208 return Err(e.into());
209 }
210 };
211
212 self.state = S::Send;
213 for command in &self.has {
214 if let Some(cmd_loc) = storage.get_location(*command)? {
216 response_cache.add_command(storage, *command, cmd_loc)?;
217 }
218 }
219 self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
220
221 self.get_next(target, provider)?
222 }
223 S::Send => self.get_next(target, provider)?,
224 S::Reset => {
225 self.state = S::Stopped;
226 let message = SyncResponseMessage::EndSession {
227 session_id: self.session_id()?,
228 };
229 Self::write(target, message)?
230 }
231 };
232
233 Ok(length)
234 }
235
236 pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
238 if self.session_id.is_none() {
239 self.session_id = Some(message.session_id());
240 }
241 if self.session_id != Some(message.session_id()) {
242 return Err(SyncError::SessionMismatch);
243 }
244
245 match message {
246 SyncRequestMessage::SyncRequest {
247 storage_id,
248 max_bytes,
249 commands,
250 ..
251 } => {
252 self.state = SyncResponderState::Start;
253 self.storage_id = Some(storage_id);
254 self.bytes_sent = max_bytes;
255 self.to_send = Vec::new();
256 self.has = commands;
257 self.next_send = 0;
258 return Ok(());
259 }
260 SyncRequestMessage::RequestMissing { .. } => {
261 todo!()
262 }
263 SyncRequestMessage::SyncResume { .. } => {
264 todo!()
265 }
266 SyncRequestMessage::EndSession { .. } => {
267 self.state = SyncResponderState::Stopped;
268 }
269 };
270
271 Ok(())
272 }
273
274 fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
275 Ok(postcard::to_slice(&msg, target)?.len())
276 }
277
278 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
279 Ok(postcard::to_slice(&msg, target)?.len())
280 }
281
282 fn find_needed_segments(
283 commands: &[Address],
284 storage: &impl Storage,
285 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
286 let mut have_locations = vec::Vec::new(); for &addr in commands {
288 let Some(location) = storage.get_location(addr)? else {
289 continue;
293 };
294
295 have_locations.push(location);
296 }
297
298 let mut heads = vec::Vec::new();
299 heads.push(storage.get_head()?);
300
301 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
302
303 while !heads.is_empty() {
304 let current = mem::take(&mut heads);
305 'heads: for head in current {
306 let segment = storage.get_segment(head)?;
307 if segment.contains_any(&result) {
308 continue 'heads;
309 }
310
311 for &location in &have_locations {
312 if segment.contains(location) {
313 if location != segment.head_location() {
314 if result.is_full() {
315 result.pop_back();
316 }
317 result
318 .push_front(location)
319 .ok()
320 .assume("too many segments")?;
321 }
322 continue 'heads;
323 }
324 }
325 heads.extend(segment.prior());
326
327 if result.is_full() {
328 result.pop_back();
329 }
330
331 let location = segment.first_location();
332 result
333 .push_front(location)
334 .ok()
335 .assume("too many segments")?;
336 }
337 }
338 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
339 for l in result {
340 r.push(l).ok().assume("too many segments")?;
341 }
342 r.sort();
345 Ok(r)
346 }
347
348 fn get_next(
349 &mut self,
350 target: &mut [u8],
351 provider: &mut impl StorageProvider,
352 ) -> Result<usize, SyncError> {
353 if self.next_send >= self.to_send.len() {
354 self.state = SyncResponderState::Idle;
355 return Ok(0);
356 }
357 let (commands, command_data, index) = self.get_commands(provider)?;
358
359 let message = SyncResponseMessage::SyncResponse {
360 session_id: self.session_id()?,
361 index: self.next_send as u64,
362 commands,
363 };
364
365 self.next_send = index;
366
367 let length = Self::write(target, message)?;
368 let total_length = length
369 .checked_add(command_data.len())
370 .assume("length + command_data_length mustn't overflow")?;
371 target
372 .get_mut(length..total_length)
373 .assume("sync message fits in target")?
374 .copy_from_slice(&command_data);
375 Ok(total_length)
376 }
377
378 pub fn push(
381 &mut self,
382 target: &mut [u8],
383 provider: &mut impl StorageProvider,
384 ) -> Result<usize, SyncError> {
385 use SyncResponderState as S;
386 let Some(storage_id) = self.storage_id else {
387 self.state = S::Reset;
388 bug!("poll called before storage_id was set");
389 };
390
391 let storage = match provider.get_storage(storage_id) {
392 Ok(s) => s,
393 Err(e) => {
394 self.state = S::Reset;
395 return Err(e.into());
396 }
397 };
398 self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
399 let (commands, command_data, index) = self.get_commands(provider)?;
400 let mut length = 0;
401 if !commands.is_empty() {
402 let message = SyncType::Push {
403 message: SyncResponseMessage::SyncResponse {
404 session_id: self.session_id()?,
405 index: self.next_send as u64,
406 commands,
407 },
408 storage_id: self.storage_id.assume("storage id must exist")?,
409 address: self.server_address.clone(),
410 };
411 self.next_send = index;
412
413 length = Self::write_sync_type(target, message)?;
414 let total_length = length
415 .checked_add(command_data.len())
416 .assume("length + command_data_length mustn't overflow")?;
417 target
418 .get_mut(length..total_length)
419 .assume("sync message fits in target")?
420 .copy_from_slice(&command_data);
421 length = total_length;
422 }
423 Ok(length)
424 }
425
426 fn get_commands(
427 &mut self,
428 provider: &mut impl StorageProvider,
429 ) -> Result<
430 (
431 Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
432 Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
433 usize,
434 ),
435 SyncError,
436 > {
437 let Some(storage_id) = self.storage_id.as_ref() else {
438 self.state = SyncResponderState::Reset;
439 bug!("get_next called before storage_id was set");
440 };
441 let storage = match provider.get_storage(*storage_id) {
442 Ok(s) => s,
443 Err(e) => {
444 self.state = SyncResponderState::Reset;
445 return Err(e.into());
446 }
447 };
448 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
449 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
450 let mut index = self.next_send;
451 for i in self.next_send..self.to_send.len() {
452 if commands.is_full() {
453 break;
454 }
455 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
456 let Some(&location) = self.to_send.get(i) else {
457 self.state = SyncResponderState::Reset;
458 bug!("send index OOB");
459 };
460
461 let segment = storage
462 .get_segment(location)
463 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
464
465 let found = segment.get_from(location);
466
467 for command in &found {
468 let mut policy_length = 0;
469
470 if let Some(policy) = command.policy() {
471 policy_length = policy.len();
472 command_data
473 .extend_from_slice(policy)
474 .ok()
475 .assume("command_data is too large")?;
476 }
477
478 let bytes = command.bytes();
479 command_data
480 .extend_from_slice(bytes)
481 .ok()
482 .assume("command_data is too large")?;
483
484 let meta = CommandMeta {
485 id: command.id(),
486 priority: command.priority(),
487 parent: command.parent(),
488 policy_length: policy_length as u32,
489 length: bytes.len() as u32,
490 max_cut: command.max_cut()?,
491 };
492
493 commands
495 .push(meta)
496 .ok()
497 .assume("too many commands in segment")?;
498 if commands.is_full() {
499 break;
500 }
501 }
502 }
503 Ok((commands, command_data, index))
504 }
505
506 fn session_id(&self) -> Result<u128, SyncError> {
507 Ok(self.session_id.assume("session id is set")?)
508 }
509}