1use std::collections::{btree_map, BTreeMap};
17use std::io::{Read, Write};
18
19use amplify::DumbDefault;
20use internet2::presentation::sphinx::{Hop, SphinxPayload};
21use p2p::bolt::PaymentRequest;
22use strict_encoding::{StrictDecode, StrictEncode};
23
24use crate::{extension, Extension, RouterExtension};
25
26pub type ExtensionQueue<N> = BTreeMap<N, Box<dyn RouterExtension<N>>>;
27
28pub trait Nomenclature
34where
35 Self: extension::Nomenclature,
36{
37 type HopPayload: SphinxPayload;
38
39 fn default_extensions() -> Vec<Box<dyn RouterExtension<Self>>>;
40
41 fn update_from_peer(
44 router: &mut Router<Self>,
45 message: &Self::PeerMessage,
46 ) -> Result<(), <Self as extension::Nomenclature>::Error>;
47}
48
49pub struct Router<N>
52where
53 N: Nomenclature,
54{
55 extensions: ExtensionQueue<N>,
56}
57
58impl<N> Router<N>
59where
60 N: Nomenclature + 'static,
61{
62 pub fn new(
64 extensions: impl IntoIterator<Item = Box<dyn RouterExtension<N>>>,
65 ) -> Self {
66 Self {
67 extensions: extensions.into_iter().fold(
68 ExtensionQueue::<N>::new(),
69 |mut queue, e| {
70 queue.insert(e.identity(), e);
71 queue
72 },
73 ),
74 }
75 }
76
77 #[inline]
78 pub fn extensions(
79 &self,
80 ) -> btree_map::Iter<N, Box<dyn RouterExtension<N>>> {
81 self.extensions.iter()
82 }
83
84 #[inline]
85 pub fn extensions_mut(
86 &mut self,
87 ) -> btree_map::IterMut<N, Box<dyn RouterExtension<N>>> {
88 self.extensions.iter_mut()
89 }
90
91 #[inline]
93 pub fn add_extension(&mut self, extension: Box<dyn RouterExtension<N>>) {
94 self.extensions.insert(extension.identity(), extension);
95 }
96
97 pub fn compute_route(
98 &mut self,
99 payment: PaymentRequest,
100 ) -> Vec<Hop<N::HopPayload>> {
101 let mut route = vec![];
102 self.build_route(payment, &mut route);
103 route
104 }
105}
106
107impl<N> Default for Router<N>
108where
109 N: 'static + Nomenclature + Default,
110{
111 fn default() -> Self {
112 Router::new(N::default_extensions())
113 }
114}
115
116impl<N> StrictEncode for Router<N>
117where
118 N: 'static + Nomenclature,
119 N::State: StrictEncode,
120{
121 fn strict_encode<E: Write>(
122 &self,
123 e: E,
124 ) -> Result<usize, strict_encoding::Error> {
125 let mut state = N::State::dumb_default();
126 self.store_state(&mut state);
127 state.strict_encode(e)
128 }
129}
130
131impl<N> StrictDecode for Router<N>
132where
133 N: 'static + Nomenclature,
134 N::State: StrictDecode,
135{
136 fn strict_decode<D: Read>(d: D) -> Result<Self, strict_encoding::Error> {
137 let state = N::State::strict_decode(d)?;
138 let mut router = Router::default();
139 router.load_state(&state);
140 Ok(router)
141 }
142}
143
144impl<N> Extension<N> for Router<N>
145where
146 N: extension::Nomenclature + Nomenclature,
147{
148 #[inline]
149 fn identity(&self) -> N {
150 N::default()
151 }
152
153 fn state_change(
154 &mut self,
155 request: &<N as extension::Nomenclature>::UpdateRequest,
156 message: &mut <N as extension::Nomenclature>::PeerMessage,
157 ) -> Result<(), <N as extension::Nomenclature>::Error> {
158 for extension in self.extensions.values_mut() {
159 extension.state_change(request, message)?;
160 }
161 Ok(())
162 }
163
164 fn update_from_local(
165 &mut self,
166 message: &<N as extension::Nomenclature>::UpdateMessage,
167 ) -> Result<(), <N as extension::Nomenclature>::Error> {
168 self.extensions
169 .iter_mut()
170 .try_for_each(|(_, e)| e.update_from_local(message))?;
171 Ok(())
172 }
173
174 fn update_from_peer(
175 &mut self,
176 message: &<N as extension::Nomenclature>::PeerMessage,
177 ) -> Result<(), <N as extension::Nomenclature>::Error> {
178 N::update_from_peer(self, message)?;
179 self.extensions
180 .iter_mut()
181 .try_for_each(|(_, e)| e.update_from_peer(message))?;
182 Ok(())
183 }
184
185 fn load_state(&mut self, state: &N::State) {
186 for extension in self.extensions.values_mut() {
187 extension.load_state(state);
188 }
189 }
190
191 fn store_state(&self, state: &mut N::State) {
192 for extension in self.extensions.values() {
193 extension.store_state(state);
194 }
195 }
196}
197
198impl<N> RouterExtension<N> for Router<N>
199where
200 N: Nomenclature + 'static,
201{
202 fn new() -> Box<dyn RouterExtension<N>>
203 where
204 Self: Sized,
205 {
206 Box::new(Router::default())
207 }
208
209 fn build_route(
210 &mut self,
211 payment: PaymentRequest,
212 route: &mut Vec<Hop<N::HopPayload>>,
213 ) {
214 for extension in self.extensions.values_mut() {
215 extension.build_route(payment, route);
216 }
217 }
218}