Skip to main content

aranya_runtime/sync/
requester.rs

1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt as _;
5use heapless::Vec;
6use serde::{Deserialize, Serialize};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, PeerCache, REQUEST_MISSING_MAX,
10    SyncCommand, SyncError, dispatcher::SyncType, responder::SyncResponseMessage,
11};
12use crate::{
13    Address, GraphId, Location,
14    storage::{Segment as _, Storage as _, StorageError, StorageProvider},
15};
16
17// TODO: Use compile-time args. This initial definition results in this clippy warning:
18// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
19// As the buffer consts will be compile-time variables in the future, we will be
20// able to tune these buffers for smaller footprints. Right now, this enum is not
21// suitable for small devices (`SyncRequest` is 4080 bytes).
22/// Messages sent from the requester to the responder.
23#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26    /// Initiate a new Sync
27    SyncRequest {
28        /// A new random value produced by a cryptographically secure RNG.
29        session_id: u128,
30        /// Specifies the graph to be synced.
31        graph_id: GraphId,
32        /// Specifies the maximum number of bytes worth of commands that
33        /// the requester wishes to receive.
34        max_bytes: u64,
35        /// Sample of the commands held by the requester. The responder should
36        /// respond with any commands that the requester may not have based on
37        /// the provided sample. When sending commands ancestors must be sent
38        /// before descendents.
39        commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40    },
41
42    /// Sent by the requester if it deduces a `SyncResponse` message has been
43    /// dropped.
44    RequestMissing {
45        /// A random-value produced by a cryptographically secure RNG
46        /// corresponding to the `session_id` in the initial `SyncRequest`.
47        session_id: u128,
48        /// `SyncResponse` indexes that the requester has not received.
49        indexes: Vec<u64, REQUEST_MISSING_MAX>,
50    },
51
52    /// Message to request the responder resumes sending `SyncResponse`s
53    /// following the specified message. This may be sent after a requester
54    /// timeout or after a `SyncEnd` has been sent.
55    SyncResume {
56        /// A random-value produced by a cryptographically secure RNG
57        /// corresponding to the `session_id` in the initial `SyncRequest`.
58        session_id: u128,
59        /// Indicates the last response message the requester received.
60        response_index: u64,
61        /// Updates the maximum number of bytes worth of commands that
62        /// the requester wishes to receive.
63        max_bytes: u64,
64    },
65
66    /// Message sent by either requester or responder to indicate the session
67    /// has been terminated or the `session_id` is no longer valid.
68    EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72    pub fn session_id(&self) -> u128 {
73        match self {
74            Self::SyncRequest { session_id, .. } => *session_id,
75            Self::RequestMissing { session_id, .. } => *session_id,
76            Self::SyncResume { session_id, .. } => *session_id,
77            Self::EndSession { session_id, .. } => *session_id,
78        }
79    }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83enum SyncRequesterState {
84    /// Object initialized; no messages sent
85    New,
86    /// Have sent a start message but have received no response
87    Start,
88    /// Received some response messages; session not yet complete
89    Waiting,
90    /// Sync is complete; byte count has not yet been reached
91    Idle,
92    /// Session is done and no more messages should be processed
93    Closed,
94    /// Sync response received out of order
95    Resync,
96    /// SyncEnd received but not everything has been received
97    PartialSync,
98    /// Session will be closed
99    Reset,
100}
101
102pub struct SyncRequester {
103    session_id: u128,
104    graph_id: GraphId,
105    state: SyncRequesterState,
106    max_bytes: u64,
107    next_message_index: u64,
108}
109
110impl SyncRequester {
111    /// Create a new [`SyncRequester`] with a random session ID.
112    pub fn new<R: Csprng>(graph_id: GraphId, rng: R) -> Self {
113        // Randomly generate session id.
114        let mut dst = [0u8; 16];
115        rng.fill_bytes(&mut dst);
116        let session_id = u128::from_le_bytes(dst);
117
118        Self {
119            session_id,
120            graph_id,
121            state: SyncRequesterState::New,
122            max_bytes: 0,
123            next_message_index: 0,
124        }
125    }
126
127    /// Create a new [`SyncRequester`] for an existing session.
128    pub fn new_session_id(graph_id: GraphId, session_id: u128) -> Self {
129        Self {
130            session_id,
131            graph_id,
132            state: SyncRequesterState::Waiting,
133            max_bytes: 0,
134            next_message_index: 0,
135        }
136    }
137
138    /// Returns true if [`Self::poll`] would produce a message.
139    pub fn ready(&self) -> bool {
140        use SyncRequesterState as S;
141        match self.state {
142            S::New | S::Resync | S::Reset => true,
143            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
144        }
145    }
146
147    /// Write a sync message in to the target buffer. Returns the number
148    /// of bytes written and the number of commands sent in the sample.
149    pub fn poll(
150        &mut self,
151        target: &mut [u8],
152        provider: &mut impl StorageProvider,
153        heads: &mut PeerCache,
154    ) -> Result<(usize, usize), SyncError> {
155        use SyncRequesterState as S;
156        let result = match self.state {
157            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
158                return Err(SyncError::NotReady);
159            }
160            S::New => {
161                self.state = S::Start;
162                self.start(self.max_bytes, target, provider, heads)?
163            }
164            S::Resync => self.resume(self.max_bytes, target)?,
165            S::Reset => {
166                self.state = S::Closed;
167                self.end_session(target)?
168            }
169        };
170
171        Ok(result)
172    }
173
174    /// Receive a sync message. Returns parsed sync commands.
175    pub fn receive<'a>(
176        &mut self,
177        data: &'a [u8],
178    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
179        let (message, remaining): (SyncResponseMessage, &'a [u8]) =
180            postcard::take_from_bytes(data)?;
181
182        self.get_sync_commands(message, remaining)
183    }
184
185    /// Extract SyncCommands from a SyncResponseMessage and remaining bytes.
186    pub fn get_sync_commands<'a>(
187        &mut self,
188        message: SyncResponseMessage,
189        remaining: &'a [u8],
190    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
191        if message.session_id() != self.session_id {
192            return Err(SyncError::SessionMismatch);
193        }
194
195        let result = match message {
196            SyncResponseMessage::SyncResponse {
197                response_index,
198                commands,
199                ..
200            } => {
201                if !matches!(
202                    self.state,
203                    SyncRequesterState::Start | SyncRequesterState::Waiting
204                ) {
205                    return Err(SyncError::SessionState);
206                }
207
208                if response_index != self.next_message_index {
209                    self.state = SyncRequesterState::Resync;
210                    return Err(SyncError::MissingSyncResponse);
211                }
212                self.next_message_index = self
213                    .next_message_index
214                    .checked_add(1)
215                    .assume("next_message_index + 1 mustn't overflow")?;
216                self.state = SyncRequesterState::Waiting;
217
218                let mut result = Vec::new();
219                let mut start: usize = 0;
220                for meta in commands {
221                    let policy_len = meta.policy_length as usize;
222
223                    let policy = match policy_len == 0 {
224                        true => None,
225                        false => {
226                            let end = start
227                                .checked_add(policy_len)
228                                .assume("start + policy_len mustn't overflow")?;
229                            let policy = &remaining[start..end];
230                            start = end;
231                            Some(policy)
232                        }
233                    };
234
235                    let len = meta.length as usize;
236                    let end = start
237                        .checked_add(len)
238                        .assume("start + len mustn't overflow")?;
239                    let payload = &remaining[start..end];
240                    start = end;
241
242                    let command = SyncCommand {
243                        id: meta.id,
244                        priority: meta.priority,
245                        parent: meta.parent,
246                        policy,
247                        data: payload,
248                        max_cut: meta.max_cut,
249                    };
250
251                    result
252                        .push(command)
253                        .ok()
254                        .assume("commands is not larger than result")?;
255                }
256
257                Some(result)
258            }
259
260            SyncResponseMessage::SyncEnd { max_index, .. } => {
261                if !matches!(
262                    self.state,
263                    SyncRequesterState::Start | SyncRequesterState::Waiting
264                ) {
265                    return Err(SyncError::SessionState);
266                }
267
268                if max_index != self.next_message_index {
269                    self.state = SyncRequesterState::Resync;
270                    return Err(SyncError::MissingSyncResponse);
271                }
272
273                self.state = SyncRequesterState::PartialSync;
274
275                None
276            }
277
278            SyncResponseMessage::Offer { .. } => {
279                if self.state != SyncRequesterState::Idle {
280                    return Err(SyncError::SessionState);
281                }
282                self.state = SyncRequesterState::Resync;
283
284                None
285            }
286
287            SyncResponseMessage::EndSession { .. } => {
288                self.state = SyncRequesterState::Closed;
289                None
290            }
291        };
292
293        Ok(result)
294    }
295
296    fn write(target: &mut [u8], msg: SyncType) -> Result<usize, SyncError> {
297        Ok(postcard::to_slice(&msg, target)?.len())
298    }
299
300    fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
301        Ok((
302            Self::write(
303                target,
304                SyncType::Poll {
305                    request: SyncRequestMessage::EndSession {
306                        session_id: self.session_id,
307                    },
308                },
309            )?,
310            0,
311        ))
312    }
313
314    fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
315        if !matches!(
316            self.state,
317            SyncRequesterState::Resync | SyncRequesterState::Idle
318        ) {
319            return Err(SyncError::SessionState);
320        }
321
322        self.state = SyncRequesterState::Waiting;
323        let message = SyncType::Poll {
324            request: SyncRequestMessage::SyncResume {
325                session_id: self.session_id,
326                response_index: self
327                    .next_message_index
328                    .checked_sub(1)
329                    .assume("next_index must be positive")?,
330                max_bytes,
331            },
332        };
333
334        Ok((Self::write(target, message)?, 0))
335    }
336
337    fn get_commands(
338        &self,
339        provider: &mut impl StorageProvider,
340        peer_cache: &mut PeerCache,
341    ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
342        let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
343
344        match provider.get_storage(self.graph_id) {
345            Err(StorageError::NoSuchStorage) => (),
346            Err(err) => {
347                return Err(SyncError::Storage(err));
348            }
349            Ok(storage) => {
350                let mut cache_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
351                for address in peer_cache.heads() {
352                    let loc = storage
353                        .get_location(*address)?
354                        .assume("location must exist")?;
355                    cache_locations
356                        .push(loc)
357                        .ok()
358                        .assume("command locations should not be full")?;
359                    if commands.len() < COMMAND_SAMPLE_MAX {
360                        commands
361                            .push(*address)
362                            .map_err(|_| SyncError::CommandOverflow)?;
363                    }
364                }
365                // Start traversal from graph head and work backwards until we hit a PeerCache head
366                let head = storage.get_head()?;
367                let mut current = vec![head];
368
369                // Here we just get the first command from the most reaseant
370                // COMMAND_SAMPLE_MAX segments in the graph. This is probbly
371                // not the best strategy as if you are far enough ahead of
372                // the other client they will just send you everything they have.
373                while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
374                    let mut next = vec::Vec::new(); //BUG not constant memory
375
376                    'current: for &location in &current {
377                        // Check if we've hit a PeerCache head - if so, stop traversing this path
378                        for &peer_cache_loc in &cache_locations {
379                            // If this is the same location as a PeerCache head, we've hit it
380                            if location == peer_cache_loc {
381                                continue 'current;
382                            }
383                            // If the current location is an ancestor of a PeerCache head,
384                            // we've passed the PeerCache head, so stop traversing this path
385                            let peer_cache_segment = storage.get_segment(peer_cache_loc)?;
386                            if (peer_cache_loc.same_segment(location)
387                                && location.max_cut <= peer_cache_loc.max_cut)
388                                || storage.is_ancestor(location, &peer_cache_segment)?
389                            {
390                                continue 'current;
391                            }
392                        }
393
394                        let segment = storage.get_segment(location)?;
395                        commands
396                            .push(segment.head_address()?)
397                            .map_err(|_| SyncError::CommandOverflow)?;
398                        next.extend(segment.prior());
399                        if commands.len() >= COMMAND_SAMPLE_MAX {
400                            break 'current;
401                        }
402                    }
403
404                    current.clone_from(&next);
405                }
406            }
407        }
408
409        Ok(commands)
410    }
411
412    /// Writes a Subscribe message to target.
413    pub fn subscribe(
414        &mut self,
415        target: &mut [u8],
416        provider: &mut impl StorageProvider,
417        heads: &mut PeerCache,
418        remain_open: u64,
419        max_bytes: u64,
420    ) -> Result<usize, SyncError> {
421        let commands = self.get_commands(provider, heads)?;
422        let message = SyncType::Subscribe {
423            remain_open,
424            max_bytes,
425            commands,
426            graph_id: self.graph_id,
427        };
428
429        Self::write(target, message)
430    }
431
432    /// Writes an Unsubscribe message to target.
433    pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
434        let message = SyncType::Unsubscribe {
435            graph_id: self.graph_id,
436        };
437
438        Self::write(target, message)
439    }
440
441    fn start(
442        &mut self,
443        max_bytes: u64,
444        target: &mut [u8],
445        provider: &mut impl StorageProvider,
446        heads: &mut PeerCache,
447    ) -> Result<(usize, usize), SyncError> {
448        if !matches!(
449            self.state,
450            SyncRequesterState::Start | SyncRequesterState::New
451        ) {
452            self.state = SyncRequesterState::Reset;
453            return Err(SyncError::SessionState);
454        }
455
456        self.state = SyncRequesterState::Start;
457        self.max_bytes = max_bytes;
458
459        let command_sample = self.get_commands(provider, heads)?;
460
461        let sent = command_sample.len();
462        let message = SyncType::Poll {
463            request: SyncRequestMessage::SyncRequest {
464                session_id: self.session_id,
465                graph_id: self.graph_id,
466                max_bytes,
467                commands: command_sample,
468            },
469        };
470
471        Ok((Self::write(target, message)?, sent))
472    }
473}