aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{BugExt as _, bug};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, CommandMeta, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX,
10    SEGMENT_BUFFER_MAX, SyncError, requester::SyncRequestMessage,
11};
12use crate::{
13    StorageError, SyncType,
14    command::{Address, CmdId, Command as _},
15    storage::{GraphId, Location, Segment as _, Storage, StorageProvider},
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub const fn new() -> Self {
25        Self { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &mut S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42        let mut retain_head = |request_head: &Address, new_head: Location| {
43            let new_head_seg = storage.get_segment(new_head)?;
44            let req_head_loc = storage
45                .get_location(*request_head)?
46                .assume("location must exist")?;
47            let req_head_seg = storage.get_segment(req_head_loc)?;
48            if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49                if request_head.id == new_head_command.address()?.id {
50                    add_command = false;
51                }
52            }
53            if storage.is_ancestor(new_head, &req_head_seg)? {
54                add_command = false;
55            }
56            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57        };
58        self.heads
59            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60        if add_command && !self.heads.is_full() {
61            self.heads
62                .push(command)
63                .ok()
64                .assume("command locations should not be full")?;
65        }
66        Ok(())
67    }
68}
69
70// TODO: Use compile-time args. This initial definition results in this clippy warning:
71// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
72// As the buffer consts will be compile-time variables in the future, we will be
73// able to tune these buffers for smaller footprints. Right now, this enum is not
74// suitable for small devices (`SyncResponse` is 14448 bytes).
75/// Messages sent from the responder to the requester.
76#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79    /// Sent in response to a `SyncRequest`
80    SyncResponse {
81        /// A random-value produced by a cryptographically secure RNG.
82        session_id: u128,
83        /// If the responder intends to send a value of command bytes
84        /// greater than the responder's configured maximum, the responder
85        /// will send more than one `SyncResponse`. The first message has an
86        /// index of 1, and each following is incremented.
87        response_index: u64,
88        /// Commands that the responder believes the requester does not have.
89        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90    },
91
92    /// End a sync session if `SyncRequest.max_bytes` has been reached or
93    /// there are no remaining commands to send.
94    SyncEnd {
95        /// A random-value produced by a cryptographically secure RNG
96        /// corresponding to the `session_id` in the initial `SyncRequest`.
97        session_id: u128,
98        /// Largest index of any `SyncResponse`
99        max_index: u64,
100        /// Set `true` if this message was sent due to reaching the `max_bytes`
101        /// budget.
102        remaining: bool,
103    },
104
105    /// Message sent by a responder after a sync has been completed, but before
106    /// the session has ended, if it has new commands in it's graph. If a
107    /// requester wishes to respond to this message, it should do so with a
108    /// new `SyncRequest`. This message may use the existing `session_id`.
109    Offer {
110        /// A random-value produced by a cryptographically secure RNG
111        /// corresponding to the `session_id` in the initial `SyncRequest`.
112        session_id: u128,
113        /// Head of the branch the responder wishes to send.
114        head: CmdId,
115    },
116
117    /// Message sent by either requester or responder to indicate the session
118    /// has been terminated or the `session_id` is no longer valid.
119    EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123    pub fn session_id(&self) -> u128 {
124        match self {
125            Self::SyncResponse { session_id, .. } => *session_id,
126            Self::SyncEnd { session_id, .. } => *session_id,
127            Self::Offer { session_id, .. } => *session_id,
128            Self::EndSession { session_id, .. } => *session_id,
129        }
130    }
131}
132
133#[derive(Debug, Default)]
134enum SyncResponderState {
135    #[default]
136    New,
137    Start,
138    Send,
139    Idle,
140    Reset,
141    Stopped,
142}
143
144#[derive(Default)]
145pub struct SyncResponder<A> {
146    session_id: Option<u128>,
147    storage_id: Option<GraphId>,
148    state: SyncResponderState,
149    bytes_sent: u64,
150    next_send: usize,
151    message_index: usize,
152    has: Vec<Address, COMMAND_SAMPLE_MAX>,
153    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
154    server_address: A,
155}
156
157impl<A: Serialize + Clone> SyncResponder<A> {
158    /// Create a new [`SyncResponder`].
159    pub fn new(server_address: A) -> Self {
160        Self {
161            session_id: None,
162            storage_id: None,
163            state: SyncResponderState::New,
164            bytes_sent: 0,
165            next_send: 0,
166            message_index: 0,
167            has: Vec::new(),
168            to_send: Vec::new(),
169            server_address,
170        }
171    }
172
173    /// Returns true if [`Self::poll`] would produce a message.
174    pub fn ready(&self) -> bool {
175        use SyncResponderState::*;
176        match self.state {
177            Reset | Start | Send => true, // TODO(chip): For Send, check whether to_send has anything to send
178            New | Idle | Stopped => false,
179        }
180    }
181
182    /// Write a sync message in to the target buffer. Returns the number
183    /// of bytes written.
184    pub fn poll(
185        &mut self,
186        target: &mut [u8],
187        provider: &mut impl StorageProvider,
188        response_cache: &mut PeerCache,
189    ) -> Result<usize, SyncError> {
190        // TODO(chip): return a status enum instead of usize
191        use SyncResponderState as S;
192        let length = match self.state {
193            S::New | S::Idle | S::Stopped => {
194                return Err(SyncError::NotReady); // TODO(chip): return Ok(NotReady)
195            }
196            S::Start => {
197                let Some(storage_id) = self.storage_id else {
198                    self.state = S::Reset;
199                    bug!("poll called before storage_id was set");
200                };
201
202                let storage = match provider.get_storage(storage_id) {
203                    Ok(s) => s,
204                    Err(e) => {
205                        self.state = S::Reset;
206                        return Err(e.into());
207                    }
208                };
209
210                self.state = S::Send;
211                for command in &self.has {
212                    // We only need to check commands that are a part of our graph.
213                    if let Some(cmd_loc) = storage.get_location(*command)? {
214                        response_cache.add_command(storage, *command, cmd_loc)?;
215                    }
216                }
217                self.to_send = Self::find_needed_segments(&self.has, storage)?;
218
219                self.get_next(target, provider)?
220            }
221            S::Send => self.get_next(target, provider)?,
222            S::Reset => {
223                self.state = S::Stopped;
224                let message = SyncResponseMessage::EndSession {
225                    session_id: self.session_id()?,
226                };
227                Self::write(target, message)?
228            }
229        };
230
231        Ok(length)
232    }
233
234    /// Receive a sync message. Updates the responders state for later polling.
235    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
236        if self.session_id.is_none() {
237            self.session_id = Some(message.session_id());
238        }
239        if self.session_id != Some(message.session_id()) {
240            return Err(SyncError::SessionMismatch);
241        }
242
243        match message {
244            SyncRequestMessage::SyncRequest {
245                storage_id,
246                max_bytes,
247                commands,
248                ..
249            } => {
250                self.state = SyncResponderState::Start;
251                self.storage_id = Some(storage_id);
252                self.bytes_sent = max_bytes;
253                self.to_send = Vec::new();
254                self.has = commands;
255                self.next_send = 0;
256                return Ok(());
257            }
258            SyncRequestMessage::RequestMissing { .. } => {
259                todo!()
260            }
261            SyncRequestMessage::SyncResume { .. } => {
262                todo!()
263            }
264            SyncRequestMessage::EndSession { .. } => {
265                self.state = SyncResponderState::Stopped;
266            }
267        }
268
269        Ok(())
270    }
271
272    fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
273        Ok(postcard::to_slice(&msg, target)?.len())
274    }
275
276    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
277        Ok(postcard::to_slice(&msg, target)?.len())
278    }
279
280    /// This (probably) returns a Vec of segment addresses where the head of each segment is
281    /// not the ancestor of any samples we have been sent. If that is longer than
282    /// SEGMENT_BUFFER_MAX, it contains the oldest segment heads where that holds.
283    fn find_needed_segments(
284        commands: &[Address],
285        storage: &impl Storage,
286    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
287        let mut have_locations = vec::Vec::new(); //BUG: not constant size
288        for &addr in commands {
289            let Some(location) = storage.get_location(addr)? else {
290                // Note: We could use things we don't
291                // have as a hint to know we should
292                // perform a sync request.
293                continue;
294            };
295
296            have_locations.push(location);
297        }
298
299        let mut heads = vec::Vec::new();
300        heads.push(storage.get_head()?);
301
302        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
303
304        while !heads.is_empty() {
305            let current = mem::take(&mut heads);
306            'heads: for head in current {
307                let segment = storage.get_segment(head)?;
308                if segment.contains_any(&result) {
309                    continue 'heads;
310                }
311
312                for &location in &have_locations {
313                    if segment.contains(location) {
314                        if location != segment.head_location() {
315                            if result.is_full() {
316                                result.pop_back();
317                            }
318                            result
319                                .push_front(location)
320                                .ok()
321                                .assume("too many segments")?;
322                        }
323                        continue 'heads;
324                    }
325                }
326                heads.extend(segment.prior());
327
328                if result.is_full() {
329                    result.pop_back();
330                }
331
332                let location = segment.first_location();
333                result
334                    .push_front(location)
335                    .ok()
336                    .assume("too many segments")?;
337            }
338        }
339        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
340        for l in result {
341            r.push(l).ok().assume("too many segments")?;
342        }
343        // Order segments to ensure that a segment isn't received before its
344        // ancestor segments.
345        r.sort();
346        Ok(r)
347    }
348
349    fn get_next(
350        &mut self,
351        target: &mut [u8],
352        provider: &mut impl StorageProvider,
353    ) -> Result<usize, SyncError> {
354        if self.next_send >= self.to_send.len() {
355            self.state = SyncResponderState::Idle;
356            let message = SyncResponseMessage::SyncEnd {
357                session_id: self.session_id()?,
358                max_index: self.message_index as u64,
359                remaining: false,
360            };
361            let length = Self::write(target, message)?;
362            return Ok(length);
363        }
364
365        let (commands, command_data, next_send) = self.get_commands(provider)?;
366
367        let message = SyncResponseMessage::SyncResponse {
368            session_id: self.session_id()?,
369            response_index: self.message_index as u64,
370            commands,
371        };
372        self.message_index = self
373            .message_index
374            .checked_add(1)
375            .assume("message_index overflow")?;
376        self.next_send = next_send;
377
378        let length = Self::write(target, message)?;
379        let total_length = length
380            .checked_add(command_data.len())
381            .assume("length + command_data_length mustn't overflow")?;
382        target
383            .get_mut(length..total_length)
384            .assume("sync message fits in target")?
385            .copy_from_slice(&command_data);
386        Ok(total_length)
387    }
388
389    /// Writes a sync push message to target for the peer. The message will
390    /// contain any commands that are after the commands in response_cache.
391    pub fn push(
392        &mut self,
393        target: &mut [u8],
394        provider: &mut impl StorageProvider,
395    ) -> Result<usize, SyncError> {
396        use SyncResponderState as S;
397        let Some(storage_id) = self.storage_id else {
398            self.state = S::Reset;
399            bug!("poll called before storage_id was set");
400        };
401
402        let storage = match provider.get_storage(storage_id) {
403            Ok(s) => s,
404            Err(e) => {
405                self.state = S::Reset;
406                return Err(e.into());
407            }
408        };
409        self.to_send = Self::find_needed_segments(&self.has, storage)?;
410        let (commands, command_data, next_send) = self.get_commands(provider)?;
411        let mut length = 0;
412        if !commands.is_empty() {
413            let message = SyncType::Push {
414                message: SyncResponseMessage::SyncResponse {
415                    session_id: self.session_id()?,
416                    response_index: self.message_index as u64,
417                    commands,
418                },
419                storage_id: self.storage_id.assume("storage id must exist")?,
420                address: self.server_address.clone(),
421            };
422            self.message_index = self
423                .message_index
424                .checked_add(1)
425                .assume("message_index increment overflow")?;
426            self.next_send = next_send;
427
428            length = Self::write_sync_type(target, message)?;
429            let total_length = length
430                .checked_add(command_data.len())
431                .assume("length + command_data_length mustn't overflow")?;
432            target
433                .get_mut(length..total_length)
434                .assume("sync message fits in target")?
435                .copy_from_slice(&command_data);
436            length = total_length;
437        }
438        Ok(length)
439    }
440
441    fn get_commands(
442        &mut self,
443        provider: &mut impl StorageProvider,
444    ) -> Result<
445        (
446            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
447            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
448            usize,
449        ),
450        SyncError,
451    > {
452        let Some(storage_id) = self.storage_id.as_ref() else {
453            self.state = SyncResponderState::Reset;
454            bug!("get_next called before storage_id was set");
455        };
456        let storage = match provider.get_storage(*storage_id) {
457            Ok(s) => s,
458            Err(e) => {
459                self.state = SyncResponderState::Reset;
460                return Err(e.into());
461            }
462        };
463        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
464        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
465        let mut index = self.next_send;
466        for i in self.next_send..self.to_send.len() {
467            if commands.is_full() {
468                break;
469            }
470            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
471            let Some(&location) = self.to_send.get(i) else {
472                self.state = SyncResponderState::Reset;
473                bug!("send index OOB");
474            };
475
476            let segment = storage
477                .get_segment(location)
478                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
479
480            let found = segment.get_from(location);
481
482            for command in &found {
483                let mut policy_length = 0;
484
485                if let Some(policy) = command.policy() {
486                    policy_length = policy.len();
487                    command_data
488                        .extend_from_slice(policy)
489                        .ok()
490                        .assume("command_data is too large")?;
491                }
492
493                let bytes = command.bytes();
494                command_data
495                    .extend_from_slice(bytes)
496                    .ok()
497                    .assume("command_data is too large")?;
498
499                let meta = CommandMeta {
500                    id: command.id(),
501                    priority: command.priority(),
502                    parent: command.parent(),
503                    policy_length: policy_length as u32,
504                    length: bytes.len() as u32,
505                    max_cut: command.max_cut()?,
506                };
507
508                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
509                commands
510                    .push(meta)
511                    .ok()
512                    .assume("too many commands in segment")?;
513                if commands.is_full() {
514                    break;
515                }
516            }
517        }
518        Ok((commands, command_data, index))
519    }
520
521    fn session_id(&self) -> Result<u128, SyncError> {
522        Ok(self.session_id.assume("session id is set")?)
523    }
524}