1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
//! Types for variable difficulty Merkle Mountain Range (MMR) in CKB.
//!
//! ## References
//!
//! - [CKB RFC 0044](https://github.com/nervosnetwork/rfcs/blob/master/rfcs/0044-ckb-light-client/0044-ckb-light-client.md)

use ckb_hash::new_blake2b;
use ckb_merkle_mountain_range::{Error as MMRError, Merge, MerkleProof, Result as MMRResult, MMR};

use crate::{
    core,
    core::{BlockNumber, EpochNumber, EpochNumberWithFraction, ExtraHashView, HeaderView},
    packed,
    prelude::*,
    utilities::compact_to_difficulty,
    U256,
};

/// A struct to implement MMR `Merge` trait
pub struct MergeHeaderDigest;
/// MMR root
pub type ChainRootMMR<S> = MMR<packed::HeaderDigest, MergeHeaderDigest, S>;
/// MMR proof
pub type MMRProof = MerkleProof<packed::HeaderDigest, MergeHeaderDigest>;

/// A Header and the fields which are used to do verification for its extra hash.
#[derive(Debug, Clone)]
pub struct VerifiableHeader {
    header: HeaderView,
    uncles_hash: packed::Byte32,
    extension: Option<packed::Bytes>,
    parent_chain_root: packed::HeaderDigest,
}

impl core::BlockView {
    /// Get the MMR header digest of the block
    pub fn digest(&self) -> packed::HeaderDigest {
        self.header().digest()
    }
}

impl core::HeaderView {
    /// Get the MMR header digest of the header
    pub fn digest(&self) -> packed::HeaderDigest {
        let raw = self.data().raw();
        packed::HeaderDigest::new_builder()
            .children_hash(self.hash())
            .total_difficulty(self.difficulty().pack())
            .start_number(raw.number())
            .end_number(raw.number())
            .start_epoch(raw.epoch())
            .end_epoch(raw.epoch())
            .start_timestamp(raw.timestamp())
            .end_timestamp(raw.timestamp())
            .start_compact_target(raw.compact_target())
            .end_compact_target(raw.compact_target())
            .build()
    }
}

/// Trait for representing a header digest.
pub trait HeaderDigest {
    /// Verify the header digest
    fn verify(&self) -> Result<(), String>;
}

impl HeaderDigest for packed::HeaderDigest {
    /// Verify the MMR header digest
    fn verify(&self) -> Result<(), String> {
        // 1. Check block numbers.
        let start_number: BlockNumber = self.start_number().unpack();
        let end_number: BlockNumber = self.end_number().unpack();
        if start_number > end_number {
            let errmsg = format!(
                "failed since the start block number is bigger than the end ([{start_number},{end_number}])"
            );
            return Err(errmsg);
        }

        // 2. Check epochs.
        let start_epoch: EpochNumberWithFraction = self.start_epoch().unpack();
        let end_epoch: EpochNumberWithFraction = self.end_epoch().unpack();
        let start_epoch_number = start_epoch.number();
        let end_epoch_number = end_epoch.number();
        if start_epoch != end_epoch
            && ((start_epoch_number > end_epoch_number)
                || (start_epoch_number == end_epoch_number
                    && start_epoch.index() > end_epoch.index()))
        {
            let errmsg = format!(
                "failed since the start epoch is bigger than the end ([{start_epoch:#},{end_epoch:#}])"
            );
            return Err(errmsg);
        }

        // 3. Check difficulties when in the same epoch.
        let start_compact_target: u32 = self.start_compact_target().unpack();
        let end_compact_target: u32 = self.end_compact_target().unpack();
        let total_difficulty: U256 = self.total_difficulty().unpack();
        if start_epoch_number == end_epoch_number {
            if start_compact_target != end_compact_target {
                // In the same epoch, all compact targets should be same.
                let errmsg = format!(
                    "failed since the compact targets should be same during epochs ([{start_epoch:#},{end_epoch:#}])"
                );
                return Err(errmsg);
            } else {
                // Sum all blocks difficulties to check total difficulty.
                let blocks_count = end_number - start_number + 1;
                let block_difficulty = compact_to_difficulty(start_compact_target);
                let total_difficulty_calculated = block_difficulty * blocks_count;
                if total_difficulty != total_difficulty_calculated {
                    let errmsg = format!(
                        "failed since total difficulty is {total_difficulty} but the calculated is {total_difficulty_calculated} \
                        during epochs ([{start_epoch:#},{end_epoch:#}])"
                    );
                    return Err(errmsg);
                }
            }
        }

        Ok(())
    }
}

impl Merge for MergeHeaderDigest {
    type Item = packed::HeaderDigest;

    fn merge(lhs: &Self::Item, rhs: &Self::Item) -> MMRResult<Self::Item> {
        let children_hash = {
            let mut hasher = new_blake2b();
            let mut hash = [0u8; 32];
            hasher.update(&lhs.calc_mmr_hash().raw_data());
            hasher.update(&rhs.calc_mmr_hash().raw_data());
            hasher.finalize(&mut hash);
            hash
        };

        let total_difficulty = {
            let l: U256 = lhs.total_difficulty().unpack();
            let r: U256 = rhs.total_difficulty().unpack();
            l + r
        };

        // 1. Check block numbers.
        let lhs_end_number: BlockNumber = lhs.end_number().unpack();
        let rhs_start_number: BlockNumber = rhs.start_number().unpack();
        if lhs_end_number + 1 != rhs_start_number {
            let errmsg = format!(
                "failed since the blocks isn't continuous ([-,{lhs_end_number}], [{rhs_start_number},-])"
            );
            return Err(MMRError::MergeError(errmsg));
        }

        // 2. Check epochs.
        let lhs_end_epoch: EpochNumberWithFraction = lhs.end_epoch().unpack();
        let rhs_start_epoch: EpochNumberWithFraction = rhs.start_epoch().unpack();
        if !rhs_start_epoch.is_successor_of(lhs_end_epoch) && !lhs_end_epoch.is_genesis() {
            let errmsg = format!(
                "failed since the epochs isn't continuous ([-,{lhs_end_epoch:#}], [{rhs_start_epoch:#},-])",
            );
            return Err(MMRError::MergeError(errmsg));
        }

        Ok(Self::Item::new_builder()
            .children_hash(children_hash.pack())
            .total_difficulty(total_difficulty.pack())
            .start_number(lhs.start_number())
            .start_epoch(lhs.start_epoch())
            .start_timestamp(lhs.start_timestamp())
            .start_compact_target(lhs.start_compact_target())
            .end_number(rhs.end_number())
            .end_epoch(rhs.end_epoch())
            .end_timestamp(rhs.end_timestamp())
            .end_compact_target(rhs.end_compact_target())
            .build())
    }

    fn merge_peaks(lhs: &Self::Item, rhs: &Self::Item) -> MMRResult<Self::Item> {
        Self::merge(rhs, lhs)
    }
}

impl From<packed::VerifiableHeader> for VerifiableHeader {
    fn from(raw: packed::VerifiableHeader) -> Self {
        Self::new(
            raw.header().into_view(),
            raw.uncles_hash(),
            raw.extension().to_opt(),
            raw.parent_chain_root(),
        )
    }
}

impl VerifiableHeader {
    /// Creates a new verifiable header.
    pub fn new(
        header: HeaderView,
        uncles_hash: packed::Byte32,
        extension: Option<packed::Bytes>,
        parent_chain_root: packed::HeaderDigest,
    ) -> Self {
        Self {
            header,
            uncles_hash,
            extension,
            parent_chain_root,
        }
    }

    /// Checks if the current verifiable header is valid.
    pub fn is_valid(&self, mmr_activated_epoch_number: EpochNumber) -> bool {
        let mmr_activated_epoch = EpochNumberWithFraction::new(mmr_activated_epoch_number, 0, 1);
        let has_chain_root = self.header().epoch() > mmr_activated_epoch;
        if has_chain_root {
            if self.header().is_genesis() {
                if !self.parent_chain_root().is_default() {
                    return false;
                }
            } else {
                let is_extension_beginning_with_chain_root_hash = self
                    .extension()
                    .map(|extension| {
                        let actual_extension_data = extension.raw_data();
                        let parent_chain_root_hash = self.parent_chain_root().calc_mmr_hash();
                        actual_extension_data.starts_with(parent_chain_root_hash.as_slice())
                    })
                    .unwrap_or(false);
                if !is_extension_beginning_with_chain_root_hash {
                    return false;
                }
            }
        }

        let expected_extension_hash = self
            .extension()
            .map(|extension| extension.calc_raw_data_hash());
        let extra_hash_view = ExtraHashView::new(self.uncles_hash(), expected_extension_hash);
        let expected_extra_hash = extra_hash_view.extra_hash();
        let actual_extra_hash = self.header().extra_hash();
        expected_extra_hash == actual_extra_hash
    }

    /// Returns the header.
    pub fn header(&self) -> &HeaderView {
        &self.header
    }

    /// Returns the uncles hash.
    pub fn uncles_hash(&self) -> packed::Byte32 {
        self.uncles_hash.clone()
    }

    /// Returns the extension.
    pub fn extension(&self) -> Option<packed::Bytes> {
        self.extension.clone()
    }

    /// Returns the chain root for its parent block.
    pub fn parent_chain_root(&self) -> packed::HeaderDigest {
        self.parent_chain_root.clone()
    }

    /// Returns the total difficulty.
    pub fn total_difficulty(&self) -> U256 {
        let parent_total_difficulty: U256 = self.parent_chain_root.total_difficulty().unpack();
        let block_difficulty = compact_to_difficulty(self.header.compact_target());
        parent_total_difficulty + block_difficulty
    }
}

/// A builder which builds the content of a message that used for proving.
pub trait ProverMessageBuilder: Builder
where
    Self::Entity: Into<packed::LightClientMessageUnion>,
{
    /// The type of the proved items.
    type ProvedItems;
    /// The type of the missing items.
    type MissingItems;
    /// Set the verifiable header which includes the chain root.
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self;
    /// Set the proof for all items which require verifying.
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self;
    /// Set the proved items.
    fn set_proved_items(self, items: Self::ProvedItems) -> Self;
    /// Set the missing items.
    fn set_missing_items(self, items: Self::MissingItems) -> Self;
}

impl ProverMessageBuilder for packed::SendLastStateProofBuilder {
    type ProvedItems = packed::VerifiableHeaderVec;
    type MissingItems = ();
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self {
        self.last_header(last_header)
    }
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self {
        self.proof(proof)
    }
    fn set_proved_items(self, items: Self::ProvedItems) -> Self {
        self.headers(items)
    }
    fn set_missing_items(self, _: Self::MissingItems) -> Self {
        self
    }
}

impl ProverMessageBuilder for packed::SendBlocksProofBuilder {
    type ProvedItems = packed::HeaderVec;
    type MissingItems = packed::Byte32Vec;
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self {
        self.last_header(last_header)
    }
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self {
        self.proof(proof)
    }
    fn set_proved_items(self, items: Self::ProvedItems) -> Self {
        self.headers(items)
    }
    fn set_missing_items(self, items: Self::MissingItems) -> Self {
        self.missing_block_hashes(items)
    }
}

impl ProverMessageBuilder for packed::SendBlocksProofV1Builder {
    type ProvedItems = (packed::HeaderVec, packed::Byte32Vec, packed::BytesOptVec);
    type MissingItems = packed::Byte32Vec;
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self {
        self.last_header(last_header)
    }
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self {
        self.proof(proof)
    }
    fn set_proved_items(self, items: Self::ProvedItems) -> Self {
        self.headers(items.0)
            .blocks_uncles_hash(items.1)
            .blocks_extension(items.2)
    }
    fn set_missing_items(self, items: Self::MissingItems) -> Self {
        self.missing_block_hashes(items)
    }
}

impl ProverMessageBuilder for packed::SendTransactionsProofBuilder {
    type ProvedItems = packed::FilteredBlockVec;
    type MissingItems = packed::Byte32Vec;
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self {
        self.last_header(last_header)
    }
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self {
        self.proof(proof)
    }
    fn set_proved_items(self, items: Self::ProvedItems) -> Self {
        self.filtered_blocks(items)
    }
    fn set_missing_items(self, items: Self::MissingItems) -> Self {
        self.missing_tx_hashes(items)
    }
}

impl ProverMessageBuilder for packed::SendTransactionsProofV1Builder {
    type ProvedItems = (
        packed::FilteredBlockVec,
        packed::Byte32Vec,
        packed::BytesOptVec,
    );
    type MissingItems = packed::Byte32Vec;
    fn set_last_header(self, last_header: packed::VerifiableHeader) -> Self {
        self.last_header(last_header)
    }
    fn set_proof(self, proof: packed::HeaderDigestVec) -> Self {
        self.proof(proof)
    }
    fn set_proved_items(self, items: Self::ProvedItems) -> Self {
        self.filtered_blocks(items.0)
            .blocks_uncles_hash(items.1)
            .blocks_extension(items.2)
    }
    fn set_missing_items(self, items: Self::MissingItems) -> Self {
        self.missing_tx_hashes(items)
    }
}