Skip to main content

rialo_api_types/
validation.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Validation for API request types
5
6use std::str::FromStr;
7
8use rialo_s_sdk::pubkey::Pubkey;
9use thiserror::Error;
10use validator::ValidationErrors;
11
12use crate::constants::*;
13
14/// Validation error types for RPC requests
15#[derive(Debug, Error, Clone)]
16pub enum ValidationError {
17    #[error("Invalid format: {0}")]
18    InvalidFormat(String),
19
20    #[error("Value out of range: {0}")]
21    OutOfRange(String),
22
23    #[error("Missing required field: {0}")]
24    MissingField(String),
25
26    #[error("Invalid signature: {0}")]
27    InvalidSignature(String),
28
29    #[error("Invalid encoding: {0}. Supported encodings: base64, base58")]
30    InvalidEncoding(String),
31
32    #[error("Invalid public key: {0}")]
33    InvalidPublicKey(String),
34
35    #[error("Invalid transaction: {0}")]
36    InvalidTransaction(String),
37
38    #[error("Multiple validation errors: {0}")]
39    Multiple(String),
40}
41
42impl From<ValidationErrors> for ValidationError {
43    fn from(errors: ValidationErrors) -> Self {
44        let mut error_messages: Vec<String> = Vec::new();
45
46        // Collect field-level errors
47        for (field, field_errors) in errors.field_errors() {
48            for error in field_errors {
49                let message = format!(
50                    "{}: {}",
51                    field,
52                    error
53                        .message
54                        .as_ref()
55                        .unwrap_or(&"validation failed".into())
56                );
57                error_messages.push(message);
58            }
59        }
60
61        // Collect struct-level errors (from nested validations)
62        for (field, struct_errors) in errors.errors() {
63            if let validator::ValidationErrorsKind::Struct(nested_errors) = struct_errors {
64                for (nested_field, nested_field_errors) in nested_errors.field_errors() {
65                    for error in nested_field_errors {
66                        let message = format!(
67                            "{}.{}: {}",
68                            field,
69                            nested_field,
70                            error
71                                .message
72                                .as_ref()
73                                .unwrap_or(&"validation failed".into())
74                        );
75                        error_messages.push(message);
76                    }
77                }
78            }
79        }
80
81        if error_messages.is_empty() {
82            ValidationError::Multiple("Unknown validation error".to_string())
83        } else if error_messages.len() == 1 {
84            ValidationError::InvalidFormat(error_messages[0].clone())
85        } else {
86            ValidationError::Multiple(error_messages.join(", "))
87        }
88    }
89}
90
91/// Result type for validation operations
92pub type ValidationResult<T> = Result<T, ValidationError>;
93
94/// Validate that the protocol version is 0
95pub fn validate_protocol_version(version: u16) -> Result<(), validator::ValidationError> {
96    if version != 0 {
97        let mut err = validator::ValidationError::new("invalid_protocol_version");
98        err.message = Some(format!("Protocol version must be 0, got {}", version).into());
99        return Err(err);
100    }
101    Ok(())
102}
103
104/// Validate Solana public key format
105pub fn validate_pubkey(pubkey: &str) -> Result<(), validator::ValidationError> {
106    Pubkey::from_str(pubkey).map_err(|_| validator::ValidationError::new("invalid_pubkey"))?;
107    Ok(())
108}
109
110/// Validate base64 encoded data
111pub fn validate_base64(data: &str) -> Result<(), validator::ValidationError> {
112    use fastcrypto::encoding::{Base64, Encoding};
113    Base64::decode(data).map_err(|_| validator::ValidationError::new("invalid_base64"))?;
114    Ok(())
115}
116
117/// Validate base58 encoded data
118pub fn validate_base58(data: &str) -> Result<(), validator::ValidationError> {
119    use fastcrypto::encoding::{Base58, Encoding};
120    Base58::decode(data).map_err(|_| validator::ValidationError::new("invalid_base58"))?;
121    Ok(())
122}
123
124/// Validate transaction signature
125pub fn validate_signature(signature: &str) -> Result<(), validator::ValidationError> {
126    // Maximum string length for a 64-byte signature in base58 is 88 characters
127    // Minimum can be as low as 64 characters if the signature has many leading zeros
128    // (each leading zero byte encodes as a single '1' character in base58)
129    if signature.len() > MAX_SIGNATURE_LENGTH {
130        return Err(validator::ValidationError::new("invalid_signature_length"));
131    }
132    validate_base58(signature)
133}
134
135/// Validate nonce format (should be valid UTF-8 and reasonable length)
136pub fn validate_nonce(nonce: &str) -> Result<(), validator::ValidationError> {
137    if nonce.is_empty() {
138        return Err(validator::ValidationError::new("empty_nonce"));
139    }
140    if nonce.len() > MAX_NONCE_LENGTH {
141        return Err(validator::ValidationError::new("nonce_too_long"));
142    }
143    Ok(())
144}
145
146/// Validate kelvins amount (should be reasonable)
147pub fn validate_kelvins(kelvins: u64) -> Result<(), validator::ValidationError> {
148    // Validate against maximum possible kelvins (500 million RLO * 1e9 kelvins/RLO)
149    if kelvins > MAX_KELVINS {
150        return Err(validator::ValidationError::new("kelvins_too_large"));
151    }
152    Ok(())
153}
154
155/// Validate limit parameter for paginated requests
156pub fn validate_limit(limit: &u64) -> Result<(), validator::ValidationError> {
157    if *limit == 0 {
158        return Err(validator::ValidationError::new("limit_zero"));
159    }
160    if *limit > MAX_PAGINATION_LIMIT {
161        return Err(validator::ValidationError::new("limit_too_large"));
162    }
163    Ok(())
164}
165
166/// Custom validator for array of public keys
167pub fn validate_pubkey_array(pubkeys: &[String]) -> Result<(), validator::ValidationError> {
168    for pubkey in pubkeys {
169        validate_pubkey(pubkey)?;
170    }
171    Ok(())
172}
173
174/// Custom validator for array of signatures
175pub fn validate_signatures_array(signatures: &[String]) -> Result<(), validator::ValidationError> {
176    for signature in signatures {
177        validate_signature(signature)?;
178    }
179    Ok(())
180}
181
182/// Custom validator for airdrop amounts
183pub fn validate_airdrop_amount(kelvins: u64) -> Result<(), validator::ValidationError> {
184    validate_kelvins(kelvins)?;
185
186    // Additional airdrop-specific validation
187    if kelvins > MAX_AIRDROP_AMOUNT {
188        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
189    }
190
191    if kelvins == 0 {
192        return Err(validator::ValidationError::new("airdrop_amount_zero"));
193    }
194
195    Ok(())
196}
197
198/// Custom validator for airdrop amounts (i64 version)
199pub fn validate_airdrop_amount_i64(kelvins: i64) -> Result<(), validator::ValidationError> {
200    // Check for negative values
201    if kelvins < 0 {
202        return Err(validator::ValidationError::new("airdrop_amount_negative"));
203    }
204
205    // Check for zero
206    if kelvins == 0 {
207        return Err(validator::ValidationError::new("airdrop_amount_zero"));
208    }
209
210    // Convert to u64 for other validations
211    let kelvins_u64 = kelvins as u64;
212    validate_kelvins(kelvins_u64)?;
213
214    // Additional airdrop-specific validation
215    if kelvins_u64 > MAX_AIRDROP_AMOUNT {
216        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
217    }
218
219    Ok(())
220}
221
222/// Validate signature limit (1-1000)
223pub fn validate_signature_limit(limit: &u16) -> Result<(), validator::ValidationError> {
224    if *limit == 0 {
225        return Err(validator::ValidationError::new("limit_must_be_positive"));
226    }
227    if *limit > MAX_PAGINATION_LIMIT as u16 {
228        return Err(validator::ValidationError::new("limit_exceeds_maximum"));
229    }
230    Ok(())
231}
232
233/// Custom validator for transaction data based on encoding
234pub fn validate_transaction_data(transaction: &str) -> Result<(), validator::ValidationError> {
235    // Empty strings are handled by the length validator, so skip custom validation for them
236    if transaction.is_empty() {
237        return Ok(());
238    }
239
240    // Validate the encoding format - try base64 first (most common)
241    if validate_base64(transaction).is_ok() {
242        // Now validate that the decoded data can be parsed as a transaction
243        return validate_transaction_structure_base64(transaction);
244    }
245
246    // If base64 fails, try base58
247    if validate_base58(transaction).is_ok() {
248        return validate_transaction_structure_base58(transaction);
249    }
250
251    // If neither encoding works, return error
252    Err(validator::ValidationError::new(
253        "invalid_transaction_encoding",
254    ))
255}
256
257/// Validate that base64 encoded data represents a valid transaction structure
258fn validate_transaction_structure_base64(
259    transaction: &str,
260) -> Result<(), validator::ValidationError> {
261    use fastcrypto::encoding::{Base64, Encoding};
262
263    // Decode the base64 data
264    let decoded = Base64::decode(transaction)
265        .map_err(|_| validator::ValidationError::new("invalid_base64_transaction"))?;
266
267    validate_transaction_bytes(&decoded)
268}
269
270/// Validate that base58 encoded data represents a valid transaction structure  
271fn validate_transaction_structure_base58(
272    transaction: &str,
273) -> Result<(), validator::ValidationError> {
274    use fastcrypto::encoding::{Base58, Encoding};
275
276    // Decode the base58 data
277    let decoded = Base58::decode(transaction)
278        .map_err(|_| validator::ValidationError::new("invalid_base58_transaction"))?;
279
280    validate_transaction_bytes(&decoded)
281}
282
283/// Validate the raw transaction bytes represent a valid transaction structure
284fn validate_transaction_bytes(transaction_bytes: &[u8]) -> Result<(), validator::ValidationError> {
285    // Basic size validation - transactions should be at least some minimum size
286    if transaction_bytes.len() < MIN_TRANSACTION_SIZE {
287        return Err(validator::ValidationError::new("transaction_too_small"));
288    }
289
290    // Maximum transaction size check (Solana has a 1232 byte limit)
291    if transaction_bytes.len() > MAX_TRANSACTION_SIZE {
292        return Err(validator::ValidationError::new("transaction_too_large"));
293    }
294
295    // Try to deserialize as a VersionedTransaction to validate structure
296    match bincode::deserialize::<rialo_s_sdk::transaction::VersionedTransaction>(transaction_bytes)
297    {
298        Ok(_) => Ok(()),
299        Err(_) => {
300            // If VersionedTransaction fails, try legacy Transaction format
301            match bincode::deserialize::<rialo_s_sdk::transaction::Transaction>(transaction_bytes) {
302                Ok(_) => Ok(()),
303                Err(_) => Err(validator::ValidationError::new(
304                    "invalid_transaction_structure",
305                )),
306            }
307        }
308    }
309}
310
311/// Custom validator for limit as string (some endpoints use string format)
312pub fn validate_limit_string(limit: &str) -> Result<(), validator::ValidationError> {
313    let limit_val: u64 = limit
314        .parse()
315        .map_err(|_| validator::ValidationError::new("invalid_limit_format"))?;
316    validate_limit(&limit_val)?;
317    Ok(())
318}
319
320/// Validate blockhash format (should be valid base58)
321pub fn validate_blockhash(blockhash: &str) -> Result<(), validator::ValidationError> {
322    // Solana blockhashes are base58 encoded and should be 32 bytes (44 characters in base58)
323    if blockhash.len() < MIN_BLOCKHASH_LENGTH || blockhash.len() > MAX_BLOCKHASH_LENGTH {
324        return Err(validator::ValidationError::new("invalid_blockhash_length"));
325    }
326    validate_base58(blockhash)
327}
328
329/// Validate array of addresses (public keys)
330pub fn validate_addresses(addresses: &[String]) -> Result<(), validator::ValidationError> {
331    for address in addresses {
332        validate_pubkey(address)?;
333    }
334    Ok(())
335}
336
337/// Validate array of signatures
338pub fn validate_signatures(signatures: &[String]) -> Result<(), validator::ValidationError> {
339    for signature in signatures {
340        validate_signature(signature)?;
341    }
342    Ok(())
343}
344
345/// Validate encoding format
346pub fn validate_encoding(encoding: &str) -> Result<(), validator::ValidationError> {
347    match encoding {
348        "json" | "jsonParsed" | "base58" | "base64" => Ok(()),
349        _ => Err(validator::ValidationError::new("invalid_encoding_format")),
350    }
351}
352
353/// Validate max transaction version
354pub fn validate_max_transaction_version(version: &u8) -> Result<(), validator::ValidationError> {
355    if *version <= 1 {
356        Ok(())
357    } else {
358        Err(validator::ValidationError::new(
359            "invalid_max_transaction_version",
360        ))
361    }
362}
363
364/// Validation middleware that validates a request
365pub fn validate_request<T>(request: T) -> ValidationResult<T>
366where
367    T: validator::Validate,
368{
369    request.validate().map_err(ValidationError::from)?;
370    Ok(request)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn test_validate_limit() {
379        assert!(validate_limit(&1).is_ok());
380        assert!(validate_limit(&MAX_PAGINATION_LIMIT).is_ok());
381        assert!(validate_limit(&0).is_err());
382        assert!(validate_limit(&(MAX_PAGINATION_LIMIT + 1)).is_err());
383    }
384
385    #[test]
386    fn test_validate_nonce() {
387        assert!(validate_nonce("valid_nonce").is_ok());
388        assert!(validate_nonce("").is_err());
389        let long_nonce = "x".repeat(65);
390        assert!(validate_nonce(&long_nonce).is_err());
391    }
392
393    #[test]
394    fn test_validate_encoding() {
395        assert!(validate_encoding("json").is_ok());
396        assert!(validate_encoding("jsonParsed").is_ok());
397        assert!(validate_encoding("base58").is_ok());
398        assert!(validate_encoding("base64").is_ok());
399        assert!(validate_encoding("invalid").is_err());
400    }
401
402    #[test]
403    fn test_validate_signature() {
404        use fastcrypto::encoding::{Base58, Encoding};
405
406        // Valid signature (87-88 chars, decodes to 64 bytes)
407        let valid_sig =
408            "5VERv8NMvzbJMEkV8xnrLkEaWRtSz9CosKDYjCJjBRnbJLgp8uirBgmQpjKhoR4tjF3ZpRzrFmBV6UjKdiSZkQUW";
409        assert!(validate_signature(valid_sig).is_ok());
410
411        // Signature with leading zeros (shorter string, still 64 bytes when decoded)
412        // Create a signature with leading zero bytes
413        let mut sig_with_zeros = [0u8; 64];
414        sig_with_zeros[63] = 1; // Only last byte is non-zero
415        let short_sig = Base58::encode(sig_with_zeros);
416        // This will be much shorter than 87 chars due to leading zeros
417        assert!(
418            short_sig.len() < 87,
419            "Expected short signature due to leading zeros, got len {}",
420            short_sig.len()
421        );
422        assert!(
423            validate_signature(&short_sig).is_ok(),
424            "Signature with leading zeros should be valid"
425        );
426
427        // All zeros signature (64 '1' characters in base58)
428        let all_zeros = [0u8; 64];
429        let all_zeros_sig = Base58::encode(all_zeros);
430        assert_eq!(all_zeros_sig.len(), 64);
431        assert!(
432            validate_signature(&all_zeros_sig).is_ok(),
433            "All-zeros signature should be valid"
434        );
435
436        // Invalid: not base58
437        assert!(validate_signature("invalid!signature").is_err());
438
439        // Invalid: too long string
440        let too_long = "1".repeat(100);
441        assert!(validate_signature(&too_long).is_err());
442    }
443
444    #[test]
445    fn test_validate_max_transaction_version() {
446        assert!(validate_max_transaction_version(&0).is_ok());
447        assert!(validate_max_transaction_version(&1).is_ok());
448        assert!(validate_max_transaction_version(&2).is_err());
449    }
450
451    #[test]
452    fn test_validation_error_from_field_errors() {
453        use validator::Validate;
454
455        // Create a struct with field-level validation errors
456        #[derive(Validate)]
457        struct TestStruct {
458            #[validate(length(min = 1, message = "Field cannot be empty"))]
459            field1: String,
460            #[validate(range(min = 0, max = 100, message = "Must be between 0 and 100"))]
461            field2: usize,
462        }
463
464        let test = TestStruct {
465            field1: "".to_string(), // Invalid: empty
466            field2: 150,            // Invalid: out of range
467        };
468
469        let validation_result = test.validate();
470        assert!(validation_result.is_err());
471
472        let validation_errors = validation_result.unwrap_err();
473        let error: ValidationError = validation_errors.into();
474
475        // Should be Multiple since there are 2 errors
476        match error {
477            ValidationError::Multiple(msg) => {
478                assert!(msg.contains("field1"));
479                assert!(msg.contains("Field cannot be empty"));
480                assert!(msg.contains("field2"));
481                assert!(msg.contains("Must be between 0 and 100"));
482            }
483            other => panic!("Expected ValidationError::Multiple, got {:?}", other),
484        }
485    }
486
487    #[test]
488    fn test_validation_error_from_nested_struct_errors() {
489        use validator::Validate;
490
491        #[derive(Validate)]
492        struct NestedConfig {
493            #[validate(range(
494                min = 0,
495                max = 100,
496                message = "Max retries must be between 0 and 100"
497            ))]
498            max_retries: usize,
499            #[validate(range(min = 0, message = "Min slot must be non-negative"))]
500            min_slot: u64,
501        }
502
503        #[derive(Validate)]
504        struct ParentStruct {
505            #[validate(length(min = 1, message = "Name cannot be empty"))]
506            name: String,
507            #[validate(nested)]
508            config: NestedConfig,
509        }
510
511        let test = ParentStruct {
512            name: "valid".to_string(),
513            config: NestedConfig {
514                max_retries: 150, // Invalid: exceeds 100
515                min_slot: 0,
516            },
517        };
518
519        let validation_result = test.validate();
520        assert!(validation_result.is_err());
521
522        let validation_errors = validation_result.unwrap_err();
523        let error: ValidationError = validation_errors.into();
524
525        // Should contain the nested field path
526        let error_msg = error.to_string();
527        assert!(
528            error_msg.contains("config.max_retries"),
529            "Expected 'config.max_retries' in error message, got: {}",
530            error_msg
531        );
532        assert!(
533            error_msg.contains("Max retries must be between 0 and 100"),
534            "Expected validation message in error, got: {}",
535            error_msg
536        );
537    }
538
539    #[test]
540    fn test_validation_error_from_mixed_errors() {
541        use validator::Validate;
542
543        #[derive(Validate)]
544        struct NestedConfig {
545            #[validate(range(min = 0, max = 100, message = "Nested field must be 0-100"))]
546            nested_field: usize,
547        }
548
549        #[derive(Validate)]
550        struct ParentStruct {
551            #[validate(length(min = 1, message = "Parent field cannot be empty"))]
552            parent_field: String,
553            #[validate(nested)]
554            config: NestedConfig,
555        }
556
557        let test = ParentStruct {
558            parent_field: "".to_string(), // Invalid: empty
559            config: NestedConfig {
560                nested_field: 150, // Invalid: out of range
561            },
562        };
563
564        let validation_result = test.validate();
565        assert!(validation_result.is_err());
566
567        let validation_errors = validation_result.unwrap_err();
568        let error: ValidationError = validation_errors.into();
569
570        let error_msg = error.to_string();
571        // Should contain both parent and nested errors
572        assert!(
573            error_msg.contains("parent_field")
574                && error_msg.contains("Parent field cannot be empty"),
575            "Expected parent field error in message, got: {}",
576            error_msg
577        );
578        assert!(
579            error_msg.contains("config.nested_field")
580                && error_msg.contains("Nested field must be 0-100"),
581            "Expected nested field error in message, got: {}",
582            error_msg
583        );
584    }
585}