1use std::any::Any;
15use std::collections::BTreeMap;
16use std::io::{Read, Write};
17
18use amplify::DumbDefault;
19use strict_encoding::{StrictDecode, StrictEncode};
20use wallet::psbt::Psbt;
21
22use super::tx_graph::TxGraph;
23use super::Funding;
24use crate::channel::FundingError;
25use crate::{extension, ChannelConstructor, ChannelExtension, Extension};
26
27pub trait Nomenclature: extension::Nomenclature
33where
34 <Self as extension::Nomenclature>::State: State,
35{
36 type Constructor: ChannelConstructor<Self>;
37
38 fn default_extenders() -> Vec<Box<dyn ChannelExtension<Self>>> {
40 Vec::default()
41 }
42
43 fn default_modifiers() -> Vec<Box<dyn ChannelExtension<Self>>> {
45 Vec::default()
46 }
47
48 fn update_from_peer(
51 channel: &mut Channel<Self>,
52 message: &Self::PeerMessage,
53 ) -> Result<(), <Self as extension::Nomenclature>::Error>;
54}
55
56pub trait State: StrictEncode + StrictDecode + DumbDefault {
58 fn to_funding(&self) -> Funding;
59 fn set_funding(&mut self, funding: &Funding);
60}
61
62pub type ExtensionQueue<N> = BTreeMap<N, Box<dyn ChannelExtension<N>>>;
63
64#[derive(Getters)]
70pub struct Channel<N>
71where
72 N: Nomenclature,
73 N::State: State,
74{
75 funding: Funding,
88
89 #[getter(as_mut)]
92 constructor: N::Constructor,
93
94 extenders: ExtensionQueue<N>,
98
99 modifiers: ExtensionQueue<N>,
103}
104
105impl<N> Channel<N>
106where
107 N: 'static + Nomenclature,
108 N::State: State,
109{
110 pub fn new(
112 constructor: N::Constructor,
113 extenders: impl IntoIterator<Item = Box<dyn ChannelExtension<N>>>,
114 modifiers: impl IntoIterator<Item = Box<dyn ChannelExtension<N>>>,
115 ) -> Self {
116 Self {
117 funding: Funding::new(),
118 constructor,
119 extenders: extenders.into_iter().fold(
120 ExtensionQueue::<N>::new(),
121 |mut queue, e| {
122 queue.insert(e.identity(), e);
123 queue
124 },
125 ),
126 modifiers: modifiers.into_iter().fold(
127 ExtensionQueue::<N>::new(),
128 |mut queue, e| {
129 queue.insert(e.identity(), e);
130 queue
131 },
132 ),
133 }
134 }
135
136 pub fn extension<E>(&'static self, id: N) -> Option<&E> {
137 self.extenders
138 .get(&id)
139 .map(|ext| ext as &dyn Any)
140 .and_then(|ext| ext.downcast_ref())
141 .or_else(|| {
142 self.modifiers
143 .get(&id)
144 .map(|ext| ext as &dyn Any)
145 .and_then(|ext| ext.downcast_ref())
146 })
147 }
148
149 pub fn extension_mut<E>(&'static mut self, id: N) -> Option<&mut E> {
150 self.extenders
151 .get_mut(&id)
152 .map(|ext| &mut *ext as &mut dyn Any)
153 .and_then(|ext| ext.downcast_mut())
154 .or_else(|| {
155 self.modifiers
156 .get_mut(&id)
157 .map(|ext| &mut *ext as &mut dyn Any)
158 .and_then(|ext| ext.downcast_mut())
159 })
160 }
161
162 #[inline]
164 pub fn extender(&self, id: N) -> Option<&dyn ChannelExtension<N>> {
165 self.extenders
166 .get(&id)
167 .map(|e| e.as_ref() as &dyn ChannelExtension<N>)
168 }
169
170 #[inline]
172 pub fn modifier(&self, id: N) -> Option<&dyn ChannelExtension<N>> {
173 self.modifiers
174 .get(&id)
175 .map(|e| e.as_ref() as &dyn ChannelExtension<N>)
176 }
177
178 #[inline]
180 pub fn extender_mut(
181 &mut self,
182 id: N,
183 ) -> Option<&mut dyn ChannelExtension<N>> {
184 self.extenders
185 .get_mut(&id)
186 .map(|e| e.as_mut() as &mut dyn ChannelExtension<N>)
187 }
188
189 #[inline]
191 pub fn modifier_mut(
192 &mut self,
193 id: N,
194 ) -> Option<&mut dyn ChannelExtension<N>> {
195 self.modifiers
196 .get_mut(&id)
197 .map(|e| e.as_mut() as &mut dyn ChannelExtension<N>)
198 }
199
200 #[inline]
204 pub fn add_extender(&mut self, extension: Box<dyn ChannelExtension<N>>) {
205 self.extenders.insert(extension.identity(), extension);
206 }
207
208 #[inline]
212 pub fn add_modifier(&mut self, modifier: Box<dyn ChannelExtension<N>>) {
213 self.modifiers.insert(modifier.identity(), modifier);
214 }
215
216 pub fn commitment_tx(
218 &mut self,
219 remote: bool,
220 ) -> Result<Psbt, <N as extension::Nomenclature>::Error> {
221 let mut tx_graph = TxGraph::from_funding(&self.funding);
222 self.build_graph(&mut tx_graph, remote)?;
223 Ok(tx_graph.render_cmt())
224 }
225
226 #[inline]
227 pub fn set_funding_amount(&mut self, amount: u64) {
228 self.funding = Funding::preliminary(amount)
229 }
230}
231
232impl<N> Channel<N>
233where
234 N: 'static + Nomenclature,
235 N::State: State,
236 <N as extension::Nomenclature>::Error: From<FundingError>,
237{
238 #[inline]
241 pub fn refund_tx(
242 &mut self,
243 funding_psbt: Psbt,
244 remote: bool,
245 ) -> Result<Psbt, <N as extension::Nomenclature>::Error> {
246 self.set_funding(funding_psbt)?;
247 self.commitment_tx(remote)
248 }
249
250 #[inline]
251 pub fn set_funding(
252 &mut self,
253 mut psbt: Psbt,
254 ) -> Result<(), <N as extension::Nomenclature>::Error> {
255 self.constructor.enrich_funding(&mut psbt, &self.funding)?;
256 self.funding = Funding::with(psbt)?;
257 Ok(())
258 }
259}
260
261impl<N> Default for Channel<N>
262where
263 N: 'static + Nomenclature + Default,
264 N::State: State,
265{
266 fn default() -> Self {
267 Channel::new(
268 N::Constructor::default(),
269 N::default_extenders(),
270 N::default_modifiers(),
271 )
272 }
273}
274
275impl<N> StrictEncode for Channel<N>
276where
277 N: 'static + Nomenclature,
278 N::State: State,
279{
280 fn strict_encode<E: Write>(
281 &self,
282 e: E,
283 ) -> Result<usize, strict_encoding::Error> {
284 let mut state = N::State::dumb_default();
285 self.store_state(&mut state);
286 state.strict_encode(e)
287 }
288}
289
290impl<N> StrictDecode for Channel<N>
291where
292 N: 'static + Nomenclature,
293 N::State: State,
294{
295 fn strict_decode<D: Read>(d: D) -> Result<Self, strict_encoding::Error> {
296 let state = N::State::strict_decode(d)?;
297 let mut channel = Channel::default();
298 channel.load_state(&state);
299 Ok(channel)
300 }
301}
302
303impl<N> Extension<N> for Channel<N>
306where
307 N: 'static + Nomenclature,
308 N::State: State,
309{
310 fn identity(&self) -> N {
311 N::default()
312 }
313
314 fn state_change(
315 &mut self,
316 request: &<N as extension::Nomenclature>::UpdateRequest,
317 message: &mut <N as extension::Nomenclature>::PeerMessage,
318 ) -> Result<(), <N as extension::Nomenclature>::Error> {
319 self.constructor.state_change(request, message)?;
320 for extension in self.extenders.values_mut() {
321 extension.state_change(request, message)?;
322 }
323 for extension in self.extenders.values_mut() {
324 extension.state_change(request, message)?;
325 }
326 Ok(())
327 }
328
329 fn update_from_local(
330 &mut self,
331 message: &<N as extension::Nomenclature>::UpdateMessage,
332 ) -> Result<(), <N as extension::Nomenclature>::Error> {
333 self.constructor.update_from_local(message)?;
334 self.extenders
335 .iter_mut()
336 .try_for_each(|(_, e)| e.update_from_local(message))?;
337 self.modifiers
338 .iter_mut()
339 .try_for_each(|(_, e)| e.update_from_local(message))?;
340 Ok(())
341 }
342
343 fn update_from_peer(
344 &mut self,
345 message: &<N as extension::Nomenclature>::PeerMessage,
346 ) -> Result<(), <N as extension::Nomenclature>::Error> {
347 N::update_from_peer(self, message)?;
348 self.constructor.update_from_peer(message)?;
349 self.extenders
350 .iter_mut()
351 .try_for_each(|(_, e)| e.update_from_peer(message))?;
352 self.modifiers
353 .iter_mut()
354 .try_for_each(|(_, e)| e.update_from_peer(message))?;
355 Ok(())
356 }
357
358 fn load_state(&mut self, state: &N::State) {
359 self.funding = state.to_funding();
360 self.constructor.load_state(state);
361 for extension in self.extenders.values_mut() {
362 extension.load_state(state);
363 }
364 for extension in self.extenders.values_mut() {
365 extension.load_state(state);
366 }
367 }
368
369 fn store_state(&self, state: &mut N::State) {
370 state.set_funding(&self.funding);
371 self.constructor.store_state(state);
372 for extension in self.extenders.values() {
373 extension.store_state(state);
374 }
375 for extension in self.extenders.values() {
376 extension.store_state(state);
377 }
378 }
379}
380
381impl<N> ChannelExtension<N> for Channel<N>
385where
386 N: 'static + Nomenclature,
387 N::State: State,
388{
389 #[inline]
390 fn new() -> Box<dyn ChannelExtension<N>> {
391 Box::new(Channel::default())
392 }
393
394 fn build_graph(
395 &self,
396 tx_graph: &mut TxGraph,
397 as_remote_node: bool,
398 ) -> Result<(), <N as extension::Nomenclature>::Error> {
399 self.constructor.build_graph(tx_graph, as_remote_node)?;
400 self.extenders
401 .iter()
402 .try_for_each(|(_, e)| e.build_graph(tx_graph, as_remote_node))?;
403 self.modifiers
404 .iter()
405 .try_for_each(|(_, e)| e.build_graph(tx_graph, as_remote_node))?;
406 Ok(())
407 }
408}
409
410pub trait History {
411 type State;
412 type Error: std::error::Error;
413
414 fn height(&self) -> usize;
415 fn get(&self, height: usize) -> Result<Self::State, Self::Error>;
416 fn top(&self) -> Result<Self::State, Self::Error>;
417 fn bottom(&self) -> Result<Self::State, Self::Error>;
418 fn dig(&self) -> Result<Self::State, Self::Error>;
419 fn push(&mut self, state: Self::State) -> Result<&mut Self, Self::Error>;
420}