Skip to main content

aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10    SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13    StorageError, SyncType,
14    command::{Address, CmdId, Command as _},
15    storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub const fn new() -> Self {
25        Self { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42
43        let mut retain_head = |request_head: &Address, new_head: Location| {
44            let new_head_seg = storage.get_segment(new_head)?;
45            let req_head_loc = storage
46                .get_location(*request_head)?
47                .assume("location must exist")?;
48            let req_head_seg = storage.get_segment(req_head_loc)?;
49            if request_head.id
50                == new_head_seg
51                    .get_command(new_head)
52                    .assume("location must exist")?
53                    .address()?
54                    .id
55            {
56                add_command = false;
57            }
58            // If the new head is an ancestor of the request head, don't add it
59            if (new_head.same_segment(req_head_loc) && new_head.command <= req_head_loc.command)
60                || storage.is_ancestor(new_head, &req_head_seg)?
61            {
62                add_command = false;
63            }
64            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65        };
66        self.heads
67            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68        if add_command && !self.heads.is_full() {
69            self.heads
70                .push(command)
71                .ok()
72                .assume("command locations should not be full")?;
73        }
74
75        Ok(())
76    }
77}
78
79// TODO: Use compile-time args. This initial definition results in this clippy warning:
80// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
81// As the buffer consts will be compile-time variables in the future, we will be
82// able to tune these buffers for smaller footprints. Right now, this enum is not
83// suitable for small devices (`SyncResponse` is 14448 bytes).
84/// Messages sent from the responder to the requester.
85#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88    /// Sent in response to a `SyncRequest`
89    SyncResponse {
90        /// A random-value produced by a cryptographically secure RNG.
91        session_id: u128,
92        /// If the responder intends to send a value of command bytes
93        /// greater than the responder's configured maximum, the responder
94        /// will send more than one `SyncResponse`. The first message has an
95        /// index of 1, and each following is incremented.
96        response_index: u64,
97        /// Commands that the responder believes the requester does not have.
98        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99    },
100
101    /// End a sync session if `SyncRequest.max_bytes` has been reached or
102    /// there are no remaining commands to send.
103    SyncEnd {
104        /// A random-value produced by a cryptographically secure RNG
105        /// corresponding to the `session_id` in the initial `SyncRequest`.
106        session_id: u128,
107        /// Largest index of any `SyncResponse`
108        max_index: u64,
109        /// Set `true` if this message was sent due to reaching the `max_bytes`
110        /// budget.
111        remaining: bool,
112    },
113
114    /// Message sent by a responder after a sync has been completed, but before
115    /// the session has ended, if it has new commands in it's graph. If a
116    /// requester wishes to respond to this message, it should do so with a
117    /// new `SyncRequest`. This message may use the existing `session_id`.
118    Offer {
119        /// A random-value produced by a cryptographically secure RNG
120        /// corresponding to the `session_id` in the initial `SyncRequest`.
121        session_id: u128,
122        /// Head of the branch the responder wishes to send.
123        head: CmdId,
124    },
125
126    /// Message sent by either requester or responder to indicate the session
127    /// has been terminated or the `session_id` is no longer valid.
128    EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132    pub fn session_id(&self) -> u128 {
133        match self {
134            Self::SyncResponse { session_id, .. } => *session_id,
135            Self::SyncEnd { session_id, .. } => *session_id,
136            Self::Offer { session_id, .. } => *session_id,
137            Self::EndSession { session_id, .. } => *session_id,
138        }
139    }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144    #[default]
145    New,
146    Start,
147    Send,
148    Idle,
149    Reset,
150    Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder {
155    session_id: Option<u128>,
156    graph_id: Option<GraphId>,
157    state: SyncResponderState,
158    bytes_sent: u64,
159    next_send: usize,
160    message_index: usize,
161    has: Vec<Address, COMMAND_SAMPLE_MAX>,
162    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163}
164
165impl SyncResponder {
166    /// Create a new [`SyncResponder`].
167    pub fn new() -> Self {
168        Self {
169            session_id: None,
170            graph_id: None,
171            state: SyncResponderState::New,
172            bytes_sent: 0,
173            next_send: 0,
174            message_index: 0,
175            has: Vec::new(),
176            to_send: Vec::new(),
177        }
178    }
179
180    /// Returns true if [`Self::poll`] would produce a message.
181    pub fn ready(&self) -> bool {
182        use SyncResponderState::*;
183        match self.state {
184            Reset | Start | Send => true, // TODO(chip): For Send, check whether to_send has anything to send
185            New | Idle | Stopped => false,
186        }
187    }
188
189    /// Write a sync message in to the target buffer. Returns the number
190    /// of bytes written.
191    pub fn poll(
192        &mut self,
193        target: &mut [u8],
194        provider: &mut impl StorageProvider,
195        response_cache: &mut PeerCache,
196    ) -> Result<usize, SyncError> {
197        // TODO(chip): return a status enum instead of usize
198        use SyncResponderState as S;
199        let length = match self.state {
200            S::New | S::Idle | S::Stopped => {
201                return Err(SyncError::NotReady); // TODO(chip): return Ok(NotReady)
202            }
203            S::Start => {
204                let Some(graph_id) = self.graph_id else {
205                    self.state = S::Reset;
206                    bug!("poll called before graph_id was set");
207                };
208
209                let storage = match provider.get_storage(graph_id) {
210                    Ok(s) => s,
211                    Err(e) => {
212                        self.state = S::Reset;
213                        return Err(e.into());
214                    }
215                };
216
217                self.state = S::Send;
218                for command in &self.has {
219                    // We only need to check commands that are a part of our graph.
220                    if let Some(cmd_loc) = storage.get_location(*command)? {
221                        response_cache.add_command(storage, *command, cmd_loc)?;
222                    }
223                }
224                self.to_send = Self::find_needed_segments(&self.has, storage)?;
225
226                self.get_next(target, provider)?
227            }
228            S::Send => self.get_next(target, provider)?,
229            S::Reset => {
230                self.state = S::Stopped;
231                let message = SyncResponseMessage::EndSession {
232                    session_id: self.session_id()?,
233                };
234                Self::write(target, message)?
235            }
236        };
237
238        Ok(length)
239    }
240
241    /// Receive a sync message. Updates the responders state for later polling.
242    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
243        if self.session_id.is_none() {
244            self.session_id = Some(message.session_id());
245        }
246        if self.session_id != Some(message.session_id()) {
247            return Err(SyncError::SessionMismatch);
248        }
249
250        match message {
251            SyncRequestMessage::SyncRequest {
252                graph_id,
253                max_bytes,
254                commands,
255                ..
256            } => {
257                self.state = SyncResponderState::Start;
258                self.graph_id = Some(graph_id);
259                self.bytes_sent = max_bytes;
260                self.to_send = Vec::new();
261                self.has = commands;
262                self.next_send = 0;
263                return Ok(());
264            }
265            SyncRequestMessage::RequestMissing { .. } => {
266                todo!()
267            }
268            SyncRequestMessage::SyncResume { .. } => {
269                todo!()
270            }
271            SyncRequestMessage::EndSession { .. } => {
272                self.state = SyncResponderState::Stopped;
273            }
274        }
275
276        Ok(())
277    }
278
279    fn write_sync_type(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
280        Ok(postcard::to_slice(&msg, target)?.len())
281    }
282
283    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
284        Ok(postcard::to_slice(&msg, target)?.len())
285    }
286
287    /// This (probably) returns a Vec of segment addresses where the head of each segment is
288    /// not the ancestor of any samples we have been sent. If that is longer than
289    /// SEGMENT_BUFFER_MAX, it contains the oldest segment heads where that holds.
290    fn find_needed_segments(
291        commands: &[Address],
292        storage: &impl Storage,
293    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
294        let mut have_locations = vec::Vec::new(); //BUG: not constant size
295        for &addr in commands {
296            // Note: We could use things we don't have as a hint to
297            // know we should perform a sync request.
298            if let Some(location) = storage.get_location(addr)? {
299                have_locations.push(location);
300            }
301        }
302
303        // Filter out locations that are ancestors of other locations in the list.
304        // If location A is an ancestor of location B, we only need to keep B since
305        // having B implies having A and all its ancestors.
306        // Iterate backwards so we can safely remove items
307        for i in (0..have_locations.len()).rev() {
308            let location_a = have_locations[i];
309            let mut is_ancestor_of_other = false;
310            for &location_b in &have_locations {
311                if location_a != location_b {
312                    let segment_b = storage.get_segment(location_b)?;
313                    if location_a.same_segment(location_b)
314                        && location_a.command <= location_b.command
315                        || storage.is_ancestor(location_a, &segment_b)?
316                    {
317                        is_ancestor_of_other = true;
318                        break;
319                    }
320                }
321            }
322            if is_ancestor_of_other {
323                have_locations.remove(i);
324            }
325        }
326
327        let mut heads = vec::Vec::new();
328        heads.push(storage.get_head()?);
329
330        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
331
332        while !heads.is_empty() {
333            let current = mem::take(&mut heads);
334            'heads: for head in current {
335                let segment = storage.get_segment(head)?;
336                // If the segment is already in the result, skip it
337                if segment.contains_any(&result) {
338                    continue 'heads;
339                }
340
341                // Check if the current segment head is an ancestor of any location in have_locations.
342                // If so, stop traversing backward from this point since the requester already has
343                // this command and all its ancestors.
344                for &have_location in &have_locations {
345                    let have_segment = storage.get_segment(have_location)?;
346                    if storage.is_ancestor(head, &have_segment)? {
347                        continue 'heads;
348                    }
349                }
350
351                // If the requester has any commands in this segment, send from the next command
352                if let Some(latest_loc) = have_locations
353                    .iter()
354                    .filter(|&&location| segment.contains(location))
355                    .max_by_key(|&&location| location.command)
356                {
357                    let next_command = latest_loc
358                        .command
359                        .checked_add(1)
360                        .assume("command + 1 mustn't overflow")?;
361                    let next_location = Location {
362                        segment: head.segment,
363                        command: next_command,
364                    };
365
366                    let head_loc = segment.head_location();
367                    if next_location.command > head_loc.command {
368                        continue 'heads;
369                    }
370                    if result.is_full() {
371                        result.pop_back();
372                    }
373                    result
374                        .push_front(next_location)
375                        .ok()
376                        .assume("too many segments")?;
377                    continue 'heads;
378                }
379
380                heads.extend(segment.prior());
381
382                if result.is_full() {
383                    result.pop_back();
384                }
385
386                let location = segment.first_location();
387                result
388                    .push_front(location)
389                    .ok()
390                    .assume("too many segments")?;
391            }
392        }
393        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
394        for l in result {
395            r.push(l).ok().assume("too many segments")?;
396        }
397        // Order segments to ensure that a segment isn't received before its
398        // ancestor segments.
399        r.sort();
400        Ok(r)
401    }
402
403    fn get_next(
404        &mut self,
405        target: &mut [u8],
406        provider: &mut impl StorageProvider,
407    ) -> Result<usize, SyncError> {
408        if self.next_send >= self.to_send.len() {
409            self.state = SyncResponderState::Idle;
410            let message = SyncResponseMessage::SyncEnd {
411                session_id: self.session_id()?,
412                max_index: self.message_index as u64,
413                remaining: false,
414            };
415            let length = Self::write(target, message)?;
416            return Ok(length);
417        }
418
419        let (commands, command_data, next_send) = self.get_commands(provider)?;
420
421        let message = SyncResponseMessage::SyncResponse {
422            session_id: self.session_id()?,
423            response_index: self.message_index as u64,
424            commands,
425        };
426        self.message_index = self
427            .message_index
428            .checked_add(1)
429            .assume("message_index overflow")?;
430        self.next_send = next_send;
431
432        let length = Self::write(target, message)?;
433        let total_length = length
434            .checked_add(command_data.len())
435            .assume("length + command_data_length mustn't overflow")?;
436        target
437            .get_mut(length..total_length)
438            .assume("sync message fits in target")?
439            .copy_from_slice(&command_data);
440        Ok(total_length)
441    }
442
443    /// Writes a sync push message to target for the peer. The message will
444    /// contain any commands that are after the commands in response_cache.
445    pub fn push(
446        &mut self,
447        target: &mut [u8],
448        provider: &mut impl StorageProvider,
449    ) -> Result<usize, SyncError> {
450        use SyncResponderState as S;
451        let Some(graph_id) = self.graph_id else {
452            self.state = S::Reset;
453            bug!("poll called before graph_id was set");
454        };
455
456        let storage = match provider.get_storage(graph_id) {
457            Ok(s) => s,
458            Err(e) => {
459                self.state = S::Reset;
460                return Err(e.into());
461            }
462        };
463        self.to_send = Self::find_needed_segments(&self.has, storage)?;
464        let (commands, command_data, next_send) = self.get_commands(provider)?;
465        let mut length = 0;
466        if !commands.is_empty() {
467            let message = SyncType::Push {
468                message: SyncResponseMessage::SyncResponse {
469                    session_id: self.session_id()?,
470                    response_index: self.message_index as u64,
471                    commands,
472                },
473                graph_id: self.graph_id.assume("graph id must exist")?,
474            };
475            self.message_index = self
476                .message_index
477                .checked_add(1)
478                .assume("message_index increment overflow")?;
479            self.next_send = next_send;
480
481            length = Self::write_sync_type(target, message)?;
482            let total_length = length
483                .checked_add(command_data.len())
484                .assume("length + command_data_length mustn't overflow")?;
485            target
486                .get_mut(length..total_length)
487                .assume("sync message fits in target")?
488                .copy_from_slice(&command_data);
489            length = total_length;
490        }
491        Ok(length)
492    }
493
494    fn get_commands(
495        &mut self,
496        provider: &mut impl StorageProvider,
497    ) -> Result<
498        (
499            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
500            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
501            usize,
502        ),
503        SyncError,
504    > {
505        let Some(graph_id) = self.graph_id.as_ref() else {
506            self.state = SyncResponderState::Reset;
507            bug!("get_next called before graph_id was set");
508        };
509        let storage = match provider.get_storage(*graph_id) {
510            Ok(s) => s,
511            Err(e) => {
512                self.state = SyncResponderState::Reset;
513                return Err(e.into());
514            }
515        };
516        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
517        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
518        let mut index = self.next_send;
519        for i in self.next_send..self.to_send.len() {
520            if commands.is_full() {
521                break;
522            }
523            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
524            let Some(&location) = self.to_send.get(i) else {
525                self.state = SyncResponderState::Reset;
526                bug!("send index OOB");
527            };
528
529            let segment = storage
530                .get_segment(location)
531                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
532
533            let found = segment.get_from(location);
534
535            for command in &found {
536                let mut policy_length = 0;
537
538                if let Some(policy) = command.policy() {
539                    policy_length = policy.len();
540                    command_data
541                        .extend_from_slice(policy)
542                        .ok()
543                        .assume("command_data is too large")?;
544                }
545
546                let bytes = command.bytes();
547                command_data
548                    .extend_from_slice(bytes)
549                    .ok()
550                    .assume("command_data is too large")?;
551
552                let max_cut = command.max_cut()?;
553                let meta = CommandMeta {
554                    id: command.id(),
555                    priority: command.priority(),
556                    parent: command.parent(),
557                    policy_length: policy_length as u32,
558                    length: bytes.len() as u32,
559                    max_cut,
560                };
561
562                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
563                commands
564                    .push(meta)
565                    .ok()
566                    .assume("too many commands in segment")?;
567                if commands.is_full() {
568                    break;
569                }
570            }
571        }
572        Ok((commands, command_data, index))
573    }
574
575    fn session_id(&self) -> Result<u128, SyncError> {
576        Ok(self.session_id.assume("session id is set")?)
577    }
578}