1use alloc::vec::Vec;
3use core::convert::Infallible;
4
5use buggy::bug;
6use postcard::{from_bytes, ser_flavors::Slice, serialize_with_flavor};
7use serde::{Deserialize, Serialize};
8use tracing::trace;
9
10use super::{
11 alloc, Command, CommandId, Engine, EngineError, FactPerspective, Perspective, Policy, PolicyId,
12 Prior, Priority, Sink, StorageError, MAX_COMMAND_LENGTH,
13};
14use crate::{Address, CommandRecall, Keys, MergeIds};
15
16impl From<StorageError> for EngineError {
17 fn from(_: StorageError) -> Self {
18 EngineError::InternalError
19 }
20}
21
22impl From<postcard::Error> for EngineError {
23 fn from(_error: postcard::Error) -> Self {
24 EngineError::Read
25 }
26}
27
28impl From<Infallible> for EngineError {
29 fn from(_error: Infallible) -> Self {
30 EngineError::Write
31 }
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct WireInit {
36 pub nonce: u128,
37 pub policy_num: [u8; 8],
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct WireMerge {
42 pub left: Address,
43 pub right: Address,
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
47pub struct WireBasic {
48 pub parent: Address,
49 pub prority: u32,
50 pub payload: (u64, u64),
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub enum WireProtocol {
55 Init(WireInit),
56 Merge(WireMerge),
57 Basic(WireBasic),
58}
59
60#[derive(Debug, Clone)]
61pub struct TestProtocol<'a> {
62 id: CommandId,
63 command: WireProtocol,
64 data: &'a [u8],
65}
66
67impl Command for TestProtocol<'_> {
68 fn priority(&self) -> Priority {
69 match &self.command {
70 WireProtocol::Init(_) => Priority::Init,
71 WireProtocol::Merge(_) => Priority::Merge,
72 WireProtocol::Basic(m) => Priority::Basic(m.prority),
73 }
74 }
75
76 fn id(&self) -> CommandId {
77 self.id
78 }
79
80 fn parent(&self) -> Prior<Address> {
81 match &self.command {
82 WireProtocol::Init(_) => Prior::None,
83 WireProtocol::Basic(m) => Prior::Single(m.parent),
84 WireProtocol::Merge(m) => Prior::Merge(m.left, m.right),
85 }
86 }
87
88 fn policy(&self) -> Option<&[u8]> {
89 match &self.command {
90 WireProtocol::Init(m) => Some(&m.policy_num),
91 WireProtocol::Merge(_) => None,
92 WireProtocol::Basic(_) => None,
93 }
94 }
95
96 fn bytes(&self) -> &[u8] {
97 self.data
98 }
99}
100
101pub struct TestEngine {
102 policy: TestPolicy,
103}
104
105impl TestEngine {
106 pub fn new() -> TestEngine {
107 TestEngine {
108 policy: TestPolicy::new(0),
109 }
110 }
111}
112
113impl Default for TestEngine {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl Engine for TestEngine {
120 type Policy = TestPolicy;
121 type Effect = TestEffect;
122
123 fn add_policy(&mut self, policy: &[u8]) -> Result<PolicyId, EngineError> {
124 Ok(PolicyId::new(policy[0] as usize))
125 }
126
127 fn get_policy(&self, _id: PolicyId) -> Result<&Self::Policy, EngineError> {
128 Ok(&self.policy)
129 }
130}
131
132pub struct TestPolicy {
133 serial: u32,
134}
135
136impl TestPolicy {
137 pub fn new(serial: u32) -> Self {
138 TestPolicy { serial }
139 }
140
141 fn origin_check_message(
142 &self,
143 command: &WireBasic,
144 facts: &mut impl FactPerspective,
145 ) -> Result<(), EngineError> {
146 let (group, count) = command.payload;
147
148 let key = group.to_be_bytes();
149 let value = count.to_be_bytes();
150
151 facts.insert("payload".into(), Keys::from_iter([key]), value.into());
152 Ok(())
153 }
154
155 fn call_rule_internal(
156 &self,
157 policy_command: &WireProtocol,
158 facts: &mut impl FactPerspective,
159 sink: &mut impl Sink<<TestPolicy as Policy>::Effect>,
160 ) -> Result<(), EngineError> {
161 if let WireProtocol::Basic(m) = &policy_command {
162 self.origin_check_message(m, facts)?;
163
164 sink.consume(TestEffect::Got(m.payload.1));
165 }
166
167 Ok(())
168 }
169
170 fn init<'a>(&self, target: &'a mut [u8], nonce: u64) -> Result<TestProtocol<'a>, EngineError> {
171 let message = WireInit {
172 nonce: u128::from(nonce),
173 policy_num: nonce.to_le_bytes(),
174 };
175
176 let command = WireProtocol::Init(message);
177 let data = write(target, &command)?;
178 let id = CommandId::hash_for_testing_only(data);
179
180 Ok(TestProtocol { id, command, data })
181 }
182
183 fn basic<'a>(
184 &self,
185 target: &'a mut [u8],
186 parent: Address,
187 payload: (u64, u64),
188 ) -> Result<TestProtocol<'a>, EngineError> {
189 let prority = 0; let message = WireBasic {
192 parent,
193 prority,
194 payload,
195 };
196
197 let command = WireProtocol::Basic(message);
198 let data = write(target, &command)?;
199 let id = CommandId::hash_for_testing_only(data);
200
201 Ok(TestProtocol { id, command, data })
202 }
203}
204
205fn write<'a>(target: &'a mut [u8], message: &WireProtocol) -> Result<&'a mut [u8], EngineError> {
206 Ok(serialize_with_flavor::<WireProtocol, Slice<'a>, &mut [u8]>(
207 message,
208 Slice::new(target),
209 )?)
210}
211
212#[derive(PartialEq, Eq, Debug, Clone)]
213pub enum TestEffect {
214 Got(u64),
215}
216
217#[derive(Debug, Clone)]
218pub struct TestSink {
219 expect: Vec<TestEffect>,
220 ignore_expect: bool,
221}
222
223impl TestSink {
224 pub fn new() -> Self {
225 TestSink {
226 expect: Vec::new(),
227 ignore_expect: false,
228 }
229 }
230
231 pub fn ignore_expectations(&mut self, ignore: bool) {
232 self.ignore_expect = ignore;
233 }
234}
235
236impl Default for TestSink {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl TestSink {
243 pub fn add_expectation(&mut self, expect: TestEffect) {
244 self.expect.push(expect);
245 }
246
247 pub fn count(&self) -> usize {
248 self.expect.len()
249 }
250}
251
252impl Sink<TestEffect> for TestSink {
253 fn begin(&mut self) {
254 }
256
257 fn consume(&mut self, effect: TestEffect) {
258 trace!(?effect, "consume");
259 if !self.ignore_expect {
260 let expect = self.expect.remove(0);
261 assert_eq!(effect, expect);
262 }
263 }
264
265 fn rollback(&mut self) {
266 }
268
269 fn commit(&mut self) {
270 }
272}
273
274#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
275pub enum TestActions {
276 Init(u64),
277 SetValue(u64, u64),
278}
279
280impl Policy for TestPolicy {
281 type Effect = TestEffect;
282 type Action<'a> = TestActions;
283 type Command<'a> = TestProtocol<'a>;
284
285 fn serial(&self) -> u32 {
286 self.serial
287 }
288
289 fn call_rule(
290 &self,
291 command: &impl Command,
292 facts: &mut impl FactPerspective,
293 sink: &mut impl Sink<Self::Effect>,
294 _recall: CommandRecall,
295 ) -> Result<(), EngineError> {
296 let policy_command: WireProtocol = from_bytes(command.bytes())?;
297 self.call_rule_internal(&policy_command, facts, sink)
298 }
299
300 fn merge<'a>(
301 &self,
302 target: &'a mut [u8],
303 ids: MergeIds,
304 ) -> Result<TestProtocol<'a>, EngineError> {
305 let (left, right) = ids.into();
306 let command = WireProtocol::Merge(WireMerge { left, right });
307 let data = write(target, &command)?;
308 let id = CommandId::hash_for_testing_only(data);
309
310 Ok(TestProtocol { id, command, data })
311 }
312
313 fn call_action(
314 &self,
315 action: Self::Action<'_>,
316 facts: &mut impl Perspective,
317 sink: &mut impl Sink<Self::Effect>,
318 ) -> Result<(), EngineError> {
319 let parent = match facts.head_address()? {
320 Prior::None => Address {
321 id: CommandId::default(),
322 max_cut: 0,
323 },
324 Prior::Single(id) => id,
325 Prior::Merge(_, _) => bug!("cannot get merge command in call_action"),
326 };
327 match action {
328 TestActions::Init(nonce) => {
329 let mut buffer = [0u8; MAX_COMMAND_LENGTH];
330 let target = buffer.as_mut_slice();
331 let command = self.init(target, nonce)?;
332
333 self.call_rule_internal(&command.command, facts, sink)?;
334
335 facts.add_command(&command)?;
336 }
337 TestActions::SetValue(key, value) => {
338 let mut buffer = [0u8; MAX_COMMAND_LENGTH];
339 let target = buffer.as_mut_slice();
340 let payload = (key, value);
341 let command = self.basic(target, parent, payload)?;
342
343 self.call_rule_internal(&command.command, facts, sink)?;
344
345 facts.add_command(&command)?;
346 }
347 }
348
349 Ok(())
350 }
351}