use crate::hash::CryptoHash;
use crate::merkle::MerklePath;
use crate::sharding::{
ReceiptProof, ShardChunk, ShardChunkHeader, ShardChunkHeaderV1, ShardChunkV1,
};
use crate::state_part::{StatePart, StatePartV0};
use crate::types::{BlockHeight, EpochId, ShardId, StateRoot, StateRootNode};
use borsh::{BorshDeserialize, BorshSerialize};
use near_primitives_core::types::{EpochHeight, ProtocolVersion};
use near_primitives_core::version::ProtocolFeature;
use near_schema_checker_lib::ProtocolSchema;
use std::sync::Arc;
#[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ReceiptProofResponse(pub CryptoHash, pub Arc<Vec<ReceiptProof>>);
#[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct RootProof(pub CryptoHash, pub MerklePath);
#[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct StateHeaderKey(pub ShardId, pub CryptoHash);
#[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct StatePartKey(pub CryptoHash, pub ShardId, pub u64 );
#[derive(
Copy, PartialEq, Eq, Clone, Debug, Hash, BorshSerialize, BorshDeserialize, ProtocolSchema,
)]
pub enum PartIdOrHeader {
Part { part_id: u64 },
Header,
}
impl Into<&'static str> for PartIdOrHeader {
fn into(self) -> &'static str {
match self {
PartIdOrHeader::Part { .. } => "part",
PartIdOrHeader::Header => "header",
}
}
}
#[derive(Copy, PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub enum StateRequestAckBody {
WillRespond,
Busy,
Error,
}
impl Into<&'static str> for StateRequestAckBody {
fn into(self) -> &'static str {
match self {
StateRequestAckBody::WillRespond => "will_respond",
StateRequestAckBody::Busy => "busy",
StateRequestAckBody::Error => "error",
}
}
}
#[derive(PartialEq, Eq, Clone, Debug, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct StateRequestAck {
pub shard_id: ShardId,
pub sync_hash: CryptoHash,
pub part_id_or_header: PartIdOrHeader,
pub body: StateRequestAckBody,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseHeaderV1 {
pub chunk: ShardChunkV1,
pub chunk_proof: MerklePath,
pub prev_chunk_header: Option<ShardChunkHeaderV1>,
pub prev_chunk_proof: Option<MerklePath>,
pub incoming_receipts_proofs: Vec<ReceiptProofResponse>,
pub root_proofs: Vec<Vec<RootProof>>,
pub state_root_node: StateRootNode,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseHeaderV2 {
pub chunk: ShardChunk,
pub chunk_proof: MerklePath,
pub prev_chunk_header: Option<ShardChunkHeader>,
pub prev_chunk_proof: Option<MerklePath>,
pub incoming_receipts_proofs: Vec<ReceiptProofResponse>,
pub root_proofs: Vec<Vec<RootProof>>,
pub state_root_node: StateRootNode,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
pub enum CachedParts {
AllParts = 0,
NoParts = 1,
BitArray(BitArray) = 2,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct BitArray {
data: Vec<u8>,
capacity: u64,
}
impl BitArray {
pub fn new(capacity: u64) -> Self {
let num_bytes = (capacity + 7) / 8;
Self { data: vec![0; num_bytes as usize], capacity }
}
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
#[allow(clippy::large_enum_variant)]
pub enum ShardStateSyncResponseHeader {
V1(ShardStateSyncResponseHeaderV1) = 0,
V2(ShardStateSyncResponseHeaderV2) = 1,
}
impl ShardStateSyncResponseHeader {
#[inline]
pub fn take_chunk(self) -> ShardChunk {
match self {
Self::V1(header) => ShardChunk::V1(header.chunk),
Self::V2(header) => header.chunk,
}
}
#[inline]
pub fn cloned_chunk(&self) -> ShardChunk {
match self {
Self::V1(header) => ShardChunk::V1(header.chunk.clone()),
Self::V2(header) => header.chunk.clone(),
}
}
#[inline]
pub fn cloned_prev_chunk_header(&self) -> Option<ShardChunkHeader> {
match self {
Self::V1(header) => header.prev_chunk_header.clone().map(ShardChunkHeader::V1),
Self::V2(header) => header.prev_chunk_header.clone(),
}
}
#[inline]
pub fn chunk_height_included(&self) -> BlockHeight {
match self {
Self::V1(header) => header.chunk.header.height_included,
Self::V2(header) => header.chunk.height_included(),
}
}
#[inline]
pub fn chunk_prev_state_root(&self) -> StateRoot {
match self {
Self::V1(header) => header.chunk.header.inner.prev_state_root,
Self::V2(header) => header.chunk.prev_state_root(),
}
}
#[inline]
pub fn chunk_proof(&self) -> &MerklePath {
match self {
Self::V1(header) => &header.chunk_proof,
Self::V2(header) => &header.chunk_proof,
}
}
#[inline]
pub fn prev_chunk_proof(&self) -> &Option<MerklePath> {
match self {
Self::V1(header) => &header.prev_chunk_proof,
Self::V2(header) => &header.prev_chunk_proof,
}
}
#[inline]
pub fn incoming_receipts_proofs(&self) -> &[ReceiptProofResponse] {
match self {
Self::V1(header) => &header.incoming_receipts_proofs,
Self::V2(header) => &header.incoming_receipts_proofs,
}
}
#[inline]
pub fn root_proofs(&self) -> &[Vec<RootProof>] {
match self {
Self::V1(header) => &header.root_proofs,
Self::V2(header) => &header.root_proofs,
}
}
#[inline]
pub fn state_root_node(&self) -> &StateRootNode {
match self {
Self::V1(header) => &header.state_root_node,
Self::V2(header) => &header.state_root_node,
}
}
pub fn num_state_parts(&self) -> u64 {
get_num_state_parts(self.state_root_node().memory_usage)
}
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseV1 {
pub header: Option<ShardStateSyncResponseHeaderV1>,
pub part: Option<(u64, Vec<u8>)>,
}
impl ShardStateSyncResponseV1 {
pub fn part_id(&self) -> Option<u64> {
self.part.as_ref().map(|(part_id, _)| *part_id)
}
pub fn payload_length(&self) -> Option<usize> {
self.part.as_ref().map(|(_, part)| part.len())
}
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseV2 {
pub header: Option<ShardStateSyncResponseHeaderV2>,
pub part: Option<(u64, Vec<u8>)>,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseV3 {
pub header: Option<ShardStateSyncResponseHeaderV2>,
pub part: Option<(u64, Vec<u8>)>,
pub cached_parts: Option<CachedParts>,
pub can_generate: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
pub struct ShardStateSyncResponseV4 {
pub header: Option<ShardStateSyncResponseHeaderV2>,
pub part: Option<(u64, StatePart)>,
}
#[derive(Debug, Clone, PartialEq, Eq, BorshSerialize, BorshDeserialize, ProtocolSchema)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
pub enum ShardStateSyncResponse {
V1(ShardStateSyncResponseV1) = 0,
V2(ShardStateSyncResponseV2) = 1,
V3(ShardStateSyncResponseV3) = 2,
V4(ShardStateSyncResponseV4) = 3,
}
impl ShardStateSyncResponse {
pub fn new_from_header(
header: Option<ShardStateSyncResponseHeaderV2>,
protocol_version: ProtocolVersion,
) -> Self {
Self::new_from_header_or_part(header, None, protocol_version)
}
pub fn new_from_part(
part: Option<(u64, StatePart)>,
protocol_version: ProtocolVersion,
) -> Self {
Self::new_from_header_or_part(None, part, protocol_version)
}
fn new_from_header_or_part(
header: Option<ShardStateSyncResponseHeaderV2>,
part: Option<(u64, StatePart)>,
protocol_version: ProtocolVersion,
) -> Self {
if ProtocolFeature::StatePartsCompression.enabled(protocol_version) {
return Self::V4(ShardStateSyncResponseV4 { header, part });
}
let part = match part {
None => None,
Some((part_id, StatePart::V0(part))) => Some((part_id, part.0)),
_ => panic!("StatePartsCompression not supported and part={part:?}"),
};
Self::V3(ShardStateSyncResponseV3 { header, part, cached_parts: None, can_generate: false })
}
pub fn take_header(self) -> Option<ShardStateSyncResponseHeader> {
match self {
Self::V1(response) => response.header.map(ShardStateSyncResponseHeader::V1),
Self::V2(response) => response.header.map(ShardStateSyncResponseHeader::V2),
Self::V3(response) => response.header.map(ShardStateSyncResponseHeader::V2),
Self::V4(response) => response.header.map(ShardStateSyncResponseHeader::V2),
}
}
pub fn part_id(&self) -> Option<u64> {
match self {
Self::V1(response) => response.part.as_ref().map(|(part_id, _)| *part_id),
Self::V2(response) => response.part.as_ref().map(|(part_id, _)| *part_id),
Self::V3(response) => response.part.as_ref().map(|(part_id, _)| *part_id),
Self::V4(response) => response.part.as_ref().map(|(part_id, _)| *part_id),
}
}
pub fn take_part(self) -> Option<(u64, StatePart)> {
match self {
Self::V1(response) => {
response.part.map(|(idx, part)| (idx, StatePart::V0(StatePartV0(part))))
}
Self::V2(response) => {
response.part.map(|(idx, part)| (idx, StatePart::V0(StatePartV0(part))))
}
Self::V3(response) => {
response.part.map(|(idx, part)| (idx, StatePart::V0(StatePartV0(part))))
}
Self::V4(response) => response.part,
}
}
pub fn payload_length(&self) -> Option<usize> {
match self {
Self::V1(response) => response.part.as_ref().map(|(_, part)| part.len()),
Self::V2(response) => response.part.as_ref().map(|(_, part)| part.len()),
Self::V3(response) => response.part.as_ref().map(|(_, part)| part.len()),
Self::V4(response) => response.part.as_ref().map(|(_, part)| part.payload_length()),
}
}
}
pub const STATE_PART_MEMORY_LIMIT: bytesize::ByteSize = bytesize::ByteSize(30 * bytesize::MIB);
pub fn get_num_state_parts(memory_usage: u64) -> u64 {
(memory_usage + STATE_PART_MEMORY_LIMIT.as_u64() - 1) / STATE_PART_MEMORY_LIMIT.as_u64()
}
#[derive(BorshSerialize, BorshDeserialize, Debug, Clone, serde::Serialize, ProtocolSchema)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
pub enum StateSyncDumpProgress {
AllDumped {
epoch_id: EpochId,
epoch_height: EpochHeight,
} = 0,
Skipped { epoch_id: EpochId, epoch_height: EpochHeight } = 1,
InProgress {
epoch_id: EpochId,
epoch_height: EpochHeight,
sync_hash: CryptoHash,
} = 2,
}
#[cfg(test)]
mod tests {
use crate::state_sync::{STATE_PART_MEMORY_LIMIT, get_num_state_parts};
#[test]
fn test_get_num_state_parts() {
assert_eq!(get_num_state_parts(0), 0);
assert_eq!(get_num_state_parts(1), 1);
assert_eq!(get_num_state_parts(STATE_PART_MEMORY_LIMIT.as_u64()), 1);
assert_eq!(get_num_state_parts(STATE_PART_MEMORY_LIMIT.as_u64() + 1), 2);
assert_eq!(get_num_state_parts(STATE_PART_MEMORY_LIMIT.as_u64() * 100), 100);
assert_eq!(get_num_state_parts(STATE_PART_MEMORY_LIMIT.as_u64() * 100 + 1), 101);
}
}