aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10    SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13    StorageError, SyncType,
14    command::{Address, CmdId, Command},
15    storage::{GraphId, Location, Segment, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub const fn new() -> Self {
25        PeerCache { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &mut S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42        let mut retain_head = |request_head: &Address, new_head: Location| {
43            let new_head_seg = storage.get_segment(new_head)?;
44            let req_head_loc = storage
45                .get_location(*request_head)?
46                .assume("location must exist")?;
47            let req_head_seg = storage.get_segment(req_head_loc)?;
48            if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49                if request_head.id == new_head_command.address()?.id {
50                    add_command = false;
51                }
52            }
53            if storage.is_ancestor(new_head, &req_head_seg)? {
54                add_command = false;
55            }
56            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57        };
58        self.heads
59            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60        if add_command && !self.heads.is_full() {
61            self.heads
62                .push(command)
63                .ok()
64                .assume("command locations should not be full")?;
65        };
66        Ok(())
67    }
68}
69
70// TODO: Use compile-time args. This initial definition results in this clippy warning:
71// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
72// As the buffer consts will be compile-time variables in the future, we will be
73// able to tune these buffers for smaller footprints. Right now, this enum is not
74// suitable for small devices (`SyncResponse` is 14448 bytes).
75/// Messages sent from the responder to the requester.
76#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79    /// Sent in response to a `SyncRequest`
80    SyncResponse {
81        /// A random-value produced by a cryptographically secure RNG.
82        session_id: u128,
83        /// If the responder intends to send a value of command bytes
84        /// greater than the responder's configured maximum, the responder
85        /// will send more than one `SyncResponse`. The first message has an
86        /// index of 1, and each following is incremented.
87        response_index: u64,
88        /// Commands that the responder believes the requester does not have.
89        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90    },
91
92    /// End a sync session if `SyncRequest.max_bytes` has been reached or
93    /// there are no remaining commands to send.
94    SyncEnd {
95        /// A random-value produced by a cryptographically secure RNG
96        /// corresponding to the `session_id` in the initial `SyncRequest`.
97        session_id: u128,
98        /// Largest index of any `SyncResponse`
99        max_index: u64,
100        /// Set `true` if this message was sent due to reaching the `max_bytes`
101        /// budget.
102        remaining: bool,
103    },
104
105    /// Message sent by a responder after a sync has been completed, but before
106    /// the session has ended, if it has new commands in it's graph. If a
107    /// requester wishes to respond to this message, it should do so with a
108    /// new `SyncRequest`. This message may use the existing `session_id`.
109    Offer {
110        /// A random-value produced by a cryptographically secure RNG
111        /// corresponding to the `session_id` in the initial `SyncRequest`.
112        session_id: u128,
113        /// Head of the branch the responder wishes to send.
114        head: CmdId,
115    },
116
117    /// Message sent by either requester or responder to indicate the session
118    /// has been terminated or the `session_id` is no longer valid.
119    EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123    pub fn session_id(&self) -> u128 {
124        match self {
125            Self::SyncResponse { session_id, .. } => *session_id,
126            Self::SyncEnd { session_id, .. } => *session_id,
127            Self::Offer { session_id, .. } => *session_id,
128            Self::EndSession { session_id, .. } => *session_id,
129        }
130    }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135    New,
136    Start,
137    Send,
138    Idle,
139    Reset,
140    Stopped,
141}
142
143impl Default for SyncResponderState {
144    fn default() -> Self {
145        Self::New
146    }
147}
148
149#[derive(Default)]
150pub struct SyncResponder<A> {
151    session_id: Option<u128>,
152    storage_id: Option<GraphId>,
153    state: SyncResponderState,
154    bytes_sent: u64,
155    next_send: usize,
156    message_index: usize,
157    has: Vec<Address, COMMAND_SAMPLE_MAX>,
158    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
159    server_address: A,
160}
161
162impl<A: Serialize + Clone> SyncResponder<A> {
163    /// Create a new [`SyncResponder`].
164    pub fn new(server_address: A) -> Self {
165        SyncResponder {
166            session_id: None,
167            storage_id: None,
168            state: SyncResponderState::New,
169            bytes_sent: 0,
170            next_send: 0,
171            message_index: 0,
172            has: Vec::new(),
173            to_send: Vec::new(),
174            server_address,
175        }
176    }
177
178    /// Returns true if [`Self::poll`] would produce a message.
179    pub fn ready(&self) -> bool {
180        use SyncResponderState::*;
181        match self.state {
182            Reset | Start | Send => true, // TODO(chip): For Send, check whether to_send has anything to send
183            New | Idle | Stopped => false,
184        }
185    }
186
187    /// Write a sync message in to the target buffer. Returns the number
188    /// of bytes written.
189    pub fn poll(
190        &mut self,
191        target: &mut [u8],
192        provider: &mut impl StorageProvider,
193        response_cache: &mut PeerCache,
194    ) -> Result<usize, SyncError> {
195        // TODO(chip): return a status enum instead of usize
196        use SyncResponderState as S;
197        let length = match self.state {
198            S::New | S::Idle | S::Stopped => {
199                return Err(SyncError::NotReady); // TODO(chip): return Ok(NotReady)
200            }
201            S::Start => {
202                let Some(storage_id) = self.storage_id else {
203                    self.state = S::Reset;
204                    bug!("poll called before storage_id was set");
205                };
206
207                let storage = match provider.get_storage(storage_id) {
208                    Ok(s) => s,
209                    Err(e) => {
210                        self.state = S::Reset;
211                        return Err(e.into());
212                    }
213                };
214
215                self.state = S::Send;
216                for command in &self.has {
217                    // We only need to check commands that are a part of our graph.
218                    if let Some(cmd_loc) = storage.get_location(*command)? {
219                        response_cache.add_command(storage, *command, cmd_loc)?;
220                    }
221                }
222                self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
223
224                self.get_next(target, provider)?
225            }
226            S::Send => self.get_next(target, provider)?,
227            S::Reset => {
228                self.state = S::Stopped;
229                let message = SyncResponseMessage::EndSession {
230                    session_id: self.session_id()?,
231                };
232                Self::write(target, message)?
233            }
234        };
235
236        Ok(length)
237    }
238
239    /// Receive a sync message. Updates the responders state for later polling.
240    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
241        if self.session_id.is_none() {
242            self.session_id = Some(message.session_id());
243        }
244        if self.session_id != Some(message.session_id()) {
245            return Err(SyncError::SessionMismatch);
246        }
247
248        match message {
249            SyncRequestMessage::SyncRequest {
250                storage_id,
251                max_bytes,
252                commands,
253                ..
254            } => {
255                self.state = SyncResponderState::Start;
256                self.storage_id = Some(storage_id);
257                self.bytes_sent = max_bytes;
258                self.to_send = Vec::new();
259                self.has = commands;
260                self.next_send = 0;
261                return Ok(());
262            }
263            SyncRequestMessage::RequestMissing { .. } => {
264                todo!()
265            }
266            SyncRequestMessage::SyncResume { .. } => {
267                todo!()
268            }
269            SyncRequestMessage::EndSession { .. } => {
270                self.state = SyncResponderState::Stopped;
271            }
272        };
273
274        Ok(())
275    }
276
277    fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
278        Ok(postcard::to_slice(&msg, target)?.len())
279    }
280
281    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
282        Ok(postcard::to_slice(&msg, target)?.len())
283    }
284
285    /// This (probably) returns a Vec of segment addresses where the head of each segment is
286    /// not the ancestor of any samples we have been sent. If that is longer than
287    /// SEGMENT_BUFFER_MAX, it contains the oldest segment heads where that holds.
288    fn find_needed_segments(
289        commands: &[Address],
290        storage: &impl Storage,
291    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
292        let mut have_locations = vec::Vec::new(); //BUG: not constant size
293        for &addr in commands {
294            let Some(location) = storage.get_location(addr)? else {
295                // Note: We could use things we don't
296                // have as a hint to know we should
297                // perform a sync request.
298                continue;
299            };
300
301            have_locations.push(location);
302        }
303
304        let mut heads = vec::Vec::new();
305        heads.push(storage.get_head()?);
306
307        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
308
309        while !heads.is_empty() {
310            let current = mem::take(&mut heads);
311            'heads: for head in current {
312                let segment = storage.get_segment(head)?;
313                if segment.contains_any(&result) {
314                    continue 'heads;
315                }
316
317                for &location in &have_locations {
318                    if segment.contains(location) {
319                        if location != segment.head_location() {
320                            if result.is_full() {
321                                result.pop_back();
322                            }
323                            result
324                                .push_front(location)
325                                .ok()
326                                .assume("too many segments")?;
327                        }
328                        continue 'heads;
329                    }
330                }
331                heads.extend(segment.prior());
332
333                if result.is_full() {
334                    result.pop_back();
335                }
336
337                let location = segment.first_location();
338                result
339                    .push_front(location)
340                    .ok()
341                    .assume("too many segments")?;
342            }
343        }
344        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
345        for l in result {
346            r.push(l).ok().assume("too many segments")?;
347        }
348        // Order segments to ensure that a segment isn't received before its
349        // ancestor segments.
350        r.sort();
351        Ok(r)
352    }
353
354    fn get_next(
355        &mut self,
356        target: &mut [u8],
357        provider: &mut impl StorageProvider,
358    ) -> Result<usize, SyncError> {
359        if self.next_send >= self.to_send.len() {
360            self.state = SyncResponderState::Idle;
361            let message = SyncResponseMessage::SyncEnd {
362                session_id: self.session_id()?,
363                max_index: self.message_index as u64,
364                remaining: false,
365            };
366            let length = Self::write(target, message)?;
367            return Ok(length);
368        }
369
370        let (commands, command_data, next_send) = self.get_commands(provider)?;
371
372        let message = SyncResponseMessage::SyncResponse {
373            session_id: self.session_id()?,
374            response_index: self.message_index as u64,
375            commands,
376        };
377        self.message_index = self
378            .message_index
379            .checked_add(1)
380            .assume("message_index overflow")?;
381        self.next_send = next_send;
382
383        let length = Self::write(target, message)?;
384        let total_length = length
385            .checked_add(command_data.len())
386            .assume("length + command_data_length mustn't overflow")?;
387        target
388            .get_mut(length..total_length)
389            .assume("sync message fits in target")?
390            .copy_from_slice(&command_data);
391        Ok(total_length)
392    }
393
394    /// Writes a sync push message to target for the peer. The message will
395    /// contain any commands that are after the commands in response_cache.
396    pub fn push(
397        &mut self,
398        target: &mut [u8],
399        provider: &mut impl StorageProvider,
400    ) -> Result<usize, SyncError> {
401        use SyncResponderState as S;
402        let Some(storage_id) = self.storage_id else {
403            self.state = S::Reset;
404            bug!("poll called before storage_id was set");
405        };
406
407        let storage = match provider.get_storage(storage_id) {
408            Ok(s) => s,
409            Err(e) => {
410                self.state = S::Reset;
411                return Err(e.into());
412            }
413        };
414        self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
415        let (commands, command_data, next_send) = self.get_commands(provider)?;
416        let mut length = 0;
417        if !commands.is_empty() {
418            let message = SyncType::Push {
419                message: SyncResponseMessage::SyncResponse {
420                    session_id: self.session_id()?,
421                    response_index: self.message_index as u64,
422                    commands,
423                },
424                storage_id: self.storage_id.assume("storage id must exist")?,
425                address: self.server_address.clone(),
426            };
427            self.message_index = self
428                .message_index
429                .checked_add(1)
430                .assume("message_index increment overflow")?;
431            self.next_send = next_send;
432
433            length = Self::write_sync_type(target, message)?;
434            let total_length = length
435                .checked_add(command_data.len())
436                .assume("length + command_data_length mustn't overflow")?;
437            target
438                .get_mut(length..total_length)
439                .assume("sync message fits in target")?
440                .copy_from_slice(&command_data);
441            length = total_length;
442        }
443        Ok(length)
444    }
445
446    fn get_commands(
447        &mut self,
448        provider: &mut impl StorageProvider,
449    ) -> Result<
450        (
451            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
452            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
453            usize,
454        ),
455        SyncError,
456    > {
457        let Some(storage_id) = self.storage_id.as_ref() else {
458            self.state = SyncResponderState::Reset;
459            bug!("get_next called before storage_id was set");
460        };
461        let storage = match provider.get_storage(*storage_id) {
462            Ok(s) => s,
463            Err(e) => {
464                self.state = SyncResponderState::Reset;
465                return Err(e.into());
466            }
467        };
468        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
469        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
470        let mut index = self.next_send;
471        for i in self.next_send..self.to_send.len() {
472            if commands.is_full() {
473                break;
474            }
475            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
476            let Some(&location) = self.to_send.get(i) else {
477                self.state = SyncResponderState::Reset;
478                bug!("send index OOB");
479            };
480
481            let segment = storage
482                .get_segment(location)
483                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
484
485            let found = segment.get_from(location);
486
487            for command in &found {
488                let mut policy_length = 0;
489
490                if let Some(policy) = command.policy() {
491                    policy_length = policy.len();
492                    command_data
493                        .extend_from_slice(policy)
494                        .ok()
495                        .assume("command_data is too large")?;
496                }
497
498                let bytes = command.bytes();
499                command_data
500                    .extend_from_slice(bytes)
501                    .ok()
502                    .assume("command_data is too large")?;
503
504                let meta = CommandMeta {
505                    id: command.id(),
506                    priority: command.priority(),
507                    parent: command.parent(),
508                    policy_length: policy_length as u32,
509                    length: bytes.len() as u32,
510                    max_cut: command.max_cut()?,
511                };
512
513                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
514                commands
515                    .push(meta)
516                    .ok()
517                    .assume("too many commands in segment")?;
518                if commands.is_full() {
519                    break;
520                }
521            }
522        }
523        Ok((commands, command_data, index))
524    }
525
526    fn session_id(&self) -> Result<u128, SyncError> {
527        Ok(self.session_id.assume("session id is set")?)
528    }
529}