1use alloc::vec::Vec;
3use core::convert::Infallible;
4
5use buggy::bug;
6use postcard::{from_bytes, ser_flavors::Slice, serialize_with_flavor};
7use serde::{Deserialize, Serialize};
8use tracing::trace;
9
10use super::{
11 alloc, Command, CommandId, Engine, EngineError, FactPerspective, Perspective, Policy, PolicyId,
12 Prior, Priority, Sink, StorageError, MAX_COMMAND_LENGTH,
13};
14use crate::{Address, CommandRecall, Keys, MergeIds};
15
16impl From<StorageError> for EngineError {
17 fn from(_: StorageError) -> Self {
18 EngineError::InternalError
19 }
20}
21
22impl From<postcard::Error> for EngineError {
23 fn from(_error: postcard::Error) -> Self {
24 EngineError::Read
25 }
26}
27
28impl From<Infallible> for EngineError {
29 fn from(_error: Infallible) -> Self {
30 EngineError::Write
31 }
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct WireInit {
36 pub nonce: u128,
37 pub policy_num: [u8; 8],
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone)]
41pub struct WireMerge {
42 pub left: Address,
43 pub right: Address,
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone, Ord, Eq, PartialOrd, PartialEq)]
47pub struct WireBasic {
48 pub parent: Address,
49 pub prority: u32,
50 pub payload: (u64, u64),
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
54pub enum WireProtocol {
55 Init(WireInit),
56 Merge(WireMerge),
57 Basic(WireBasic),
58}
59
60#[derive(Debug, Clone)]
61pub struct TestProtocol<'a> {
62 id: CommandId,
63 command: WireProtocol,
64 data: &'a [u8],
65}
66
67impl Command for TestProtocol<'_> {
68 fn priority(&self) -> Priority {
69 match &self.command {
70 WireProtocol::Init(_) => Priority::Init,
71 WireProtocol::Merge(_) => Priority::Merge,
72 WireProtocol::Basic(m) => Priority::Basic(m.prority),
73 }
74 }
75
76 fn id(&self) -> CommandId {
77 self.id
78 }
79
80 fn parent(&self) -> Prior<Address> {
81 match &self.command {
82 WireProtocol::Init(_) => Prior::None,
83 WireProtocol::Basic(m) => Prior::Single(m.parent),
84 WireProtocol::Merge(m) => Prior::Merge(m.left, m.right),
85 }
86 }
87
88 fn policy(&self) -> Option<&[u8]> {
89 match &self.command {
90 WireProtocol::Init(m) => Some(&m.policy_num),
91 WireProtocol::Merge(_) => None,
92 WireProtocol::Basic(_) => None,
93 }
94 }
95
96 fn bytes(&self) -> &[u8] {
97 self.data
98 }
99}
100
101pub struct TestEngine {
102 policy: TestPolicy,
103}
104
105impl TestEngine {
106 pub fn new() -> TestEngine {
107 TestEngine {
108 policy: TestPolicy::new(0),
109 }
110 }
111}
112
113impl Default for TestEngine {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl Engine for TestEngine {
120 type Policy = TestPolicy;
121 type Effect = TestEffect;
122
123 fn add_policy(&mut self, policy: &[u8]) -> Result<PolicyId, EngineError> {
124 Ok(PolicyId::new(policy[0] as usize))
125 }
126
127 fn get_policy(&self, _id: PolicyId) -> Result<&Self::Policy, EngineError> {
128 Ok(&self.policy)
129 }
130}
131
132pub struct TestPolicy {
133 serial: u32,
134}
135
136impl TestPolicy {
137 pub fn new(serial: u32) -> Self {
138 TestPolicy { serial }
139 }
140
141 fn origin_check_message(
142 &self,
143 command: &WireBasic,
144 facts: &mut impl FactPerspective,
145 ) -> Result<(), EngineError> {
146 let (group, count) = command.payload;
147
148 let key = group.to_be_bytes();
149 let value = count.to_be_bytes();
150
151 facts.insert("payload".into(), Keys::from_iter([key]), value.into());
152 Ok(())
153 }
154
155 fn call_rule_internal(
156 &self,
157 policy_command: &WireProtocol,
158 facts: &mut impl FactPerspective,
159 sink: &mut impl Sink<<TestPolicy as Policy>::Effect>,
160 ) -> Result<(), EngineError> {
161 if let WireProtocol::Basic(m) = &policy_command {
162 self.origin_check_message(m, facts)?;
163
164 sink.consume(TestEffect::Got(m.payload.1));
165 }
166
167 Ok(())
168 }
169
170 fn init<'a>(&self, target: &'a mut [u8], nonce: u64) -> Result<TestProtocol<'a>, EngineError> {
171 let message = WireInit {
172 nonce: u128::from(nonce),
173 policy_num: nonce.to_le_bytes(),
174 };
175
176 let command = WireProtocol::Init(message);
177 let data = write(target, &command)?;
178 let id = CommandId::hash_for_testing_only(data);
179
180 Ok(TestProtocol { id, command, data })
181 }
182
183 fn basic<'a>(
184 &self,
185 target: &'a mut [u8],
186 parent: Address,
187 payload: (u64, u64),
188 ) -> Result<TestProtocol<'a>, EngineError> {
189 let prority = 0; let message = WireBasic {
192 parent,
193 prority,
194 payload,
195 };
196
197 let command = WireProtocol::Basic(message);
198 let data = write(target, &command)?;
199 let id = CommandId::hash_for_testing_only(data);
200
201 Ok(TestProtocol { id, command, data })
202 }
203}
204
205fn write<'a>(target: &'a mut [u8], message: &WireProtocol) -> Result<&'a mut [u8], EngineError> {
206 Ok(serialize_with_flavor::<WireProtocol, Slice<'a>, &mut [u8]>(
207 message,
208 Slice::new(target),
209 )?)
210}
211
212#[derive(PartialEq, Eq, Debug, Clone)]
213pub enum TestEffect {
214 Got(u64),
215}
216
217#[derive(Debug, Clone)]
218pub struct TestSink {
219 expect: Vec<TestEffect>,
220 ignore_expect: bool,
221}
222
223impl TestSink {
224 pub fn new() -> Self {
225 TestSink {
226 expect: Vec::new(),
227 ignore_expect: false,
228 }
229 }
230
231 pub fn ignore_expectations(&mut self, ignore: bool) {
232 self.ignore_expect = ignore;
233 }
234}
235
236impl Default for TestSink {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242impl TestSink {
243 pub fn add_expectation(&mut self, expect: TestEffect) {
244 self.expect.push(expect);
245 }
246
247 pub fn count(&self) -> usize {
248 self.expect.len()
249 }
250}
251
252impl Sink<TestEffect> for TestSink {
253 fn begin(&mut self) {
254 }
256
257 fn consume(&mut self, effect: TestEffect) {
258 trace!(?effect, "consume");
259 if !self.ignore_expect {
260 assert!(!self.expect.is_empty(), "consumed {effect:?} while empty");
261 let expect = self.expect.remove(0);
262 trace!(consuming = ?effect, expected = ?expect, remainder = ?self.expect);
263 assert_eq!(
264 effect, expect,
265 "consumed {effect:?} while expecting {expect:?}"
266 );
267 }
268 }
269
270 fn rollback(&mut self) {
271 }
273
274 fn commit(&mut self) {
275 }
277}
278
279#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
280pub enum TestActions {
281 Init(u64),
282 SetValue(u64, u64),
283}
284
285impl Policy for TestPolicy {
286 type Effect = TestEffect;
287 type Action<'a> = TestActions;
288 type Command<'a> = TestProtocol<'a>;
289
290 fn serial(&self) -> u32 {
291 self.serial
292 }
293
294 fn call_rule(
295 &self,
296 command: &impl Command,
297 facts: &mut impl FactPerspective,
298 sink: &mut impl Sink<Self::Effect>,
299 _recall: CommandRecall,
300 ) -> Result<(), EngineError> {
301 let policy_command: WireProtocol = from_bytes(command.bytes())?;
302 self.call_rule_internal(&policy_command, facts, sink)
303 }
304
305 fn merge<'a>(
306 &self,
307 target: &'a mut [u8],
308 ids: MergeIds,
309 ) -> Result<TestProtocol<'a>, EngineError> {
310 let (left, right) = ids.into();
311 let command = WireProtocol::Merge(WireMerge { left, right });
312 let data = write(target, &command)?;
313 let id = CommandId::hash_for_testing_only(data);
314
315 Ok(TestProtocol { id, command, data })
316 }
317
318 fn call_action(
319 &self,
320 action: Self::Action<'_>,
321 facts: &mut impl Perspective,
322 sink: &mut impl Sink<Self::Effect>,
323 ) -> Result<(), EngineError> {
324 let parent = match facts.head_address()? {
325 Prior::None => Address {
326 id: CommandId::default(),
327 max_cut: 0,
328 },
329 Prior::Single(id) => id,
330 Prior::Merge(_, _) => bug!("cannot get merge command in call_action"),
331 };
332 match action {
333 TestActions::Init(nonce) => {
334 let mut buffer = [0u8; MAX_COMMAND_LENGTH];
335 let target = buffer.as_mut_slice();
336 let command = self.init(target, nonce)?;
337
338 self.call_rule_internal(&command.command, facts, sink)?;
339
340 facts.add_command(&command)?;
341 }
342 TestActions::SetValue(key, value) => {
343 let mut buffer = [0u8; MAX_COMMAND_LENGTH];
344 let target = buffer.as_mut_slice();
345 let payload = (key, value);
346 let command = self.basic(target, parent, payload)?;
347
348 self.call_rule_internal(&command.command, facts, sink)?;
349
350 facts.add_command(&command)?;
351 }
352 }
353
354 Ok(())
355 }
356}