use crate::errors::{BytecodeError, LinkError};
use serde::de::{Error as DeError, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashSet;
use std::fmt::{Formatter, Result as FmtResult};
use std::mem;
use web3::types::{Address, Bytes};
#[derive(Clone, Debug, Default, Serialize)]
pub struct Bytecode(String);
impl Bytecode {
pub fn from_hex_str(s: &str) -> Result<Self, BytecodeError> {
if s.is_empty() {
return Ok(Bytecode::default());
}
if s.len() % 2 != 0 {
return Err(BytecodeError::InvalidLength);
}
let s = s.strip_prefix("0x").unwrap_or(s);
for block in CodeIter(s) {
let block = block?;
if let Some(pos) = block
.bytes()
.position(|b| !matches!(b, b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F'))
{
return Err(BytecodeError::InvalidHexDigit(
block.chars().nth(pos).expect("valid pos"),
));
}
}
Ok(Bytecode(s.to_string()))
}
pub fn link<S>(&mut self, name: S, address: Address) -> Result<(), LinkError>
where
S: AsRef<str>,
{
let name = name.as_ref();
assert!(name.len() <= 38, "invalid library name for linking");
let placeholder = format!("__{:_<38}", name);
let address = to_fixed_hex(&address);
if !self.0.contains(&placeholder) {
return Err(LinkError::NotFound(name.to_string()));
}
self.0 = self.0.replace(&placeholder, &address);
Ok(())
}
pub fn to_bytes(&self) -> Result<Bytes, LinkError> {
match self.undefined_libraries().next() {
Some(library) => Err(LinkError::UndefinedLibrary(library.to_string())),
None => Ok(Bytes(hex::decode(&self.0).expect("valid hex"))),
}
}
pub fn undefined_libraries(&self) -> LibIter<'_> {
LibIter {
cursor: &self.0,
seen: HashSet::new(),
}
}
pub fn requires_linking(&self) -> bool {
self.undefined_libraries().next().is_some()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
struct CodeIter<'a>(&'a str);
impl<'a> Iterator for CodeIter<'a> {
type Item = Result<&'a str, BytecodeError>;
fn next(&mut self) -> Option<Self::Item> {
if self.0.is_empty() {
return None;
}
match self.0.find("__") {
Some(pos) => {
let (block, tail) = self.0.split_at(pos);
if tail.len() < 40 {
Some(Err(BytecodeError::PlaceholderTooShort))
} else {
self.0 = &tail[40..];
Some(Ok(block))
}
}
None => Some(Ok(mem::take(&mut self.0))),
}
}
}
pub struct LibIter<'a> {
cursor: &'a str,
seen: HashSet<&'a str>,
}
impl<'a> Iterator for LibIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
while let Some(pos) = self.cursor.find("__") {
let (placeholder, tail) = self.cursor[pos..].split_at(40);
let lib = placeholder.trim_matches('_');
self.cursor = tail;
if self.seen.insert(lib) {
return Some(lib);
}
}
None
}
}
impl<'de> Deserialize<'de> for Bytecode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(BytecodeVisitor)
}
}
struct BytecodeVisitor;
impl<'de> Visitor<'de> for BytecodeVisitor {
type Value = Bytecode;
fn expecting(&self, f: &mut Formatter) -> FmtResult {
write!(f, "valid EVM bytecode string representation")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: DeError,
{
Bytecode::from_hex_str(v).map_err(E::custom)
}
}
fn to_fixed_hex(address: &Address) -> String {
format!("{:040x}", address)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_bytecode_is_empty() {
assert!(Bytecode::default().is_empty());
}
#[test]
fn empty_hex_bytecode_is_empty() {
assert!(Bytecode::from_hex_str("0x").unwrap().is_empty());
}
#[test]
fn unprefixed_hex_bytecode_is_not_empty() {
assert!(!Bytecode::from_hex_str("feedface").unwrap().is_empty());
}
#[test]
fn to_fixed_hex_() {
for (value, expected) in &[
(
"0x0000000000000000000000000000000000000000",
"0000000000000000000000000000000000000000",
),
(
"0x0102030405060708091020304050607080900001",
"0102030405060708091020304050607080900001",
),
(
"0x9fac3b52be975567103c4695a2835bba40076da1",
"9fac3b52be975567103c4695a2835bba40076da1",
),
] {
let value: Address = value[2..].parse().unwrap();
assert_eq!(to_fixed_hex(&value), *expected);
}
}
#[test]
fn bytecode_link_success() {
let address = Address::zero();
let address_encoded = [0u8; 20];
let name = "name";
let placeholder = format!("__{:_<38}", name);
let mut bytecode = Bytecode::from_hex_str(&format!(
"0x61{}{}61{}",
placeholder, placeholder, placeholder
))
.unwrap();
bytecode.link(name, address).unwrap();
let bytes = bytecode.to_bytes().unwrap();
let mut expected = Vec::<u8>::new();
expected.extend([0x61]);
expected.extend(address_encoded);
expected.extend(address_encoded);
expected.extend([0x61]);
expected.extend(address_encoded);
assert_eq!(bytes.0, expected);
}
#[test]
fn bytecode_link_fail() {
let address = Address::zero();
let placeholder = format!("__{:_<38}", "name0");
let mut bytecode = Bytecode::from_hex_str(&format!(
"0x61{}{}61{}",
placeholder, placeholder, placeholder
))
.unwrap();
match bytecode.link("name1", address) {
Err(LinkError::NotFound(_)) => (),
_ => panic!("should fail with not found error"),
}
}
}