1use std::{
16 fmt::{Debug, Formatter},
17 sync::Arc,
18};
19
20use amaru_consensus::headers_tree::data_generation::Action;
21use amaru_kernel::{
22 BlockHeader, EraHistory, IsHeader, NetworkName, Peer, Point, ProtocolParameters, Transaction,
23 cardano::network_block::make_encoded_block,
24};
25use amaru_mempool::InMemoryMempool;
26use amaru_ouroboros::{ChainStore, ConnectionsResource, TxId, in_memory_consensus_store::InMemConsensusStore};
27use amaru_stores::in_memory::MemoryStore;
28use anyhow::anyhow;
29use parking_lot::Mutex;
30use pure_stage::trace_buffer::TraceBuffer;
31
32use crate::{
33 stages::config::{Config, StoreType},
34 tests::{
35 configuration::NodeType::{NodeUnderTest, UpstreamNode},
36 in_memory_connection_provider::InMemoryConnectionProvider,
37 test_data::{create_transactions, create_transactions_in_mempool},
38 },
39};
40
41#[derive(Clone)]
49pub struct NodeTestConfig {
50 pub chain_store: Arc<dyn ChainStore<BlockHeader>>,
51 pub mempool: Arc<InMemoryMempool<Transaction>>,
52 pub connections: ConnectionsResource,
53 pub chain_length: usize,
54 pub upstream_peers: Vec<Peer>,
55 pub listen_address: String,
56 pub mailbox_size: usize,
57 pub trace_buffer: Arc<Mutex<TraceBuffer>>,
58 pub seed: u64,
59 pub actions: Vec<Action>,
60 pub node_type: NodeType,
61 pub network_name: NetworkName,
62}
63
64impl Debug for NodeTestConfig {
65 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("NodeTestConfig")
67 .field("chain_length", &self.chain_length)
68 .field("upstream_peers", &self.upstream_peers)
69 .field("listen_address", &self.listen_address)
70 .field("mailbox_size", &self.mailbox_size)
71 .field("seed", &self.seed)
72 .field("actions", &self.actions)
73 .field("node_type", &self.node_type)
74 .finish()
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum NodeType {
80 UpstreamNode,
81 NodeUnderTest,
82 DownstreamNode,
83}
84
85impl Default for NodeTestConfig {
86 fn default() -> Self {
87 Self {
88 chain_store: Arc::new(InMemConsensusStore::default()),
89 mempool: Arc::new(InMemoryMempool::default()),
90 connections: Arc::new(InMemoryConnectionProvider::default()),
91 chain_length: 10,
92 upstream_peers: vec![Peer::new("127.0.0.1:3001")],
93 listen_address: "127.0.0.1:3000".to_string(),
94 mailbox_size: 10000,
95 trace_buffer: Arc::new(Mutex::new(TraceBuffer::default())),
96 seed: 42,
97 actions: Vec::new(),
98 network_name: NetworkName::Preprod,
99 node_type: NodeUnderTest,
100 }
101 }
102}
103
104impl NodeTestConfig {
105 pub fn enter_span(&self) -> tracing::span::EnteredSpan {
107 tracing::info_span!("node", id = %self.listen_address).entered()
108 }
109
110 pub fn initiator() -> Self {
111 Self::default()
112 .with_chain_length(INITIATOR_BLOCKS_NB)
113 .with_txs(INITIATOR_TXS_NB)
114 .with_upstream_peer(Peer::new("127.0.0.1:3001"))
115 .with_listen_address("127.0.0.1:3000")
116 .with_node_type(NodeUnderTest)
117 }
118
119 pub fn responder() -> Self {
120 Self::default()
121 .with_chain_length(RESPONDER_BLOCKS_NB)
122 .with_txs(RESPONDER_TXS_NB)
123 .with_no_upstream_peers()
124 .with_listen_address("127.0.0.1:3001")
125 .with_node_type(UpstreamNode)
126 }
127
128 pub fn era_history(&self) -> &EraHistory {
129 self.network_name.into()
130 }
131
132 pub fn protocol_parameters(&self) -> anyhow::Result<&ProtocolParameters> {
133 self.network_name.try_into().map_err(|e: String| anyhow!(e))
134 }
135
136 pub fn with_no_upstream_peers(mut self) -> Self {
137 self.upstream_peers = vec![];
138 self
139 }
140
141 pub fn with_listen_address(mut self, listen_address: &str) -> Self {
142 self.listen_address = listen_address.to_string();
143 self
144 }
145
146 pub fn with_chain_length(mut self, chain_length: usize) -> Self {
147 self.chain_length = chain_length;
148 self
149 }
150
151 pub fn with_chain_store(mut self, chain_store: Arc<dyn ChainStore<BlockHeader>>) -> Self {
152 self.chain_store = chain_store;
153 self
154 }
155
156 pub fn with_mempool(mut self, mempool: Arc<InMemoryMempool<Transaction>>) -> Self {
157 self.mempool = mempool;
158 self
159 }
160
161 pub fn with_connections(mut self, connections: ConnectionsResource) -> Self {
162 self.connections = connections;
163 self
164 }
165
166 pub fn with_mailbox_size(mut self, size: usize) -> Self {
167 self.mailbox_size = size;
168 self
169 }
170
171 pub fn with_trace_buffer(mut self, trace_buffer: Arc<Mutex<TraceBuffer>>) -> Self {
172 self.trace_buffer = trace_buffer;
173 self
174 }
175
176 pub fn with_seed(mut self, seed: u64) -> Self {
177 self.seed = seed;
178 self
179 }
180
181 pub fn with_network_name(mut self, network_name: NetworkName) -> Self {
182 self.network_name = network_name;
183 self
184 }
185
186 pub fn with_txs(self, txs_nb: usize) -> Self {
187 create_transactions_in_mempool(self.mempool.clone(), txs_nb);
188 self
189 }
190
191 pub fn with_upstream_peer(mut self, upstream_peer: Peer) -> Self {
192 self.upstream_peers = vec![upstream_peer];
193 self
194 }
195
196 pub fn with_upstream_peers(mut self, upstream_peers: Vec<Peer>) -> Self {
197 self.upstream_peers = upstream_peers;
198 self
199 }
200
201 pub fn with_actions(mut self, actions: Vec<Action>) -> Self {
202 self.actions = actions;
203 self
204 }
205
206 pub fn upstream_peers(&self) -> Vec<Peer> {
207 self.upstream_peers.clone()
208 }
209
210 pub fn with_node_type(mut self, node_type: NodeType) -> Self {
211 self.node_type = node_type;
212 self
213 }
214
215 #[expect(clippy::unwrap_used)]
223 pub fn with_validated_blocks(self, headers: Vec<BlockHeader>) -> Self {
224 let _span = self.enter_span();
225 for header in headers.iter() {
226 tracing::info!(
227 "storing block for header {} (parent hash: {})",
228 header.point(),
229 header.parent_hash().unwrap_or(Point::Origin.hash())
230 );
231 self.chain_store.store_header(header).unwrap();
232 self.chain_store.store_block(&header.hash(), &make_encoded_block(header, self.era_history())).unwrap();
233 self.chain_store.roll_forward_chain(&header.point()).unwrap();
234 }
235
236 if let Some(header) = headers.first() {
237 tracing::info!("set the anchor to {}", header.point());
238 self.chain_store.set_anchor_hash(&header.hash()).unwrap();
239 tracing::info!("set the tip to {}", header.point());
240 self.chain_store.set_best_chain_hash(&header.hash()).unwrap();
241 }
242 self
243 }
244
245 pub fn make_node_configuration(&self) -> anyhow::Result<Config> {
249 let mut config = Config {
250 upstream_peers: self.upstream_peers.iter().map(|p| p.name.clone()).collect(),
251 network: self.network_name,
252 network_magic: self.network_name.to_network_magic(),
253 ..Default::default()
254 };
255
256 config.listen_address = self.listen_address.clone();
257
258 let ledger_store = MemoryStore::new(self.era_history().clone(), self.protocol_parameters()?.clone());
263 let chain_anchor = self
264 .chain_store
265 .load_header(&self.chain_store.get_anchor_hash())
266 .map(|h| h.point())
267 .unwrap_or(Point::Origin);
268 ledger_store.set_tip(chain_anchor);
269
270 config.ledger_store = StoreType::InMem(ledger_store);
271 config.chain_store = StoreType::InMem(self.chain_store.clone());
272 Ok(config)
273 }
274}
275
276pub const RESPONDER_BLOCKS_NB: usize = 10;
277pub const INITIATOR_BLOCKS_NB: usize = 4;
278
279pub const RESPONDER_TXS_NB: usize = 10;
280pub const INITIATOR_TXS_NB: usize = 10;
281
282pub fn get_tx_ids() -> Vec<TxId> {
284 create_transactions(RESPONDER_TXS_NB).into_iter().map(|tx| TxId::from(&tx)).collect()
285}