1use alloc::vec;
2use core::mem;
3
4use aranya_buggy::{bug, BugExt};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9 requester::SyncRequestMessage, CommandMeta, SyncError, COMMAND_RESPONSE_MAX,
10 COMMAND_SAMPLE_MAX, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX, SEGMENT_BUFFER_MAX,
11};
12use crate::{
13 command::{Address, Command, CommandId},
14 storage::{GraphId, Location, Segment, Storage, StorageProvider},
15 StorageError,
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20 heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24 pub fn new() -> Self {
25 PeerCache { heads: Vec::new() }
26 }
27
28 pub fn heads(&self) -> &[Address] {
29 &self.heads
30 }
31
32 pub fn add_command<S>(
33 &mut self,
34 storage: &mut S,
35 command: Address,
36 cmd_loc: Location,
37 ) -> Result<(), StorageError>
38 where
39 S: Storage,
40 {
41 let mut add_command = true;
42 let mut retain_head = |request_head: &Address, new_head: Location| {
43 let new_head_seg = storage.get_segment(new_head)?;
44 let req_head_loc = storage
45 .get_location(*request_head)?
46 .assume("location must exist")?;
47 let req_head_seg = storage.get_segment(req_head_loc)?;
48 if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49 if request_head.id == new_head_command.address()?.id {
50 add_command = false;
51 }
52 }
53 if storage.is_ancestor(new_head, &req_head_seg)? {
54 add_command = false;
55 }
56 Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57 };
58 self.heads
59 .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60 if add_command && !self.heads.is_full() {
61 self.heads
62 .push(command)
63 .ok()
64 .assume("command locations should not be full")?;
65 };
66 Ok(())
67 }
68}
69
70#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79 SyncResponse {
81 session_id: u128,
83 index: u64,
88 commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90 },
91
92 SyncEnd {
95 session_id: u128,
98 max_index: u64,
100 remaining: bool,
103 },
104
105 Offer {
110 session_id: u128,
113 head: CommandId,
115 },
116
117 EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123 pub fn session_id(&self) -> u128 {
124 match self {
125 Self::SyncResponse { session_id, .. } => *session_id,
126 Self::SyncEnd { session_id, .. } => *session_id,
127 Self::Offer { session_id, .. } => *session_id,
128 Self::EndSession { session_id, .. } => *session_id,
129 }
130 }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135 New,
136 Start,
137 Send,
138 Idle,
139 Reset,
140 Stopped,
141}
142
143impl Default for SyncResponderState {
144 fn default() -> Self {
145 Self::New
146 }
147}
148
149#[derive(Default)]
150pub struct SyncResponder {
151 session_id: Option<u128>,
152 storage_id: Option<GraphId>,
153 state: SyncResponderState,
154 bytes_sent: u64,
155 next_send: usize,
156 has: Vec<Address, COMMAND_SAMPLE_MAX>,
157 to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
158}
159
160impl SyncResponder {
161 pub fn new() -> Self {
163 SyncResponder {
164 session_id: None,
165 storage_id: None,
166 state: SyncResponderState::New,
167 bytes_sent: 0,
168 next_send: 0,
169 has: Vec::new(),
170 to_send: Vec::new(),
171 }
172 }
173
174 pub fn ready(&self) -> bool {
176 use SyncResponderState::*;
177 match self.state {
178 Reset | Start | Send => true,
179 New | Idle | Stopped => false,
180 }
181 }
182
183 pub fn poll(
186 &mut self,
187 target: &mut [u8],
188 provider: &mut impl StorageProvider,
189 response_cache: &mut PeerCache,
190 ) -> Result<usize, SyncError> {
191 use SyncResponderState as S;
192 let length = match self.state {
193 S::New | S::Idle | S::Stopped => {
194 return Err(SyncError::NotReady);
195 }
196 S::Start => {
197 let Some(storage_id) = self.storage_id else {
198 self.state = S::Reset;
199 bug!("poll called before storage_id was set");
200 };
201
202 let storage = match provider.get_storage(storage_id) {
203 Ok(s) => s,
204 Err(e) => {
205 self.state = S::Reset;
206 return Err(e.into());
207 }
208 };
209
210 self.state = S::Send;
211 for command in &self.has {
212 if let Some(cmd_loc) = storage.get_location(*command)? {
214 response_cache.add_command(storage, *command, cmd_loc)?;
215 }
216 }
217 self.to_send = SyncResponder::find_needed_segments(&self.has, storage)?;
218
219 self.get_next(target, provider)?
220 }
221 S::Send => self.get_next(target, provider)?,
222 S::Reset => {
223 self.state = S::Stopped;
224 let message = SyncResponseMessage::EndSession {
225 session_id: self.session_id()?,
226 };
227 Self::write(target, message)?
228 }
229 };
230
231 Ok(length)
232 }
233
234 pub fn receive(&mut self, data: &[u8]) -> Result<(), SyncError> {
236 let message: SyncRequestMessage = postcard::from_bytes(data)?;
237 if self.session_id.is_none() {
238 self.session_id = Some(message.session_id());
239 }
240 if self.session_id != Some(message.session_id()) {
241 return Err(SyncError::SessionMismatch);
242 }
243
244 match message {
245 SyncRequestMessage::SyncRequest {
246 storage_id,
247 max_bytes,
248 commands,
249 ..
250 } => {
251 self.state = SyncResponderState::Start;
252 self.storage_id = Some(storage_id);
253 self.bytes_sent = max_bytes;
254 self.to_send = Vec::new();
255 self.has = commands;
256 self.next_send = 0;
257 return Ok(());
258 }
259 SyncRequestMessage::RequestMissing { .. } => {
260 todo!()
261 }
262 SyncRequestMessage::SyncResume { .. } => {
263 todo!()
264 }
265 SyncRequestMessage::EndSession { .. } => {
266 self.state = SyncResponderState::Stopped;
267 }
268 };
269
270 Ok(())
271 }
272
273 fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
274 Ok(postcard::to_slice(&msg, target)?.len())
275 }
276
277 fn find_needed_segments(
278 commands: &[Address],
279 storage: &impl Storage,
280 ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
281 let mut have_locations = vec::Vec::new(); for &addr in commands {
283 let Some(location) = storage.get_location(addr)? else {
284 continue;
288 };
289
290 have_locations.push(location);
291 }
292
293 let mut heads = vec::Vec::new();
294 heads.push(storage.get_head()?);
295
296 let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
297
298 while !heads.is_empty() {
299 let current = mem::take(&mut heads);
300 'heads: for head in current {
301 let segment = storage.get_segment(head)?;
302 if segment.contains_any(&result) {
303 continue 'heads;
304 }
305
306 for &location in &have_locations {
307 if segment.contains(location) {
308 if location != segment.head_location() {
309 if result.is_full() {
310 result.pop_back();
311 }
312 result
313 .push_front(location)
314 .ok()
315 .assume("too many segments")?;
316 }
317 continue 'heads;
318 }
319 }
320 heads.extend(segment.prior());
321
322 if result.is_full() {
323 result.pop_back();
324 }
325
326 let location = segment.first_location();
327 result
328 .push_front(location)
329 .ok()
330 .assume("too many segments")?;
331 }
332 }
333 let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
334 for l in result {
335 r.push(l).ok().assume("too many segments")?;
336 }
337 r.sort();
340 Ok(r)
341 }
342
343 fn get_next(
344 &mut self,
345 target: &mut [u8],
346 provider: &mut impl StorageProvider,
347 ) -> Result<usize, SyncError> {
348 let Some(storage_id) = self.storage_id else {
349 self.state = SyncResponderState::Reset;
350 bug!("get_next called before storage_id was set");
351 };
352
353 let storage = match provider.get_storage(storage_id) {
354 Ok(s) => s,
355 Err(e) => {
356 self.state = SyncResponderState::Reset;
357 return Err(e.into());
358 }
359 };
360
361 if self.next_send >= self.to_send.len() {
362 self.state = SyncResponderState::Idle;
363 return Ok(0);
364 }
365
366 let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
367 let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
368 let mut index = self.next_send;
369
370 for i in self.next_send..self.to_send.len() {
371 if commands.is_full() {
372 break;
373 }
374 index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
375 let Some(&location) = self.to_send.get(i) else {
376 self.state = SyncResponderState::Reset;
377 bug!("send index OOB");
378 };
379
380 let segment = storage
381 .get_segment(location)
382 .inspect_err(|_| self.state = SyncResponderState::Reset)?;
383
384 let found = segment.get_from(location);
385
386 for command in &found {
387 let mut policy_length = 0;
388
389 if let Some(policy) = command.policy() {
390 policy_length = policy.len();
391 command_data
392 .extend_from_slice(policy)
393 .ok()
394 .assume("command_data is too large")?;
395 }
396
397 let bytes = command.bytes();
398 command_data
399 .extend_from_slice(bytes)
400 .ok()
401 .assume("command_data is too large")?;
402
403 let meta = CommandMeta {
404 id: command.id(),
405 priority: command.priority(),
406 parent: command.parent(),
407 policy_length: policy_length as u32,
408 length: bytes.len() as u32,
409 max_cut: command.max_cut()?,
410 };
411
412 commands
414 .push(meta)
415 .ok()
416 .assume("too many commands in segment")?;
417 if commands.is_full() {
418 break;
419 }
420 }
421 }
422
423 let message = SyncResponseMessage::SyncResponse {
424 session_id: self.session_id()?,
425 index: self.next_send as u64,
426 commands,
427 };
428
429 self.next_send = index;
430
431 let mut length = Self::write(target, message)?;
432 let total_length = length
433 .checked_add(command_data.len())
434 .assume("length + command_data_length mustn't overflow")?;
435 target
436 .get_mut(length..total_length)
437 .assume("sync message fits in target")?
438 .copy_from_slice(&command_data);
439 length = total_length;
440 Ok(length)
441 }
442
443 fn session_id(&self) -> Result<u128, SyncError> {
444 Ok(self.session_id.assume("session id is set")?)
445 }
446}