rialo_api_types/
validation.rs

1// Copyright (c) Subzero Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Validation for API request types
5
6use std::str::FromStr;
7
8use rialo_s_sdk::pubkey::Pubkey;
9use thiserror::Error;
10use validator::ValidationErrors;
11
12use crate::constants::*;
13
14/// Validation error types for RPC requests
15#[derive(Debug, Error, Clone)]
16pub enum ValidationError {
17    #[error("Invalid format: {0}")]
18    InvalidFormat(String),
19
20    #[error("Value out of range: {0}")]
21    OutOfRange(String),
22
23    #[error("Missing required field: {0}")]
24    MissingField(String),
25
26    #[error("Invalid signature: {0}")]
27    InvalidSignature(String),
28
29    #[error("Invalid encoding: {0}. Supported encodings: base64, base58")]
30    InvalidEncoding(String),
31
32    #[error("Invalid public key: {0}")]
33    InvalidPublicKey(String),
34
35    #[error("Invalid transaction: {0}")]
36    InvalidTransaction(String),
37
38    #[error("Multiple validation errors: {0}")]
39    Multiple(String),
40}
41
42impl From<ValidationErrors> for ValidationError {
43    fn from(errors: ValidationErrors) -> Self {
44        let mut error_messages: Vec<String> = Vec::new();
45
46        // Collect field-level errors
47        for (field, field_errors) in errors.field_errors() {
48            for error in field_errors {
49                let message = format!(
50                    "{}: {}",
51                    field,
52                    error
53                        .message
54                        .as_ref()
55                        .unwrap_or(&"validation failed".into())
56                );
57                error_messages.push(message);
58            }
59        }
60
61        // Collect struct-level errors (from nested validations)
62        for (field, struct_errors) in errors.errors() {
63            if let validator::ValidationErrorsKind::Struct(nested_errors) = struct_errors {
64                for (nested_field, nested_field_errors) in nested_errors.field_errors() {
65                    for error in nested_field_errors {
66                        let message = format!(
67                            "{}.{}: {}",
68                            field,
69                            nested_field,
70                            error
71                                .message
72                                .as_ref()
73                                .unwrap_or(&"validation failed".into())
74                        );
75                        error_messages.push(message);
76                    }
77                }
78            }
79        }
80
81        if error_messages.is_empty() {
82            ValidationError::Multiple("Unknown validation error".to_string())
83        } else if error_messages.len() == 1 {
84            ValidationError::InvalidFormat(error_messages[0].clone())
85        } else {
86            ValidationError::Multiple(error_messages.join(", "))
87        }
88    }
89}
90
91/// Result type for validation operations
92pub type ValidationResult<T> = Result<T, ValidationError>;
93
94/// Validate Solana public key format
95pub fn validate_pubkey(pubkey: &str) -> Result<(), validator::ValidationError> {
96    Pubkey::from_str(pubkey).map_err(|_| validator::ValidationError::new("invalid_pubkey"))?;
97    Ok(())
98}
99
100/// Validate base64 encoded data
101pub fn validate_base64(data: &str) -> Result<(), validator::ValidationError> {
102    use fastcrypto::encoding::{Base64, Encoding};
103    Base64::decode(data).map_err(|_| validator::ValidationError::new("invalid_base64"))?;
104    Ok(())
105}
106
107/// Validate base58 encoded data
108pub fn validate_base58(data: &str) -> Result<(), validator::ValidationError> {
109    use fastcrypto::encoding::{Base58, Encoding};
110    Base58::decode(data).map_err(|_| validator::ValidationError::new("invalid_base58"))?;
111    Ok(())
112}
113
114/// Validate transaction signature
115pub fn validate_signature(signature: &str) -> Result<(), validator::ValidationError> {
116    // Solana signatures are base58 encoded and should be 87 or 88 characters long
117    // (87 when there are leading zero bytes that get omitted in base58 encoding)
118    if signature.len() < MIN_SIGNATURE_LENGTH || signature.len() > MAX_SIGNATURE_LENGTH {
119        return Err(validator::ValidationError::new("invalid_signature_length"));
120    }
121    validate_base58(signature)
122}
123
124/// Validate nonce format (should be valid UTF-8 and reasonable length)
125pub fn validate_nonce(nonce: &str) -> Result<(), validator::ValidationError> {
126    if nonce.is_empty() {
127        return Err(validator::ValidationError::new("empty_nonce"));
128    }
129    if nonce.len() > MAX_NONCE_LENGTH {
130        return Err(validator::ValidationError::new("nonce_too_long"));
131    }
132    Ok(())
133}
134
135/// Validate lamports amount (should be reasonable)
136pub fn validate_lamports(lamports: u64) -> Result<(), validator::ValidationError> {
137    // Validate against maximum possible lamports (500 million SOL * 1e9 lamports/SOL)
138    if lamports > MAX_LAMPORTS {
139        return Err(validator::ValidationError::new("lamports_too_large"));
140    }
141    Ok(())
142}
143
144/// Validate limit parameter for paginated requests
145pub fn validate_limit(limit: &u64) -> Result<(), validator::ValidationError> {
146    if *limit == 0 {
147        return Err(validator::ValidationError::new("limit_zero"));
148    }
149    if *limit > MAX_PAGINATION_LIMIT {
150        return Err(validator::ValidationError::new("limit_too_large"));
151    }
152    Ok(())
153}
154
155/// Custom validator for array of public keys
156pub fn validate_pubkey_array(pubkeys: &[String]) -> Result<(), validator::ValidationError> {
157    for pubkey in pubkeys {
158        validate_pubkey(pubkey)?;
159    }
160    Ok(())
161}
162
163/// Custom validator for array of signatures
164pub fn validate_signatures_array(signatures: &[String]) -> Result<(), validator::ValidationError> {
165    for signature in signatures {
166        validate_signature(signature)?;
167    }
168    Ok(())
169}
170
171/// Custom validator for airdrop amounts
172pub fn validate_airdrop_amount(lamports: u64) -> Result<(), validator::ValidationError> {
173    validate_lamports(lamports)?;
174
175    // Additional airdrop-specific validation
176    if lamports > MAX_AIRDROP_AMOUNT {
177        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
178    }
179
180    if lamports == 0 {
181        return Err(validator::ValidationError::new("airdrop_amount_zero"));
182    }
183
184    Ok(())
185}
186
187/// Custom validator for airdrop amounts (i64 version)
188pub fn validate_airdrop_amount_i64(lamports: i64) -> Result<(), validator::ValidationError> {
189    // Check for negative values
190    if lamports < 0 {
191        return Err(validator::ValidationError::new("airdrop_amount_negative"));
192    }
193
194    // Check for zero
195    if lamports == 0 {
196        return Err(validator::ValidationError::new("airdrop_amount_zero"));
197    }
198
199    // Convert to u64 for other validations
200    let lamports_u64 = lamports as u64;
201    validate_lamports(lamports_u64)?;
202
203    // Additional airdrop-specific validation
204    if lamports_u64 > MAX_AIRDROP_AMOUNT {
205        return Err(validator::ValidationError::new("airdrop_amount_too_large"));
206    }
207
208    Ok(())
209}
210
211/// Validate signature limit (1-1000)
212pub fn validate_signature_limit(limit: &u16) -> Result<(), validator::ValidationError> {
213    if *limit == 0 {
214        return Err(validator::ValidationError::new("limit_must_be_positive"));
215    }
216    if *limit > MAX_PAGINATION_LIMIT as u16 {
217        return Err(validator::ValidationError::new("limit_exceeds_maximum"));
218    }
219    Ok(())
220}
221
222/// Custom validator for transaction data based on encoding
223pub fn validate_transaction_data(transaction: &str) -> Result<(), validator::ValidationError> {
224    // Empty strings are handled by the length validator, so skip custom validation for them
225    if transaction.is_empty() {
226        return Ok(());
227    }
228
229    // Validate the encoding format - try base64 first (most common)
230    if validate_base64(transaction).is_ok() {
231        // Now validate that the decoded data can be parsed as a transaction
232        return validate_transaction_structure_base64(transaction);
233    }
234
235    // If base64 fails, try base58
236    if validate_base58(transaction).is_ok() {
237        return validate_transaction_structure_base58(transaction);
238    }
239
240    // If neither encoding works, return error
241    Err(validator::ValidationError::new(
242        "invalid_transaction_encoding",
243    ))
244}
245
246/// Validate that base64 encoded data represents a valid transaction structure
247fn validate_transaction_structure_base64(
248    transaction: &str,
249) -> Result<(), validator::ValidationError> {
250    use fastcrypto::encoding::{Base64, Encoding};
251
252    // Decode the base64 data
253    let decoded = Base64::decode(transaction)
254        .map_err(|_| validator::ValidationError::new("invalid_base64_transaction"))?;
255
256    validate_transaction_bytes(&decoded)
257}
258
259/// Validate that base58 encoded data represents a valid transaction structure  
260fn validate_transaction_structure_base58(
261    transaction: &str,
262) -> Result<(), validator::ValidationError> {
263    use fastcrypto::encoding::{Base58, Encoding};
264
265    // Decode the base58 data
266    let decoded = Base58::decode(transaction)
267        .map_err(|_| validator::ValidationError::new("invalid_base58_transaction"))?;
268
269    validate_transaction_bytes(&decoded)
270}
271
272/// Validate the raw transaction bytes represent a valid transaction structure
273fn validate_transaction_bytes(transaction_bytes: &[u8]) -> Result<(), validator::ValidationError> {
274    // Basic size validation - transactions should be at least some minimum size
275    if transaction_bytes.len() < MIN_TRANSACTION_SIZE {
276        return Err(validator::ValidationError::new("transaction_too_small"));
277    }
278
279    // Maximum transaction size check (Solana has a 1232 byte limit)
280    if transaction_bytes.len() > MAX_TRANSACTION_SIZE {
281        return Err(validator::ValidationError::new("transaction_too_large"));
282    }
283
284    // Try to deserialize as a VersionedTransaction to validate structure
285    match bincode::deserialize::<rialo_s_sdk::transaction::VersionedTransaction>(transaction_bytes)
286    {
287        Ok(_) => Ok(()),
288        Err(_) => {
289            // If VersionedTransaction fails, try legacy Transaction format
290            match bincode::deserialize::<rialo_s_sdk::transaction::Transaction>(transaction_bytes) {
291                Ok(_) => Ok(()),
292                Err(_) => Err(validator::ValidationError::new(
293                    "invalid_transaction_structure",
294                )),
295            }
296        }
297    }
298}
299
300/// Custom validator for limit as string (some endpoints use string format)
301pub fn validate_limit_string(limit: &str) -> Result<(), validator::ValidationError> {
302    let limit_val: u64 = limit
303        .parse()
304        .map_err(|_| validator::ValidationError::new("invalid_limit_format"))?;
305    validate_limit(&limit_val)?;
306    Ok(())
307}
308
309/// Validate blockhash format (should be valid base58)
310pub fn validate_blockhash(blockhash: &str) -> Result<(), validator::ValidationError> {
311    // Solana blockhashes are base58 encoded and should be 32 bytes (44 characters in base58)
312    if blockhash.len() < MIN_BLOCKHASH_LENGTH || blockhash.len() > MAX_BLOCKHASH_LENGTH {
313        return Err(validator::ValidationError::new("invalid_blockhash_length"));
314    }
315    validate_base58(blockhash)
316}
317
318/// Validate array of addresses (public keys)
319pub fn validate_addresses(addresses: &[String]) -> Result<(), validator::ValidationError> {
320    for address in addresses {
321        validate_pubkey(address)?;
322    }
323    Ok(())
324}
325
326/// Validate array of signatures
327pub fn validate_signatures(signatures: &[String]) -> Result<(), validator::ValidationError> {
328    for signature in signatures {
329        validate_signature(signature)?;
330    }
331    Ok(())
332}
333
334/// Validate encoding format
335pub fn validate_encoding(encoding: &str) -> Result<(), validator::ValidationError> {
336    match encoding {
337        "json" | "jsonParsed" | "base58" | "base64" => Ok(()),
338        _ => Err(validator::ValidationError::new("invalid_encoding_format")),
339    }
340}
341
342/// Validate max transaction version
343pub fn validate_max_transaction_version(version: &u8) -> Result<(), validator::ValidationError> {
344    if *version <= 1 {
345        Ok(())
346    } else {
347        Err(validator::ValidationError::new(
348            "invalid_max_transaction_version",
349        ))
350    }
351}
352
353/// Validation middleware that validates a request
354pub fn validate_request<T>(request: T) -> ValidationResult<T>
355where
356    T: validator::Validate,
357{
358    request.validate().map_err(ValidationError::from)?;
359    Ok(request)
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_validate_limit() {
368        assert!(validate_limit(&1).is_ok());
369        assert!(validate_limit(&MAX_PAGINATION_LIMIT).is_ok());
370        assert!(validate_limit(&0).is_err());
371        assert!(validate_limit(&(MAX_PAGINATION_LIMIT + 1)).is_err());
372    }
373
374    #[test]
375    fn test_validate_nonce() {
376        assert!(validate_nonce("valid_nonce").is_ok());
377        assert!(validate_nonce("").is_err());
378        let long_nonce = "x".repeat(65);
379        assert!(validate_nonce(&long_nonce).is_err());
380    }
381
382    #[test]
383    fn test_validate_encoding() {
384        assert!(validate_encoding("json").is_ok());
385        assert!(validate_encoding("jsonParsed").is_ok());
386        assert!(validate_encoding("base58").is_ok());
387        assert!(validate_encoding("base64").is_ok());
388        assert!(validate_encoding("invalid").is_err());
389    }
390
391    #[test]
392    fn test_validate_max_transaction_version() {
393        assert!(validate_max_transaction_version(&0).is_ok());
394        assert!(validate_max_transaction_version(&1).is_ok());
395        assert!(validate_max_transaction_version(&2).is_err());
396    }
397
398    #[test]
399    fn test_validation_error_from_field_errors() {
400        use validator::Validate;
401
402        // Create a struct with field-level validation errors
403        #[derive(Validate)]
404        struct TestStruct {
405            #[validate(length(min = 1, message = "Field cannot be empty"))]
406            field1: String,
407            #[validate(range(min = 0, max = 100, message = "Must be between 0 and 100"))]
408            field2: usize,
409        }
410
411        let test = TestStruct {
412            field1: "".to_string(), // Invalid: empty
413            field2: 150,            // Invalid: out of range
414        };
415
416        let validation_result = test.validate();
417        assert!(validation_result.is_err());
418
419        let validation_errors = validation_result.unwrap_err();
420        let error: ValidationError = validation_errors.into();
421
422        // Should be Multiple since there are 2 errors
423        match error {
424            ValidationError::Multiple(msg) => {
425                assert!(msg.contains("field1"));
426                assert!(msg.contains("Field cannot be empty"));
427                assert!(msg.contains("field2"));
428                assert!(msg.contains("Must be between 0 and 100"));
429            }
430            other => panic!("Expected ValidationError::Multiple, got {:?}", other),
431        }
432    }
433
434    #[test]
435    fn test_validation_error_from_nested_struct_errors() {
436        use validator::Validate;
437
438        #[derive(Validate)]
439        struct NestedConfig {
440            #[validate(range(
441                min = 0,
442                max = 100,
443                message = "Max retries must be between 0 and 100"
444            ))]
445            max_retries: usize,
446            #[validate(range(min = 0, message = "Min slot must be non-negative"))]
447            min_slot: u64,
448        }
449
450        #[derive(Validate)]
451        struct ParentStruct {
452            #[validate(length(min = 1, message = "Name cannot be empty"))]
453            name: String,
454            #[validate(nested)]
455            config: NestedConfig,
456        }
457
458        let test = ParentStruct {
459            name: "valid".to_string(),
460            config: NestedConfig {
461                max_retries: 150, // Invalid: exceeds 100
462                min_slot: 0,
463            },
464        };
465
466        let validation_result = test.validate();
467        assert!(validation_result.is_err());
468
469        let validation_errors = validation_result.unwrap_err();
470        let error: ValidationError = validation_errors.into();
471
472        // Should contain the nested field path
473        let error_msg = error.to_string();
474        assert!(
475            error_msg.contains("config.max_retries"),
476            "Expected 'config.max_retries' in error message, got: {}",
477            error_msg
478        );
479        assert!(
480            error_msg.contains("Max retries must be between 0 and 100"),
481            "Expected validation message in error, got: {}",
482            error_msg
483        );
484    }
485
486    #[test]
487    fn test_validation_error_from_mixed_errors() {
488        use validator::Validate;
489
490        #[derive(Validate)]
491        struct NestedConfig {
492            #[validate(range(min = 0, max = 100, message = "Nested field must be 0-100"))]
493            nested_field: usize,
494        }
495
496        #[derive(Validate)]
497        struct ParentStruct {
498            #[validate(length(min = 1, message = "Parent field cannot be empty"))]
499            parent_field: String,
500            #[validate(nested)]
501            config: NestedConfig,
502        }
503
504        let test = ParentStruct {
505            parent_field: "".to_string(), // Invalid: empty
506            config: NestedConfig {
507                nested_field: 150, // Invalid: out of range
508            },
509        };
510
511        let validation_result = test.validate();
512        assert!(validation_result.is_err());
513
514        let validation_errors = validation_result.unwrap_err();
515        let error: ValidationError = validation_errors.into();
516
517        let error_msg = error.to_string();
518        // Should contain both parent and nested errors
519        assert!(
520            error_msg.contains("parent_field")
521                && error_msg.contains("Parent field cannot be empty"),
522            "Expected parent field error in message, got: {}",
523            error_msg
524        );
525        assert!(
526            error_msg.contains("config.nested_field")
527                && error_msg.contains("Nested field must be 0-100"),
528            "Expected nested field error in message, got: {}",
529            error_msg
530        );
531    }
532}