aranya_runtime/
protocol.rs

1//! [`Engine`]/[`Policy`] test implementation.
2use alloc::vec::Vec;
3use core::convert::Infallible;
4
5use buggy::bug;
6use postcard::{from_bytes, ser_flavors::Slice, serialize_with_flavor};
7use serde::{Deserialize, Serialize};
8use tracing::trace;
9
10use super::{
11    alloc, Command, CommandId, Engine, EngineError, FactPerspective, Perspective, Policy, PolicyId,
12    Prior, Priority, Sink, StorageError, MAX_COMMAND_LENGTH,
13};
14use crate::{Address, CommandRecall, Keys, MergeIds};
15
16impl From<StorageError> for EngineError {
17    fn from(_: StorageError) -> Self {
18        EngineError::InternalError
19    }
20}
21
22impl From<postcard::Error> for EngineError {
23    fn from(_error: postcard::Error) -> Self {
24        EngineError::Read
25    }
26}
27
28impl From<Infallible> for EngineError {
29    fn from(_error: Infallible) -> Self {
30        EngineError::Write
31    }
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct WireInit {
36    pub nonce: u128,
37    pub policy_num: [u8; 8],
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct WireMerge {
42    pub left: Address,
43    pub right: Address,
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
47pub struct WireBasic {
48    pub parent: Address,
49    pub prority: u32,
50    pub payload: (u64, u64),
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub enum WireProtocol {
55    Init(WireInit),
56    Merge(WireMerge),
57    Basic(WireBasic),
58}
59
60#[derive(Debug, Clone)]
61pub struct TestProtocol<'a> {
62    id: CommandId,
63    command: WireProtocol,
64    data: &'a [u8],
65}
66
67impl Command for TestProtocol<'_> {
68    fn priority(&self) -> Priority {
69        match &self.command {
70            WireProtocol::Init(_) => Priority::Init,
71            WireProtocol::Merge(_) => Priority::Merge,
72            WireProtocol::Basic(m) => Priority::Basic(m.prority),
73        }
74    }
75
76    fn id(&self) -> CommandId {
77        self.id
78    }
79
80    fn parent(&self) -> Prior<Address> {
81        match &self.command {
82            WireProtocol::Init(_) => Prior::None,
83            WireProtocol::Basic(m) => Prior::Single(m.parent),
84            WireProtocol::Merge(m) => Prior::Merge(m.left, m.right),
85        }
86    }
87
88    fn policy(&self) -> Option<&[u8]> {
89        match &self.command {
90            WireProtocol::Init(m) => Some(&m.policy_num),
91            WireProtocol::Merge(_) => None,
92            WireProtocol::Basic(_) => None,
93        }
94    }
95
96    fn bytes(&self) -> &[u8] {
97        self.data
98    }
99}
100
101pub struct TestEngine {
102    policy: TestPolicy,
103}
104
105impl TestEngine {
106    pub fn new() -> TestEngine {
107        TestEngine {
108            policy: TestPolicy::new(0),
109        }
110    }
111}
112
113impl Default for TestEngine {
114    fn default() -> Self {
115        Self::new()
116    }
117}
118
119impl Engine for TestEngine {
120    type Policy = TestPolicy;
121    type Effect = TestEffect;
122
123    fn add_policy(&mut self, policy: &[u8]) -> Result<PolicyId, EngineError> {
124        Ok(PolicyId::new(policy[0] as usize))
125    }
126
127    fn get_policy(&self, _id: PolicyId) -> Result<&Self::Policy, EngineError> {
128        Ok(&self.policy)
129    }
130}
131
132pub struct TestPolicy {
133    serial: u32,
134}
135
136impl TestPolicy {
137    pub fn new(serial: u32) -> Self {
138        TestPolicy { serial }
139    }
140
141    fn origin_check_message(
142        &self,
143        command: &WireBasic,
144        facts: &mut impl FactPerspective,
145    ) -> Result<(), EngineError> {
146        let (group, count) = command.payload;
147
148        let key = group.to_be_bytes();
149        let value = count.to_be_bytes();
150
151        facts.insert("payload".into(), Keys::from_iter([key]), value.into());
152        Ok(())
153    }
154
155    fn call_rule_internal(
156        &self,
157        policy_command: &WireProtocol,
158        facts: &mut impl FactPerspective,
159        sink: &mut impl Sink<<TestPolicy as Policy>::Effect>,
160    ) -> Result<(), EngineError> {
161        if let WireProtocol::Basic(m) = &policy_command {
162            self.origin_check_message(m, facts)?;
163
164            sink.consume(TestEffect::Got(m.payload.1));
165        }
166
167        Ok(())
168    }
169
170    fn init<'a>(&self, target: &'a mut [u8], nonce: u64) -> Result<TestProtocol<'a>, EngineError> {
171        let message = WireInit {
172            nonce: u128::from(nonce),
173            policy_num: nonce.to_le_bytes(),
174        };
175
176        let command = WireProtocol::Init(message);
177        let data = write(target, &command)?;
178        let id = CommandId::hash_for_testing_only(data);
179
180        Ok(TestProtocol { id, command, data })
181    }
182
183    fn basic<'a>(
184        &self,
185        target: &'a mut [u8],
186        parent: Address,
187        payload: (u64, u64),
188    ) -> Result<TestProtocol<'a>, EngineError> {
189        let prority = 0; //BUG
190
191        let message = WireBasic {
192            parent,
193            prority,
194            payload,
195        };
196
197        let command = WireProtocol::Basic(message);
198        let data = write(target, &command)?;
199        let id = CommandId::hash_for_testing_only(data);
200
201        Ok(TestProtocol { id, command, data })
202    }
203}
204
205fn write<'a>(target: &'a mut [u8], message: &WireProtocol) -> Result<&'a mut [u8], EngineError> {
206    Ok(serialize_with_flavor::<WireProtocol, Slice<'a>, &mut [u8]>(
207        message,
208        Slice::new(target),
209    )?)
210}
211
212#[derive(PartialEq, Eq, Debug, Clone)]
213pub enum TestEffect {
214    Got(u64),
215}
216
217#[derive(Debug, Clone)]
218pub struct TestSink {
219    expect: Vec<TestEffect>,
220    ignore_expect: bool,
221}
222
223impl TestSink {
224    pub fn new() -> Self {
225        TestSink {
226            expect: Vec::new(),
227            ignore_expect: false,
228        }
229    }
230
231    pub fn ignore_expectations(&mut self, ignore: bool) {
232        self.ignore_expect = ignore;
233    }
234}
235
236impl Default for TestSink {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl TestSink {
243    pub fn add_expectation(&mut self, expect: TestEffect) {
244        self.expect.push(expect);
245    }
246
247    pub fn count(&self) -> usize {
248        self.expect.len()
249    }
250}
251
252impl Sink<TestEffect> for TestSink {
253    fn begin(&mut self) {
254        //NOOP
255    }
256
257    fn consume(&mut self, effect: TestEffect) {
258        trace!(?effect, "consume");
259        if !self.ignore_expect {
260            let expect = self.expect.remove(0);
261            assert_eq!(effect, expect);
262        }
263    }
264
265    fn rollback(&mut self) {
266        //NOOP
267    }
268
269    fn commit(&mut self) {
270        //NOOP
271    }
272}
273
274#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
275pub enum TestActions {
276    Init(u64),
277    SetValue(u64, u64),
278}
279
280impl Policy for TestPolicy {
281    type Effect = TestEffect;
282    type Action<'a> = TestActions;
283    type Command<'a> = TestProtocol<'a>;
284
285    fn serial(&self) -> u32 {
286        self.serial
287    }
288
289    fn call_rule(
290        &self,
291        command: &impl Command,
292        facts: &mut impl FactPerspective,
293        sink: &mut impl Sink<Self::Effect>,
294        _recall: CommandRecall,
295    ) -> Result<(), EngineError> {
296        let policy_command: WireProtocol = from_bytes(command.bytes())?;
297        self.call_rule_internal(&policy_command, facts, sink)
298    }
299
300    fn merge<'a>(
301        &self,
302        target: &'a mut [u8],
303        ids: MergeIds,
304    ) -> Result<TestProtocol<'a>, EngineError> {
305        let (left, right) = ids.into();
306        let command = WireProtocol::Merge(WireMerge { left, right });
307        let data = write(target, &command)?;
308        let id = CommandId::hash_for_testing_only(data);
309
310        Ok(TestProtocol { id, command, data })
311    }
312
313    fn call_action(
314        &self,
315        action: Self::Action<'_>,
316        facts: &mut impl Perspective,
317        sink: &mut impl Sink<Self::Effect>,
318    ) -> Result<(), EngineError> {
319        let parent = match facts.head_address()? {
320            Prior::None => Address {
321                id: CommandId::default(),
322                max_cut: 0,
323            },
324            Prior::Single(id) => id,
325            Prior::Merge(_, _) => bug!("cannot get merge command in call_action"),
326        };
327        match action {
328            TestActions::Init(nonce) => {
329                let mut buffer = [0u8; MAX_COMMAND_LENGTH];
330                let target = buffer.as_mut_slice();
331                let command = self.init(target, nonce)?;
332
333                self.call_rule_internal(&command.command, facts, sink)?;
334
335                facts.add_command(&command)?;
336            }
337            TestActions::SetValue(key, value) => {
338                let mut buffer = [0u8; MAX_COMMAND_LENGTH];
339                let target = buffer.as_mut_slice();
340                let payload = (key, value);
341                let command = self.basic(target, parent, payload)?;
342
343                self.call_rule_internal(&command.command, facts, sink)?;
344
345                facts.add_command(&command)?;
346            }
347        }
348
349        Ok(())
350    }
351}