aranya_runtime/sync/
responder.rs

1use alloc::vec;
2use core::mem;
3
4use aranya_buggy::{bug, BugExt};
5use heapless::{Deque, Vec};
6use serde::{Deserialize, Serialize};
7
8use super::{
9    requester::SyncRequestMessage, CommandMeta, SyncError, COMMAND_RESPONSE_MAX,
10    COMMAND_SAMPLE_MAX, MAX_SYNC_MESSAGE_SIZE, PEER_HEAD_MAX, SEGMENT_BUFFER_MAX,
11};
12use crate::{
13    command::{Address, Command, CommandId},
14    storage::{GraphId, Location, Segment, Storage, StorageProvider},
15    StorageError,
16};
17
18#[derive(Default, Debug)]
19pub struct PeerCache {
20    heads: Vec<Address, { PEER_HEAD_MAX }>,
21}
22
23impl PeerCache {
24    pub fn new() -> Self {
25        PeerCache { heads: Vec::new() }
26    }
27
28    pub fn heads(&self) -> &[Address] {
29        &self.heads
30    }
31
32    pub fn add_command<S>(
33        &mut self,
34        storage: &mut S,
35        command: Address,
36        cmd_loc: Location,
37    ) -> Result<(), StorageError>
38    where
39        S: Storage,
40    {
41        let mut add_command = true;
42        let mut retain_head = |request_head: &Address, new_head: Location| {
43            let new_head_seg = storage.get_segment(new_head)?;
44            let req_head_loc = storage
45                .get_location(*request_head)?
46                .assume("location must exist")?;
47            let req_head_seg = storage.get_segment(req_head_loc)?;
48            if let Some(new_head_command) = new_head_seg.get_command(new_head) {
49                if request_head.id == new_head_command.address()?.id {
50                    add_command = false;
51                }
52            }
53            if storage.is_ancestor(new_head, &req_head_seg)? {
54                add_command = false;
55            }
56            Ok::<bool, StorageError>(!storage.is_ancestor(req_head_loc, &new_head_seg)?)
57        };
58        self.heads
59            .retain(|h| retain_head(h, cmd_loc).unwrap_or(false));
60        if add_command && !self.heads.is_full() {
61            self.heads
62                .push(command)
63                .ok()
64                .assume("command locations should not be full")?;
65        };
66        Ok(())
67    }
68}
69
70// TODO: Use compile-time args. This initial definition results in this clippy warning:
71// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
72// As the buffer consts will be compile-time variables in the future, we will be
73// able to tune these buffers for smaller footprints. Right now, this enum is not
74// suitable for small devices (`SyncResponse` is 14848 bytes).
75/// Messages sent from the responder to the requester.
76#[derive(Serialize, Deserialize, Debug)]
77#[allow(clippy::large_enum_variant)]
78pub enum SyncResponseMessage {
79    /// Sent in response to a `SyncRequest`
80    SyncResponse {
81        /// A random-value produced by a cryptographically secure RNG.
82        session_id: u128,
83        /// If the responder intends to send a value of command bytes
84        /// greater than the responder's configured maximum, the responder
85        /// will send more than one `SyncResponse`. The first message has an
86        /// index of 1, and each following is incremented.
87        index: u64,
88        /// Commands that the responder believes the requester does not have.
89        commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX>,
90    },
91
92    /// End a sync session if `SyncRequest.max_bytes` has been reached or
93    /// there are no remaining commands to send.
94    SyncEnd {
95        /// A random-value produced by a cryptographically secure RNG
96        /// corresponding to the `session_id` in the initial `SyncRequest`.
97        session_id: u128,
98        /// Largest index of any `SyncResponse`
99        max_index: u64,
100        /// Set `true` if this message was sent due to reaching the `max_bytes`
101        /// budget.
102        remaining: bool,
103    },
104
105    /// Message sent by a responder after a sync has been completed, but before
106    /// the session has ended, if it has new commands in it's graph. If a
107    /// requester wishes to respond to this message, it should do so with a
108    /// new `SyncRequest`. This message may use the existing `session_id`.
109    Offer {
110        /// A random-value produced by a cryptographically secure RNG
111        /// corresponding to the `session_id` in the initial `SyncRequest`.
112        session_id: u128,
113        /// Head of the branch the responder wishes to send.
114        head: CommandId,
115    },
116
117    /// Message sent by either requester or responder to indicate the session
118    /// has been terminated or the `session_id` is no longer valid.
119    EndSession { session_id: u128 },
120}
121
122impl SyncResponseMessage {
123    pub fn session_id(&self) -> u128 {
124        match self {
125            Self::SyncResponse { session_id, .. } => *session_id,
126            Self::SyncEnd { session_id, .. } => *session_id,
127            Self::Offer { session_id, .. } => *session_id,
128            Self::EndSession { session_id, .. } => *session_id,
129        }
130    }
131}
132
133#[derive(Debug)]
134enum SyncResponderState {
135    New,
136    Start,
137    Send,
138    Idle,
139    Reset,
140    Stopped,
141}
142
143impl Default for SyncResponderState {
144    fn default() -> Self {
145        Self::New
146    }
147}
148
149#[derive(Default)]
150pub struct SyncResponder {
151    session_id: Option<u128>,
152    storage_id: Option<GraphId>,
153    state: SyncResponderState,
154    bytes_sent: u64,
155    next_send: usize,
156    has: Vec<Address, COMMAND_SAMPLE_MAX>,
157    to_send: Vec<Location, SEGMENT_BUFFER_MAX>,
158}
159
160impl SyncResponder {
161    /// Create a new [`SyncResponder`].
162    pub fn new() -> Self {
163        SyncResponder {
164            session_id: None,
165            storage_id: None,
166            state: SyncResponderState::New,
167            bytes_sent: 0,
168            next_send: 0,
169            has: Vec::new(),
170            to_send: Vec::new(),
171        }
172    }
173
174    /// Returns true if [`Self::poll`] would produce a message.
175    pub fn ready(&self) -> bool {
176        use SyncResponderState::*;
177        match self.state {
178            Reset | Start | Send => true,
179            New | Idle | Stopped => false,
180        }
181    }
182
183    /// Write a sync message in to the target buffer. Returns the number
184    /// of bytes written.
185    pub fn poll(
186        &mut self,
187        target: &mut [u8],
188        provider: &mut impl StorageProvider,
189        response_cache: &mut PeerCache,
190    ) -> Result<usize, SyncError> {
191        use SyncResponderState as S;
192        let length = match self.state {
193            S::New | S::Idle | S::Stopped => {
194                return Err(SyncError::NotReady);
195            }
196            S::Start => {
197                let Some(storage_id) = self.storage_id else {
198                    self.state = S::Reset;
199                    bug!("poll called before storage_id was set");
200                };
201
202                let storage = match provider.get_storage(storage_id) {
203                    Ok(s) => s,
204                    Err(e) => {
205                        self.state = S::Reset;
206                        return Err(e.into());
207                    }
208                };
209
210                self.state = S::Send;
211                for command in &self.has {
212                    // We only need to check commands that are a part of our graph.
213                    if let Some(cmd_loc) = storage.get_location(*command)? {
214                        response_cache.add_command(storage, *command, cmd_loc)?;
215                    }
216                }
217                self.to_send = SyncResponder::find_needed_segments(&self.has, storage)?;
218
219                self.get_next(target, provider)?
220            }
221            S::Send => self.get_next(target, provider)?,
222            S::Reset => {
223                self.state = S::Stopped;
224                let message = SyncResponseMessage::EndSession {
225                    session_id: self.session_id()?,
226                };
227                Self::write(target, message)?
228            }
229        };
230
231        Ok(length)
232    }
233
234    /// Receive a sync message. Updates the responders state for later polling.
235    pub fn receive(&mut self, data: &[u8]) -> Result<(), SyncError> {
236        let message: SyncRequestMessage = postcard::from_bytes(data)?;
237        if self.session_id.is_none() {
238            self.session_id = Some(message.session_id());
239        }
240        if self.session_id != Some(message.session_id()) {
241            return Err(SyncError::SessionMismatch);
242        }
243
244        match message {
245            SyncRequestMessage::SyncRequest {
246                storage_id,
247                max_bytes,
248                commands,
249                ..
250            } => {
251                self.state = SyncResponderState::Start;
252                self.storage_id = Some(storage_id);
253                self.bytes_sent = max_bytes;
254                self.to_send = Vec::new();
255                self.has = commands;
256                self.next_send = 0;
257                return Ok(());
258            }
259            SyncRequestMessage::RequestMissing { .. } => {
260                todo!()
261            }
262            SyncRequestMessage::SyncResume { .. } => {
263                todo!()
264            }
265            SyncRequestMessage::EndSession { .. } => {
266                self.state = SyncResponderState::Stopped;
267            }
268        };
269
270        Ok(())
271    }
272
273    fn write(target: &mut [u8], msg: SyncResponseMessage) -> Result<usize, SyncError> {
274        Ok(postcard::to_slice(&msg, target)?.len())
275    }
276
277    fn find_needed_segments(
278        commands: &[Address],
279        storage: &impl Storage,
280    ) -> Result<Vec<Location, SEGMENT_BUFFER_MAX>, SyncError> {
281        let mut have_locations = vec::Vec::new(); //BUG: not constant size
282        for &addr in commands {
283            let Some(location) = storage.get_location(addr)? else {
284                // Note: We could use things we don't
285                // have as a hint to know we should
286                // perform a sync request.
287                continue;
288            };
289
290            have_locations.push(location);
291        }
292
293        let mut heads = vec::Vec::new();
294        heads.push(storage.get_head()?);
295
296        let mut result: Deque<Location, SEGMENT_BUFFER_MAX> = Deque::new();
297
298        while !heads.is_empty() {
299            let current = mem::take(&mut heads);
300            'heads: for head in current {
301                let segment = storage.get_segment(head)?;
302                if segment.contains_any(&result) {
303                    continue 'heads;
304                }
305
306                for &location in &have_locations {
307                    if segment.contains(location) {
308                        if location != segment.head_location() {
309                            if result.is_full() {
310                                result.pop_back();
311                            }
312                            result
313                                .push_front(location)
314                                .ok()
315                                .assume("too many segments")?;
316                        }
317                        continue 'heads;
318                    }
319                }
320                heads.extend(segment.prior());
321
322                if result.is_full() {
323                    result.pop_back();
324                }
325
326                let location = segment.first_location();
327                result
328                    .push_front(location)
329                    .ok()
330                    .assume("too many segments")?;
331            }
332        }
333        let mut r: Vec<Location, SEGMENT_BUFFER_MAX> = Vec::new();
334        for l in result {
335            r.push(l).ok().assume("too many segments")?;
336        }
337        // Order segments to ensure that a segment isn't received before its
338        // ancestor segments.
339        r.sort();
340        Ok(r)
341    }
342
343    fn get_next(
344        &mut self,
345        target: &mut [u8],
346        provider: &mut impl StorageProvider,
347    ) -> Result<usize, SyncError> {
348        let Some(storage_id) = self.storage_id else {
349            self.state = SyncResponderState::Reset;
350            bug!("get_next called before storage_id was set");
351        };
352
353        let storage = match provider.get_storage(storage_id) {
354            Ok(s) => s,
355            Err(e) => {
356                self.state = SyncResponderState::Reset;
357                return Err(e.into());
358            }
359        };
360
361        if self.next_send >= self.to_send.len() {
362            self.state = SyncResponderState::Idle;
363            return Ok(0);
364        }
365
366        let mut commands: Vec<CommandMeta, COMMAND_RESPONSE_MAX> = Vec::new();
367        let mut command_data: Vec<u8, MAX_SYNC_MESSAGE_SIZE> = Vec::new();
368        let mut index = self.next_send;
369
370        for i in self.next_send..self.to_send.len() {
371            if commands.is_full() {
372                break;
373            }
374            index = index.checked_add(1).assume("index + 1 mustn't overflow")?;
375            let Some(&location) = self.to_send.get(i) else {
376                self.state = SyncResponderState::Reset;
377                bug!("send index OOB");
378            };
379
380            let segment = storage
381                .get_segment(location)
382                .inspect_err(|_| self.state = SyncResponderState::Reset)?;
383
384            let found = segment.get_from(location);
385
386            for command in &found {
387                let mut policy_length = 0;
388
389                if let Some(policy) = command.policy() {
390                    policy_length = policy.len();
391                    command_data
392                        .extend_from_slice(policy)
393                        .ok()
394                        .assume("command_data is too large")?;
395                }
396
397                let bytes = command.bytes();
398                command_data
399                    .extend_from_slice(bytes)
400                    .ok()
401                    .assume("command_data is too large")?;
402
403                let meta = CommandMeta {
404                    id: command.id(),
405                    priority: command.priority(),
406                    parent: command.parent(),
407                    policy_length: policy_length as u32,
408                    length: bytes.len() as u32,
409                    max_cut: command.max_cut()?,
410                };
411
412                // FIXME(jdygert): Handle segments with more than COMMAND_RESPONSE_MAX commands.
413                commands
414                    .push(meta)
415                    .ok()
416                    .assume("too many commands in segment")?;
417                if commands.is_full() {
418                    break;
419                }
420            }
421        }
422
423        let message = SyncResponseMessage::SyncResponse {
424            session_id: self.session_id()?,
425            index: self.next_send as u64,
426            commands,
427        };
428
429        self.next_send = index;
430
431        let mut length = Self::write(target, message)?;
432        let total_length = length
433            .checked_add(command_data.len())
434            .assume("length + command_data_length mustn't overflow")?;
435        target
436            .get_mut(length..total_length)
437            .assume("sync message fits in target")?
438            .copy_from_slice(&command_data);
439        length = total_length;
440        Ok(length)
441    }
442
443    fn session_id(&self) -> Result<u128, SyncError> {
444        Ok(self.session_id.assume("session id is set")?)
445    }
446}