use crate::delta_ext::{addition, deserialize, serialize, subtraction, DeltaOp};
use aptos_crypto::hash::DefaultHasher;
use aptos_types::vm_status::StatusCode;
use better_any::{Tid, TidAble};
use move_deps::{
move_binary_format::errors::{PartialVMError, PartialVMResult},
move_core_types::account_address::AccountAddress,
move_table_extension::{TableHandle, TableResolver},
move_vm_runtime::{
native_functions,
native_functions::{NativeContext, NativeFunctionTable},
},
move_vm_types::{
loaded_data::runtime_types::Type,
natives::function::NativeResult,
pop_arg,
values::{Reference, Struct, StructRef, Value},
},
};
use smallvec::smallvec;
use std::{
cell::RefCell,
collections::{BTreeMap, BTreeSet, VecDeque},
convert::TryInto,
sync::Arc,
};
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum AggregatorState {
Data,
PositiveDelta,
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct AggregatorID {
pub handle: u128,
pub key: u128,
}
impl AggregatorID {
fn new(handle: u128, key: u128) -> Self {
AggregatorID { handle, key }
}
}
struct Aggregator {
value: u128,
state: AggregatorState,
limit: u128,
}
impl Aggregator {
fn add(&mut self, value: u128) -> PartialVMResult<()> {
self.value = addition(self.value, value, self.limit)?;
Ok(())
}
fn sub(&mut self, value: u128) -> PartialVMResult<()> {
match self.state {
AggregatorState::Data => {
self.value = subtraction(self.value, value)?;
Ok(())
}
AggregatorState::PositiveDelta => {
unreachable!("subtraction always materializes the value")
}
}
}
fn materialize(
&mut self,
context: &NativeAggregatorContext,
id: &AggregatorID,
) -> PartialVMResult<()> {
if self.state == AggregatorState::Data {
return Ok(());
}
let key_bytes = serialize(&id.key);
context
.resolver
.resolve_table_entry(&TableHandle(id.handle), &key_bytes)
.map_err(|_| extension_error("could not find the value of the aggregator"))?
.map_or(
Err(extension_error(
"could not find the value of the aggregator",
)),
|bytes| {
debug_assert!(self.state == AggregatorState::PositiveDelta);
let base = deserialize(&bytes);
self.value = addition(base, self.value, self.limit)?;
self.state = AggregatorState::Data;
Ok(())
},
)
}
}
#[derive(Default)]
struct AggregatorData {
new_aggregators: BTreeSet<AggregatorID>,
destroyed_aggregators: BTreeSet<AggregatorID>,
aggregators: BTreeMap<AggregatorID, Aggregator>,
}
impl AggregatorData {
fn get_aggregator(&mut self, id: AggregatorID, limit: u128) -> &mut Aggregator {
self.aggregators.entry(id).or_insert_with(|| Aggregator {
value: 0,
state: AggregatorState::PositiveDelta,
limit,
});
self.aggregators.get_mut(&id).unwrap()
}
fn num_aggregators(&self) -> u128 {
self.aggregators.len() as u128
}
fn create_new_aggregator(&mut self, id: AggregatorID, limit: u128) {
let aggregator = Aggregator {
value: 0,
state: AggregatorState::Data,
limit,
};
self.aggregators.insert(id, aggregator);
self.new_aggregators.insert(id);
}
fn remove_aggregator(&mut self, id: AggregatorID) {
self.aggregators.remove(&id);
if self.new_aggregators.contains(&id) {
self.new_aggregators.remove(&id);
} else {
self.destroyed_aggregators.insert(id);
}
}
}
#[derive(Debug)]
pub enum AggregatorChange {
Write(u128),
Merge(DeltaOp),
Delete,
}
pub struct AggregatorChangeSet {
pub changes: BTreeMap<AggregatorID, AggregatorChange>,
}
#[derive(Tid)]
pub struct NativeAggregatorContext<'a> {
txn_hash: u128,
resolver: &'a dyn TableResolver,
aggregator_data: RefCell<AggregatorData>,
}
impl<'a> NativeAggregatorContext<'a> {
pub fn new(txn_hash: u128, resolver: &'a dyn TableResolver) -> Self {
Self {
txn_hash,
resolver,
aggregator_data: Default::default(),
}
}
pub fn into_change_set(self) -> AggregatorChangeSet {
let NativeAggregatorContext {
aggregator_data, ..
} = self;
let AggregatorData {
destroyed_aggregators,
aggregators,
..
} = aggregator_data.into_inner();
let mut changes = BTreeMap::new();
for (id, aggregator) in aggregators {
let Aggregator {
value,
state,
limit,
} = aggregator;
let change = match state {
AggregatorState::Data => AggregatorChange::Write(value),
AggregatorState::PositiveDelta => {
let delta_op = DeltaOp::Addition { value, limit };
AggregatorChange::Merge(delta_op)
}
};
changes.insert(id, change);
}
for id in destroyed_aggregators {
changes.insert(id, AggregatorChange::Delete);
}
AggregatorChangeSet { changes }
}
}
pub fn aggregator_natives(aggregator_addr: AccountAddress) -> NativeFunctionTable {
native_functions::make_table(
aggregator_addr,
&[
("aggregator", "add", Arc::new(native_add)),
("aggregator", "read", Arc::new(native_read)),
("aggregator", "destroy", Arc::new(native_destroy)),
("aggregator", "sub", Arc::new(native_sub)),
(
"aggregator_factory",
"new_aggregator",
Arc::new(native_new_aggregator),
),
],
)
}
fn native_new_aggregator(
context: &mut NativeContext,
_ty_args: Vec<Type>,
mut args: VecDeque<Value>,
) -> PartialVMResult<NativeResult> {
if !cfg!(any(test, feature = "aggregator-extension")) {
return Err(not_supported_error());
}
assert!(args.len() == 2);
let limit = pop_arg!(args, u128);
let handle = get_handle(&pop_arg!(args, StructRef))?;
let aggregator_context = context.extensions().get::<NativeAggregatorContext>();
let mut aggregator_data = aggregator_context.aggregator_data.borrow_mut();
let txn_hash_buffer = u128::to_be_bytes(aggregator_context.txn_hash);
let num_aggregators_buffer = u128::to_be_bytes(aggregator_data.num_aggregators());
let mut hasher = DefaultHasher::new(&[0_u8; 0]);
hasher.update(&txn_hash_buffer);
hasher.update(&num_aggregators_buffer);
let hash = hasher.finish();
let bytes = &hash.to_vec()[..16];
let key = u128::from_be_bytes(bytes.try_into().expect("not enough bytes"));
let id = AggregatorID::new(handle, key);
aggregator_data.create_new_aggregator(id, limit);
Ok(NativeResult::ok(
0,
smallvec![Value::struct_(Struct::pack(vec![
Value::u128(handle),
Value::u128(key),
Value::u128(limit),
]))],
))
}
fn native_add(
context: &mut NativeContext,
_ty_args: Vec<Type>,
mut args: VecDeque<Value>,
) -> PartialVMResult<NativeResult> {
if !cfg!(any(test, feature = "aggregator-extension")) {
return Err(not_supported_error());
}
assert!(args.len() == 2);
let value = pop_arg!(args, u128);
let aggregator_ref = pop_arg!(args, StructRef);
let (handle, key, limit) = get_aggregator_fields(&aggregator_ref)?;
let id = AggregatorID::new(handle, key);
let aggregator_context = context.extensions().get::<NativeAggregatorContext>();
let mut aggregator_data = aggregator_context.aggregator_data.borrow_mut();
let aggregator = aggregator_data.get_aggregator(id, limit);
aggregator.add(value)?;
Ok(NativeResult::ok(0, smallvec![]))
}
fn native_read(
context: &mut NativeContext,
_ty_args: Vec<Type>,
mut args: VecDeque<Value>,
) -> PartialVMResult<NativeResult> {
if !cfg!(any(test, feature = "aggregator-extension")) {
return Err(not_supported_error());
}
assert!(args.len() == 1);
let aggregator_ref = pop_arg!(args, StructRef);
let (handle, key, limit) = get_aggregator_fields(&aggregator_ref)?;
let id = AggregatorID::new(handle, key);
let aggregator_context = context.extensions().get::<NativeAggregatorContext>();
let mut aggregator_data = aggregator_context.aggregator_data.borrow_mut();
let aggregator = aggregator_data.get_aggregator(id, limit);
aggregator.materialize(aggregator_context, &id)?;
Ok(NativeResult::ok(
0,
smallvec![Value::u128(aggregator.value)],
))
}
fn native_sub(
context: &mut NativeContext,
_ty_args: Vec<Type>,
mut args: VecDeque<Value>,
) -> PartialVMResult<NativeResult> {
if !cfg!(any(test, feature = "aggregator-extension")) {
return Err(not_supported_error());
}
assert!(args.len() == 2);
let value = pop_arg!(args, u128);
let aggregator_ref = pop_arg!(args, StructRef);
let (handle, key, limit) = get_aggregator_fields(&aggregator_ref)?;
let id = AggregatorID::new(handle, key);
let aggregator_context = context.extensions().get::<NativeAggregatorContext>();
let mut aggregator_data = aggregator_context.aggregator_data.borrow_mut();
let aggregator = aggregator_data.get_aggregator(id, limit);
aggregator.materialize(aggregator_context, &id)?;
aggregator.sub(value)?;
Ok(NativeResult::ok(0, smallvec![]))
}
fn native_destroy(
context: &mut NativeContext,
_ty_args: Vec<Type>,
mut args: VecDeque<Value>,
) -> PartialVMResult<NativeResult> {
if !cfg!(any(test, feature = "aggregator-extension")) {
return Err(not_supported_error());
}
assert!(args.len() == 1);
let aggregator_struct = pop_arg!(args, Struct);
let (handle, key, _) = unpack_aggregator_struct(aggregator_struct)?;
let aggregator_context = context.extensions().get::<NativeAggregatorContext>();
let mut aggregator_data = aggregator_context.aggregator_data.borrow_mut();
let id = AggregatorID::new(handle, key);
aggregator_data.remove_aggregator(id);
Ok(NativeResult::ok(0, smallvec![]))
}
const PHANTOM_TABLE_FIELD_INDEX: usize = 0;
const TABLE_HANDLE_FIELD_INDEX: usize = 0;
const HANDLE_FIELD_INDEX: usize = 0;
const KEY_FIELD_INDEX: usize = 1;
const LIMIT_FIELD_INDEX: usize = 2;
fn get_handle(aggregator_table: &StructRef) -> PartialVMResult<u128> {
aggregator_table
.borrow_field(PHANTOM_TABLE_FIELD_INDEX)?
.value_as::<StructRef>()?
.borrow_field(TABLE_HANDLE_FIELD_INDEX)?
.value_as::<Reference>()?
.read_ref()?
.value_as::<u128>()
}
fn get_aggregator_field(aggregator: &StructRef, index: usize) -> PartialVMResult<Value> {
let field_ref = aggregator.borrow_field(index)?.value_as::<Reference>()?;
field_ref.read_ref()
}
fn get_aggregator_fields(aggregator: &StructRef) -> PartialVMResult<(u128, u128, u128)> {
let handle = get_aggregator_field(aggregator, HANDLE_FIELD_INDEX)?.value_as::<u128>()?;
let key = get_aggregator_field(aggregator, KEY_FIELD_INDEX)?.value_as::<u128>()?;
let limit = get_aggregator_field(aggregator, LIMIT_FIELD_INDEX)?.value_as::<u128>()?;
Ok((handle, key, limit))
}
fn unpack_aggregator_struct(aggregator_struct: Struct) -> PartialVMResult<(u128, u128, u128)> {
let mut fields: Vec<Value> = aggregator_struct.unpack()?.collect();
assert!(fields.len() == 3);
let pop_with_err = |vec: &mut Vec<Value>, msg: &str| {
vec.pop()
.map_or(Err(extension_error(msg)), |v| v.value_as::<u128>())
};
let limit = pop_with_err(&mut fields, "unable to pop 'limit' field")?;
let key = pop_with_err(&mut fields, "unable to pop 'key' field")?;
let handle = pop_with_err(&mut fields, "unable to pop 'handle' field")?;
Ok((handle, key, limit))
}
fn extension_error(message: impl ToString) -> PartialVMError {
PartialVMError::new(StatusCode::VM_EXTENSION_ERROR).with_message(message.to_string())
}
const ENOT_SUPPORTED: u64 = 0x0C_0003;
fn not_supported_error() -> PartialVMError {
PartialVMError::new(StatusCode::ABORTED)
.with_message("this experimental feature is not supported".to_string())
.with_sub_status(ENOT_SUPPORTED)
}
#[cfg(test)]
mod test {
use super::*;
use aptos_state_view::StateView;
use aptos_types::state_store::{state_key::StateKey, table::TableHandle as AptosTableHandle};
use claim::{assert_err, assert_matches, assert_ok};
use move_deps::move_table_extension::TableOperation;
use once_cell::sync::Lazy;
use std::collections::HashMap;
#[derive(Default)]
pub struct FakeTestStorage {
data: HashMap<StateKey, Vec<u8>>,
}
impl FakeTestStorage {
fn new() -> Self {
let mut data = HashMap::new();
data.insert(id_to_state_key(test_id(4)), serialize(&900));
data.insert(id_to_state_key(test_id(5)), serialize(&5));
FakeTestStorage { data }
}
}
impl StateView for FakeTestStorage {
fn get_state_value(&self, state_key: &StateKey) -> anyhow::Result<Option<Vec<u8>>> {
Ok(self.data.get(state_key).cloned())
}
fn is_genesis(&self) -> bool {
self.data.is_empty()
}
}
impl TableResolver for FakeTestStorage {
fn resolve_table_entry(
&self,
handle: &TableHandle,
key: &[u8],
) -> Result<Option<Vec<u8>>, anyhow::Error> {
let state_key = StateKey::table_item(AptosTableHandle::from(*handle), key.to_vec());
self.get_state_value(&state_key)
}
fn operation_cost(&self, _op: TableOperation, _key_size: usize, _val_size: usize) -> u64 {
1
}
}
fn test_id(key: u128) -> AggregatorID {
AggregatorID::new(0, key)
}
fn id_to_state_key(id: AggregatorID) -> StateKey {
let key_bytes = serialize(&id.key);
StateKey::table_item(AptosTableHandle(id.handle), key_bytes)
}
fn test_set_up(context: &NativeAggregatorContext) {
let mut aggregator_data = context.aggregator_data.borrow_mut();
aggregator_data.create_new_aggregator(test_id(0), 1000);
aggregator_data.create_new_aggregator(test_id(1), 1000);
aggregator_data.create_new_aggregator(test_id(2), 1000);
aggregator_data.get_aggregator(test_id(3), 1000);
aggregator_data.get_aggregator(test_id(4), 1000);
aggregator_data.get_aggregator(test_id(5), 10);
aggregator_data.remove_aggregator(test_id(0));
aggregator_data.remove_aggregator(test_id(3));
aggregator_data.remove_aggregator(test_id(6));
}
#[allow(clippy::redundant_closure)]
static TEST_RESOLVER: Lazy<FakeTestStorage> = Lazy::new(|| FakeTestStorage::new());
#[test]
fn test_into_change_set() {
let context = NativeAggregatorContext::new(0, &*TEST_RESOLVER);
test_set_up(&context);
let AggregatorChangeSet { changes } = context.into_change_set();
assert!(!changes.contains_key(&test_id(0)));
assert_matches!(
changes.get(&test_id(1)).unwrap(),
AggregatorChange::Write(0)
);
assert_matches!(
changes.get(&test_id(2)).unwrap(),
AggregatorChange::Write(0)
);
assert_matches!(changes.get(&test_id(3)).unwrap(), AggregatorChange::Delete);
assert_matches!(
changes.get(&test_id(4)).unwrap(),
AggregatorChange::Merge(DeltaOp::Addition {
value: 0,
limit: 1000
})
);
assert_matches!(
changes.get(&test_id(5)).unwrap(),
AggregatorChange::Merge(DeltaOp::Addition {
value: 0,
limit: 10
})
);
assert_matches!(changes.get(&test_id(6)).unwrap(), AggregatorChange::Delete);
}
#[test]
fn test_aggregator_natives() {
let context = NativeAggregatorContext::new(0, &*TEST_RESOLVER);
test_set_up(&context);
let mut aggregator_data = context.aggregator_data.borrow_mut();
let aggregator = aggregator_data.get_aggregator(test_id(1), 1000);
assert_matches!(aggregator.state, AggregatorState::Data);
assert_eq!(aggregator.value, 0);
assert_ok!(aggregator.add(100));
assert_ok!(aggregator.add(900));
assert_matches!(aggregator.state, AggregatorState::Data);
assert_eq!(aggregator.value, 1000);
assert_err!(aggregator.add(1));
let aggregator = aggregator_data.get_aggregator(test_id(4), 1000);
assert_matches!(aggregator.state, AggregatorState::PositiveDelta);
assert_eq!(aggregator.value, 0);
assert_ok!(aggregator.add(100));
assert_ok!(aggregator.add(100));
assert_matches!(aggregator.state, AggregatorState::PositiveDelta);
assert_eq!(aggregator.value, 200);
assert_err!(aggregator.materialize(&context, &test_id(4)));
let aggregator = aggregator_data.get_aggregator(test_id(5), 10);
assert_matches!(aggregator.state, AggregatorState::PositiveDelta);
assert_eq!(aggregator.value, 0);
assert_ok!(aggregator.add(2));
assert_matches!(aggregator.state, AggregatorState::PositiveDelta);
assert_eq!(aggregator.value, 2);
assert_ok!(aggregator.materialize(&context, &test_id(5)));
assert_matches!(aggregator.state, AggregatorState::Data);
assert_eq!(aggregator.value, 7);
assert_ok!(aggregator.sub(7));
assert_matches!(aggregator.state, AggregatorState::Data);
assert_eq!(aggregator.value, 0);
}
}