aranya_runtime/sync/
requester.rs

1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt;
5use heapless::Vec;
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8use super::{
9    dispatcher::SyncType, responder::SyncResponseMessage, PeerCache, SyncCommand, SyncError,
10    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, REQUEST_MISSING_MAX,
11};
12use crate::{
13    storage::{Segment, Storage, StorageError, StorageProvider},
14    Address, Command, GraphId, Location,
15};
16
17// TODO: Use compile-time args. This initial definition results in this clippy warning:
18// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
19// As the buffer consts will be compile-time variables in the future, we will be
20// able to tune these buffers for smaller footprints. Right now, this enum is not
21// suitable for small devices (`SyncRequest` is 6512 bytes).
22/// Messages sent from the requester to the responder.
23#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26    /// Initiate a new Sync
27    SyncRequest {
28        /// A new random value produced by a cryptographically secure RNG.
29        session_id: u128,
30        /// Specifies the graph to be synced.
31        storage_id: GraphId,
32        /// Specifies the maximum number of bytes worth of commands that
33        /// the requester wishes to receive.
34        max_bytes: u64,
35        /// Sample of the commands held by the requester. The responder should
36        /// respond with any commands that the requester may not have based on
37        /// the provided sample. When sending commands ancestors must be sent
38        /// before descendents.
39        commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40    },
41
42    /// Sent by the requester if it deduces a `SyncResponse` message has been
43    /// dropped.
44    RequestMissing {
45        /// A random-value produced by a cryptographically secure RNG
46        /// corresponding to the `session_id` in the initial `SyncRequest`.
47        session_id: u128,
48        /// `SyncResponse` indexes that the requester has not received.
49        indexes: Vec<u64, REQUEST_MISSING_MAX>,
50    },
51
52    /// Message to request the responder resumes sending `SyncResponse`s
53    /// following the specified message. This may be sent after a requester
54    /// timeout or after a `SyncEnd` has been sent.
55    SyncResume {
56        /// A random-value produced by a cryptographically secure RNG
57        /// corresponding to the `session_id` in the initial `SyncRequest`.
58        session_id: u128,
59        /// Indicates the last response message the requester received.
60        response_index: u64,
61        /// Updates the maximum number of bytes worth of commands that
62        /// the requester wishes to receive.
63        max_bytes: u64,
64    },
65
66    /// Message sent by either requester or responder to indicate the session
67    /// has been terminated or the `session_id` is no longer valid.
68    EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72    pub fn session_id(&self) -> u128 {
73        match self {
74            Self::SyncRequest { session_id, .. } => *session_id,
75            Self::RequestMissing { session_id, .. } => *session_id,
76            Self::SyncResume { session_id, .. } => *session_id,
77            Self::EndSession { session_id, .. } => *session_id,
78        }
79    }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83pub enum SyncRequesterState {
84    New,
85    Start,
86    Waiting,
87    Idle,
88    Closed,
89    Resync,
90    PartialSync,
91    Reset,
92}
93
94// The length of the Out Of Order buffer
95const OOO_LEN: usize = 4;
96pub struct SyncRequester<'a, A> {
97    session_id: u128,
98    storage_id: GraphId,
99    state: SyncRequesterState,
100    max_bytes: u64,
101    next_index: u64,
102    #[allow(unused)] // TODO(jdygert): Figure out what this is for...
103    ooo_buffer: [Option<&'a [u8]>; OOO_LEN],
104    server_address: A,
105}
106
107impl<A: DeserializeOwned + Serialize + Clone> SyncRequester<'_, A> {
108    /// Create a new [`SyncRequester`] with a random session ID.
109    pub fn new<R: Csprng>(storage_id: GraphId, rng: &mut R, server_address: A) -> Self {
110        // Randomly generate session id.
111        let mut dst = [0u8; 16];
112        rng.fill_bytes(&mut dst);
113        let session_id = u128::from_le_bytes(dst);
114
115        SyncRequester {
116            session_id,
117            storage_id,
118            state: SyncRequesterState::New,
119            max_bytes: 0,
120            next_index: 0,
121            ooo_buffer: core::array::from_fn(|_| None),
122            server_address,
123        }
124    }
125
126    /// Create a new [`SyncRequester`] for an existing session.
127    pub fn new_session_id(storage_id: GraphId, session_id: u128, server_address: A) -> Self {
128        SyncRequester {
129            session_id,
130            storage_id,
131            state: SyncRequesterState::Waiting,
132            max_bytes: 0,
133            next_index: 0,
134            ooo_buffer: core::array::from_fn(|_| None),
135            server_address,
136        }
137    }
138
139    /// Returns the server address.
140    pub fn server_addr(&self) -> A {
141        self.server_address.clone()
142    }
143
144    /// Returns true if [`Self::poll`] would produce a message.
145    pub fn ready(&self) -> bool {
146        use SyncRequesterState as S;
147        match self.state {
148            S::New | S::Resync | S::Reset => true,
149            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
150        }
151    }
152
153    /// Write a sync message in to the target buffer. Returns the number
154    /// of bytes written and the number of commands sent in the sample.
155    pub fn poll(
156        &mut self,
157        target: &mut [u8],
158        provider: &mut impl StorageProvider,
159        heads: &mut PeerCache,
160    ) -> Result<(usize, usize), SyncError> {
161        use SyncRequesterState as S;
162        let result = match self.state {
163            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
164                return Err(SyncError::NotReady)
165            }
166            S::New => {
167                self.state = S::Start;
168                self.start(self.max_bytes, target, provider, heads)?
169            }
170            S::Resync => self.resume(self.max_bytes, target)?,
171            S::Reset => {
172                self.state = S::Closed;
173                self.end_session(target)?
174            }
175        };
176
177        Ok(result)
178    }
179
180    /// Receive a sync message. Returns parsed sync commands.
181    pub fn receive<'a>(
182        &mut self,
183        data: &'a [u8],
184    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
185        let (message, remaining): (SyncResponseMessage, &'a [u8]) =
186            postcard::take_from_bytes(data)?;
187
188        self.get_sync_commands(message, remaining)
189    }
190
191    /// Extract SyncCommands from a SyncResponseMessage and remaining bytes.
192    pub fn get_sync_commands<'a>(
193        &mut self,
194        message: SyncResponseMessage,
195        remaining: &'a [u8],
196    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_SAMPLE_MAX>>, SyncError> {
197        if message.session_id() != self.session_id {
198            return Err(SyncError::SessionMismatch);
199        }
200
201        let result = match message {
202            SyncResponseMessage::SyncResponse {
203                index, commands, ..
204            } => {
205                if !matches!(
206                    self.state,
207                    SyncRequesterState::Start | SyncRequesterState::Waiting
208                ) {
209                    return Err(SyncError::SessionState);
210                }
211
212                if index != self.next_index {
213                    self.state = SyncRequesterState::Resync;
214                    return Err(SyncError::MissingSyncResponse);
215                }
216                self.next_index = self
217                    .next_index
218                    .checked_add(1)
219                    .assume("next_index + 1 mustn't overflow")?;
220                self.state = SyncRequesterState::Waiting;
221
222                let mut result = Vec::new();
223                let mut start: usize = 0;
224                for meta in commands {
225                    let policy_len = meta.policy_length as usize;
226
227                    let policy = match policy_len == 0 {
228                        true => None,
229                        false => {
230                            let end = start
231                                .checked_add(policy_len)
232                                .assume("start + policy_len mustn't overflow")?;
233                            let policy = &remaining[start..end];
234                            start = end;
235                            Some(policy)
236                        }
237                    };
238
239                    let len = meta.length as usize;
240                    let end = start
241                        .checked_add(len)
242                        .assume("start + len mustn't overflow")?;
243                    let payload = &remaining[start..end];
244                    start = end;
245
246                    let command = SyncCommand {
247                        id: meta.id,
248                        priority: meta.priority,
249                        parent: meta.parent,
250                        policy,
251                        data: payload,
252                        max_cut: meta.max_cut,
253                    };
254
255                    result
256                        .push(command)
257                        .ok()
258                        .assume("commands is not larger than result")?;
259                }
260
261                Some(result)
262            }
263
264            SyncResponseMessage::SyncEnd { max_index, .. } => {
265                if !matches!(
266                    self.state,
267                    SyncRequesterState::Start | SyncRequesterState::Waiting
268                ) {
269                    return Err(SyncError::SessionState);
270                }
271
272                if max_index
273                    != self
274                        .next_index
275                        .checked_sub(1)
276                        .assume("next_index must be positive")?
277                {
278                    self.state = SyncRequesterState::Resync;
279                    return Err(SyncError::MissingSyncResponse);
280                }
281
282                self.state = SyncRequesterState::PartialSync;
283
284                None
285            }
286
287            SyncResponseMessage::Offer { .. } => {
288                if self.state != SyncRequesterState::Idle {
289                    return Err(SyncError::SessionState);
290                }
291                self.state = SyncRequesterState::Resync;
292
293                None
294            }
295
296            SyncResponseMessage::EndSession { .. } => {
297                self.state = SyncRequesterState::Closed;
298                None
299            }
300        };
301
302        Ok(result)
303    }
304
305    fn write(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
306        Ok(postcard::to_slice(&msg, target)?.len())
307    }
308
309    fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
310        Ok((
311            Self::write(
312                target,
313                SyncType::Poll {
314                    request: SyncRequestMessage::EndSession {
315                        session_id: self.session_id,
316                    },
317                    address: self.server_address.clone(),
318                },
319            )?,
320            0,
321        ))
322    }
323
324    fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
325        if !matches!(
326            self.state,
327            SyncRequesterState::Resync | SyncRequesterState::Idle
328        ) {
329            return Err(SyncError::SessionState);
330        }
331
332        self.state = SyncRequesterState::Waiting;
333        let message = SyncType::Poll {
334            request: SyncRequestMessage::SyncResume {
335                session_id: self.session_id,
336                response_index: self
337                    .next_index
338                    .checked_sub(1)
339                    .assume("next_index must be positive")?,
340                max_bytes,
341            },
342            address: self.server_address.clone(),
343        };
344
345        Ok((Self::write(target, message)?, 0))
346    }
347
348    fn get_commands(
349        &self,
350        provider: &mut impl StorageProvider,
351        heads: &mut PeerCache,
352    ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
353        let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
354
355        match provider.get_storage(self.storage_id) {
356            Err(StorageError::NoSuchStorage) => (),
357            Err(err) => {
358                return Err(SyncError::Storage(err));
359            }
360            Ok(storage) => {
361                let mut command_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
362                for address in heads.heads() {
363                    command_locations
364                        .push(
365                            storage
366                                .get_location(*address)?
367                                .assume("location must exist")?,
368                        )
369                        .ok()
370                        .assume("command locations should not be full")?;
371                    if commands.len() < COMMAND_SAMPLE_MAX {
372                        commands
373                            .push(*address)
374                            .map_err(|_| SyncError::CommandOverflow)?;
375                    }
376                }
377                let head = storage.get_head()?;
378
379                let mut current = vec![head];
380
381                // Here we just get the first command from the most reaseant
382                // COMMAND_SAMPLE_MAX segments in the graph. This is probbly
383                // not the best strategy as if you are far enough ahead of
384                // the other client they will just send you everything they have.
385                while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
386                    let mut next = vec::Vec::new(); //BUG not constant memory
387
388                    'current: for &location in &current {
389                        let segment = storage.get_segment(location)?;
390
391                        let head = segment.head()?;
392                        let head_address = head.address()?;
393                        for loc in &command_locations {
394                            if loc.segment == location.segment {
395                                continue 'current;
396                            }
397                        }
398                        commands
399                            .push(head_address)
400                            .map_err(|_| SyncError::CommandOverflow)?;
401                        next.extend(segment.prior());
402                        if commands.len() >= COMMAND_SAMPLE_MAX {
403                            break 'current;
404                        }
405                    }
406
407                    current = next.to_vec();
408                }
409            }
410        }
411        Ok(commands)
412    }
413
414    /// Writes a Subscribe message to target.
415    pub fn subscribe(
416        &mut self,
417        target: &mut [u8],
418        provider: &mut impl StorageProvider,
419        heads: &mut PeerCache,
420        remain_open: u64,
421        max_bytes: u64,
422    ) -> Result<usize, SyncError> {
423        let commands = self.get_commands(provider, heads)?;
424        let message = SyncType::Subscribe {
425            remain_open,
426            max_bytes,
427            commands,
428            address: self.server_address.clone(),
429            storage_id: self.storage_id,
430        };
431
432        Self::write(target, message)
433    }
434
435    /// Writes an Unsubscribe message to target.
436    pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
437        let message = SyncType::Unsubscribe {
438            address: self.server_address.clone(),
439        };
440
441        Self::write(target, message)
442    }
443
444    fn start(
445        &mut self,
446        max_bytes: u64,
447        target: &mut [u8],
448        provider: &mut impl StorageProvider,
449        heads: &mut PeerCache,
450    ) -> Result<(usize, usize), SyncError> {
451        if !matches!(
452            self.state,
453            SyncRequesterState::Start | SyncRequesterState::New
454        ) {
455            self.state = SyncRequesterState::Reset;
456            return Err(SyncError::SessionState);
457        }
458
459        self.state = SyncRequesterState::Start;
460        self.max_bytes = max_bytes;
461
462        let commands = self.get_commands(provider, heads)?;
463
464        let sent = commands.len();
465        let message = SyncType::Poll {
466            request: SyncRequestMessage::SyncRequest {
467                session_id: self.session_id,
468                storage_id: self.storage_id,
469                max_bytes,
470                commands,
471            },
472            address: self.server_address.clone(),
473        };
474
475        Ok((Self::write(target, message)?, sent))
476    }
477}