aranya_runtime/
protocol.rs

1//! [`Engine`]/[`Policy`] test implementation.
2use alloc::vec::Vec;
3use core::convert::Infallible;
4
5use buggy::bug;
6use postcard::{from_bytes, ser_flavors::Slice, serialize_with_flavor};
7use serde::{Deserialize, Serialize};
8use tracing::trace;
9
10use super::{
11    alloc, Command, CommandId, Engine, EngineError, FactPerspective, Perspective, Policy, PolicyId,
12    Prior, Priority, Sink, StorageError, MAX_COMMAND_LENGTH,
13};
14use crate::{Address, CommandRecall, Keys, MergeIds};
15
16impl From<StorageError> for EngineError {
17    fn from(_: StorageError) -> Self {
18        EngineError::InternalError
19    }
20}
21
22impl From<postcard::Error> for EngineError {
23    fn from(_error: postcard::Error) -> Self {
24        EngineError::Read
25    }
26}
27
28impl From<Infallible> for EngineError {
29    fn from(_error: Infallible) -> Self {
30        EngineError::Write
31    }
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct WireInit {
36    pub nonce: u128,
37    pub policy_num: [u8; 8],
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct WireMerge {
42    pub left: Address,
43    pub right: Address,
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
47pub struct WireBasic {
48    pub parent: Address,
49    pub prority: u32,
50    pub payload: (u64, u64),
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub enum WireProtocol {
55    Init(WireInit),
56    Merge(WireMerge),
57    Basic(WireBasic),
58}
59
60#[derive(Debug, Clone)]
61pub struct TestProtocol<'a> {
62    id: CommandId,
63    command: WireProtocol,
64    data: &'a [u8],
65}
66
67impl Command for TestProtocol<'_> {
68    fn priority(&self) -> Priority {
69        match &self.command {
70            WireProtocol::Init(_) => Priority::Init,
71            WireProtocol::Merge(_) => Priority::Merge,
72            WireProtocol::Basic(m) => Priority::Basic(m.prority),
73        }
74    }
75
76    fn id(&self) -> CommandId {
77        self.id
78    }
79
80    fn parent(&self) -> Prior<Address> {
81        match &self.command {
82            WireProtocol::Init(_) => Prior::None,
83            WireProtocol::Basic(m) => Prior::Single(m.parent),
84            WireProtocol::Merge(m) => Prior::Merge(m.left, m.right),
85        }
86    }
87
88    fn policy(&self) -> Option<&[u8]> {
89        match &self.command {
90            WireProtocol::Init(m) => Some(&m.policy_num),
91            WireProtocol::Merge(_) => None,
92            WireProtocol::Basic(_) => None,
93        }
94    }
95
96    fn bytes(&self) -> &[u8] {
97        self.data
98    }
99}
100
101pub struct TestEngine {
102    policy: TestPolicy,
103}
104
105impl TestEngine {
106    pub fn new() -> TestEngine {
107        TestEngine {
108            policy: TestPolicy::new(0),
109        }
110    }
111}
112
113impl Default for TestEngine {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl Engine for TestEngine {
120    type Policy = TestPolicy;
121    type Effect = TestEffect;
122
123    fn add_policy(&mut self, policy: &[u8]) -> Result<PolicyId, EngineError> {
124        Ok(PolicyId::new(policy[0] as usize))
125    }
126
127    fn get_policy(&self, _id: PolicyId) -> Result<&Self::Policy, EngineError> {
128        Ok(&self.policy)
129    }
130}
131
132pub struct TestPolicy {
133    serial: u32,
134}
135
136impl TestPolicy {
137    pub fn new(serial: u32) -> Self {
138        TestPolicy { serial }
139    }
140
141    fn origin_check_message(
142        &self,
143        command: &WireBasic,
144        facts: &mut impl FactPerspective,
145    ) -> Result<(), EngineError> {
146        let (group, count) = command.payload;
147
148        let key = group.to_be_bytes();
149        let value = count.to_be_bytes();
150
151        facts.insert("payload".into(), Keys::from_iter([key]), value.into());
152        Ok(())
153    }
154
155    fn call_rule_internal(
156        &self,
157        policy_command: &WireProtocol,
158        facts: &mut impl FactPerspective,
159        sink: &mut impl Sink<<TestPolicy as Policy>::Effect>,
160    ) -> Result<(), EngineError> {
161        if let WireProtocol::Basic(m) = &policy_command {
162            self.origin_check_message(m, facts)?;
163
164            sink.consume(TestEffect::Got(m.payload.1));
165        }
166
167        Ok(())
168    }
169
170    fn init<'a>(&self, target: &'a mut [u8], nonce: u64) -> Result<TestProtocol<'a>, EngineError> {
171        let message = WireInit {
172            nonce: u128::from(nonce),
173            policy_num: nonce.to_le_bytes(),
174        };
175
176        let command = WireProtocol::Init(message);
177        let data = write(target, &command)?;
178        let id = CommandId::hash_for_testing_only(data);
179
180        Ok(TestProtocol { id, command, data })
181    }
182
183    fn basic<'a>(
184        &self,
185        target: &'a mut [u8],
186        parent: Address,
187        payload: (u64, u64),
188    ) -> Result<TestProtocol<'a>, EngineError> {
189        let prority = 0; //BUG
190
191        let message = WireBasic {
192            parent,
193            prority,
194            payload,
195        };
196
197        let command = WireProtocol::Basic(message);
198        let data = write(target, &command)?;
199        let id = CommandId::hash_for_testing_only(data);
200
201        Ok(TestProtocol { id, command, data })
202    }
203}
204
205fn write<'a>(target: &'a mut [u8], message: &WireProtocol) -> Result<&'a mut [u8], EngineError> {
206    Ok(serialize_with_flavor::<WireProtocol, Slice<'a>, &mut [u8]>(
207        message,
208        Slice::new(target),
209    )?)
210}
211
212#[derive(PartialEq, Eq, Debug, Clone)]
213pub enum TestEffect {
214    Got(u64),
215}
216
217#[derive(Debug, Clone)]
218pub struct TestSink {
219    expect: Vec<TestEffect>,
220    ignore_expect: bool,
221}
222
223impl TestSink {
224    pub fn new() -> Self {
225        TestSink {
226            expect: Vec::new(),
227            ignore_expect: false,
228        }
229    }
230
231    pub fn ignore_expectations(&mut self, ignore: bool) {
232        self.ignore_expect = ignore;
233    }
234}
235
236impl Default for TestSink {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl TestSink {
243    pub fn add_expectation(&mut self, expect: TestEffect) {
244        self.expect.push(expect);
245    }
246
247    pub fn count(&self) -> usize {
248        self.expect.len()
249    }
250}
251
252impl Sink<TestEffect> for TestSink {
253    fn begin(&mut self) {
254        //NOOP
255    }
256
257    fn consume(&mut self, effect: TestEffect) {
258        trace!(?effect, "consume");
259        if !self.ignore_expect {
260            assert!(!self.expect.is_empty(), "consumed {effect:?} while empty");
261            let expect = self.expect.remove(0);
262            trace!(consuming = ?effect, expected = ?expect, remainder = ?self.expect);
263            assert_eq!(
264                effect, expect,
265                "consumed {effect:?} while expecting {expect:?}"
266            );
267        }
268    }
269
270    fn rollback(&mut self) {
271        //NOOP
272    }
273
274    fn commit(&mut self) {
275        //NOOP
276    }
277}
278
279#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
280pub enum TestActions {
281    Init(u64),
282    SetValue(u64, u64),
283}
284
285impl Policy for TestPolicy {
286    type Effect = TestEffect;
287    type Action<'a> = TestActions;
288    type Command<'a> = TestProtocol<'a>;
289
290    fn serial(&self) -> u32 {
291        self.serial
292    }
293
294    fn call_rule(
295        &self,
296        command: &impl Command,
297        facts: &mut impl FactPerspective,
298        sink: &mut impl Sink<Self::Effect>,
299        _recall: CommandRecall,
300    ) -> Result<(), EngineError> {
301        let policy_command: WireProtocol = from_bytes(command.bytes())?;
302        self.call_rule_internal(&policy_command, facts, sink)
303    }
304
305    fn merge<'a>(
306        &self,
307        target: &'a mut [u8],
308        ids: MergeIds,
309    ) -> Result<TestProtocol<'a>, EngineError> {
310        let (left, right) = ids.into();
311        let command = WireProtocol::Merge(WireMerge { left, right });
312        let data = write(target, &command)?;
313        let id = CommandId::hash_for_testing_only(data);
314
315        Ok(TestProtocol { id, command, data })
316    }
317
318    fn call_action(
319        &self,
320        action: Self::Action<'_>,
321        facts: &mut impl Perspective,
322        sink: &mut impl Sink<Self::Effect>,
323    ) -> Result<(), EngineError> {
324        let parent = match facts.head_address()? {
325            Prior::None => Address {
326                id: CommandId::default(),
327                max_cut: 0,
328            },
329            Prior::Single(id) => id,
330            Prior::Merge(_, _) => bug!("cannot get merge command in call_action"),
331        };
332        match action {
333            TestActions::Init(nonce) => {
334                let mut buffer = [0u8; MAX_COMMAND_LENGTH];
335                let target = buffer.as_mut_slice();
336                let command = self.init(target, nonce)?;
337
338                self.call_rule_internal(&command.command, facts, sink)?;
339
340                facts.add_command(&command)?;
341            }
342            TestActions::SetValue(key, value) => {
343                let mut buffer = [0u8; MAX_COMMAND_LENGTH];
344                let target = buffer.as_mut_slice();
345                let payload = (key, value);
346                let command = self.basic(target, parent, payload)?;
347
348                self.call_rule_internal(&command.command, facts, sink)?;
349
350                facts.add_command(&command)?;
351            }
352        }
353
354        Ok(())
355    }
356}