use std::any::Any;
use std::collections::BTreeMap;
use std::io::{Read, Write};
use amplify::DumbDefault;
use strict_encoding::{StrictDecode, StrictEncode};
use wallet::psbt::Psbt;
use super::tx_graph::TxGraph;
use super::Funding;
use crate::channel::FundingError;
use crate::{extension, ChannelConstructor, ChannelExtension, Extension};
pub trait Nomenclature: extension::Nomenclature
where
<Self as extension::Nomenclature>::State: State,
{
type Constructor: ChannelConstructor<Self>;
fn default_extenders() -> Vec<Box<dyn ChannelExtension<Self>>> {
Vec::default()
}
fn default_modifiers() -> Vec<Box<dyn ChannelExtension<Self>>> {
Vec::default()
}
fn update_from_peer(
channel: &mut Channel<Self>,
message: &Self::PeerMessage,
) -> Result<(), <Self as extension::Nomenclature>::Error>;
}
pub trait State: StrictEncode + StrictDecode + DumbDefault {
fn to_funding(&self) -> Funding;
fn set_funding(&mut self, funding: &Funding);
}
pub type ExtensionQueue<N> = BTreeMap<N, Box<dyn ChannelExtension<N>>>;
#[derive(Getters)]
pub struct Channel<N>
where
N: Nomenclature,
N::State: State,
{
funding: Funding,
#[getter(as_mut)]
constructor: N::Constructor,
extenders: ExtensionQueue<N>,
modifiers: ExtensionQueue<N>,
}
impl<N> Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
{
pub fn new(
constructor: N::Constructor,
extenders: impl IntoIterator<Item = Box<dyn ChannelExtension<N>>>,
modifiers: impl IntoIterator<Item = Box<dyn ChannelExtension<N>>>,
) -> Self {
Self {
funding: Funding::new(),
constructor,
extenders: extenders.into_iter().fold(
ExtensionQueue::<N>::new(),
|mut queue, e| {
queue.insert(e.identity(), e);
queue
},
),
modifiers: modifiers.into_iter().fold(
ExtensionQueue::<N>::new(),
|mut queue, e| {
queue.insert(e.identity(), e);
queue
},
),
}
}
pub fn extension<E>(&'static self, id: N) -> Option<&E> {
self.extenders
.get(&id)
.map(|ext| ext as &dyn Any)
.and_then(|ext| ext.downcast_ref())
.or_else(|| {
self.modifiers
.get(&id)
.map(|ext| ext as &dyn Any)
.and_then(|ext| ext.downcast_ref())
})
}
pub fn extension_mut<E>(&'static mut self, id: N) -> Option<&mut E> {
self.extenders
.get_mut(&id)
.map(|ext| &mut *ext as &mut dyn Any)
.and_then(|ext| ext.downcast_mut())
.or_else(|| {
self.modifiers
.get_mut(&id)
.map(|ext| &mut *ext as &mut dyn Any)
.and_then(|ext| ext.downcast_mut())
})
}
#[inline]
pub fn extender(&self, id: N) -> Option<&dyn ChannelExtension<N>> {
self.extenders
.get(&id)
.map(|e| e.as_ref() as &dyn ChannelExtension<N>)
}
#[inline]
pub fn modifier(&self, id: N) -> Option<&dyn ChannelExtension<N>> {
self.modifiers
.get(&id)
.map(|e| e.as_ref() as &dyn ChannelExtension<N>)
}
#[inline]
pub fn extender_mut(
&mut self,
id: N,
) -> Option<&mut dyn ChannelExtension<N>> {
self.extenders
.get_mut(&id)
.map(|e| e.as_mut() as &mut dyn ChannelExtension<N>)
}
#[inline]
pub fn modifier_mut(
&mut self,
id: N,
) -> Option<&mut dyn ChannelExtension<N>> {
self.modifiers
.get_mut(&id)
.map(|e| e.as_mut() as &mut dyn ChannelExtension<N>)
}
#[inline]
pub fn add_extender(&mut self, extension: Box<dyn ChannelExtension<N>>) {
self.extenders.insert(extension.identity(), extension);
}
#[inline]
pub fn add_modifier(&mut self, modifier: Box<dyn ChannelExtension<N>>) {
self.modifiers.insert(modifier.identity(), modifier);
}
pub fn commitment_tx(
&mut self,
remote: bool,
) -> Result<Psbt, <N as extension::Nomenclature>::Error> {
let mut tx_graph = TxGraph::from_funding(&self.funding);
self.build_graph(&mut tx_graph, remote)?;
Ok(tx_graph.render_cmt())
}
#[inline]
pub fn set_funding_amount(&mut self, amount: u64) {
self.funding = Funding::preliminary(amount)
}
}
impl<N> Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
<N as extension::Nomenclature>::Error: From<FundingError>,
{
#[inline]
pub fn refund_tx(
&mut self,
funding_psbt: Psbt,
remote: bool,
) -> Result<Psbt, <N as extension::Nomenclature>::Error> {
self.set_funding(funding_psbt)?;
self.commitment_tx(remote)
}
#[inline]
pub fn set_funding(
&mut self,
mut psbt: Psbt,
) -> Result<(), <N as extension::Nomenclature>::Error> {
self.constructor.enrich_funding(&mut psbt, &self.funding)?;
self.funding = Funding::with(psbt)?;
Ok(())
}
}
impl<N> Default for Channel<N>
where
N: 'static + Nomenclature + Default,
N::State: State,
{
fn default() -> Self {
Channel::new(
N::Constructor::default(),
N::default_extenders(),
N::default_modifiers(),
)
}
}
impl<N> StrictEncode for Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
{
fn strict_encode<E: Write>(
&self,
e: E,
) -> Result<usize, strict_encoding::Error> {
let mut state = N::State::dumb_default();
self.store_state(&mut state);
state.strict_encode(e)
}
}
impl<N> StrictDecode for Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
{
fn strict_decode<D: Read>(d: D) -> Result<Self, strict_encoding::Error> {
let state = N::State::strict_decode(d)?;
let mut channel = Channel::default();
channel.load_state(&state);
Ok(channel)
}
}
impl<N> Extension<N> for Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
{
fn identity(&self) -> N {
N::default()
}
fn state_change(
&mut self,
request: &<N as extension::Nomenclature>::UpdateRequest,
message: &mut <N as extension::Nomenclature>::PeerMessage,
) -> Result<(), <N as extension::Nomenclature>::Error> {
self.constructor.state_change(request, message)?;
for extension in self.extenders.values_mut() {
extension.state_change(request, message)?;
}
for extension in self.extenders.values_mut() {
extension.state_change(request, message)?;
}
Ok(())
}
fn update_from_local(
&mut self,
message: &<N as extension::Nomenclature>::UpdateMessage,
) -> Result<(), <N as extension::Nomenclature>::Error> {
self.constructor.update_from_local(message)?;
self.extenders
.iter_mut()
.try_for_each(|(_, e)| e.update_from_local(message))?;
self.modifiers
.iter_mut()
.try_for_each(|(_, e)| e.update_from_local(message))?;
Ok(())
}
fn update_from_peer(
&mut self,
message: &<N as extension::Nomenclature>::PeerMessage,
) -> Result<(), <N as extension::Nomenclature>::Error> {
N::update_from_peer(self, message)?;
self.constructor.update_from_peer(message)?;
self.extenders
.iter_mut()
.try_for_each(|(_, e)| e.update_from_peer(message))?;
self.modifiers
.iter_mut()
.try_for_each(|(_, e)| e.update_from_peer(message))?;
Ok(())
}
fn load_state(&mut self, state: &N::State) {
self.funding = state.to_funding();
self.constructor.load_state(state);
for extension in self.extenders.values_mut() {
extension.load_state(state);
}
for extension in self.extenders.values_mut() {
extension.load_state(state);
}
}
fn store_state(&self, state: &mut N::State) {
state.set_funding(&self.funding);
self.constructor.store_state(state);
for extension in self.extenders.values() {
extension.store_state(state);
}
for extension in self.extenders.values() {
extension.store_state(state);
}
}
}
impl<N> ChannelExtension<N> for Channel<N>
where
N: 'static + Nomenclature,
N::State: State,
{
#[inline]
fn new() -> Box<dyn ChannelExtension<N>> {
Box::new(Channel::default())
}
fn build_graph(
&self,
tx_graph: &mut TxGraph,
as_remote_node: bool,
) -> Result<(), <N as extension::Nomenclature>::Error> {
self.constructor.build_graph(tx_graph, as_remote_node)?;
self.extenders
.iter()
.try_for_each(|(_, e)| e.build_graph(tx_graph, as_remote_node))?;
self.modifiers
.iter()
.try_for_each(|(_, e)| e.build_graph(tx_graph, as_remote_node))?;
Ok(())
}
}
pub trait History {
type State;
type Error: std::error::Error;
fn height(&self) -> usize;
fn get(&self, height: usize) -> Result<Self::State, Self::Error>;
fn top(&self) -> Result<Self::State, Self::Error>;
fn bottom(&self) -> Result<Self::State, Self::Error>;
fn dig(&self) -> Result<Self::State, Self::Error>;
fn push(&mut self, state: Self::State) -> Result<&mut Self, Self::Error>;
}