use core::marker::PhantomData;
use crate::descriptor::{self, Descriptor, DescriptorValidator};
use crate::parser::keys::KeyToken;
use crate::parser::{Fragment, NodeIndex};
use crate::script::{AddressBuilderError, ScriptBuilderError};
use crate::type_checker::CorrectnessPropertiesVisitor;
use crate::{Vec, parser::AST};
use crate::{limits, parser, type_checker};
use alloc::string::String;
use bitcoin::{Address, Network, ScriptBuf};
pub(crate) trait ASTVisitor<T> {
type Error;
fn visit_ast(&mut self, ctx: &Context, node: &AST) -> Result<T, Self::Error>;
#[inline]
fn visit_ast_by_index(&mut self, ctx: &Context, index: NodeIndex) -> Result<T, Self::Error> {
self.visit_ast(ctx, &ctx.nodes[index as usize])
}
#[inline]
fn visit(&mut self, ctx: &Context) -> Result<T, Self::Error> {
self.visit_ast(ctx, &ctx.get_root())
}
}
pub struct Context {
nodes: Vec<AST>,
root: AST,
top_level_descriptor: Descriptor,
inner_descriptor: Descriptor,
}
impl Context {
pub(crate) fn new(
nodes: Vec<AST>,
root: AST,
top_level_descriptor: Descriptor,
inner_descriptor: Descriptor,
) -> Self {
Self {
nodes,
root,
top_level_descriptor,
inner_descriptor,
}
}
pub fn get_nodes(&self) -> &[AST] {
&self.nodes[..]
}
pub fn get_root(&self) -> &AST {
&self.root
}
pub fn top_level_descriptor(&self) -> Descriptor {
self.top_level_descriptor.clone()
}
pub fn descriptor(&self) -> Descriptor {
self.inner_descriptor.clone()
}
pub fn is_wrapped(&self) -> bool {
self.top_level_descriptor == Descriptor::Sh
}
pub fn get_node(&self, index: NodeIndex) -> &AST {
&self.nodes[index as usize]
}
#[cfg(feature = "satisfy")]
pub fn satisfy(
&self,
satisfier: &dyn crate::satisfy::Satisfier,
) -> Result<crate::satisfy::Satisfactions, crate::satisfy::SatisfyError> {
crate::satisfy::satisfy(self, satisfier, &self.get_root())
}
pub fn iterate_keys_mut(&mut self, mut callback: impl FnMut(&mut KeyToken)) {
self.nodes
.iter_mut()
.for_each(|node| match &mut node.fragment {
Fragment::PkK { key } => callback(key),
Fragment::PkH { key } => callback(key),
Fragment::Multi { keys, .. } => {
for key in keys.iter_mut() {
callback(key);
}
}
Fragment::MultiA { keys, .. } => {
for key in keys.iter_mut() {
callback(key);
}
}
Fragment::RawPkH { key } => callback(key),
Fragment::RawTr { key, .. } => {
callback(key);
}
Fragment::RawPk { key } => {
callback(key);
}
_ => (),
});
}
pub fn iterate_keys(&self, mut callback: impl FnMut(&KeyToken)) {
self.nodes.iter().for_each(|node| match &node.fragment {
Fragment::PkK { key } => callback(key),
Fragment::PkH { key } => callback(key),
Fragment::Multi { keys, .. } => {
for key in keys.iter() {
callback(key);
}
}
Fragment::MultiA { keys, .. } => {
for key in keys.iter() {
callback(key);
}
}
Fragment::RawPkH { key } => callback(key),
Fragment::RawTr { key, .. } => {
callback(key);
}
Fragment::RawPk { key } => {
callback(key);
}
_ => (),
});
}
pub fn derive(&mut self, index: u32) -> Result<(), String> {
for node in &mut self.nodes {
match &mut node.fragment {
Fragment::PkK { key } | Fragment::PkH { key } | Fragment::RawPkH { key } => {
let derived = key.derive(index)?;
*key = derived;
}
Fragment::Multi { keys, k } => {
for key in keys.iter_mut() {
let derived = key.derive(index)?;
*key = derived;
}
}
Fragment::MultiA { keys, k } => {
for key in keys.iter_mut() {
let derived = key.derive(index)?;
*key = derived;
}
}
_ => (),
}
}
Ok(())
}
pub fn serialize(&self) -> String {
let mut serializer = crate::utils::serialize::Serializer::new();
serializer.serialize(self)
}
pub fn build_script<'a>(&self) -> Result<ScriptBuf, ScriptBuilderError<'a>> {
crate::script::build_script(self)
}
pub fn build_address<'a>(&self, network: Network) -> Result<Address, AddressBuilderError<'a>> {
crate::script::build_address(self, network)
}
}
#[cfg_attr(feature = "debug", derive(Debug))]
pub enum ContextError<'a> {
ParserError(parser::ParseError<'a>),
TypeCheckerError(type_checker::CorrectnessPropertiesVisitorError),
DescriptorVisitorError(descriptor::DescriptorVisitorError),
LimitsError(limits::LimitsError),
}
impl<'a> TryFrom<&'a str> for Context {
type Error = ContextError<'a>;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
let ctx = parser::parse(value).map_err(ContextError::ParserError)?;
let type_info = CorrectnessPropertiesVisitor::new()
.visit(&ctx)
.map_err(ContextError::TypeCheckerError)?;
let _: () = DescriptorValidator::new()
.validate(&ctx)
.map_err(ContextError::DescriptorVisitorError)?;
limits::check_recursion_depth(type_info.tree_height).map_err(ContextError::LimitsError)?;
limits::check_script_size(&ctx.descriptor(), type_info.pk_cost)
.map_err(ContextError::LimitsError)?;
Ok(ctx)
}
}