provable-sdk 0.1.2

Provable SDK for Rust
Documentation
use crate::error::{ProvableError, Result};
use crate::types::{
    MerkleProofCompatibilityMismatch, MerkleProofCompatibilityResult, MerkleProofInput,
    MerkleProofLevel, MerkleProofResponse, NormalizedMerkleProof,
};
use crate::util::normalize_hex_string;
use crate::util::normalize_level_counts;
use std::collections::HashMap;

pub fn normalize_merkle_proof(input: impl Into<MerkleProofInput>) -> Result<NormalizedMerkleProof> {
    match input.into() {
        MerkleProofInput::Normalized(value) => Ok(value),
        MerkleProofInput::Response(value) => normalize_merkle_proof_response(value),
    }
}

fn normalize_merkle_proof_response(raw: MerkleProofResponse) -> Result<NormalizedMerkleProof> {
    if !raw.success {
        return Err(ProvableError::new(
            raw.error
                .clone()
                .or(raw.message.clone())
                .unwrap_or_else(|| "Missing merkle proof".to_string()),
        ));
    }
    let hash_item = normalize_hex_string(&raw.hash_item)
        .ok_or_else(|| ProvableError::new("Invalid merkle proof structure"))?;
    let root = normalize_hex_string(&raw.root).unwrap_or_default();
    if raw.data_type.is_empty() {
        return Err(ProvableError::new("Invalid merkle proof structure"));
    }
    let proof = raw
        .proof
        .iter()
        .map(|hash| {
            normalize_hex_string(hash).ok_or_else(|| ProvableError::new("Invalid proof hash"))
        })
        .collect::<Result<Vec<_>>>()?;

    Ok(NormalizedMerkleProof {
        data_type: raw.data_type.clone(),
        hash_item,
        proof,
        root,
        position: raw.position,
        levels: raw.levels,
        level_counts: raw.level_counts.clone(),
        level_starts: raw.level_starts.clone(),
        raw,
    })
}

pub fn get_merkle_proof_levels(
    input: impl Into<MerkleProofInput>,
) -> Result<Vec<MerkleProofLevel>> {
    let proof = normalize_merkle_proof(input)?;
    let level_counts =
        normalize_level_counts(&proof.level_counts, proof.levels, proof.proof.len())?;
    let mut levels = Vec::with_capacity(level_counts.len());
    let mut offset = 0usize;
    for (level, count) in level_counts.iter().enumerate() {
        let start = proof
            .level_starts
            .get(level)
            .copied()
            .unwrap_or_else(|| default_level_start(proof.position, level, *count));
        let hashes = proof.proof[offset..offset + count].to_vec();
        levels.push(MerkleProofLevel {
            level,
            start,
            count: *count,
            hashes,
        });
        offset += count;
    }
    Ok(levels)
}

pub fn check_merkle_proof_compatibility(
    previous_input: impl Into<MerkleProofInput>,
    next_input: impl Into<MerkleProofInput>,
) -> Result<MerkleProofCompatibilityResult> {
    let previous = normalize_merkle_proof(previous_input)?;
    let next = normalize_merkle_proof(next_input)?;
    let previous_levels = get_merkle_proof_levels(previous.clone())?;
    let next_levels = get_merkle_proof_levels(next.clone())?;
    let mut mismatches = Vec::new();

    if previous.data_type != next.data_type {
        mismatches.push(MerkleProofCompatibilityMismatch {
            kind: "data_type".to_string(),
            message: format!(
                "data_type mismatch previous={} next={}",
                previous.data_type, next.data_type
            ),
            ..Default::default()
        });
    }
    if previous.hash_item != next.hash_item {
        mismatches.push(MerkleProofCompatibilityMismatch {
            kind: "hash_item".to_string(),
            previous_hash: Some(previous.hash_item.clone()),
            next_hash: Some(next.hash_item.clone()),
            message: format!(
                "hash_item mismatch previous={} next={}",
                previous.hash_item, next.hash_item
            ),
            ..Default::default()
        });
    }
    if previous.position != next.position {
        mismatches.push(MerkleProofCompatibilityMismatch {
            kind: "position".to_string(),
            previous_position: Some(previous.position),
            next_position: Some(next.position),
            message: format!(
                "position mismatch previous={} next={}",
                previous.position, next.position
            ),
            ..Default::default()
        });
    }

    let next_maps = next_levels
        .iter()
        .map(|level| {
            let map = level
                .hashes
                .iter()
                .enumerate()
                .map(|(index, hash)| (level.start + index as i64, (hash.clone(), index)))
                .collect::<HashMap<_, _>>();
            (level.level, map)
        })
        .collect::<HashMap<_, _>>();

    let mut checked_entries = 0usize;
    for level in &previous_levels {
        let Some(next_level) = next_maps.get(&level.level) else {
            mismatches.push(MerkleProofCompatibilityMismatch {
                kind: "missing_level".to_string(),
                level: Some(level.level),
                message: format!("missing level={} in new proof", level.level),
                ..Default::default()
            });
            continue;
        };

        for (previous_index, previous_hash) in level.hashes.iter().enumerate() {
            let position = level.start + previous_index as i64;
            let Some((next_hash, next_index)) = next_level.get(&position) else {
                mismatches.push(MerkleProofCompatibilityMismatch {
                    kind: "missing_position".to_string(),
                    level: Some(level.level),
                    position: Some(position),
                    previous_index: Some(previous_index),
                    previous_hash: Some(previous_hash.clone()),
                    message: format!(
                        "missing level={} position={} in new proof",
                        level.level, position
                    ),
                    ..Default::default()
                });
                continue;
            };

            checked_entries += 1;
            if next_hash != previous_hash {
                mismatches.push(MerkleProofCompatibilityMismatch {
                    kind: "hash_mismatch".to_string(),
                    level: Some(level.level),
                    position: Some(position),
                    previous_index: Some(previous_index),
                    next_index: Some(*next_index),
                    previous_hash: Some(previous_hash.clone()),
                    next_hash: Some(next_hash.clone()),
                    message: format!(
                        "hash mismatch level={} position={} previous={} next={}",
                        level.level, position, previous_hash, next_hash
                    ),
                    ..Default::default()
                });
            }
        }
    }

    Ok(MerkleProofCompatibilityResult {
        compatible: mismatches.is_empty(),
        checked_entries,
        previous,
        next,
        previous_levels,
        next_levels,
        mismatches,
    })
}

fn default_level_start(position: i64, level: usize, count: usize) -> i64 {
    let level_position = position_at_level(position, level);
    if count == 0 {
        return level_position;
    }
    (level_position / count as i64) * count as i64
}

fn position_at_level(position: i64, level: usize) -> i64 {
    let mut current = position;
    for _ in 0..level {
        current /= 256;
    }
    current
}