use serde::{Deserialize, Serialize};
use serde_big_array::BigArray;
use crate::address::Address;
pub const DOMAIN_TAG: &str = "omninode.inference_attestation.v1";
pub const MAX_SESSION_ID_BYTES: usize = 256;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InferenceAttestationDigest {
pub session_id: String,
pub model_hash: [u8; 32],
pub manifest_root: [u8; 32],
pub response_hash: [u8; 32],
pub proof_root: [u8; 32],
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InferenceAttestationTxData {
pub digest: InferenceAttestationDigest,
#[serde(with = "BigArray")]
pub verifier_signature: [u8; 64],
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum AttestationError {
#[error("session_id is {0} bytes; max is {MAX_SESSION_ID_BYTES}")]
SessionIdTooLong(usize),
#[error("inner verifier signature is invalid for the supplied public key")]
InvalidSignature,
#[error("bincode-serializing digest failed: {0}")]
Serialization(String),
}
pub fn canonical_digest_bytes(
digest: &InferenceAttestationDigest,
) -> Result<Vec<u8>, AttestationError> {
bincode::serialize(digest).map_err(|e| AttestationError::Serialization(e.to_string()))
}
pub fn signing_input_bytes(
digest: &InferenceAttestationDigest,
) -> Result<Vec<u8>, AttestationError> {
let canonical = canonical_digest_bytes(digest)?;
let domain = DOMAIN_TAG.as_bytes();
let mut out = Vec::with_capacity(domain.len() + canonical.len());
out.extend_from_slice(domain);
out.extend_from_slice(&canonical);
Ok(out)
}
pub fn verify_attestation_signature(
tx_data: &InferenceAttestationTxData,
verifier_public_key: &[u8; 32],
) -> Result<(), AttestationError> {
if tx_data.digest.session_id.len() > MAX_SESSION_ID_BYTES {
return Err(AttestationError::SessionIdTooLong(
tx_data.digest.session_id.len(),
));
}
let signing_input = signing_input_bytes(&tx_data.digest)?;
let verifying_key = ed25519_dalek::VerifyingKey::from_bytes(verifier_public_key)
.map_err(|_| AttestationError::InvalidSignature)?;
let signature = ed25519_dalek::Signature::from_bytes(&tx_data.verifier_signature);
verifying_key
.verify_strict(&signing_input, &signature)
.map_err(|_| AttestationError::InvalidSignature)
}
pub fn verifier_address(public_key: &[u8; 32]) -> Address {
Address::from_public_key(public_key)
}
pub const INFERENCE_ATTESTATION_KEY_DOMAIN: &[u8] = b"InferenceAttestationKeyV1";
pub fn inference_attestation_key(
session_id: &str,
verifier_address: &Address,
) -> [u8; 32] {
let inner = bincode::serialize(&(session_id, verifier_address))
.expect("bincode of (String, Address) cannot fail");
let mut hasher = blake3::Hasher::new();
hasher.update(INFERENCE_ATTESTATION_KEY_DOMAIN);
hasher.update(&inner);
*hasher.finalize().as_bytes()
}
pub const INFERENCE_ATTESTATION_SESSION_INDEX_DOMAIN: &[u8] =
b"InferenceAttestationSessionIndexV1";
pub const SESSION_ID_HASH_BYTES: usize = 16;
pub fn session_index_key(
session_id: &str,
verifier_address: &Address,
) -> [u8; 36] {
let mut out = [0u8; 36];
out[..SESSION_ID_HASH_BYTES].copy_from_slice(&session_index_prefix(session_id));
out[SESSION_ID_HASH_BYTES..].copy_from_slice(verifier_address.as_bytes());
out
}
pub fn session_index_prefix(session_id: &str) -> [u8; SESSION_ID_HASH_BYTES] {
let mut hasher = blake3::Hasher::new();
hasher.update(INFERENCE_ATTESTATION_SESSION_INDEX_DOMAIN);
hasher.update(session_id.as_bytes());
let hash = hasher.finalize();
let mut prefix = [0u8; SESSION_ID_HASH_BYTES];
prefix.copy_from_slice(&hash.as_bytes()[..SESSION_ID_HASH_BYTES]);
prefix
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InferenceAttestationStatusInfo {
pub status: String,
pub included_at_height: Option<u64>,
pub reason: Option<String>,
}
pub fn classify_inference_attestation_status(
stored_tx: Option<&crate::SignedTransaction>,
mempool_tx: Option<&crate::SignedTransaction>,
receipt: Option<&crate::Receipt>,
current_height: u64,
finality_depth: u64,
) -> InferenceAttestationStatusInfo {
use crate::{TxInner, TxPayload, TxStatus};
let is_attestation = |signed: &crate::SignedTransaction| -> bool {
matches!(&signed.inner, TxInner::V2(v2)
if matches!(&v2.payload, TxPayload::InferenceAttestation(_)))
};
let confirmed = stored_tx.map(is_attestation).unwrap_or(false)
|| mempool_tx.map(is_attestation).unwrap_or(false);
if !confirmed {
return InferenceAttestationStatusInfo {
status: "unknown".to_string(),
included_at_height: None,
reason: None,
};
}
if let Some(r) = receipt {
let is_success = matches!(r.status, TxStatus::Success);
let status_str = if !is_success {
"failed"
} else if current_height >= r.block_height.saturating_add(finality_depth) {
"finalized"
} else {
"included"
};
let reason = if !is_success {
Some(r.status.description().to_string())
} else {
None
};
return InferenceAttestationStatusInfo {
status: status_str.to_string(),
included_at_height: Some(r.block_height),
reason,
};
}
if mempool_tx.is_some() {
InferenceAttestationStatusInfo {
status: "submitted".to_string(),
included_at_height: None,
reason: None,
}
} else {
InferenceAttestationStatusInfo {
status: "unknown".to_string(),
included_at_height: None,
reason: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InferenceAttestationRecord {
pub digest: InferenceAttestationDigest,
#[serde(with = "BigArray")]
pub verifier_signature: [u8; 64],
pub included_at_height: u64,
pub tx_hash: crate::Hash,
}