1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4
5use crate::context::{Context, Info};
6use crate::error::{ExecError, RevertError};
7use crate::opcodes::*;
8
9#[derive(Debug, Clone, Default)]
10pub struct CallInfo<W: Word> {
11    pub origin: W::Addr,
12    pub caller: W::Addr,
13    pub call_value: W,
14    pub calldata: Vec<u8>,
15}
16
17pub trait Word: Clone + Debug + Default + Copy + PartialEq + Eq + PartialOrd + Ord + Hash {
18    type Addr: Clone + Debug + Default + Copy;
19    const MAX: Self;
20    const ZERO: Self;
21    const ONE: Self;
22    const BITS: usize;
23    fn from_addr(addr: Self::Addr) -> Self;
24    fn to_addr(self) -> Self::Addr;
25    fn hex(&self) -> String;
26    fn low_u64(&self) -> u64;
27    fn from_u64(val: u64) -> Self;
28    fn bit(&self, bit: usize) -> bool;
29    fn is_neg(&self) -> bool {
30        self.bit(Self::BITS - 1)
31    }
32    fn to_usize(&self) -> Result<usize, ExecError>;
33    fn from_big_endian(slice: &[u8]) -> Self;
34    fn to_big_endian(&self) -> Vec<u8>;
35    fn add(self, other: Self) -> Self;
36    fn sub(self, other: Self) -> Self;
37    fn mul(self, other: Self) -> Self;
38    fn div(self, other: Self) -> Self;
39    fn rem(self, other: Self) -> Self;
40    fn shl(self, other: Self) -> Self;
41    fn shr(self, other: Self) -> Self;
42    fn and(self, other: Self) -> Self;
43    fn or(self, other: Self) -> Self;
44    fn xor(self, other: Self) -> Self;
45    fn pow(self, other: Self) -> Self;
46    fn not(self) -> Self;
47    fn neg(self) -> Self {
48        self.not().add(Self::ONE)
49    }
50    fn lt(self, other: Self) -> bool;
51    fn gt(self, other: Self) -> bool {
52        other.lt(self)
53    }
54    fn addmod(self, other: Self, n: Self) -> Self;
55    fn mulmod(self, other: Self, n: Self) -> Self;
56}
57
58#[derive(Debug, Clone, Default)]
59pub struct Machine<W: Word> {
60    pub address: W::Addr,
61    pub code: Vec<u8>,
62    pub pc: usize,
63    pub gas_used: usize,
64    pub stack: Vec<W>,
65    pub memory: Vec<u8>,
66    pub transient: HashMap<W, W>,
67    pub last_return: Option<Vec<u8>>,
68}
69
70impl<W: Word> Machine<W> {
71    pub fn new(address: W::Addr, code: Vec<u8>) -> Self {
72        Self {
73            address,
74            code,
75            pc: 0,
76            gas_used: 0,
77            stack: Vec::new(),
78            memory: Vec::new(),
79            transient: HashMap::new(),
80            last_return: None,
81        }
82    }
83    pub fn run<C: Context<W>>(
84        mut self,
85        ctx: &mut C,
86        call_info: &CallInfo<W>,
87    ) -> Result<ExecutionResult, ExecError> {
88        let mut opcode_table: HashMap<u8, Box<dyn OpcodeHandler<W, C>>> = HashMap::new();
89        opcode_table.insert(0x00, Box::new(OpcodeHalt));
90        opcode_table.insert(0x01, Box::new(OpcodeBinaryOp::Add));
91        opcode_table.insert(0x02, Box::new(OpcodeBinaryOp::Mul));
92        opcode_table.insert(0x03, Box::new(OpcodeBinaryOp::Sub));
93        opcode_table.insert(0x04, Box::new(OpcodeBinaryOp::Div));
94        opcode_table.insert(0x05, Box::new(OpcodeBinaryOp::Sdiv));
95        opcode_table.insert(0x06, Box::new(OpcodeBinaryOp::Mod));
96        opcode_table.insert(0x07, Box::new(OpcodeBinaryOp::Smod));
97        opcode_table.insert(0x08, Box::new(OpcodeModularOp::AddMod));
98        opcode_table.insert(0x09, Box::new(OpcodeModularOp::MulMod));
99        opcode_table.insert(0x0a, Box::new(OpcodeBinaryOp::Exp));
100        opcode_table.insert(0x0b, Box::new(OpcodeBinaryOp::SignExtend));
101        opcode_table.insert(0x10, Box::new(OpcodeBinaryOp::Lt));
102        opcode_table.insert(0x11, Box::new(OpcodeBinaryOp::Gt));
103        opcode_table.insert(0x12, Box::new(OpcodeBinaryOp::Slt));
104        opcode_table.insert(0x13, Box::new(OpcodeBinaryOp::Sgt));
105        opcode_table.insert(0x14, Box::new(OpcodeBinaryOp::Eq));
106        opcode_table.insert(0x15, Box::new(OpcodeUnaryOp::IsZero));
107        opcode_table.insert(0x16, Box::new(OpcodeBinaryOp::And));
108        opcode_table.insert(0x17, Box::new(OpcodeBinaryOp::Or));
109        opcode_table.insert(0x18, Box::new(OpcodeBinaryOp::Xor));
110        opcode_table.insert(0x19, Box::new(OpcodeUnaryOp::Not));
111        opcode_table.insert(0x1a, Box::new(OpcodeBinaryOp::Byte));
112        opcode_table.insert(0x1b, Box::new(OpcodeBinaryOp::Shl));
113        opcode_table.insert(0x1c, Box::new(OpcodeBinaryOp::Shr));
114        opcode_table.insert(0x1d, Box::new(OpcodeBinaryOp::Sar));
115        opcode_table.insert(0x20, Box::new(OpcodeKeccak));
116        opcode_table.insert(0x30, Box::new(OpcodeAddress));
117        opcode_table.insert(0x31, Box::new(OpcodeBalance));
118        opcode_table.insert(0x32, Box::new(OpcodeOrigin));
119        opcode_table.insert(0x33, Box::new(OpcodeCaller));
120        opcode_table.insert(0x34, Box::new(OpcodeCallValue));
121        opcode_table.insert(0x35, Box::new(OpcodeCalldataLoad));
122        opcode_table.insert(0x36, Box::new(OpcodeCalldataSize));
123        opcode_table.insert(0x37, Box::new(OpcodeCalldataCopy));
124        opcode_table.insert(0x38, Box::new(OpcodeCodeSize));
125        opcode_table.insert(0x39, Box::new(OpcodeCodeCopy));
126        opcode_table.insert(0x3a, Box::new(OpcodeInfo(Info::GasPrice)));
127        opcode_table.insert(0x3b, Box::new(OpcodeExtCodeSize));
128        opcode_table.insert(0x3c, Box::new(OpcodeExtCodeCopy));
129        opcode_table.insert(0x3d, Box::new(OpcodeReturnDataSize));
130        opcode_table.insert(0x3e, Box::new(OpcodeReturnDataCopy));
131        opcode_table.insert(0x3f, Box::new(OpcodeExtCodeHash));
132        opcode_table.insert(0x40, Box::new(OpcodeBlockHash));
133        opcode_table.insert(0x41, Box::new(OpcodeInfo(Info::Coinbase)));
134        opcode_table.insert(0x42, Box::new(OpcodeInfo(Info::Timestamp)));
135        opcode_table.insert(0x43, Box::new(OpcodeInfo(Info::Number)));
136        opcode_table.insert(0x44, Box::new(OpcodeInfo(Info::PrevRandao)));
137        opcode_table.insert(0x45, Box::new(OpcodeInfo(Info::GasLimit)));
138        opcode_table.insert(0x46, Box::new(OpcodeInfo(Info::ChainId)));
139        opcode_table.insert(0x47, Box::new(OpcodeSelfBalance));
140        opcode_table.insert(0x48, Box::new(OpcodeInfo(Info::BaseFee)));
141        opcode_table.insert(0x49, Box::new(OpcodeBlobHash));
142        opcode_table.insert(0x4a, Box::new(OpcodeInfo(Info::BlobBaseFee)));
143        opcode_table.insert(0x50, Box::new(OpcodePop));
144        opcode_table.insert(0x51, Box::new(OpcodeMload));
145        opcode_table.insert(0x52, Box::new(OpcodeMstore));
146        opcode_table.insert(0x53, Box::new(OpcodeMstore8));
147        opcode_table.insert(0x54, Box::new(OpcodeSload));
148        opcode_table.insert(0x55, Box::new(OpcodeSstore));
149        opcode_table.insert(0x56, Box::new(OpcodeJump));
150        opcode_table.insert(0x57, Box::new(OpcodeJumpi));
151        opcode_table.insert(0x5b, Box::new(OpcodeJumpDest));
152        opcode_table.insert(0x5c, Box::new(OpcodeTload));
153        opcode_table.insert(0x5d, Box::new(OpcodeTstore));
154        opcode_table.insert(0x5e, Box::new(OpcodeMcopy));
155        for sz in 0..=32 {
156            opcode_table.insert(0x5f + sz, Box::new(OpcodePush(sz)));
157        }
158        for sz in 0..16 {
159            opcode_table.insert(0x80 + sz, Box::new(OpcodeDup(sz)));
160        }
161        for sz in 0..16 {
162            opcode_table.insert(0x90 + sz, Box::new(OpcodeSwap(sz)));
163        }
164        for sz in 0..5 {
165            opcode_table.insert(0xa0 + sz, Box::new(OpcodeLog(sz)));
166        }
167        opcode_table.insert(0xf0, Box::new(OpcodeCreate));
168        opcode_table.insert(0xf1, Box::new(OpcodeCall::Call));
169        opcode_table.insert(0xf2, Box::new(OpcodeUnsupported(0xf2)));
170        opcode_table.insert(0xf3, Box::new(OpcodeReturn));
171        opcode_table.insert(0xf2, Box::new(OpcodeCall::DelegateCall));
172        opcode_table.insert(0xf2, Box::new(OpcodeCreate2));
173        opcode_table.insert(0xfa, Box::new(OpcodeCall::StaticCall));
174        opcode_table.insert(0xfd, Box::new(OpcodeRevert));
175        opcode_table.insert(0xff, Box::new(OpcodeSelfDestruct));
176
177        while self.pc < self.code.len() {
178            let opcode = self.code[self.pc];
179            if let Some(opcode_fn) = opcode_table.get(&opcode) {
180                if let Some(res) = opcode_fn.call(ctx, &mut self, call_info)? {
181                    return Ok(res);
182                }
183            } else {
184                return Err(RevertError::UnknownOpcode(opcode).into());
185            }
186        }
187        Ok(ExecutionResult::Halted)
188    }
189
190    pub fn mem_put(
191        &mut self,
192        target_offset: usize,
193        source: &[u8],
194        source_offset: usize,
195        len: usize,
196    ) {
197        if source_offset >= source.len() {
198            return;
199        }
200        let source_end = std::cmp::min(source_offset + len, source.len());
201        let src = &source[source_offset..source_end];
202        let expected_len = target_offset + src.len();
203        if expected_len > self.memory.len() {
204            self.memory.resize(expected_len, 0);
205        }
206        self.memory[target_offset..target_offset + src.len()].copy_from_slice(src);
207    }
208    pub fn mem_get(&mut self, offset: usize, size: usize) -> Vec<u8> {
209        let mut ret = vec![0u8; size];
210        if offset < self.memory.len() {
211            let sz = std::cmp::min(self.memory.len().saturating_sub(offset), size);
212            ret[..sz].copy_from_slice(&self.memory[offset..offset + sz]);
213        }
214        ret
215    }
216    pub fn pop_stack(&mut self) -> Result<W, ExecError> {
217        Ok(self
218            .stack
219            .pop()
220            .ok_or(RevertError::NotEnoughValuesOnStack)?)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use alloy_primitives::primitives::{Address, U256};
227
228    use crate::machine::Machine;
229
230    #[test]
231    fn test_mem_put() {
232        let mut m = Machine::<U256>::new(Address::ZERO, vec![]);
233        assert_eq!(m.memory, vec![]);
234        m.mem_put(2, &[1, 2, 3], 1, 10);
235        assert_eq!(m.memory, vec![0, 0, 2, 3]);
236        m.mem_put(1, &[4, 5, 6], 5, 10);
237        assert_eq!(m.memory, vec![0, 0, 2, 3]);
238        m.mem_put(1, &[4, 5, 6], 1, 1);
239        assert_eq!(m.memory, vec![0, 5, 2, 3]);
240        m.mem_put(1, &[7, 8, 9], 1, 0);
241        assert_eq!(m.memory, vec![0, 5, 2, 3]);
242        m.mem_put(3, &[7, 8, 9], 2, 10);
243        assert_eq!(m.memory, vec![0, 5, 2, 9]);
244        m.mem_put(4, &[10, 11, 12, 13], 0, 2);
245        assert_eq!(m.memory, vec![0, 5, 2, 9, 10, 11]);
246        m.mem_put(4, &[10, 11, 12, 13], 0, 100);
247        assert_eq!(m.memory, vec![0, 5, 2, 9, 10, 11, 12, 13]);
248        m.mem_put(10, &[10, 11, 12, 13], 2, 100);
249        assert_eq!(m.memory, vec![0, 5, 2, 9, 10, 11, 12, 13, 0, 0, 12, 13]);
250    }
251
252    #[test]
253    fn test_mem_get() {
254        let mut m = Machine::<U256>::new(Address::ZERO, vec![]);
255        m.memory = vec![0, 10, 20, 30, 40, 50];
256        assert_eq!(m.mem_get(1, 3), vec![10, 20, 30]);
257        assert_eq!(
258            m.mem_get(0, 100),
259            vec![
260                0, 10, 20, 30, 40, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
261                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
262                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
263                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
264            ]
265        );
266        assert_eq!(m.mem_get(100, 2), vec![0, 0]);
267        assert_eq!(m.mem_get(5, 2), vec![50, 0]);
268    }
269}