1use crate::block_index::{BlockIndex, BlockIndexError, BlockValidationStatus};
19use abtc_domain::consensus::connect::{
20 connect_block, disconnect_block, BlockConnectResult, ConnectBlockError, MemoryUtxoSet, UtxoView,
21};
22use abtc_domain::consensus::ConsensusParams;
23use abtc_domain::primitives::{Block, BlockHash, OutPoint};
24use abtc_ports::{ChainStateStore, UtxoEntry};
25
26use std::collections::HashMap;
27
28#[derive(Debug)]
32pub enum ChainStateError {
33 OrphanBlock,
35 ValidationFailed(ConnectBlockError),
37 IndexError(BlockIndexError),
39 MissingBlockData(BlockHash),
41 ReorgFailed {
43 disconnected: u32,
45 reason: Box<ConnectBlockError>,
47 },
48 CorruptedIndex(BlockHash),
50 NoForkPoint,
52}
53
54impl std::fmt::Display for ChainStateError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 ChainStateError::OrphanBlock => write!(f, "orphan block (unknown parent)"),
58 ChainStateError::ValidationFailed(e) => write!(f, "validation failed: {}", e),
59 ChainStateError::IndexError(e) => write!(f, "block index error: {}", e),
60 ChainStateError::MissingBlockData(h) => {
61 write!(f, "missing block data for {}", h)
62 }
63 ChainStateError::ReorgFailed {
64 disconnected,
65 reason,
66 } => write!(
67 f,
68 "reorg failed after disconnecting {} blocks: {}",
69 disconnected, reason
70 ),
71 ChainStateError::CorruptedIndex(h) => {
72 write!(f, "corrupted block index: missing entry for {}", h)
73 }
74 ChainStateError::NoForkPoint => write!(f, "no fork point found during reorg"),
75 }
76 }
77}
78
79impl std::error::Error for ChainStateError {}
80
81impl From<BlockIndexError> for ChainStateError {
82 fn from(e: BlockIndexError) -> Self {
83 match e {
84 BlockIndexError::OrphanHeader => ChainStateError::OrphanBlock,
85 other => ChainStateError::IndexError(other),
86 }
87 }
88}
89
90impl From<ConnectBlockError> for ChainStateError {
91 fn from(e: ConnectBlockError) -> Self {
92 ChainStateError::ValidationFailed(e)
93 }
94}
95
96#[derive(Debug)]
100pub enum ProcessBlockResult {
101 Connected {
103 hash: BlockHash,
105 height: u32,
107 },
108 Reorged {
110 hash: BlockHash,
112 height: u32,
114 disconnected: u32,
116 connected: u32,
118 },
119 SideChain {
121 hash: BlockHash,
123 height: u32,
125 },
126 AlreadyKnown {
128 hash: BlockHash,
130 },
131}
132
133pub struct ChainState {
140 index: BlockIndex,
142 utxo_set: MemoryUtxoSet,
144 params: ConsensusParams,
146 undo_data: HashMap<BlockHash, BlockConnectResult>,
149 blocks: HashMap<BlockHash, Block>,
153 tip: BlockHash,
155 tip_height: u32,
157 verify_scripts: bool,
159}
160
161impl ChainState {
162 pub fn new(genesis: Block, params: ConsensusParams) -> Result<Self, ChainStateError> {
167 let genesis_hash = genesis.block_hash();
168 let header = genesis.header.clone();
169
170 let mut index = BlockIndex::new_with_pow_limit(params.pow_limit_bits);
172 index.init_genesis(header);
173
174 let utxo_set = MemoryUtxoSet::new();
176 let result = connect_block(&genesis, 0, &utxo_set, ¶ms, false)?;
177
178 let mut chain_state = ChainState {
179 index,
180 utxo_set,
181 params,
182 undo_data: HashMap::new(),
183 blocks: HashMap::new(),
184 tip: genesis_hash,
185 tip_height: 0,
186 verify_scripts: true,
187 };
188
189 chain_state.utxo_set.apply_connect(&result);
191 chain_state.undo_data.insert(genesis_hash, result);
192 chain_state.blocks.insert(genesis_hash, genesis);
193
194 Ok(chain_state)
195 }
196
197 pub fn set_verify_scripts(&mut self, verify: bool) {
199 self.verify_scripts = verify;
200 }
201
202 pub fn process_block(&mut self, block: Block) -> Result<ProcessBlockResult, ChainStateError> {
207 let hash = block.block_hash();
208
209 if self.blocks.contains_key(&hash) {
211 return Ok(ProcessBlockResult::AlreadyKnown { hash });
212 }
213
214 let (_, reorg_signalled) = self.index.add_header(block.header.clone())?;
216
217 self.blocks.insert(hash, block);
219
220 let entry = self.index.get(&hash).unwrap();
221 let height = entry.height;
222
223 if self.index.get(&hash).unwrap().header.prev_block_hash == self.tip {
225 return self.connect_tip(hash, height);
226 }
227
228 if reorg_signalled {
230 return self.activate_best_chain(hash);
231 }
232
233 Ok(ProcessBlockResult::SideChain { hash, height })
235 }
236
237 fn connect_tip(
239 &mut self,
240 hash: BlockHash,
241 height: u32,
242 ) -> Result<ProcessBlockResult, ChainStateError> {
243 let block = self.blocks.get(&hash).unwrap();
244
245 let result = connect_block(
246 block,
247 height,
248 &self.utxo_set,
249 &self.params,
250 self.verify_scripts,
251 )?;
252
253 self.utxo_set.apply_connect(&result);
254 self.undo_data.insert(hash, result);
255 self.tip = hash;
256 self.tip_height = height;
257 self.index
258 .set_status(&hash, BlockValidationStatus::FullyValidated);
259
260 Ok(ProcessBlockResult::Connected { hash, height })
261 }
262
263 fn activate_best_chain(
268 &mut self,
269 new_tip_hash: BlockHash,
270 ) -> Result<ProcessBlockResult, ChainStateError> {
271 let old_chain = self.index.get_ancestor_chain(&self.tip);
273 let new_chain = self.index.get_ancestor_chain(&new_tip_hash);
274
275 let old_set: std::collections::HashSet<BlockHash> = old_chain.iter().copied().collect();
277
278 let fork_hash = new_chain
281 .iter()
282 .rev()
283 .find(|h| old_set.contains(h))
284 .copied()
285 .ok_or(ChainStateError::NoForkPoint)?;
286
287 let fork_height = self
288 .index
289 .get(&fork_hash)
290 .ok_or(ChainStateError::CorruptedIndex(fork_hash))?
291 .height;
292
293 let to_disconnect: Vec<BlockHash> = old_chain
295 .iter()
296 .take_while(|h| **h != fork_hash)
297 .copied()
298 .collect();
299
300 let to_connect: Vec<BlockHash> = new_chain
303 .iter()
304 .rev()
305 .skip_while(|h| **h != fork_hash)
306 .skip(1) .copied()
308 .collect();
309
310 let num_disconnect = to_disconnect.len() as u32;
311 let num_connect = to_connect.len() as u32;
312
313 for (i, old_hash) in to_disconnect.iter().enumerate() {
315 let undo = self
316 .undo_data
317 .remove(old_hash)
318 .ok_or(ChainStateError::MissingBlockData(*old_hash))?;
319
320 let disc = disconnect_block(&undo);
321 self.utxo_set.apply_disconnect(&disc);
322 self.index
323 .set_status(old_hash, BlockValidationStatus::HeaderValid);
324
325 self.tip = if i + 1 < to_disconnect.len() {
326 self.index
329 .get(old_hash)
330 .ok_or(ChainStateError::CorruptedIndex(*old_hash))?
331 .header
332 .prev_block_hash
333 } else {
334 fork_hash
335 };
336 }
337
338 self.tip = fork_hash;
339 self.tip_height = fork_height;
340
341 for new_hash in &to_connect {
343 let height = self
344 .index
345 .get(new_hash)
346 .ok_or(ChainStateError::CorruptedIndex(*new_hash))?
347 .height;
348 let block = self
349 .blocks
350 .get(new_hash)
351 .ok_or(ChainStateError::MissingBlockData(*new_hash))?;
352
353 let result = connect_block(
354 block,
355 height,
356 &self.utxo_set,
357 &self.params,
358 self.verify_scripts,
359 )
360 .map_err(|e| ChainStateError::ReorgFailed {
361 disconnected: num_disconnect,
362 reason: Box::new(e),
363 })?;
364
365 self.utxo_set.apply_connect(&result);
366 self.undo_data.insert(*new_hash, result);
367 self.tip = *new_hash;
368 self.tip_height = height;
369 self.index
370 .set_status(new_hash, BlockValidationStatus::FullyValidated);
371 }
372
373 let final_height = self.tip_height;
374
375 Ok(ProcessBlockResult::Reorged {
376 hash: new_tip_hash,
377 height: final_height,
378 disconnected: num_disconnect,
379 connected: num_connect,
380 })
381 }
382
383 pub fn tip(&self) -> BlockHash {
387 self.tip
388 }
389
390 pub fn tip_height(&self) -> u32 {
392 self.tip_height
393 }
394
395 pub fn index(&self) -> &BlockIndex {
397 &self.index
398 }
399
400 pub fn index_mut(&mut self) -> &mut BlockIndex {
402 &mut self.index
403 }
404
405 pub fn utxo_set(&self) -> &MemoryUtxoSet {
407 &self.utxo_set
408 }
409
410 pub fn get_block(&self, hash: &BlockHash) -> Option<&Block> {
412 self.blocks.get(hash)
413 }
414
415 pub fn params(&self) -> &ConsensusParams {
417 &self.params
418 }
419
420 pub fn block_count(&self) -> usize {
422 self.blocks.len()
423 }
424
425 pub fn utxo_count(&self) -> usize {
427 self.utxo_set.len()
428 }
429
430 pub fn has_utxo(&self, outpoint: &OutPoint) -> bool {
432 self.utxo_set.get_utxo(outpoint).is_some()
433 }
434
435 pub fn get_block_hash_at_height(&self, height: u32) -> Option<BlockHash> {
437 self.index.get_hash_at_height(height)
438 }
439
440 pub async fn flush_to_store(
448 &self,
449 store: &dyn ChainStateStore,
450 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
451 let utxo_adds: Vec<(abtc_domain::primitives::Txid, u32, UtxoEntry)> = self
453 .utxo_set
454 .iter()
455 .map(|(outpoint, entry)| {
456 (
457 outpoint.txid,
458 outpoint.vout,
459 UtxoEntry {
460 output: entry.output.clone(),
461 height: entry.height,
462 is_coinbase: entry.is_coinbase,
463 },
464 )
465 })
466 .collect();
467
468 store.write_utxo_set(utxo_adds, Vec::new()).await?;
470 store.write_chain_tip(self.tip, self.tip_height).await?;
471
472 tracing::info!(
473 "Flushed {} UTXOs to persistent store (tip={}, height={})",
474 self.utxo_set.len(),
475 self.tip,
476 self.tip_height
477 );
478
479 Ok(())
480 }
481
482 pub async fn flush_block_delta(
486 &self,
487 result: &BlockConnectResult,
488 store: &dyn ChainStateStore,
489 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
490 let adds: Vec<(abtc_domain::primitives::Txid, u32, UtxoEntry)> = result
491 .created
492 .iter()
493 .map(|(outpoint, entry)| {
494 (
495 outpoint.txid,
496 outpoint.vout,
497 UtxoEntry {
498 output: entry.output.clone(),
499 height: entry.height,
500 is_coinbase: entry.is_coinbase,
501 },
502 )
503 })
504 .collect();
505
506 let removes: Vec<(abtc_domain::primitives::Txid, u32)> = result
507 .spent
508 .keys()
509 .map(|outpoint| (outpoint.txid, outpoint.vout))
510 .collect();
511
512 store.write_utxo_set(adds, removes).await?;
513 store.write_chain_tip(self.tip, self.tip_height).await?;
514
515 Ok(())
516 }
517
518 pub async fn load_tip_from_store(
522 store: &dyn ChainStateStore,
523 ) -> Result<Option<(BlockHash, u32)>, Box<dyn std::error::Error + Send + Sync>> {
524 let (hash, height) = store.get_best_chain_tip().await?;
525 if hash == BlockHash::zero() && height == 0 {
526 return Ok(None);
527 }
528 Ok(Some((hash, height)))
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use abtc_domain::primitives::{
536 Amount, Block, BlockHash, BlockHeader, Hash256, Transaction, TxOut,
537 };
538 use abtc_domain::Script;
539
540 fn mainnet_params() -> ConsensusParams {
541 ConsensusParams::mainnet()
542 }
543
544 fn make_genesis() -> Block {
545 let coinbase = Transaction::coinbase(
546 0,
547 Script::from_bytes(vec![0x04, 0xFF, 0xFF, 0x00, 0x1D]),
548 vec![TxOut::new(
549 Amount::from_sat(5_000_000_000),
550 Script::from_bytes(vec![0x76, 0xA9]),
551 )],
552 );
553 let mut block = Block::new(
555 BlockHeader {
556 version: 1,
557 prev_block_hash: BlockHash::zero(),
558 merkle_root: Hash256::zero(),
559 time: 1231006505,
560 bits: 0x1d00ffff,
561 nonce: 0,
562 },
563 vec![coinbase],
564 );
565 block.header.merkle_root = block.compute_merkle_root();
566 block
567 }
568
569 #[test]
570 fn test_chain_state_creation() {
571 let genesis = make_genesis();
572 let cs = ChainState::new(genesis.clone(), mainnet_params()).unwrap();
573
574 assert_eq!(cs.tip_height(), 0);
575 assert_eq!(cs.tip(), genesis.block_hash());
576 }
577
578 #[test]
579 fn test_duplicate_block_is_already_known() {
580 let genesis = make_genesis();
581 let genesis2 = genesis.clone();
582 let mut cs = ChainState::new(genesis, mainnet_params()).unwrap();
583
584 match cs.process_block(genesis2).unwrap() {
585 ProcessBlockResult::AlreadyKnown { .. } => {}
586 other => panic!("Expected AlreadyKnown, got {:?}", other),
587 }
588 }
589
590 #[test]
591 fn test_orphan_block_rejected() {
592 let genesis = make_genesis();
593 let mut cs = ChainState::new(genesis, mainnet_params()).unwrap();
594
595 let orphan_header = BlockHeader {
597 version: 1,
598 prev_block_hash: BlockHash::from_hash(Hash256::from_bytes([0xFF; 32])),
599 merkle_root: Hash256::from_bytes([0u8; 32]),
600 time: 1231006505 + 600,
601 bits: 0x1d00ffff,
602 nonce: 42,
603 };
604 let orphan = Block::new(orphan_header, vec![]);
605
606 match cs.process_block(orphan) {
607 Err(ChainStateError::OrphanBlock) => {}
608 other => panic!("Expected OrphanBlock error, got {:?}", other),
609 }
610 }
611
612 #[test]
613 fn test_set_verify_scripts() {
614 let genesis = make_genesis();
615 let mut cs = ChainState::new(genesis, mainnet_params()).unwrap();
616
617 assert!(cs.verify_scripts);
618 cs.set_verify_scripts(false);
619 assert!(!cs.verify_scripts);
620 }
621
622 #[test]
623 fn test_chain_state_error_display() {
624 let err = ChainStateError::OrphanBlock;
625 assert!(err.to_string().contains("orphan"));
626
627 let err = ChainStateError::MissingBlockData(BlockHash::zero());
628 assert!(err.to_string().contains("missing block data"));
629 }
630
631 #[test]
632 fn test_chain_state_error_from_block_index_orphan() {
633 let err = ChainStateError::from(BlockIndexError::OrphanHeader);
634 match err {
635 ChainStateError::OrphanBlock => {}
636 _ => panic!("Expected OrphanBlock"),
637 }
638 }
639
640 #[tokio::test]
641 async fn test_load_tip_from_empty_store() {
642 let store = abtc_adapters::storage::InMemoryChainStateStore::new();
643 let result = ChainState::load_tip_from_store(&store).await.unwrap();
644 assert!(result.is_none());
645 }
646
647 #[tokio::test]
648 async fn test_load_tip_from_populated_store() {
649 let store = abtc_adapters::storage::InMemoryChainStateStore::new();
650 let tip = BlockHash::from_hash(Hash256::from_bytes([0x42; 32]));
651 store.write_chain_tip(tip, 100).await.unwrap();
652
653 let result = ChainState::load_tip_from_store(&store).await.unwrap();
654 assert_eq!(result, Some((tip, 100)));
655 }
656
657 #[tokio::test]
658 async fn test_flush_to_store() {
659 let genesis = make_genesis();
660 let cs = ChainState::new(genesis, mainnet_params()).unwrap();
661 let store = abtc_adapters::storage::InMemoryChainStateStore::new();
662
663 cs.flush_to_store(&store).await.unwrap();
664
665 let (tip, height) = store.get_best_chain_tip().await.unwrap();
666 assert_eq!(tip, cs.tip());
667 assert_eq!(height, 0);
668 }
669
670 #[test]
681 fn regression_chain_state_error_variants_exist() {
682 let hash = BlockHash::from_hash(Hash256::from_bytes([0xAB; 32]));
686 let err1 = ChainStateError::CorruptedIndex(hash);
687 let err2 = ChainStateError::NoForkPoint;
688 let s1 = format!("{}", err1);
689 let s2 = format!("{}", err2);
690 assert!(s1.contains("corrupted"));
691 assert!(s2.contains("fork"));
692 }
693}