aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use buggy::{bug, BugExt};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    requester::SyncRequestMessage, CommandMeta, SyncError, COMMAND_RESPONSE_MAX,
10    COMMAND_SAMPLE_MAX, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX, SEGMENT_BUFFER_MAX,
11};
12use crate::{
13    command::{Address, Command, CommandId},
14    storage::{GraphId, Location, Segment, Storage, StorageProvider},
15    StorageError, SyncType,
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub fn new() -> Self {
25        PeerCache { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &mut S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42        let mut retain_head = |request_head: &Address, new_head: Location| {
43            let new_head_seg = storage.get_segment(new_head)?;
44            let req_head_loc = storage
45                .get_location(*request_head)?
46                .assume("location must exist")?;
47            let req_head_seg = storage.get_segment(req_head_loc)?;
48            if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49                if request_head.id == new_head_command.address()?.id {
50                    add_command = false;
51                }
52            }
53            if storage.is_ancestor(new_head, &req_head_seg)? {
54                add_command = false;
55            }
56            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57        };
58        self.heads
59            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60        if add_command && !self.heads.is_full() {
61            self.heads
62                .push(command)
63                .ok()
64                .assume("command locations should not be full")?;
65        };
66        Ok(())
67    }
68}
69
70// TODO: Use compile-time args. This initial definition results in this clippy warning:
71// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
72// As the buffer consts will be compile-time variables in the future, we will be
73// able to tune these buffers for smaller footprints. Right now, this enum is not
74// suitable for small devices (`SyncResponse` is 14848 bytes).
75/// Messages sent from the responder to the requester.
76#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79    /// Sent in response to a `SyncRequest`
80    SyncResponse {
81        /// A random-value produced by a cryptographically secure RNG.
82        session_id: u128,
83        /// If the responder intends to send a value of command bytes
84        /// greater than the responder's configured maximum, the responder
85        /// will send more than one `SyncResponse`. The first message has an
86        /// index of 1, and each following is incremented.
87        index: u64,
88        /// Commands that the responder believes the requester does not have.
89        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90    },
91
92    /// End a sync session if `SyncRequest.max_bytes` has been reached or
93    /// there are no remaining commands to send.
94    SyncEnd {
95        /// A random-value produced by a cryptographically secure RNG
96        /// corresponding to the `session_id` in the initial `SyncRequest`.
97        session_id: u128,
98        /// Largest index of any `SyncResponse`
99        max_index: u64,
100        /// Set `true` if this message was sent due to reaching the `max_bytes`
101        /// budget.
102        remaining: bool,
103    },
104
105    /// Message sent by a responder after a sync has been completed, but before
106    /// the session has ended, if it has new commands in it's graph. If a
107    /// requester wishes to respond to this message, it should do so with a
108    /// new `SyncRequest`. This message may use the existing `session_id`.
109    Offer {
110        /// A random-value produced by a cryptographically secure RNG
111        /// corresponding to the `session_id` in the initial `SyncRequest`.
112        session_id: u128,
113        /// Head of the branch the responder wishes to send.
114        head: CommandId,
115    },
116
117    /// Message sent by either requester or responder to indicate the session
118    /// has been terminated or the `session_id` is no longer valid.
119    EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123    pub fn session_id(&self) -> u128 {
124        match self {
125            Self::SyncResponse { session_id, .. } => *session_id,
126            Self::SyncEnd { session_id, .. } => *session_id,
127            Self::Offer { session_id, .. } => *session_id,
128            Self::EndSession { session_id, .. } => *session_id,
129        }
130    }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135    New,
136    Start,
137    Send,
138    Idle,
139    Reset,
140    Stopped,
141}
142
143impl Default for SyncResponderState {
144    fn default() -> Self {
145        Self::New
146    }
147}
148
149#[derive(Default)]
150pub struct SyncResponder<A> {
151    session_id: Option<u128>,
152    storage_id: Option<GraphId>,
153    state: SyncResponderState,
154    bytes_sent: u64,
155    next_send: usize,
156    has: Vec<Address, COMMAND_SAMPLE_MAX>,
157    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
158    server_address: A,
159}
160
161impl<A: Serialize + Clone> SyncResponder<A> {
162    /// Create a new [`SyncResponder`].
163    pub fn new(server_address: A) -> Self {
164        SyncResponder {
165            session_id: None,
166            storage_id: None,
167            state: SyncResponderState::New,
168            bytes_sent: 0,
169            next_send: 0,
170            has: Vec::new(),
171            to_send: Vec::new(),
172            server_address,
173        }
174    }
175
176    /// Returns true if [`Self::poll`] would produce a message.
177    pub fn ready(&self) -> bool {
178        use SyncResponderState::*;
179        match self.state {
180            Reset | Start | Send => true,
181            New | Idle | Stopped => false,
182        }
183    }
184
185    /// Write a sync message in to the target buffer. Returns the number
186    /// of bytes written.
187    pub fn poll(
188        &mut self,
189        target: &mut [u8],
190        provider: &mut impl StorageProvider,
191        response_cache: &mut PeerCache,
192    ) -> Result<usize, SyncError> {
193        use SyncResponderState as S;
194        let length = match self.state {
195            S::New | S::Idle | S::Stopped => {
196                return Err(SyncError::NotReady);
197            }
198            S::Start => {
199                let Some(storage_id) = self.storage_id else {
200                    self.state = S::Reset;
201                    bug!("poll called before storage_id was set");
202                };
203
204                let storage = match provider.get_storage(storage_id) {
205                    Ok(s) => s,
206                    Err(e) => {
207                        self.state = S::Reset;
208                        return Err(e.into());
209                    }
210                };
211
212                self.state = S::Send;
213                for command in &self.has {
214                    // We only need to check commands that are a part of our graph.
215                    if let Some(cmd_loc) = storage.get_location(*command)? {
216                        response_cache.add_command(storage, *command, cmd_loc)?;
217                    }
218                }
219                self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
220
221                self.get_next(target, provider)?
222            }
223            S::Send => self.get_next(target, provider)?,
224            S::Reset => {
225                self.state = S::Stopped;
226                let message = SyncResponseMessage::EndSession {
227                    session_id: self.session_id()?,
228                };
229                Self::write(target, message)?
230            }
231        };
232
233        Ok(length)
234    }
235
236    /// Receive a sync message. Updates the responders state for later polling.
237    pub fn receive(&mut self, message: SyncRequestMessage) -> Result<(), SyncError> {
238        if self.session_id.is_none() {
239            self.session_id = Some(message.session_id());
240        }
241        if self.session_id != Some(message.session_id()) {
242            return Err(SyncError::SessionMismatch);
243        }
244
245        match message {
246            SyncRequestMessage::SyncRequest {
247                storage_id,
248                max_bytes,
249                commands,
250                ..
251            } => {
252                self.state = SyncResponderState::Start;
253                self.storage_id = Some(storage_id);
254                self.bytes_sent = max_bytes;
255                self.to_send = Vec::new();
256                self.has = commands;
257                self.next_send = 0;
258                return Ok(());
259            }
260            SyncRequestMessage::RequestMissing { .. } => {
261                todo!()
262            }
263            SyncRequestMessage::SyncResume { .. } => {
264                todo!()
265            }
266            SyncRequestMessage::EndSession { .. } => {
267                self.state = SyncResponderState::Stopped;
268            }
269        };
270
271        Ok(())
272    }
273
274    fn write_sync_type(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
275        Ok(postcard::to_slice(&msg, target)?.len())
276    }
277
278    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
279        Ok(postcard::to_slice(&msg, target)?.len())
280    }
281
282    fn find_needed_segments(
283        commands: &[Address],
284        storage: &impl Storage,
285    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
286        let mut have_locations = vec::Vec::new(); //BUG: not constant size
287        for &addr in commands {
288            let Some(location) = storage.get_location(addr)? else {
289                // Note: We could use things we don't
290                // have as a hint to know we should
291                // perform a sync request.
292                continue;
293            };
294
295            have_locations.push(location);
296        }
297
298        let mut heads = vec::Vec::new();
299        heads.push(storage.get_head()?);
300
301        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
302
303        while !heads.is_empty() {
304            let current = mem::take(&mut heads);
305            'heads: for head in current {
306                let segment = storage.get_segment(head)?;
307                if segment.contains_any(&result) {
308                    continue 'heads;
309                }
310
311                for &location in &have_locations {
312                    if segment.contains(location) {
313                        if location != segment.head_location() {
314                            if result.is_full() {
315                                result.pop_back();
316                            }
317                            result
318                                .push_front(location)
319                                .ok()
320                                .assume("too many segments")?;
321                        }
322                        continue 'heads;
323                    }
324                }
325                heads.extend(segment.prior());
326
327                if result.is_full() {
328                    result.pop_back();
329                }
330
331                let location = segment.first_location();
332                result
333                    .push_front(location)
334                    .ok()
335                    .assume("too many segments")?;
336            }
337        }
338        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
339        for l in result {
340            r.push(l).ok().assume("too many segments")?;
341        }
342        // Order segments to ensure that a segment isn't received before its
343        // ancestor segments.
344        r.sort();
345        Ok(r)
346    }
347
348    fn get_next(
349        &mut self,
350        target: &mut [u8],
351        provider: &mut impl StorageProvider,
352    ) -> Result<usize, SyncError> {
353        if self.next_send >= self.to_send.len() {
354            self.state = SyncResponderState::Idle;
355            return Ok(0);
356        }
357        let (commands, command_data, index) = self.get_commands(provider)?;
358
359        let message = SyncResponseMessage::SyncResponse {
360            session_id: self.session_id()?,
361            index: self.next_send as u64,
362            commands,
363        };
364
365        self.next_send = index;
366
367        let length = Self::write(target, message)?;
368        let total_length = length
369            .checked_add(command_data.len())
370            .assume("length + command_data_length mustn't overflow")?;
371        target
372            .get_mut(length..total_length)
373            .assume("sync message fits in target")?
374            .copy_from_slice(&command_data);
375        Ok(total_length)
376    }
377
378    /// Writes a sync push message to target for the peer. The message will
379    /// contain any commands that are after the commands in response_cache.
380    pub fn push(
381        &mut self,
382        target: &mut [u8],
383        provider: &mut impl StorageProvider,
384    ) -> Result<usize, SyncError> {
385        use SyncResponderState as S;
386        let Some(storage_id) = self.storage_id else {
387            self.state = S::Reset;
388            bug!("poll called before storage_id was set");
389        };
390
391        let storage = match provider.get_storage(storage_id) {
392            Ok(s) => s,
393            Err(e) => {
394                self.state = S::Reset;
395                return Err(e.into());
396            }
397        };
398        self.to_send = SyncResponder::<A>::find_needed_segments(&self.has, storage)?;
399        let (commands, command_data, index) = self.get_commands(provider)?;
400        let mut length = 0;
401        if !commands.is_empty() {
402            let message = SyncType::Push {
403                message: SyncResponseMessage::SyncResponse {
404                    session_id: self.session_id()?,
405                    index: self.next_send as u64,
406                    commands,
407                },
408                storage_id: self.storage_id.assume("storage id must exist")?,
409                address: self.server_address.clone(),
410            };
411            self.next_send = index;
412
413            length = Self::write_sync_type(target, message)?;
414            let total_length = length
415                .checked_add(command_data.len())
416                .assume("length + command_data_length mustn't overflow")?;
417            target
418                .get_mut(length..total_length)
419                .assume("sync message fits in target")?
420                .copy_from_slice(&command_data);
421            length = total_length;
422        }
423        Ok(length)
424    }
425
426    fn get_commands(
427        &mut self,
428        provider: &mut impl StorageProvider,
429    ) -> Result<
430        (
431            Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
432            Vec<u8, MAX_SYNC_MESSAGE_SIZE>,
433            usize,
434        ),
435        SyncError,
436    > {
437        let Some(storage_id) = self.storage_id.as_ref() else {
438            self.state = SyncResponderState::Reset;
439            bug!("get_next called before storage_id was set");
440        };
441        let storage = match provider.get_storage(*storage_id) {
442            Ok(s) => s,
443            Err(e) => {
444                self.state = SyncResponderState::Reset;
445                return Err(e.into());
446            }
447        };
448        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
449        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
450        let mut index = self.next_send;
451        for i in self.next_send..self.to_send.len() {
452            if commands.is_full() {
453                break;
454            }
455            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
456            let Some(&location) = self.to_send.get(i) else {
457                self.state = SyncResponderState::Reset;
458                bug!("send index OOB");
459            };
460
461            let segment = storage
462                .get_segment(location)
463                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
464
465            let found = segment.get_from(location);
466
467            for command in &found {
468                let mut policy_length = 0;
469
470                if let Some(policy) = command.policy() {
471                    policy_length = policy.len();
472                    command_data
473                        .extend_from_slice(policy)
474                        .ok()
475                        .assume("command_data is too large")?;
476                }
477
478                let bytes = command.bytes();
479                command_data
480                    .extend_from_slice(bytes)
481                    .ok()
482                    .assume("command_data is too large")?;
483
484                let meta = CommandMeta {
485                    id: command.id(),
486                    priority: command.priority(),
487                    parent: command.parent(),
488                    policy_length: policy_length as u32,
489                    length: bytes.len() as u32,
490                    max_cut: command.max_cut()?,
491                };
492
493                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
494                commands
495                    .push(meta)
496                    .ok()
497                    .assume("too many commands in segment")?;
498                if commands.is_full() {
499                    break;
500                }
501            }
502        }
503        Ok((commands, command_data, index))
504    }
505
506    fn session_id(&self) -> Result<u128, SyncError> {
507        Ok(self.session_id.assume("session id is set")?)
508    }
509}