use std::str::FromStr;
use rialo_s_sdk::pubkey::Pubkey;
use thiserror::Error;
use validator::ValidationErrors;
use crate::constants::*;
#[derive(Debug, Error, Clone)]
pub enum ValidationError {
#[error("Invalid format: {0}")]
InvalidFormat(String),
#[error("Value out of range: {0}")]
OutOfRange(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Invalid signature: {0}")]
InvalidSignature(String),
#[error("Invalid encoding: {0}. Supported encodings: base64, base58")]
InvalidEncoding(String),
#[error("Invalid public key: {0}")]
InvalidPublicKey(String),
#[error("Invalid transaction: {0}")]
InvalidTransaction(String),
#[error("Multiple validation errors: {0}")]
Multiple(String),
}
impl From<ValidationErrors> for ValidationError {
fn from(errors: ValidationErrors) -> Self {
let mut error_messages: Vec<String> = Vec::new();
for (field, field_errors) in errors.field_errors() {
for error in field_errors {
let message = format!(
"{}: {}",
field,
error
.message
.as_ref()
.unwrap_or(&"validation failed".into())
);
error_messages.push(message);
}
}
for (field, struct_errors) in errors.errors() {
if let validator::ValidationErrorsKind::Struct(nested_errors) = struct_errors {
for (nested_field, nested_field_errors) in nested_errors.field_errors() {
for error in nested_field_errors {
let message = format!(
"{}.{}: {}",
field,
nested_field,
error
.message
.as_ref()
.unwrap_or(&"validation failed".into())
);
error_messages.push(message);
}
}
}
}
if error_messages.is_empty() {
ValidationError::Multiple("Unknown validation error".to_string())
} else if error_messages.len() == 1 {
ValidationError::InvalidFormat(error_messages[0].clone())
} else {
ValidationError::Multiple(error_messages.join(", "))
}
}
}
pub type ValidationResult<T> = Result<T, ValidationError>;
pub fn validate_pubkey(pubkey: &str) -> Result<(), validator::ValidationError> {
Pubkey::from_str(pubkey).map_err(|_| validator::ValidationError::new("invalid_pubkey"))?;
Ok(())
}
pub fn validate_base64(data: &str) -> Result<(), validator::ValidationError> {
use fastcrypto::encoding::{Base64, Encoding};
Base64::decode(data).map_err(|_| validator::ValidationError::new("invalid_base64"))?;
Ok(())
}
pub fn validate_base58(data: &str) -> Result<(), validator::ValidationError> {
use fastcrypto::encoding::{Base58, Encoding};
Base58::decode(data).map_err(|_| validator::ValidationError::new("invalid_base58"))?;
Ok(())
}
pub fn validate_signature(signature: &str) -> Result<(), validator::ValidationError> {
if signature.len() > MAX_SIGNATURE_LENGTH {
return Err(validator::ValidationError::new("invalid_signature_length"));
}
validate_base58(signature)
}
pub fn validate_nonce(nonce: &str) -> Result<(), validator::ValidationError> {
if nonce.is_empty() {
return Err(validator::ValidationError::new("empty_nonce"));
}
if nonce.len() > MAX_NONCE_LENGTH {
return Err(validator::ValidationError::new("nonce_too_long"));
}
Ok(())
}
pub fn validate_kelvins(kelvins: u64) -> Result<(), validator::ValidationError> {
if kelvins > MAX_KELVINS {
return Err(validator::ValidationError::new("kelvins_too_large"));
}
Ok(())
}
pub fn validate_limit(limit: &u64) -> Result<(), validator::ValidationError> {
if *limit == 0 {
return Err(validator::ValidationError::new("limit_zero"));
}
if *limit > MAX_PAGINATION_LIMIT {
return Err(validator::ValidationError::new("limit_too_large"));
}
Ok(())
}
pub fn validate_pubkey_array(pubkeys: &[String]) -> Result<(), validator::ValidationError> {
for pubkey in pubkeys {
validate_pubkey(pubkey)?;
}
Ok(())
}
pub fn validate_signatures_array(signatures: &[String]) -> Result<(), validator::ValidationError> {
for signature in signatures {
validate_signature(signature)?;
}
Ok(())
}
pub fn validate_airdrop_amount(kelvins: u64) -> Result<(), validator::ValidationError> {
validate_kelvins(kelvins)?;
if kelvins > MAX_AIRDROP_AMOUNT {
return Err(validator::ValidationError::new("airdrop_amount_too_large"));
}
if kelvins == 0 {
return Err(validator::ValidationError::new("airdrop_amount_zero"));
}
Ok(())
}
pub fn validate_airdrop_amount_i64(kelvins: i64) -> Result<(), validator::ValidationError> {
if kelvins < 0 {
return Err(validator::ValidationError::new("airdrop_amount_negative"));
}
if kelvins == 0 {
return Err(validator::ValidationError::new("airdrop_amount_zero"));
}
let kelvins_u64 = kelvins as u64;
validate_kelvins(kelvins_u64)?;
if kelvins_u64 > MAX_AIRDROP_AMOUNT {
return Err(validator::ValidationError::new("airdrop_amount_too_large"));
}
Ok(())
}
pub fn validate_signature_limit(limit: &u16) -> Result<(), validator::ValidationError> {
if *limit == 0 {
return Err(validator::ValidationError::new("limit_must_be_positive"));
}
if *limit > MAX_PAGINATION_LIMIT as u16 {
return Err(validator::ValidationError::new("limit_exceeds_maximum"));
}
Ok(())
}
pub fn validate_transaction_data(transaction: &str) -> Result<(), validator::ValidationError> {
if transaction.is_empty() {
return Ok(());
}
if validate_base64(transaction).is_ok() {
return validate_transaction_structure_base64(transaction);
}
if validate_base58(transaction).is_ok() {
return validate_transaction_structure_base58(transaction);
}
Err(validator::ValidationError::new(
"invalid_transaction_encoding",
))
}
fn validate_transaction_structure_base64(
transaction: &str,
) -> Result<(), validator::ValidationError> {
use fastcrypto::encoding::{Base64, Encoding};
let decoded = Base64::decode(transaction)
.map_err(|_| validator::ValidationError::new("invalid_base64_transaction"))?;
validate_transaction_bytes(&decoded)
}
fn validate_transaction_structure_base58(
transaction: &str,
) -> Result<(), validator::ValidationError> {
use fastcrypto::encoding::{Base58, Encoding};
let decoded = Base58::decode(transaction)
.map_err(|_| validator::ValidationError::new("invalid_base58_transaction"))?;
validate_transaction_bytes(&decoded)
}
fn validate_transaction_bytes(transaction_bytes: &[u8]) -> Result<(), validator::ValidationError> {
if transaction_bytes.len() < MIN_TRANSACTION_SIZE {
return Err(validator::ValidationError::new("transaction_too_small"));
}
if transaction_bytes.len() > MAX_TRANSACTION_SIZE {
return Err(validator::ValidationError::new("transaction_too_large"));
}
match bincode::deserialize::<rialo_s_sdk::transaction::VersionedTransaction>(transaction_bytes)
{
Ok(_) => Ok(()),
Err(_) => {
match bincode::deserialize::<rialo_s_sdk::transaction::Transaction>(transaction_bytes) {
Ok(_) => Ok(()),
Err(_) => Err(validator::ValidationError::new(
"invalid_transaction_structure",
)),
}
}
}
}
pub fn validate_limit_string(limit: &str) -> Result<(), validator::ValidationError> {
let limit_val: u64 = limit
.parse()
.map_err(|_| validator::ValidationError::new("invalid_limit_format"))?;
validate_limit(&limit_val)?;
Ok(())
}
pub fn validate_blockhash(blockhash: &str) -> Result<(), validator::ValidationError> {
if blockhash.len() < MIN_BLOCKHASH_LENGTH || blockhash.len() > MAX_BLOCKHASH_LENGTH {
return Err(validator::ValidationError::new("invalid_blockhash_length"));
}
validate_base58(blockhash)
}
pub fn validate_addresses(addresses: &[String]) -> Result<(), validator::ValidationError> {
for address in addresses {
validate_pubkey(address)?;
}
Ok(())
}
pub fn validate_signatures(signatures: &[String]) -> Result<(), validator::ValidationError> {
for signature in signatures {
validate_signature(signature)?;
}
Ok(())
}
pub fn validate_encoding(encoding: &str) -> Result<(), validator::ValidationError> {
match encoding {
"json" | "jsonParsed" | "base58" | "base64" => Ok(()),
_ => Err(validator::ValidationError::new("invalid_encoding_format")),
}
}
pub fn validate_max_transaction_version(version: &u8) -> Result<(), validator::ValidationError> {
if *version <= 1 {
Ok(())
} else {
Err(validator::ValidationError::new(
"invalid_max_transaction_version",
))
}
}
pub fn validate_request<T>(request: T) -> ValidationResult<T>
where
T: validator::Validate,
{
request.validate().map_err(ValidationError::from)?;
Ok(request)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_limit() {
assert!(validate_limit(&1).is_ok());
assert!(validate_limit(&MAX_PAGINATION_LIMIT).is_ok());
assert!(validate_limit(&0).is_err());
assert!(validate_limit(&(MAX_PAGINATION_LIMIT + 1)).is_err());
}
#[test]
fn test_validate_nonce() {
assert!(validate_nonce("valid_nonce").is_ok());
assert!(validate_nonce("").is_err());
let long_nonce = "x".repeat(65);
assert!(validate_nonce(&long_nonce).is_err());
}
#[test]
fn test_validate_encoding() {
assert!(validate_encoding("json").is_ok());
assert!(validate_encoding("jsonParsed").is_ok());
assert!(validate_encoding("base58").is_ok());
assert!(validate_encoding("base64").is_ok());
assert!(validate_encoding("invalid").is_err());
}
#[test]
fn test_validate_signature() {
use fastcrypto::encoding::{Base58, Encoding};
let valid_sig =
"5VERv8NMvzbJMEkV8xnrLkEaWRtSz9CosKDYjCJjBRnbJLgp8uirBgmQpjKhoR4tjF3ZpRzrFmBV6UjKdiSZkQUW";
assert!(validate_signature(valid_sig).is_ok());
let mut sig_with_zeros = [0u8; 64];
sig_with_zeros[63] = 1; let short_sig = Base58::encode(sig_with_zeros);
assert!(
short_sig.len() < 87,
"Expected short signature due to leading zeros, got len {}",
short_sig.len()
);
assert!(
validate_signature(&short_sig).is_ok(),
"Signature with leading zeros should be valid"
);
let all_zeros = [0u8; 64];
let all_zeros_sig = Base58::encode(all_zeros);
assert_eq!(all_zeros_sig.len(), 64);
assert!(
validate_signature(&all_zeros_sig).is_ok(),
"All-zeros signature should be valid"
);
assert!(validate_signature("invalid!signature").is_err());
let too_long = "1".repeat(100);
assert!(validate_signature(&too_long).is_err());
}
#[test]
fn test_validate_max_transaction_version() {
assert!(validate_max_transaction_version(&0).is_ok());
assert!(validate_max_transaction_version(&1).is_ok());
assert!(validate_max_transaction_version(&2).is_err());
}
#[test]
fn test_validation_error_from_field_errors() {
use validator::Validate;
#[derive(Validate)]
struct TestStruct {
#[validate(length(min = 1, message = "Field cannot be empty"))]
field1: String,
#[validate(range(min = 0, max = 100, message = "Must be between 0 and 100"))]
field2: usize,
}
let test = TestStruct {
field1: "".to_string(), field2: 150, };
let validation_result = test.validate();
assert!(validation_result.is_err());
let validation_errors = validation_result.unwrap_err();
let error: ValidationError = validation_errors.into();
match error {
ValidationError::Multiple(msg) => {
assert!(msg.contains("field1"));
assert!(msg.contains("Field cannot be empty"));
assert!(msg.contains("field2"));
assert!(msg.contains("Must be between 0 and 100"));
}
other => panic!("Expected ValidationError::Multiple, got {:?}", other),
}
}
#[test]
fn test_validation_error_from_nested_struct_errors() {
use validator::Validate;
#[derive(Validate)]
struct NestedConfig {
#[validate(range(
min = 0,
max = 100,
message = "Max retries must be between 0 and 100"
))]
max_retries: usize,
#[validate(range(min = 0, message = "Min slot must be non-negative"))]
min_slot: u64,
}
#[derive(Validate)]
struct ParentStruct {
#[validate(length(min = 1, message = "Name cannot be empty"))]
name: String,
#[validate(nested)]
config: NestedConfig,
}
let test = ParentStruct {
name: "valid".to_string(),
config: NestedConfig {
max_retries: 150, min_slot: 0,
},
};
let validation_result = test.validate();
assert!(validation_result.is_err());
let validation_errors = validation_result.unwrap_err();
let error: ValidationError = validation_errors.into();
let error_msg = error.to_string();
assert!(
error_msg.contains("config.max_retries"),
"Expected 'config.max_retries' in error message, got: {}",
error_msg
);
assert!(
error_msg.contains("Max retries must be between 0 and 100"),
"Expected validation message in error, got: {}",
error_msg
);
}
#[test]
fn test_validation_error_from_mixed_errors() {
use validator::Validate;
#[derive(Validate)]
struct NestedConfig {
#[validate(range(min = 0, max = 100, message = "Nested field must be 0-100"))]
nested_field: usize,
}
#[derive(Validate)]
struct ParentStruct {
#[validate(length(min = 1, message = "Parent field cannot be empty"))]
parent_field: String,
#[validate(nested)]
config: NestedConfig,
}
let test = ParentStruct {
parent_field: "".to_string(), config: NestedConfig {
nested_field: 150, },
};
let validation_result = test.validate();
assert!(validation_result.is_err());
let validation_errors = validation_result.unwrap_err();
let error: ValidationError = validation_errors.into();
let error_msg = error.to_string();
assert!(
error_msg.contains("parent_field")
&& error_msg.contains("Parent field cannot be empty"),
"Expected parent field error in message, got: {}",
error_msg
);
assert!(
error_msg.contains("config.nested_field")
&& error_msg.contains("Nested field must be 0-100"),
"Expected nested field error in message, got: {}",
error_msg
);
}
}