Skip to main content

aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10    SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13    StorageError, SyncType,
14    command::{Address, CmdId, Command as _},
15    storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub const fn new() -> Self {
25        Self { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42
43        let mut retain_head = |request_head: &Address, new_head: Location| {
44            let new_head_seg = storage.get_segment(new_head)?;
45            let req_head_loc = storage
46                .get_location(*request_head)?
47                .assume("location must exist")?;
48            let req_head_seg = storage.get_segment(req_head_loc)?;
49            if request_head.id
50                == new_head_seg
51                    .get_command(new_head)
52                    .assume("location must exist")?
53                    .address()?
54                    .id
55            {
56                add_command = false;
57            }
58            // If the new head is an ancestor of the request head, don't add it
59            if (new_head.same_segment(req_head_loc) && new_head.max_cut <= req_head_loc.max_cut)
60                || storage.is_ancestor(new_head, &req_head_seg)?
61            {
62                add_command = false;
63            }
64            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
65        };
66        self.heads
67            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
68        if add_command && !self.heads.is_full() {
69            self.heads
70                .push(command)
71                .ok()
72                .assume("command locations should not be full")?;
73        }
74
75        Ok(())
76    }
77}
78
79// TODO: Use compile-time args. This initial definition results in this clippy warning:
80// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
81// As the buffer consts will be compile-time variables in the future, we will be
82// able to tune these buffers for smaller footprints. Right now, this enum is not
83// suitable for small devices (`SyncResponse` is 14448 bytes).
84/// Messages sent from the responder to the requester.
85#[derive(Serialize, Deserialize, Debug)]
86#[allow(clippy::large_enum_variant)]
87pub enum SyncResponseMessage {
88    /// Sent in response to a `SyncRequest`
89    SyncResponse {
90        /// A random-value produced by a cryptographically secure RNG.
91        session_id: u128,
92        /// If the responder intends to send a value of command bytes
93        /// greater than the responder's configured maximum, the responder
94        /// will send more than one `SyncResponse`. The first message has an
95        /// index of 1, and each following is incremented.
96        response_index: u64,
97        /// Commands that the responder believes the requester does not have.
98        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
99    },
100
101    /// End a sync session if `SyncRequest.max_bytes` has been reached or
102    /// there are no remaining commands to send.
103    SyncEnd {
104        /// A random-value produced by a cryptographically secure RNG
105        /// corresponding to the `session_id` in the initial `SyncRequest`.
106        session_id: u128,
107        /// Largest index of any `SyncResponse`
108        max_index: u64,
109        /// Set `true` if this message was sent due to reaching the `max_bytes`
110        /// budget.
111        remaining: bool,
112    },
113
114    /// Message sent by a responder after a sync has been completed, but before
115    /// the session has ended, if it has new commands in it's graph. If a
116    /// requester wishes to respond to this message, it should do so with a
117    /// new `SyncRequest`. This message may use the existing `session_id`.
118    Offer {
119        /// A random-value produced by a cryptographically secure RNG
120        /// corresponding to the `session_id` in the initial `SyncRequest`.
121        session_id: u128,
122        /// Head of the branch the responder wishes to send.
123        head: CmdId,
124    },
125
126    /// Message sent by either requester or responder to indicate the session
127    /// has been terminated or the `session_id` is no longer valid.
128    EndSession { session_id: u128 },
129}
130
131impl SyncResponseMessage {
132    pub fn session_id(&self) -> u128 {
133        match self {
134            Self::SyncResponse { session_id, .. } => *session_id,
135            Self::SyncEnd { session_id, .. } => *session_id,
136            Self::Offer { session_id, .. } => *session_id,
137            Self::EndSession { session_id, .. } => *session_id,
138        }
139    }
140}
141
142#[derive(Debug, Default)]
143enum SyncResponderState {
144    #[default]
145    New,
146    Start,
147    Send,
148    Idle,
149    Reset,
150    Stopped,
151}
152
153#[derive(Default)]
154pub struct SyncResponder {
155    session_id: Option<u128>,
156    graph_id: Option<GraphId>,
157    state: SyncResponderState,
158    bytes_sent: u64,
159    next_send: usize,
160    message_index: usize,
161    has: Vec<Address, COMMAND_SAMPLE_MAX>,
162    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
163}
164
165impl SyncResponder {
166    /// Create a new [`SyncResponder`].
167    pub fn new() -> Self {
168        Self {
169            session_id: None,
170            graph_id: None,
171            state: SyncResponderState::New,
172            bytes_sent: 0,
173            next_send: 0,
174            message_index: 0,
175            has: Vec::new(),
176            to_send: Vec::new(),
177        }
178    }
179
180    /// Returns true if [`Self::poll`] would produce a message.
181    pub fn ready(&self) -> bool {
182        use SyncResponderState::*;
183        match self.state {
184            Reset | Start | Send => true, // TODO(chip): For Send, check whether to_send has anything to send
185            New | Idle | Stopped => false,
186        }
187    }
188
189    /// Write a sync message in to the target buffer. Returns the number
190    /// of bytes written.
191    pub fn poll(
192        &mut self,
193        target: &mut [u8],
194        provider: &mut impl StorageProvider,
195        response_cache: &mut PeerCache,
196    ) -> Result<usize, SyncError> {
197        // TODO(chip): return a status enum instead of usize
198        use SyncResponderState as S;
199        let length = match self.state {
200            S::New | S::Idle | S::Stopped => {
201                return Err(SyncError::NotReady); // TODO(chip): return Ok(NotReady)
202            }
203            S::Start => {
204                let Some(graph_id) = self.graph_id else {
205                    self.state = S::Reset;
206                    bug!("poll called before graph_id was set");
207                };
208
209                let storage = match provider.get_storage(graph_id) {
210                    Ok(s) => s,
211                    Err(e) => {
212                        self.state = S::Reset;
213                        return Err(e.into());
214                    }
215                };
216
217                self.state = S::Send;
218                for command in &self.has {
219                    // We only need to check commands that are a part of our graph.
220                    if let Some(cmd_loc) = storage.get_location(*command)? {
221                        response_cache.add_command(storage, *command, cmd_loc)?;
222                    }
223                }
224                self.to_send = Self::find_needed_segments(&self.has, storage)?;
225
226                self.get_next(target, provider)?
227            }
228            S::Send => self.get_next(target, provider)?,
229            S::Reset => {
230                self.state = S::Stopped;
231                let message = SyncResponseMessage::EndSession {
232                    session_id: self.session_id()?,
233                };
234                Self::write(target, message)?
235            }
236        };
237
238        Ok(length)
239    }
240
241    /// Receive a sync message. Updates the responders state for later polling.
242    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
243        if self.session_id.is_none() {
244            self.session_id = Some(message.session_id());
245        }
246        if self.session_id != Some(message.session_id()) {
247            return Err(SyncError::SessionMismatch);
248        }
249
250        match message {
251            SyncRequestMessage::SyncRequest {
252                graph_id,
253                max_bytes,
254                commands,
255                ..
256            } => {
257                self.state = SyncResponderState::Start;
258                self.graph_id = Some(graph_id);
259                self.bytes_sent = max_bytes;
260                self.to_send = Vec::new();
261                self.has = commands;
262                self.next_send = 0;
263                return Ok(());
264            }
265            SyncRequestMessage::RequestMissing { .. } => {
266                todo!()
267            }
268            SyncRequestMessage::SyncResume { .. } => {
269                todo!()
270            }
271            SyncRequestMessage::EndSession { .. } => {
272                self.state = SyncResponderState::Stopped;
273            }
274        }
275
276        Ok(())
277    }
278
279    fn write_sync_type(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
280        Ok(postcard::to_slice(&msg, target)?.len())
281    }
282
283    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
284        Ok(postcard::to_slice(&msg, target)?.len())
285    }
286
287    /// This (probably) returns a Vec of segment addresses where the head of each segment is
288    /// not the ancestor of any samples we have been sent. If that is longer than
289    /// SEGMENT_BUFFER_MAX, it contains the oldest segment heads where that holds.
290    fn find_needed_segments(
291        commands: &[Address],
292        storage: &impl Storage,
293    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
294        let mut have_locations = vec::Vec::new(); //BUG: not constant size
295        for &addr in commands {
296            // Note: We could use things we don't have as a hint to
297            // know we should perform a sync request.
298            if let Some(location) = storage.get_location(addr)? {
299                have_locations.push(location);
300            }
301        }
302
303        // Filter out locations that are ancestors of other locations in the list.
304        // If location A is an ancestor of location B, we only need to keep B since
305        // having B implies having A and all its ancestors.
306        // Iterate backwards so we can safely remove items
307        for i in (0..have_locations.len()).rev() {
308            let location_a = have_locations[i];
309            let mut is_ancestor_of_other = false;
310            for &location_b in &have_locations {
311                if location_a != location_b {
312                    let segment_b = storage.get_segment(location_b)?;
313                    if location_a.same_segment(location_b)
314                        && location_a.max_cut <= location_b.max_cut
315                        || storage.is_ancestor(location_a, &segment_b)?
316                    {
317                        is_ancestor_of_other = true;
318                        break;
319                    }
320                }
321            }
322            if is_ancestor_of_other {
323                have_locations.remove(i);
324            }
325        }
326
327        let mut heads = vec::Vec::new();
328        heads.push(storage.get_head()?);
329
330        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
331
332        while !heads.is_empty() {
333            let current = mem::take(&mut heads);
334            'heads: for head in current {
335                // If the segment is already in the result, skip it
336                if result.iter().any(|loc| loc.same_segment(head)) {
337                    continue 'heads;
338                }
339
340                // Check if the current segment head is an ancestor of any location in have_locations.
341                // If so, stop traversing backward from this point since the requester already has
342                // this command and all its ancestors.
343                for &have_location in &have_locations {
344                    let have_segment = storage.get_segment(have_location)?;
345                    if storage.is_ancestor(head, &have_segment)? {
346                        continue 'heads;
347                    }
348                }
349
350                let segment = storage.get_segment(head)?;
351
352                // If the requester has any commands in this segment, send from the next command
353                if let Some(latest_loc) = have_locations
354                    .iter()
355                    .filter(|loc| loc.same_segment(head))
356                    .max_by_key(|&&location| location.max_cut)
357                {
358                    let next_max_cut = latest_loc
359                        .max_cut
360                        .checked_add(1)
361                        .assume("command + 1 mustn't overflow")?;
362                    let next_location = Location {
363                        segment: head.segment,
364                        max_cut: next_max_cut,
365                    };
366
367                    let head_loc = segment.head_location()?;
368                    if next_location.max_cut > head_loc.max_cut {
369                        continue 'heads;
370                    }
371                    if result.is_full() {
372                        result.pop_back();
373                    }
374                    result
375                        .push_front(next_location)
376                        .ok()
377                        .assume("too many segments")?;
378                    continue 'heads;
379                }
380
381                heads.extend(segment.prior());
382
383                if result.is_full() {
384                    result.pop_back();
385                }
386
387                let location = segment.first_location();
388                result
389                    .push_front(location)
390                    .ok()
391                    .assume("too many segments")?;
392            }
393        }
394        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
395        for l in result {
396            r.push(l).ok().assume("too many segments")?;
397        }
398        // Order segments to ensure that a segment isn't received before its
399        // ancestor segments.
400        r.sort();
401        Ok(r)
402    }
403
404    fn get_next(
405        &mut self,
406        target: &mut [u8],
407        provider: &mut impl StorageProvider,
408    ) -> Result<usize, SyncError> {
409        if self.next_send >= self.to_send.len() {
410            self.state = SyncResponderState::Idle;
411            let message = SyncResponseMessage::SyncEnd {
412                session_id: self.session_id()?,
413                max_index: self.message_index as u64,
414                remaining: false,
415            };
416            let length = Self::write(target, message)?;
417            return Ok(length);
418        }
419
420        let (commands, command_data, next_send) = self.get_commands(provider)?;
421
422        let message = SyncResponseMessage::SyncResponse {
423            session_id: self.session_id()?,
424            response_index: self.message_index as u64,
425            commands,
426        };
427        self.message_index = self
428            .message_index
429            .checked_add(1)
430            .assume("message_index overflow")?;
431        self.next_send = next_send;
432
433        let length = Self::write(target, message)?;
434        let total_length = length
435            .checked_add(command_data.len())
436            .assume("length + command_data_length mustn't overflow")?;
437        target
438            .get_mut(length..total_length)
439            .assume("sync message fits in target")?
440            .copy_from_slice(&command_data);
441        Ok(total_length)
442    }
443
444    /// Writes a sync push message to target for the peer. The message will
445    /// contain any commands that are after the commands in response_cache.
446    pub fn push(
447        &mut self,
448        target: &mut [u8],
449        provider: &mut impl StorageProvider,
450    ) -> Result<usize, SyncError> {
451        use SyncResponderState as S;
452        let Some(graph_id) = self.graph_id else {
453            self.state = S::Reset;
454            bug!("poll called before graph_id was set");
455        };
456
457        let storage = match provider.get_storage(graph_id) {
458            Ok(s) => s,
459            Err(e) => {
460                self.state = S::Reset;
461                return Err(e.into());
462            }
463        };
464        self.to_send = Self::find_needed_segments(&self.has, storage)?;
465        let (commands, command_data, next_send) = self.get_commands(provider)?;
466        let mut length = 0;
467        if !commands.is_empty() {
468            let message = SyncType::Push {
469                message: SyncResponseMessage::SyncResponse {
470                    session_id: self.session_id()?,
471                    response_index: self.message_index as u64,
472                    commands,
473                },
474                graph_id: self.graph_id.assume("graph id must exist")?,
475            };
476            self.message_index = self
477                .message_index
478                .checked_add(1)
479                .assume("message_index increment overflow")?;
480            self.next_send = next_send;
481
482            length = Self::write_sync_type(target, message)?;
483            let total_length = length
484                .checked_add(command_data.len())
485                .assume("length + command_data_length mustn't overflow")?;
486            target
487                .get_mut(length..total_length)
488                .assume("sync message fits in target")?
489                .copy_from_slice(&command_data);
490            length = total_length;
491        }
492        Ok(length)
493    }
494
495    fn get_commands(
496        &mut self,
497        provider: &mut impl StorageProvider,
498    ) -> Result<
499        (
500            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
501            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
502            usize,
503        ),
504        SyncError,
505    > {
506        let Some(graph_id) = self.graph_id.as_ref() else {
507            self.state = SyncResponderState::Reset;
508            bug!("get_next called before graph_id was set");
509        };
510        let storage = match provider.get_storage(*graph_id) {
511            Ok(s) => s,
512            Err(e) => {
513                self.state = SyncResponderState::Reset;
514                return Err(e.into());
515            }
516        };
517        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
518        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
519        let mut index = self.next_send;
520        for i in self.next_send..self.to_send.len() {
521            if commands.is_full() {
522                break;
523            }
524            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
525            let Some(&location) = self.to_send.get(i) else {
526                self.state = SyncResponderState::Reset;
527                bug!("send index OOB");
528            };
529
530            let segment = storage
531                .get_segment(location)
532                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
533
534            let found = segment.get_from(location);
535
536            for command in &found {
537                let mut policy_length = 0;
538
539                if let Some(policy) = command.policy() {
540                    policy_length = policy.len();
541                    command_data
542                        .extend_from_slice(policy)
543                        .ok()
544                        .assume("command_data is too large")?;
545                }
546
547                let bytes = command.bytes();
548                command_data
549                    .extend_from_slice(bytes)
550                    .ok()
551                    .assume("command_data is too large")?;
552
553                let max_cut = command.max_cut()?;
554                let meta = CommandMeta {
555                    id: command.id(),
556                    priority: command.priority(),
557                    parent: command.parent(),
558                    policy_length: policy_length as u32,
559                    length: bytes.len() as u32,
560                    max_cut,
561                };
562
563                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
564                commands
565                    .push(meta)
566                    .ok()
567                    .assume("too many commands in segment")?;
568                if commands.is_full() {
569                    break;
570                }
571            }
572        }
573        Ok((commands, command_data, index))
574    }
575
576    fn session_id(&self) -> Result<u128, SyncError> {
577        Ok(self.session_id.assume("session id is set")?)
578    }
579}