Skip to main content

rialo_api_types/
validation.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Validation for API request types
5
6use std::str::FromStr;
7
8use rialo_s_sdk::pubkey::Pubkey;
9use thiserror::Error;
10use validator::ValidationErrors;
11
12use crate::constants::*;
13
14/// Validation error types for RPC requests
15#[derive(Debug, Error, Clone)]
16pub enum ValidationError {
17    #[error("Invalid format: {0}")]
18    InvalidFormat(String),
19
20    #[error("Value out of range: {0}")]
21    OutOfRange(String),
22
23    #[error("Missing required field: {0}")]
24    MissingField(String),
25
26    #[error("Invalid signature: {0}")]
27    InvalidSignature(String),
28
29    #[error("Invalid encoding: {0}. Supported encodings: base64, base58")]
30    InvalidEncoding(String),
31
32    #[error("Invalid public key: {0}")]
33    InvalidPublicKey(String),
34
35    #[error("Invalid transaction: {0}")]
36    InvalidTransaction(String),
37
38    #[error("Multiple validation errors: {0}")]
39    Multiple(String),
40}
41
42impl From<ValidationErrors> for ValidationError {
43    fn from(errors: ValidationErrors) -> Self {
44        let mut error_messages: Vec<String> = Vec::new();
45
46        // Collect field-level errors
47        for (field, field_errors) in errors.field_errors() {
48            for error in field_errors {
49                let message = format!(
50                    "{}: {}",
51                    field,
52                    error
53                        .message
54                        .as_ref()
55                        .unwrap_or(&"validation failed".into())
56                );
57                error_messages.push(message);
58            }
59        }
60
61        // Collect struct-level errors (from nested validations)
62        for (field, struct_errors) in errors.errors() {
63            if let validator::ValidationErrorsKind::Struct(nested_errors) = struct_errors {
64                for (nested_field, nested_field_errors) in nested_errors.field_errors() {
65                    for error in nested_field_errors {
66                        let message = format!(
67                            "{}.{}: {}",
68                            field,
69                            nested_field,
70                            error
71                                .message
72                                .as_ref()
73                                .unwrap_or(&"validation failed".into())
74                        );
75                        error_messages.push(message);
76                    }
77                }
78            }
79        }
80
81        if error_messages.is_empty() {
82            ValidationError::Multiple("Unknown validation error".to_string())
83        } else if error_messages.len() == 1 {
84            ValidationError::InvalidFormat(error_messages[0].clone())
85        } else {
86            ValidationError::Multiple(error_messages.join(", "))
87        }
88    }
89}
90
91/// Result type for validation operations
92pub type ValidationResult<T> = Result<T, ValidationError>;
93
94/// Validate Solana public key format
95pub fn validate_pubkey(pubkey: &str) -> Result<(), validator::ValidationError> {
96    Pubkey::from_str(pubkey).map_err(|_| validator::ValidationError::new("invalid_pubkey"))?;
97    Ok(())
98}
99
100/// Validate base64 encoded data
101pub fn validate_base64(data: &str) -> Result<(), validator::ValidationError> {
102    use fastcrypto::encoding::{Base64, Encoding};
103    Base64::decode(data).map_err(|_| validator::ValidationError::new("invalid_base64"))?;
104    Ok(())
105}
106
107/// Validate base58 encoded data
108pub fn validate_base58(data: &str) -> Result<(), validator::ValidationError> {
109    use fastcrypto::encoding::{Base58, Encoding};
110    Base58::decode(data).map_err(|_| validator::ValidationError::new("invalid_base58"))?;
111    Ok(())
112}
113
114/// Validate transaction signature
115pub fn validate_signature(signature: &str) -> Result<(), validator::ValidationError> {
116    // Maximum string length for a 64-byte signature in base58 is 88 characters
117    // Minimum can be as low as 64 characters if the signature has many leading zeros
118    // (each leading zero byte encodes as a single '1' character in base58)
119    if signature.len() > MAX_SIGNATURE_LENGTH {
120        return Err(validator::ValidationError::new("invalid_signature_length"));
121    }
122    validate_base58(signature)
123}
124
125/// Validate nonce format (should be valid UTF-8 and reasonable length)
126pub fn validate_nonce(nonce: &str) -> Result<(), validator::ValidationError> {
127    if nonce.is_empty() {
128        return Err(validator::ValidationError::new("empty_nonce"));
129    }
130    if nonce.len() > MAX_NONCE_LENGTH {
131        return Err(validator::ValidationError::new("nonce_too_long"));
132    }
133    Ok(())
134}
135
136/// Validate kelvins amount (should be reasonable)
137pub fn validate_kelvins(kelvins: u64) -> Result<(), validator::ValidationError> {
138    // Validate against maximum possible kelvins (500 million RLO * 1e9 kelvins/RLO)
139    if kelvins > MAX_KELVINS {
140        return Err(validator::ValidationError::new("kelvins_too_large"));
141    }
142    Ok(())
143}
144
145/// Validate limit parameter for paginated requests
146pub fn validate_limit(limit: &u64) -> Result<(), validator::ValidationError> {
147    if *limit == 0 {
148        return Err(validator::ValidationError::new("limit_zero"));
149    }
150    if *limit > MAX_PAGINATION_LIMIT {
151        return Err(validator::ValidationError::new("limit_too_large"));
152    }
153    Ok(())
154}
155
156/// Custom validator for array of public keys
157pub fn validate_pubkey_array(pubkeys: &[String]) -> Result<(), validator::ValidationError> {
158    for pubkey in pubkeys {
159        validate_pubkey(pubkey)?;
160    }
161    Ok(())
162}
163
164/// Custom validator for array of signatures
165pub fn validate_signatures_array(signatures: &[String]) -> Result<(), validator::ValidationError> {
166    for signature in signatures {
167        validate_signature(signature)?;
168    }
169    Ok(())
170}
171
172/// Custom validator for airdrop amounts
173pub fn validate_airdrop_amount(kelvins: u64) -> Result<(), validator::ValidationError> {
174    validate_kelvins(kelvins)?;
175
176    // Additional airdrop-specific validation
177    if kelvins > MAX_AIRDROP_AMOUNT {
178        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
179    }
180
181    if kelvins == 0 {
182        return Err(validator::ValidationError::new("airdrop_amount_zero"));
183    }
184
185    Ok(())
186}
187
188/// Custom validator for airdrop amounts (i64 version)
189pub fn validate_airdrop_amount_i64(kelvins: i64) -> Result<(), validator::ValidationError> {
190    // Check for negative values
191    if kelvins < 0 {
192        return Err(validator::ValidationError::new("airdrop_amount_negative"));
193    }
194
195    // Check for zero
196    if kelvins == 0 {
197        return Err(validator::ValidationError::new("airdrop_amount_zero"));
198    }
199
200    // Convert to u64 for other validations
201    let kelvins_u64 = kelvins as u64;
202    validate_kelvins(kelvins_u64)?;
203
204    // Additional airdrop-specific validation
205    if kelvins_u64 > MAX_AIRDROP_AMOUNT {
206        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
207    }
208
209    Ok(())
210}
211
212/// Validate signature limit (1-1000)
213pub fn validate_signature_limit(limit: &u16) -> Result<(), validator::ValidationError> {
214    if *limit == 0 {
215        return Err(validator::ValidationError::new("limit_must_be_positive"));
216    }
217    if *limit > MAX_PAGINATION_LIMIT as u16 {
218        return Err(validator::ValidationError::new("limit_exceeds_maximum"));
219    }
220    Ok(())
221}
222
223/// Custom validator for transaction data based on encoding
224pub fn validate_transaction_data(transaction: &str) -> Result<(), validator::ValidationError> {
225    // Empty strings are handled by the length validator, so skip custom validation for them
226    if transaction.is_empty() {
227        return Ok(());
228    }
229
230    // Validate the encoding format - try base64 first (most common)
231    if validate_base64(transaction).is_ok() {
232        // Now validate that the decoded data can be parsed as a transaction
233        return validate_transaction_structure_base64(transaction);
234    }
235
236    // If base64 fails, try base58
237    if validate_base58(transaction).is_ok() {
238        return validate_transaction_structure_base58(transaction);
239    }
240
241    // If neither encoding works, return error
242    Err(validator::ValidationError::new(
243        "invalid_transaction_encoding",
244    ))
245}
246
247/// Validate that base64 encoded data represents a valid transaction structure
248fn validate_transaction_structure_base64(
249    transaction: &str,
250) -> Result<(), validator::ValidationError> {
251    use fastcrypto::encoding::{Base64, Encoding};
252
253    // Decode the base64 data
254    let decoded = Base64::decode(transaction)
255        .map_err(|_| validator::ValidationError::new("invalid_base64_transaction"))?;
256
257    validate_transaction_bytes(&decoded)
258}
259
260/// Validate that base58 encoded data represents a valid transaction structure  
261fn validate_transaction_structure_base58(
262    transaction: &str,
263) -> Result<(), validator::ValidationError> {
264    use fastcrypto::encoding::{Base58, Encoding};
265
266    // Decode the base58 data
267    let decoded = Base58::decode(transaction)
268        .map_err(|_| validator::ValidationError::new("invalid_base58_transaction"))?;
269
270    validate_transaction_bytes(&decoded)
271}
272
273/// Validate the raw transaction bytes represent a valid transaction structure
274fn validate_transaction_bytes(transaction_bytes: &[u8]) -> Result<(), validator::ValidationError> {
275    // Basic size validation - transactions should be at least some minimum size
276    if transaction_bytes.len() < MIN_TRANSACTION_SIZE {
277        return Err(validator::ValidationError::new("transaction_too_small"));
278    }
279
280    // Maximum transaction size check (Solana has a 1232 byte limit)
281    if transaction_bytes.len() > MAX_TRANSACTION_SIZE {
282        return Err(validator::ValidationError::new("transaction_too_large"));
283    }
284
285    // Try to deserialize as a VersionedTransaction to validate structure
286    match bincode::deserialize::<rialo_s_sdk::transaction::VersionedTransaction>(transaction_bytes)
287    {
288        Ok(_) => Ok(()),
289        Err(_) => {
290            // If VersionedTransaction fails, try legacy Transaction format
291            match bincode::deserialize::<rialo_s_sdk::transaction::Transaction>(transaction_bytes) {
292                Ok(_) => Ok(()),
293                Err(_) => Err(validator::ValidationError::new(
294                    "invalid_transaction_structure",
295                )),
296            }
297        }
298    }
299}
300
301/// Custom validator for limit as string (some endpoints use string format)
302pub fn validate_limit_string(limit: &str) -> Result<(), validator::ValidationError> {
303    let limit_val: u64 = limit
304        .parse()
305        .map_err(|_| validator::ValidationError::new("invalid_limit_format"))?;
306    validate_limit(&limit_val)?;
307    Ok(())
308}
309
310/// Validate blockhash format (should be valid base58)
311pub fn validate_blockhash(blockhash: &str) -> Result<(), validator::ValidationError> {
312    // Solana blockhashes are base58 encoded and should be 32 bytes (44 characters in base58)
313    if blockhash.len() < MIN_BLOCKHASH_LENGTH || blockhash.len() > MAX_BLOCKHASH_LENGTH {
314        return Err(validator::ValidationError::new("invalid_blockhash_length"));
315    }
316    validate_base58(blockhash)
317}
318
319/// Validate array of addresses (public keys)
320pub fn validate_addresses(addresses: &[String]) -> Result<(), validator::ValidationError> {
321    for address in addresses {
322        validate_pubkey(address)?;
323    }
324    Ok(())
325}
326
327/// Validate array of signatures
328pub fn validate_signatures(signatures: &[String]) -> Result<(), validator::ValidationError> {
329    for signature in signatures {
330        validate_signature(signature)?;
331    }
332    Ok(())
333}
334
335/// Validate encoding format
336pub fn validate_encoding(encoding: &str) -> Result<(), validator::ValidationError> {
337    match encoding {
338        "json" | "jsonParsed" | "base58" | "base64" => Ok(()),
339        _ => Err(validator::ValidationError::new("invalid_encoding_format")),
340    }
341}
342
343/// Validate max transaction version
344pub fn validate_max_transaction_version(version: &u8) -> Result<(), validator::ValidationError> {
345    if *version <= 1 {
346        Ok(())
347    } else {
348        Err(validator::ValidationError::new(
349            "invalid_max_transaction_version",
350        ))
351    }
352}
353
354/// Validation middleware that validates a request
355pub fn validate_request<T>(request: T) -> ValidationResult<T>
356where
357    T: validator::Validate,
358{
359    request.validate().map_err(ValidationError::from)?;
360    Ok(request)
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_validate_limit() {
369        assert!(validate_limit(&1).is_ok());
370        assert!(validate_limit(&MAX_PAGINATION_LIMIT).is_ok());
371        assert!(validate_limit(&0).is_err());
372        assert!(validate_limit(&(MAX_PAGINATION_LIMIT + 1)).is_err());
373    }
374
375    #[test]
376    fn test_validate_nonce() {
377        assert!(validate_nonce("valid_nonce").is_ok());
378        assert!(validate_nonce("").is_err());
379        let long_nonce = "x".repeat(65);
380        assert!(validate_nonce(&long_nonce).is_err());
381    }
382
383    #[test]
384    fn test_validate_encoding() {
385        assert!(validate_encoding("json").is_ok());
386        assert!(validate_encoding("jsonParsed").is_ok());
387        assert!(validate_encoding("base58").is_ok());
388        assert!(validate_encoding("base64").is_ok());
389        assert!(validate_encoding("invalid").is_err());
390    }
391
392    #[test]
393    fn test_validate_signature() {
394        use fastcrypto::encoding::{Base58, Encoding};
395
396        // Valid signature (87-88 chars, decodes to 64 bytes)
397        let valid_sig =
398            "5VERv8NMvzbJMEkV8xnrLkEaWRtSz9CosKDYjCJjBRnbJLgp8uirBgmQpjKhoR4tjF3ZpRzrFmBV6UjKdiSZkQUW";
399        assert!(validate_signature(valid_sig).is_ok());
400
401        // Signature with leading zeros (shorter string, still 64 bytes when decoded)
402        // Create a signature with leading zero bytes
403        let mut sig_with_zeros = [0u8; 64];
404        sig_with_zeros[63] = 1; // Only last byte is non-zero
405        let short_sig = Base58::encode(sig_with_zeros);
406        // This will be much shorter than 87 chars due to leading zeros
407        assert!(
408            short_sig.len() < 87,
409            "Expected short signature due to leading zeros, got len {}",
410            short_sig.len()
411        );
412        assert!(
413            validate_signature(&short_sig).is_ok(),
414            "Signature with leading zeros should be valid"
415        );
416
417        // All zeros signature (64 '1' characters in base58)
418        let all_zeros = [0u8; 64];
419        let all_zeros_sig = Base58::encode(all_zeros);
420        assert_eq!(all_zeros_sig.len(), 64);
421        assert!(
422            validate_signature(&all_zeros_sig).is_ok(),
423            "All-zeros signature should be valid"
424        );
425
426        // Invalid: not base58
427        assert!(validate_signature("invalid!signature").is_err());
428
429        // Invalid: too long string
430        let too_long = "1".repeat(100);
431        assert!(validate_signature(&too_long).is_err());
432    }
433
434    #[test]
435    fn test_validate_max_transaction_version() {
436        assert!(validate_max_transaction_version(&0).is_ok());
437        assert!(validate_max_transaction_version(&1).is_ok());
438        assert!(validate_max_transaction_version(&2).is_err());
439    }
440
441    #[test]
442    fn test_validation_error_from_field_errors() {
443        use validator::Validate;
444
445        // Create a struct with field-level validation errors
446        #[derive(Validate)]
447        struct TestStruct {
448            #[validate(length(min = 1, message = "Field cannot be empty"))]
449            field1: String,
450            #[validate(range(min = 0, max = 100, message = "Must be between 0 and 100"))]
451            field2: usize,
452        }
453
454        let test = TestStruct {
455            field1: "".to_string(), // Invalid: empty
456            field2: 150,            // Invalid: out of range
457        };
458
459        let validation_result = test.validate();
460        assert!(validation_result.is_err());
461
462        let validation_errors = validation_result.unwrap_err();
463        let error: ValidationError = validation_errors.into();
464
465        // Should be Multiple since there are 2 errors
466        match error {
467            ValidationError::Multiple(msg) => {
468                assert!(msg.contains("field1"));
469                assert!(msg.contains("Field cannot be empty"));
470                assert!(msg.contains("field2"));
471                assert!(msg.contains("Must be between 0 and 100"));
472            }
473            other => panic!("Expected ValidationError::Multiple, got {:?}", other),
474        }
475    }
476
477    #[test]
478    fn test_validation_error_from_nested_struct_errors() {
479        use validator::Validate;
480
481        #[derive(Validate)]
482        struct NestedConfig {
483            #[validate(range(
484                min = 0,
485                max = 100,
486                message = "Max retries must be between 0 and 100"
487            ))]
488            max_retries: usize,
489            #[validate(range(min = 0, message = "Min slot must be non-negative"))]
490            min_slot: u64,
491        }
492
493        #[derive(Validate)]
494        struct ParentStruct {
495            #[validate(length(min = 1, message = "Name cannot be empty"))]
496            name: String,
497            #[validate(nested)]
498            config: NestedConfig,
499        }
500
501        let test = ParentStruct {
502            name: "valid".to_string(),
503            config: NestedConfig {
504                max_retries: 150, // Invalid: exceeds 100
505                min_slot: 0,
506            },
507        };
508
509        let validation_result = test.validate();
510        assert!(validation_result.is_err());
511
512        let validation_errors = validation_result.unwrap_err();
513        let error: ValidationError = validation_errors.into();
514
515        // Should contain the nested field path
516        let error_msg = error.to_string();
517        assert!(
518            error_msg.contains("config.max_retries"),
519            "Expected 'config.max_retries' in error message, got: {}",
520            error_msg
521        );
522        assert!(
523            error_msg.contains("Max retries must be between 0 and 100"),
524            "Expected validation message in error, got: {}",
525            error_msg
526        );
527    }
528
529    #[test]
530    fn test_validation_error_from_mixed_errors() {
531        use validator::Validate;
532
533        #[derive(Validate)]
534        struct NestedConfig {
535            #[validate(range(min = 0, max = 100, message = "Nested field must be 0-100"))]
536            nested_field: usize,
537        }
538
539        #[derive(Validate)]
540        struct ParentStruct {
541            #[validate(length(min = 1, message = "Parent field cannot be empty"))]
542            parent_field: String,
543            #[validate(nested)]
544            config: NestedConfig,
545        }
546
547        let test = ParentStruct {
548            parent_field: "".to_string(), // Invalid: empty
549            config: NestedConfig {
550                nested_field: 150, // Invalid: out of range
551            },
552        };
553
554        let validation_result = test.validate();
555        assert!(validation_result.is_err());
556
557        let validation_errors = validation_result.unwrap_err();
558        let error: ValidationError = validation_errors.into();
559
560        let error_msg = error.to_string();
561        // Should contain both parent and nested errors
562        assert!(
563            error_msg.contains("parent_field")
564                && error_msg.contains("Parent field cannot be empty"),
565            "Expected parent field error in message, got: {}",
566            error_msg
567        );
568        assert!(
569            error_msg.contains("config.nested_field")
570                && error_msg.contains("Nested field must be 0-100"),
571            "Expected nested field error in message, got: {}",
572            error_msg
573        );
574    }
575}