use super::Error;
use derive_more::From;
use std::collections::{
hash_map::Entry,
HashMap,
};
pub struct ChainExtensionHandler {
registered: HashMap<ExtensionId, Box<dyn ChainExtension>>,
output: Vec<u8>,
}
#[derive(
Debug, From, scale::Encode, scale::Decode, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
pub struct ExtensionId(u16);
pub trait ChainExtension {
fn ext_id(&self) -> u16;
#[allow(clippy::ptr_arg)]
fn call(&mut self, func_id: u16, input: &[u8], output: &mut Vec<u8>) -> u32;
}
impl Default for ChainExtensionHandler {
fn default() -> Self {
ChainExtensionHandler::new()
}
}
impl ChainExtensionHandler {
pub fn new() -> Self {
Self {
registered: HashMap::new(),
output: Vec::new(),
}
}
pub fn reset(&mut self) {
self.registered.clear();
self.output.clear();
}
pub fn register(&mut self, extension: Box<dyn ChainExtension>) {
let ext_id = extension.ext_id();
self.registered.insert(ExtensionId::from(ext_id), extension);
}
pub fn eval(&mut self, id: u32, input: &[u8]) -> Result<(u32, &[u8]), Error> {
self.output.clear();
let func_id = (id & 0x0000FFFF) as u16;
let ext_id = (id >> 16) as u16;
let extension_id = ExtensionId::from(ext_id);
match self.registered.entry(extension_id) {
Entry::Occupied(occupied) => {
let status_code =
occupied.into_mut().call(func_id, input, &mut self.output);
Ok((status_code, &mut self.output))
}
Entry::Vacant(_vacant) => Err(Error::UnregisteredChainExtension),
}
}
}