use crate::{
traits::{
AsTransactionAuthorizedOrigin, DecodeWithVersion, DecodeWithVersionWithMemTracking,
DispatchInfoOf, DispatchOriginOf, DispatchTransaction, Dispatchable, Pipeline,
PipelineMetadataBuilder, PipelineVersion, PostDispatchInfoOf, TransactionExtension,
},
transaction_validity::{TransactionSource, TransactionValidityError, ValidTransaction},
};
use codec::{Decode, DecodeWithMemTracking, Encode};
use core::fmt::Debug;
use scale_info::TypeInfo;
use sp_weights::Weight;
#[derive(Encode, Clone, Debug, TypeInfo, PartialEq, Eq)]
pub struct PipelineAtVers<const VERSION: u8, Extension> {
pub extension: Extension,
}
impl<const VERSION: u8, Extension> PipelineAtVers<VERSION, Extension> {
pub fn new(extension: Extension) -> Self {
Self { extension }
}
}
impl<const VERSION: u8, Extension: Decode> DecodeWithVersion
for PipelineAtVers<VERSION, Extension>
{
fn decode_with_version<I: codec::Input>(
extension_version: u8,
input: &mut I,
) -> Result<Self, codec::Error> {
if extension_version == VERSION {
Ok(PipelineAtVers { extension: Extension::decode(input)? })
} else {
Err(codec::Error::from("Invalid extension version"))
}
}
}
impl<const VERSION: u8, Extension: DecodeWithMemTracking> DecodeWithVersionWithMemTracking
for PipelineAtVers<VERSION, Extension>
{
}
impl<const VERSION: u8, Extension> PipelineVersion for PipelineAtVers<VERSION, Extension> {
fn version(&self) -> u8 {
VERSION
}
}
impl<const VERSION: u8, Call, Extension> Pipeline<Call> for PipelineAtVers<VERSION, Extension>
where
Call: Dispatchable<RuntimeOrigin: AsTransactionAuthorizedOrigin> + Encode,
Extension: TransactionExtension<Call>,
{
fn build_metadata(builder: &mut PipelineMetadataBuilder) {
builder.push_versioned_extension(VERSION, Extension::metadata());
}
fn validate_only(
&self,
origin: DispatchOriginOf<Call>,
call: &Call,
info: &DispatchInfoOf<Call>,
len: usize,
source: TransactionSource,
) -> Result<ValidTransaction, TransactionValidityError> {
self.extension
.validate_only(origin, call, info, len, source, VERSION)
.map(|x| x.0)
}
fn dispatch_transaction(
self,
origin: DispatchOriginOf<Call>,
call: Call,
info: &DispatchInfoOf<Call>,
len: usize,
) -> crate::ApplyExtrinsicResultWithInfo<PostDispatchInfoOf<Call>> {
self.extension.dispatch_transaction(origin, call, info, len, VERSION)
}
fn weight(&self, call: &Call) -> Weight {
self.extension.weight(call)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
traits::{
Dispatchable, Implication, TransactionExtension, TransactionSource, ValidateResult,
},
transaction_validity::{InvalidTransaction, TransactionValidityError, ValidTransaction},
DispatchError,
};
use codec::{Decode, DecodeWithMemTracking, Encode};
use sp_weights::Weight;
#[derive(Clone, Debug, Encode, Decode, PartialEq, Eq)]
pub struct MockCall(pub u64);
#[derive(Debug)]
pub struct MockOrigin(pub u64);
impl AsTransactionAuthorizedOrigin for MockOrigin {
fn is_transaction_authorized(&self) -> bool {
true
}
}
impl Dispatchable for MockCall {
type RuntimeOrigin = MockOrigin;
type Config = ();
type Info = ();
type PostInfo = ();
fn dispatch(
self,
origin: Self::RuntimeOrigin,
) -> crate::DispatchResultWithInfo<Self::PostInfo> {
if origin.0 == 0 {
return Err(DispatchError::Other("origin is 0").into());
}
Ok(Default::default())
}
}
#[derive(PartialEq, Eq, Clone, Debug, Encode, Decode, DecodeWithMemTracking, TypeInfo)]
pub struct SimpleExtension {
pub token: u32,
pub w: u64,
}
impl TransactionExtension<MockCall> for SimpleExtension {
const IDENTIFIER: &'static str = "SimpleExtension";
type Implicit = ();
fn implicit(&self) -> Result<Self::Implicit, TransactionValidityError> {
Ok(())
}
type Val = ();
type Pre = ();
fn weight(&self, _call: &MockCall) -> Weight {
Weight::from_parts(self.w, 0)
}
fn validate(
&self,
origin: MockOrigin,
_call: &MockCall,
_info: &DispatchInfoOf<MockCall>,
_len: usize,
_self_implicit: Self::Implicit,
_inherited_implication: &impl Implication,
_source: TransactionSource,
) -> ValidateResult<Self::Val, MockCall> {
if self.token == 0 {
Err(InvalidTransaction::Custom(1).into())
} else {
Ok((ValidTransaction::default(), (), origin))
}
}
fn prepare(
self,
_val: Self::Val,
_origin: &MockOrigin,
_call: &MockCall,
_info: &DispatchInfoOf<MockCall>,
_len: usize,
) -> Result<Self::Pre, TransactionValidityError> {
Ok(())
}
}
pub type ExtV3 = PipelineAtVers<3, SimpleExtension>;
pub type ExtV10 = PipelineAtVers<10, SimpleExtension>;
#[test]
fn decode_with_correct_version_succeeds() {
let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
let encoded = ext_v3.encode();
let decoded = <ExtV3 as DecodeWithVersion>::decode_with_version(3, &mut &encoded[..])
.expect("should decode fine with matching version");
assert_eq!(decoded.extension.token, 55);
assert_eq!(decoded.extension.w, 1234);
}
#[test]
fn decode_with_incorrect_version_fails() {
let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
let encoded = ext_v3.encode();
let decode_err = <ExtV3 as DecodeWithVersion>::decode_with_version(10, &mut &encoded[..])
.expect_err("should fail decode due to invalid version");
let decode_err_str = format!("{}", decode_err);
assert!(decode_err_str.contains("Invalid extension version"));
}
#[test]
fn version_is_correct() {
let ext_v3 = ExtV3 { extension: SimpleExtension { token: 55, w: 1234 } };
assert_eq!(ext_v3.version(), 3);
let ext_v10 = ExtV10 { extension: SimpleExtension { token: 1, w: 1 } };
assert_eq!(ext_v10.version(), 10);
}
#[test]
fn pipeline_functions_work() {
let ext_v3 = ExtV3 { extension: SimpleExtension { token: 999, w: 50 } };
let call = MockCall(0x_f00);
assert_eq!(ext_v3.weight(&call).ref_time(), 50);
{
let invalid_ext_v3 = ExtV3 { extension: SimpleExtension { token: 0, w: 10 } };
let validity = invalid_ext_v3.validate_only(
MockOrigin(1),
&call,
&Default::default(),
0,
TransactionSource::External,
);
assert_eq!(
validity,
Err(TransactionValidityError::Invalid(InvalidTransaction::Custom(1)))
);
}
let validity_ok = ext_v3.validate_only(
MockOrigin(2),
&call,
&Default::default(),
0,
TransactionSource::Local,
);
assert!(validity_ok.is_ok());
let valid = validity_ok.unwrap();
assert_eq!(valid, ValidTransaction::default());
}
#[test]
fn dispatch_transaction_works() {
let ext_v3 = ExtV3 { extension: SimpleExtension { token: 1, w: 10 } };
let call = MockCall(123);
let info = Default::default();
let len = 0usize;
ext_v3
.clone()
.dispatch_transaction(MockOrigin(1), call.clone(), &info, len)
.expect("valid dispatch")
.expect("should be OK");
let res_fail = ext_v3.dispatch_transaction(MockOrigin(0), call, &info, len);
let block_err = res_fail.expect("valid").expect_err("should fail");
assert_eq!(block_err.error, DispatchError::Other("origin is 0"));
}
}