1use std::convert::TryInto;
3use std::fmt;
4use std::num::NonZeroU16;
5
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct Var(NonZeroU16);
9
10impl Var {
11 pub(crate) fn with_id<T>(id: T) -> Self
12 where
13 T: TryInto<NonZeroU16>,
14 T::Error: std::fmt::Debug,
15 {
16 Var(id.try_into().unwrap())
17 }
18}
19
20impl fmt::Display for Var {
21 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
22 write!(f, "var{}", self.0)
23 }
24}
25
26#[derive(Debug, Clone, Eq, PartialEq, Hash)]
28pub struct Expr {
29 ops: Vec<Sym>,
30}
31
32impl fmt::Display for Expr {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 self.walk(&mut DisplayVisit(f))
35 }
36}
37
38impl From<Var> for Expr {
39 fn from(v: Var) -> Expr {
40 Self {
41 ops: vec![Sym::Var(v)],
42 }
43 }
44}
45
46impl Expr {
47 fn concat(op: Sym, args: &[&Self]) -> Self {
48 assert_eq!(op.children() as usize, args.len());
49
50 let capacity = 1 + args.iter().map(|x| x.ops.len()).sum::<usize>();
51 let mut ops = Vec::with_capacity(capacity);
52 ops.push(op);
53 for arg in args {
54 ops.extend_from_slice(&arg.ops);
55 }
56
57 Self { ops }
58 }
59
60 pub fn address() -> Self {
62 Self {
63 ops: vec![Sym::Address],
64 }
65 }
66
67 pub fn origin() -> Self {
69 Self {
70 ops: vec![Sym::Origin],
71 }
72 }
73
74 pub fn caller() -> Self {
76 Self {
77 ops: vec![Sym::Caller],
78 }
79 }
80
81 pub fn call_value() -> Self {
83 Self {
84 ops: vec![Sym::CallValue],
85 }
86 }
87
88 pub fn call_data_size() -> Self {
90 Self {
91 ops: vec![Sym::CallDataSize],
92 }
93 }
94
95 pub fn code_size() -> Self {
97 Self {
98 ops: vec![Sym::CodeSize],
99 }
100 }
101
102 pub fn gas_price() -> Self {
104 Self {
105 ops: vec![Sym::GasPrice],
106 }
107 }
108
109 pub fn return_data_size() -> Self {
111 Self {
112 ops: vec![Sym::ReturnDataSize],
113 }
114 }
115
116 pub fn coinbase() -> Self {
118 Self {
119 ops: vec![Sym::Coinbase],
120 }
121 }
122
123 pub fn timestamp() -> Self {
125 Self {
126 ops: vec![Sym::Timestamp],
127 }
128 }
129
130 pub fn number() -> Self {
132 Self {
133 ops: vec![Sym::Number],
134 }
135 }
136
137 pub fn difficulty() -> Self {
139 Self {
140 ops: vec![Sym::Difficulty],
141 }
142 }
143
144 pub fn gas_limit() -> Self {
146 Self {
147 ops: vec![Sym::GasLimit],
148 }
149 }
150
151 pub fn chain_id() -> Self {
153 Self {
154 ops: vec![Sym::ChainId],
155 }
156 }
157
158 pub fn self_balance() -> Self {
160 Self {
161 ops: vec![Sym::SelfBalance],
162 }
163 }
164
165 pub fn base_fee() -> Self {
167 Self {
168 ops: vec![Sym::BaseFee],
169 }
170 }
171
172 pub fn pc(offset: u16) -> Self {
174 Self {
175 ops: vec![Sym::GetPc(offset)],
176 }
177 }
178
179 pub fn m_size() -> Self {
181 Self {
182 ops: vec![Sym::MSize],
183 }
184 }
185
186 pub fn gas() -> Self {
188 Self {
189 ops: vec![Sym::Gas],
190 }
191 }
192
193 pub fn create(value: &Self, offset: &Self, length: &Self) -> Self {
195 Self::concat(Sym::Create, &[value, offset, length])
196 }
197
198 pub fn create2(value: &Self, offset: &Self, length: &Self, salt: &Self) -> Self {
200 Self::concat(Sym::Create2, &[value, offset, length, salt])
201 }
202
203 pub fn call_code(
205 gas: &Self,
206 addr: &Self,
207 value: &Self,
208 args_offset: &Self,
209 args_len: &Self,
210 ret_offset: &Self,
211 ret_len: &Self,
212 ) -> Self {
213 Self::concat(
214 Sym::CallCode,
215 &[gas, addr, value, args_offset, args_len, ret_offset, ret_len],
216 )
217 }
218
219 pub fn call(
221 gas: &Self,
222 addr: &Self,
223 value: &Self,
224 args_offset: &Self,
225 args_len: &Self,
226 ret_offset: &Self,
227 ret_len: &Self,
228 ) -> Self {
229 Self::concat(
230 Sym::Call,
231 &[gas, addr, value, args_offset, args_len, ret_offset, ret_len],
232 )
233 }
234
235 pub fn static_call(
237 gas: &Self,
238 addr: &Self,
239 args_offset: &Self,
240 args_len: &Self,
241 ret_offset: &Self,
242 ret_len: &Self,
243 ) -> Self {
244 Self::concat(
245 Sym::StaticCall,
246 &[gas, addr, args_offset, args_len, ret_offset, ret_len],
247 )
248 }
249
250 pub fn delegate_call(
252 gas: &Self,
253 addr: &Self,
254 args_offset: &Self,
255 args_len: &Self,
256 ret_offset: &Self,
257 ret_len: &Self,
258 ) -> Self {
259 Self::concat(
260 Sym::DelegateCall,
261 &[gas, addr, args_offset, args_len, ret_offset, ret_len],
262 )
263 }
264
265 pub fn add(&self, rhs: &Self) -> Self {
267 Self::concat(Sym::Add, &[self, rhs])
268 }
269
270 pub fn sub(&self, rhs: &Self) -> Self {
272 Self::concat(Sym::Sub, &[self, rhs])
273 }
274
275 pub fn mul(&self, rhs: &Self) -> Self {
277 Self::concat(Sym::Mul, &[self, rhs])
278 }
279
280 pub fn div(&self, rhs: &Self) -> Self {
282 Self::concat(Sym::Div, &[self, rhs])
283 }
284
285 pub fn s_div(&self, rhs: &Self) -> Self {
287 Self::concat(Sym::SDiv, &[self, rhs])
288 }
289
290 pub fn modulo(&self, rhs: &Self) -> Self {
292 Self::concat(Sym::Mod, &[self, rhs])
293 }
294
295 pub fn s_modulo(&self, rhs: &Self) -> Self {
297 Self::concat(Sym::SMod, &[self, rhs])
298 }
299
300 pub fn add_mod(&self, add: &Self, modulo: &Self) -> Self {
302 Self::concat(Sym::AddMod, &[self, add, modulo])
303 }
304
305 pub fn mul_mod(&self, mul: &Self, modulo: &Self) -> Self {
307 Self::concat(Sym::MulMod, &[self, mul, modulo])
308 }
309
310 pub fn exp(&self, rhs: &Self) -> Self {
312 Self::concat(Sym::Exp, &[self, rhs])
313 }
314
315 pub fn lt(&self, rhs: &Self) -> Self {
317 Self::concat(Sym::Lt, &[self, rhs])
318 }
319
320 pub fn gt(&self, rhs: &Self) -> Self {
322 Self::concat(Sym::Gt, &[self, rhs])
323 }
324
325 pub fn s_lt(&self, rhs: &Self) -> Self {
327 Self::concat(Sym::SLt, &[self, rhs])
328 }
329
330 pub fn s_gt(&self, rhs: &Self) -> Self {
332 Self::concat(Sym::SGt, &[self, rhs])
333 }
334
335 pub fn is_eq(&self, rhs: &Self) -> Self {
337 Self::concat(Sym::Eq, &[self, rhs])
338 }
339
340 pub fn and(&self, rhs: &Self) -> Self {
342 Self::concat(Sym::And, &[self, rhs])
343 }
344
345 pub fn or(&self, rhs: &Self) -> Self {
347 Self::concat(Sym::Or, &[self, rhs])
348 }
349
350 pub fn xor(&self, rhs: &Self) -> Self {
352 Self::concat(Sym::Xor, &[self, rhs])
353 }
354
355 pub fn byte(&self, value: &Self) -> Self {
357 Self::concat(Sym::Byte, &[self, value])
358 }
359
360 pub fn shl(&self, rhs: &Self) -> Self {
362 Self::concat(Sym::Shl, &[self, rhs])
363 }
364
365 pub fn shr(&self, value: &Self) -> Self {
367 Self::concat(Sym::Shr, &[self, value])
368 }
369
370 pub fn sar(&self, rhs: &Self) -> Self {
372 Self::concat(Sym::Sar, &[self, rhs])
373 }
374
375 pub fn keccak256(offset: &Self, len: &Self) -> Self {
377 Self::concat(Sym::Keccak256, &[offset, len])
378 }
379
380 pub fn sign_extend(&self, b: &Self) -> Self {
382 Self::concat(Sym::SignExtend, &[self, b])
383 }
384
385 pub fn is_zero(&self) -> Self {
387 Self::concat(Sym::IsZero, &[self])
388 }
389
390 pub fn not(&self) -> Self {
392 Self::concat(Sym::Not, &[self])
393 }
394
395 pub fn block_hash(&self) -> Self {
397 Self::concat(Sym::BlockHash, &[self])
398 }
399
400 pub fn balance(&self) -> Self {
402 Self::concat(Sym::Balance, &[self])
403 }
404
405 pub fn call_data_load(&self) -> Self {
407 Self::concat(Sym::CallDataLoad, &[self])
408 }
409
410 pub fn ext_code_size(&self) -> Self {
412 Self::concat(Sym::ExtCodeSize, &[self])
413 }
414
415 pub fn ext_code_hash(&self) -> Self {
417 Self::concat(Sym::ExtCodeHash, &[self])
418 }
419
420 pub fn m_load(&self) -> Self {
422 Self::concat(Sym::MLoad, &[self])
423 }
424
425 pub fn s_load(&self) -> Self {
427 Self::concat(Sym::SLoad, &[self])
428 }
429
430 pub fn as_var(&self) -> Option<Var> {
433 match self.ops.as_slice() {
434 [Sym::Var(v)] => Some(*v),
435 _ => None,
436 }
437 }
438
439 pub fn constant<A>(arr: A) -> Self
441 where
442 A: AsRef<[u8]>,
443 {
444 let arr = arr.as_ref();
445 let mut buf = [0u8; 32];
446 let start = buf.len() - arr.len();
447 buf[start..].copy_from_slice(arr);
448 Self {
449 ops: vec![Sym::Const(buf.into())],
450 }
451 }
452
453 #[cfg(test)]
454 pub(crate) fn constant_offset<T: Into<u128>>(offset: T) -> Self {
455 let offset: u128 = offset.into();
456 let mut buf = [0u8; 32];
457 buf[16..].copy_from_slice(&offset.to_be_bytes());
458
459 Self {
460 ops: vec![Sym::Const(buf.into())],
461 }
462 }
463}
464
465struct DisplayVisit<'a, 'b>(&'a mut fmt::Formatter<'b>);
468
469impl<'a, 'b> Visit for DisplayVisit<'a, 'b> {
470 type Error = fmt::Error;
471
472 fn empty(&mut self) -> fmt::Result {
473 write!(self.0, "{{}}")
474 }
475
476 fn exit(&mut self, op: &Sym) -> fmt::Result {
477 match op {
478 Sym::Const(_) => Ok(()),
479 Sym::Var(_) => Ok(()),
480 Sym::IsZero => write!(self.0, " = 0)"),
481 _ => write!(self.0, ")"),
482 }
483 }
484
485 fn between(&mut self, op: &Sym, idx: u8) -> fmt::Result {
486 let txt = match op {
487 Sym::Add => " + ",
488 Sym::Mul => " × ",
489 Sym::Sub => " - ",
490 Sym::Div => " ÷ ",
491 Sym::SDiv => " ÷⃡ ",
492 Sym::Mod => " ﹪ ",
493 Sym::SMod => " ﹪⃡ ",
494 Sym::AddMod => match idx {
495 0 => " + ",
496 1 => ") ﹪ ",
497 _ => unreachable!(),
498 },
499 Sym::MulMod => match idx {
500 0 => " × ",
501 1 => ") ﹪ ",
502 _ => unreachable!(),
503 },
504 Sym::Exp => " ** ",
505 Sym::Lt => " < ",
506 Sym::Gt => " > ",
507 Sym::SLt => " <⃡ ",
508 Sym::SGt => " >⃡ ",
509 Sym::Eq => " = ",
510 Sym::And => " & ",
511 Sym::Or => " | ",
512 Sym::Xor => " ^ ",
513 q if q.children() < 2 => unreachable!(),
514 _ => ", ",
515 };
516
517 write!(self.0, "{}", txt)
518 }
519
520 fn enter(&mut self, op: &Sym) -> fmt::Result {
521 match op {
522 Sym::Const(v) => {
523 write!(self.0, "0x{}", hex::encode(**v))
525 }
526 Sym::Var(v) => write!(self.0, "{}", v),
527 Sym::AddMod => write!(self.0, "(("),
528 Sym::MulMod => write!(self.0, "(("),
529 Sym::Keccak256 => write!(self.0, "keccak256("),
530 Sym::Byte => write!(self.0, "byte("),
531 Sym::SignExtend => write!(self.0, "signextend("),
532 Sym::Not => write!(self.0, "~("),
533 Sym::CallDataLoad => write!(self.0, "calldata("),
534 Sym::ExtCodeSize => write!(self.0, "extcodesize("),
535 Sym::ExtCodeHash => write!(self.0, "extcodehash("),
536 Sym::MLoad => write!(self.0, "mload("),
537 Sym::SLoad => write!(self.0, "sload("),
538 Sym::Address => write!(self.0, "address("),
539 Sym::Balance => write!(self.0, "balance("),
540 Sym::Origin => write!(self.0, "origin("),
541 Sym::Caller => write!(self.0, "caller("),
542 Sym::CallValue => write!(self.0, "callvalue("),
543 Sym::CallDataSize => write!(self.0, "calldatasize("),
544 Sym::CodeSize => write!(self.0, "codesize("),
545 Sym::GasPrice => write!(self.0, "gasprice("),
546 Sym::ReturnDataSize => write!(self.0, "returndatasize("),
547 Sym::BlockHash => write!(self.0, "blockhash("),
548 Sym::Coinbase => write!(self.0, "coinbase("),
549 Sym::Timestamp => write!(self.0, "timestamp("),
550 Sym::Number => write!(self.0, "number("),
551 Sym::Difficulty => write!(self.0, "difficulty("),
552 Sym::GasLimit => write!(self.0, "gaslimit("),
553 Sym::ChainId => write!(self.0, "chainid("),
554 Sym::SelfBalance => write!(self.0, "selfbalance("),
555 Sym::BaseFee => write!(self.0, "basefee("),
556 Sym::GetPc(pc) => write!(self.0, "pc({}", pc),
557 Sym::MSize => write!(self.0, "msize("),
558 Sym::Gas => write!(self.0, "gas("),
559 Sym::Create => write!(self.0, "create("),
560 Sym::CallCode => write!(self.0, "callcode("),
561 Sym::Call => write!(self.0, "call("),
562 Sym::StaticCall => write!(self.0, "staticcall("),
563 Sym::DelegateCall => write!(self.0, "delegatecall("),
564 Sym::Shl => write!(self.0, "shl("),
565 Sym::Shr => write!(self.0, "shr("),
566 Sym::Sar => write!(self.0, "sar("),
567 _ => write!(self.0, "("),
568 }
569 }
570}
571
572impl Expr {
573 pub fn walk<V>(&self, visitor: &mut V) -> Result<(), V::Error>
576 where
577 V: Visit,
578 {
579 if self.ops.is_empty() {
580 visitor.empty()
581 } else {
582 Self::inner_walk(&self.ops, visitor)?;
583 Ok(())
586 }
587 }
588
589 fn inner_walk<'a, V>(mut ops: &'a [Sym], visitor: &mut V) -> Result<&'a [Sym], V::Error>
590 where
591 V: Visit,
592 {
593 if ops.is_empty() {
594 unreachable!();
595 }
596
597 let op = &ops[0];
598
599 visitor.enter(op)?;
600
601 for idx in 0..op.children() {
602 ops = Self::inner_walk(&ops[1..], visitor)?;
603
604 if (idx + 1) < op.children() {
605 visitor.between(op, idx)?;
606 }
607 }
608
609 visitor.exit(op)?;
610
611 Ok(ops)
612 }
613}
614
615pub trait Visit {
617 type Error;
619
620 fn empty(&mut self) -> Result<(), Self::Error> {
622 Ok(())
623 }
624
625 fn enter(&mut self, _: &Sym) -> Result<(), Self::Error> {
627 Ok(())
628 }
629
630 fn between(&mut self, _: &Sym, _: u8) -> Result<(), Self::Error> {
632 Ok(())
633 }
634
635 fn exit(&mut self, _: &Sym) -> Result<(), Self::Error> {
637 Ok(())
638 }
639}
640
641#[derive(Debug, Clone, Eq, PartialEq, Hash)]
647pub enum Sym {
648 Const(Box<[u8; 32]>),
650
651 Var(Var),
653
654 Add,
656
657 Mul,
659
660 Sub,
662
663 Div,
665
666 SDiv,
668
669 Mod,
671
672 SMod,
674
675 AddMod,
677
678 MulMod,
680
681 Exp,
683
684 Lt,
686
687 Gt,
689
690 SLt,
692
693 SGt,
695
696 Eq,
698
699 And,
701
702 Or,
704
705 Xor,
707
708 Byte,
710
711 Shl,
713
714 Shr,
716
717 Sar,
719
720 Keccak256,
722
723 SignExtend,
725
726 IsZero,
728
729 Not,
731
732 CallDataLoad,
734
735 ExtCodeSize,
737
738 ExtCodeHash,
740
741 MLoad,
743
744 SLoad,
746
747 Balance,
749
750 BlockHash,
752
753 Address,
755
756 Origin,
758
759 Caller,
761
762 CallValue,
764
765 CallDataSize,
767
768 CodeSize,
770
771 GasPrice,
773
774 ReturnDataSize,
776
777 Coinbase,
779
780 Timestamp,
782
783 Number,
785
786 Difficulty,
788
789 GasLimit,
791
792 ChainId,
794
795 SelfBalance,
797
798 BaseFee,
800
801 GetPc(u16),
803
804 MSize,
806
807 Gas,
809
810 Create,
812
813 Create2,
815
816 CallCode,
818
819 Call,
821
822 StaticCall,
824
825 DelegateCall,
827}
828
829impl Sym {
830 fn children(&self) -> u8 {
831 match self {
832 Sym::Add
833 | Sym::Mul
834 | Sym::Sub
835 | Sym::Div
836 | Sym::SDiv
837 | Sym::Mod
838 | Sym::SMod
839 | Sym::Exp
840 | Sym::Lt
841 | Sym::Gt
842 | Sym::SLt
843 | Sym::SGt
844 | Sym::Eq
845 | Sym::And
846 | Sym::Or
847 | Sym::Xor
848 | Sym::Byte
849 | Sym::Shl
850 | Sym::Shr
851 | Sym::Sar
852 | Sym::SignExtend
853 | Sym::Keccak256 => 2,
854
855 Sym::IsZero
856 | Sym::Not
857 | Sym::CallDataLoad
858 | Sym::ExtCodeSize
859 | Sym::ExtCodeHash
860 | Sym::BlockHash
861 | Sym::Balance
862 | Sym::MLoad
863 | Sym::SLoad => 1,
864
865 Sym::Address
866 | Sym::Origin
867 | Sym::Caller
868 | Sym::CallValue
869 | Sym::CallDataSize
870 | Sym::CodeSize
871 | Sym::GasPrice
872 | Sym::ReturnDataSize
873 | Sym::Coinbase
874 | Sym::Timestamp
875 | Sym::Number
876 | Sym::Difficulty
877 | Sym::GasLimit
878 | Sym::ChainId
879 | Sym::SelfBalance
880 | Sym::BaseFee
881 | Sym::GetPc(_)
882 | Sym::MSize
883 | Sym::Gas
884 | Sym::Const(_)
885 | Sym::Var(_) => 0,
886
887 Sym::AddMod | Sym::MulMod | Sym::Create => 3,
888
889 Sym::Create2 => 4,
890
891 Sym::Call | Sym::CallCode => 7,
892
893 Sym::DelegateCall | Sym::StaticCall => 6,
894 }
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901
902 #[test]
903 fn expr_display_add_mod() {
904 let expected = "((caller() + origin()) ﹪ var1)";
905 let var = Var::with_id(NonZeroU16::new(1).unwrap());
906 let input = Expr {
907 ops: vec![Sym::AddMod, Sym::Caller, Sym::Origin, Sym::Var(var)],
908 };
909
910 let actual = input.to_string();
911 assert_eq!(expected, actual);
912 }
913
914 #[test]
915 fn expr_display_call() {
916 let expected = "call(gas(), caller(), callvalue(), sload(pc(3)), mload(origin()), number(), timestamp())";
917 let input = Expr {
918 ops: vec![
919 Sym::Call,
920 Sym::Gas,
921 Sym::Caller,
922 Sym::CallValue,
923 Sym::SLoad,
924 Sym::GetPc(3),
925 Sym::MLoad,
926 Sym::Origin,
927 Sym::Number,
928 Sym::Timestamp,
929 ],
930 };
931
932 let actual = input.to_string();
933 assert_eq!(expected, actual);
934 }
935
936 #[test]
937 fn expr_display_add() {
938 let expected = "(0x0000000000000000000000000000000000000000000000000000000000000000 + 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff)";
939 let input = Expr {
940 ops: vec![
941 Sym::Add,
942 Sym::Const(Box::new([0x00; 32])),
943 Sym::Const(Box::new([0xff; 32])),
944 ],
945 };
946
947 let actual = input.to_string();
948 assert_eq!(expected, actual);
949 }
950}