aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10    SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13    StorageError, SyncType,
14    command::{Address, CmdId, Command as _},
15    storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub const fn new() -> Self {
25        Self { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42
43        let mut retain_head = |request_head: &Address, new_head: Location| {
44            let new_head_seg = storage.get_segment(new_head)?;
45            let req_head_loc = storage
46                .get_location(*request_head)?
47                .assume("location must exist")?;
48            let req_head_seg = storage.get_segment(req_head_loc)?;
49            if request_head.id
50                == new_head_seg
51                    .get_command(new_head)
52                    .assume("location must exist")?
53                    .address()?
54                    .id
55            {
56                add_command = false;
57            }
58            // If the new head is an ancestor of the request head, don't add it
59            if (new_head.same_segment(req_head_loc) && new_head.command <= req_head_loc.command)
60                || storage.is_ancestor(new_head, &req_head_seg)?
61            {
62                add_command = false;
63            }
64            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65        };
66        self.heads
67            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68        if add_command && !self.heads.is_full() {
69            self.heads
70                .push(command)
71                .ok()
72                .assume("command locations should not be full")?;
73        }
74
75        Ok(())
76    }
77}
78
79// TODO: Use compile-time args. This initial definition results in this clippy warning:
80// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
81// As the buffer consts will be compile-time variables in the future, we will be
82// able to tune these buffers for smaller footprints. Right now, this enum is not
83// suitable for small devices (`SyncResponse` is 14448 bytes).
84/// Messages sent from the responder to the requester.
85#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88    /// Sent in response to a `SyncRequest`
89    SyncResponse {
90        /// A random-value produced by a cryptographically secure RNG.
91        session_id: u128,
92        /// If the responder intends to send a value of command bytes
93        /// greater than the responder's configured maximum, the responder
94        /// will send more than one `SyncResponse`. The first message has an
95        /// index of 1, and each following is incremented.
96        response_index: u64,
97        /// Commands that the responder believes the requester does not have.
98        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99    },
100
101    /// End a sync session if `SyncRequest.max_bytes` has been reached or
102    /// there are no remaining commands to send.
103    SyncEnd {
104        /// A random-value produced by a cryptographically secure RNG
105        /// corresponding to the `session_id` in the initial `SyncRequest`.
106        session_id: u128,
107        /// Largest index of any `SyncResponse`
108        max_index: u64,
109        /// Set `true` if this message was sent due to reaching the `max_bytes`
110        /// budget.
111        remaining: bool,
112    },
113
114    /// Message sent by a responder after a sync has been completed, but before
115    /// the session has ended, if it has new commands in it's graph. If a
116    /// requester wishes to respond to this message, it should do so with a
117    /// new `SyncRequest`. This message may use the existing `session_id`.
118    Offer {
119        /// A random-value produced by a cryptographically secure RNG
120        /// corresponding to the `session_id` in the initial `SyncRequest`.
121        session_id: u128,
122        /// Head of the branch the responder wishes to send.
123        head: CmdId,
124    },
125
126    /// Message sent by either requester or responder to indicate the session
127    /// has been terminated or the `session_id` is no longer valid.
128    EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132    pub fn session_id(&self) -> u128 {
133        match self {
134            Self::SyncResponse { session_id, .. } => *session_id,
135            Self::SyncEnd { session_id, .. } => *session_id,
136            Self::Offer { session_id, .. } => *session_id,
137            Self::EndSession { session_id, .. } => *session_id,
138        }
139    }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144    #[default]
145    New,
146    Start,
147    Send,
148    Idle,
149    Reset,
150    Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder<A> {
155    session_id: Option<u128>,
156    storage_id: Option<GraphId>,
157    state: SyncResponderState,
158    bytes_sent: u64,
159    next_send: usize,
160    message_index: usize,
161    has: Vec<Address, COMMAND_SAMPLE_MAX>,
162    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163    server_address: A,
164}
165
166impl<A: Serialize + Clone> SyncResponder<A> {
167    /// Create a new [`SyncResponder`].
168    pub fn new(server_address: A) -> Self {
169        Self {
170            session_id: None,
171            storage_id: None,
172            state: SyncResponderState::New,
173            bytes_sent: 0,
174            next_send: 0,
175            message_index: 0,
176            has: Vec::new(),
177            to_send: Vec::new(),
178            server_address,
179        }
180    }
181
182    /// Returns true if [`Self::poll`] would produce a message.
183    pub fn ready(&self) -> bool {
184        use SyncResponderState::*;
185        match self.state {
186            Reset | Start | Send => true, // TODO(chip): For Send, check whether to_send has anything to send
187            New | Idle | Stopped => false,
188        }
189    }
190
191    /// Write a sync message in to the target buffer. Returns the number
192    /// of bytes written.
193    pub fn poll(
194        &mut self,
195        target: &mut [u8],
196        provider: &mut impl StorageProvider,
197        response_cache: &mut PeerCache,
198    ) -> Result<usize, SyncError> {
199        // TODO(chip): return a status enum instead of usize
200        use SyncResponderState as S;
201        let length = match self.state {
202            S::New | S::Idle | S::Stopped => {
203                return Err(SyncError::NotReady); // TODO(chip): return Ok(NotReady)
204            }
205            S::Start => {
206                let Some(storage_id) = self.storage_id else {
207                    self.state = S::Reset;
208                    bug!("poll called before storage_id was set");
209                };
210
211                let storage = match provider.get_storage(storage_id) {
212                    Ok(s) => s,
213                    Err(e) => {
214                        self.state = S::Reset;
215                        return Err(e.into());
216                    }
217                };
218
219                self.state = S::Send;
220                for command in &self.has {
221                    // We only need to check commands that are a part of our graph.
222                    if let Some(cmd_loc) = storage.get_location(*command)? {
223                        response_cache.add_command(storage, *command, cmd_loc)?;
224                    }
225                }
226                self.to_send = Self::find_needed_segments(&self.has, storage)?;
227
228                self.get_next(target, provider)?
229            }
230            S::Send => self.get_next(target, provider)?,
231            S::Reset => {
232                self.state = S::Stopped;
233                let message = SyncResponseMessage::EndSession {
234                    session_id: self.session_id()?,
235                };
236                Self::write(target, message)?
237            }
238        };
239
240        Ok(length)
241    }
242
243    /// Receive a sync message. Updates the responders state for later polling.
244    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
245        if self.session_id.is_none() {
246            self.session_id = Some(message.session_id());
247        }
248        if self.session_id != Some(message.session_id()) {
249            return Err(SyncError::SessionMismatch);
250        }
251
252        match message {
253            SyncRequestMessage::SyncRequest {
254                storage_id,
255                max_bytes,
256                commands,
257                ..
258            } => {
259                self.state = SyncResponderState::Start;
260                self.storage_id = Some(storage_id);
261                self.bytes_sent = max_bytes;
262                self.to_send = Vec::new();
263                self.has = commands;
264                self.next_send = 0;
265                return Ok(());
266            }
267            SyncRequestMessage::RequestMissing { .. } => {
268                todo!()
269            }
270            SyncRequestMessage::SyncResume { .. } => {
271                todo!()
272            }
273            SyncRequestMessage::EndSession { .. } => {
274                self.state = SyncResponderState::Stopped;
275            }
276        }
277
278        Ok(())
279    }
280
281    fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
282        Ok(postcard::to_slice(&msg, target)?.len())
283    }
284
285    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
286        Ok(postcard::to_slice(&msg, target)?.len())
287    }
288
289    /// This (probably) returns a Vec of segment addresses where the head of each segment is
290    /// not the ancestor of any samples we have been sent. If that is longer than
291    /// SEGMENT_BUFFER_MAX, it contains the oldest segment heads where that holds.
292    fn find_needed_segments(
293        commands: &[Address],
294        storage: &impl Storage,
295    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
296        let mut have_locations = vec::Vec::new(); //BUG: not constant size
297        for &addr in commands {
298            // Note: We could use things we don't have as a hint to
299            // know we should perform a sync request.
300            if let Some(location) = storage.get_location(addr)? {
301                have_locations.push(location);
302            }
303        }
304
305        // Filter out locations that are ancestors of other locations in the list.
306        // If location A is an ancestor of location B, we only need to keep B since
307        // having B implies having A and all its ancestors.
308        // Iterate backwards so we can safely remove items
309        for i in (0..have_locations.len()).rev() {
310            let location_a = have_locations[i];
311            let mut is_ancestor_of_other = false;
312            for &location_b in &have_locations {
313                if location_a != location_b {
314                    let segment_b = storage.get_segment(location_b)?;
315                    if location_a.same_segment(location_b)
316                        && location_a.command <= location_b.command
317                        || storage.is_ancestor(location_a, &segment_b)?
318                    {
319                        is_ancestor_of_other = true;
320                        break;
321                    }
322                }
323            }
324            if is_ancestor_of_other {
325                have_locations.remove(i);
326            }
327        }
328
329        let mut heads = vec::Vec::new();
330        heads.push(storage.get_head()?);
331
332        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
333
334        while !heads.is_empty() {
335            let current = mem::take(&mut heads);
336            'heads: for head in current {
337                let segment = storage.get_segment(head)?;
338                // If the segment is already in the result, skip it
339                if segment.contains_any(&result) {
340                    continue 'heads;
341                }
342
343                // Check if the current segment head is an ancestor of any location in have_locations.
344                // If so, stop traversing backward from this point since the requester already has
345                // this command and all its ancestors.
346                for &have_location in &have_locations {
347                    let have_segment = storage.get_segment(have_location)?;
348                    if storage.is_ancestor(head, &have_segment)? {
349                        continue 'heads;
350                    }
351                }
352
353                // If the requester has any commands in this segment, send from the next command
354                if let Some(latest_loc) = have_locations
355                    .iter()
356                    .filter(|&&location| segment.contains(location))
357                    .max_by_key(|&&location| location.command)
358                {
359                    let next_command = latest_loc
360                        .command
361                        .checked_add(1)
362                        .assume("command + 1 mustn't overflow")?;
363                    let next_location = Location {
364                        segment: head.segment,
365                        command: next_command,
366                    };
367
368                    let head_loc = segment.head_location();
369                    if next_location.command > head_loc.command {
370                        continue 'heads;
371                    }
372                    if result.is_full() {
373                        result.pop_back();
374                    }
375                    result
376                        .push_front(next_location)
377                        .ok()
378                        .assume("too many segments")?;
379                    continue 'heads;
380                }
381
382                heads.extend(segment.prior());
383
384                if result.is_full() {
385                    result.pop_back();
386                }
387
388                let location = segment.first_location();
389                result
390                    .push_front(location)
391                    .ok()
392                    .assume("too many segments")?;
393            }
394        }
395        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
396        for l in result {
397            r.push(l).ok().assume("too many segments")?;
398        }
399        // Order segments to ensure that a segment isn't received before its
400        // ancestor segments.
401        r.sort();
402        Ok(r)
403    }
404
405    fn get_next(
406        &mut self,
407        target: &mut [u8],
408        provider: &mut impl StorageProvider,
409    ) -> Result<usize, SyncError> {
410        if self.next_send >= self.to_send.len() {
411            self.state = SyncResponderState::Idle;
412            let message = SyncResponseMessage::SyncEnd {
413                session_id: self.session_id()?,
414                max_index: self.message_index as u64,
415                remaining: false,
416            };
417            let length = Self::write(target, message)?;
418            return Ok(length);
419        }
420
421        let (commands, command_data, next_send) = self.get_commands(provider)?;
422
423        let message = SyncResponseMessage::SyncResponse {
424            session_id: self.session_id()?,
425            response_index: self.message_index as u64,
426            commands,
427        };
428        self.message_index = self
429            .message_index
430            .checked_add(1)
431            .assume("message_index overflow")?;
432        self.next_send = next_send;
433
434        let length = Self::write(target, message)?;
435        let total_length = length
436            .checked_add(command_data.len())
437            .assume("length + command_data_length mustn't overflow")?;
438        target
439            .get_mut(length..total_length)
440            .assume("sync message fits in target")?
441            .copy_from_slice(&command_data);
442        Ok(total_length)
443    }
444
445    /// Writes a sync push message to target for the peer. The message will
446    /// contain any commands that are after the commands in response_cache.
447    pub fn push(
448        &mut self,
449        target: &mut [u8],
450        provider: &mut impl StorageProvider,
451    ) -> Result<usize, SyncError> {
452        use SyncResponderState as S;
453        let Some(storage_id) = self.storage_id else {
454            self.state = S::Reset;
455            bug!("poll called before storage_id was set");
456        };
457
458        let storage = match provider.get_storage(storage_id) {
459            Ok(s) => s,
460            Err(e) => {
461                self.state = S::Reset;
462                return Err(e.into());
463            }
464        };
465        self.to_send = Self::find_needed_segments(&self.has, storage)?;
466        let (commands, command_data, next_send) = self.get_commands(provider)?;
467        let mut length = 0;
468        if !commands.is_empty() {
469            let message = SyncType::Push {
470                message: SyncResponseMessage::SyncResponse {
471                    session_id: self.session_id()?,
472                    response_index: self.message_index as u64,
473                    commands,
474                },
475                storage_id: self.storage_id.assume("storage id must exist")?,
476                address: self.server_address.clone(),
477            };
478            self.message_index = self
479                .message_index
480                .checked_add(1)
481                .assume("message_index increment overflow")?;
482            self.next_send = next_send;
483
484            length = Self::write_sync_type(target, message)?;
485            let total_length = length
486                .checked_add(command_data.len())
487                .assume("length + command_data_length mustn't overflow")?;
488            target
489                .get_mut(length..total_length)
490                .assume("sync message fits in target")?
491                .copy_from_slice(&command_data);
492            length = total_length;
493        }
494        Ok(length)
495    }
496
497    fn get_commands(
498        &mut self,
499        provider: &mut impl StorageProvider,
500    ) -> Result<
501        (
502            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
503            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
504            usize,
505        ),
506        SyncError,
507    > {
508        let Some(storage_id) = self.storage_id.as_ref() else {
509            self.state = SyncResponderState::Reset;
510            bug!("get_next called before storage_id was set");
511        };
512        let storage = match provider.get_storage(*storage_id) {
513            Ok(s) => s,
514            Err(e) => {
515                self.state = SyncResponderState::Reset;
516                return Err(e.into());
517            }
518        };
519        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
520        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
521        let mut index = self.next_send;
522        for i in self.next_send..self.to_send.len() {
523            if commands.is_full() {
524                break;
525            }
526            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
527            let Some(&location) = self.to_send.get(i) else {
528                self.state = SyncResponderState::Reset;
529                bug!("send index OOB");
530            };
531
532            let segment = storage
533                .get_segment(location)
534                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
535
536            let found = segment.get_from(location);
537
538            for command in &found {
539                let mut policy_length = 0;
540
541                if let Some(policy) = command.policy() {
542                    policy_length = policy.len();
543                    command_data
544                        .extend_from_slice(policy)
545                        .ok()
546                        .assume("command_data is too large")?;
547                }
548
549                let bytes = command.bytes();
550                command_data
551                    .extend_from_slice(bytes)
552                    .ok()
553                    .assume("command_data is too large")?;
554
555                let max_cut = command.max_cut()?;
556                let meta = CommandMeta {
557                    id: command.id(),
558                    priority: command.priority(),
559                    parent: command.parent(),
560                    policy_length: policy_length as u32,
561                    length: bytes.len() as u32,
562                    max_cut,
563                };
564
565                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
566                commands
567                    .push(meta)
568                    .ok()
569                    .assume("too many commands in segment")?;
570                if commands.is_full() {
571                    break;
572                }
573            }
574        }
575        Ok((commands, command_data, index))
576    }
577
578    fn session_id(&self) -> Result<u128, SyncError> {
579        Ok(self.session_id.assume("session id is set")?)
580    }
581}