1use std::str::FromStr;
7
8use rialo_s_sdk::pubkey::Pubkey;
9use thiserror::Error;
10use validator::ValidationErrors;
11
12use crate::constants::*;
13
14#[derive(Debug, Error, Clone)]
16pub enum ValidationError {
17 #[error("Invalid format: {0}")]
18 InvalidFormat(String),
19
20 #[error("Value out of range: {0}")]
21 OutOfRange(String),
22
23 #[error("Missing required field: {0}")]
24 MissingField(String),
25
26 #[error("Invalid signature: {0}")]
27 InvalidSignature(String),
28
29 #[error("Invalid encoding: {0}. Supported encodings: base64, base58")]
30 InvalidEncoding(String),
31
32 #[error("Invalid public key: {0}")]
33 InvalidPublicKey(String),
34
35 #[error("Invalid transaction: {0}")]
36 InvalidTransaction(String),
37
38 #[error("Multiple validation errors: {0}")]
39 Multiple(String),
40}
41
42impl From<ValidationErrors> for ValidationError {
43 fn from(errors: ValidationErrors) -> Self {
44 let mut error_messages: Vec<String> = Vec::new();
45
46 for (field, field_errors) in errors.field_errors() {
48 for error in field_errors {
49 let message = format!(
50 "{}: {}",
51 field,
52 error
53 .message
54 .as_ref()
55 .unwrap_or(&"validation failed".into())
56 );
57 error_messages.push(message);
58 }
59 }
60
61 for (field, struct_errors) in errors.errors() {
63 if let validator::ValidationErrorsKind::Struct(nested_errors) = struct_errors {
64 for (nested_field, nested_field_errors) in nested_errors.field_errors() {
65 for error in nested_field_errors {
66 let message = format!(
67 "{}.{}: {}",
68 field,
69 nested_field,
70 error
71 .message
72 .as_ref()
73 .unwrap_or(&"validation failed".into())
74 );
75 error_messages.push(message);
76 }
77 }
78 }
79 }
80
81 if error_messages.is_empty() {
82 ValidationError::Multiple("Unknown validation error".to_string())
83 } else if error_messages.len() == 1 {
84 ValidationError::InvalidFormat(error_messages[0].clone())
85 } else {
86 ValidationError::Multiple(error_messages.join(", "))
87 }
88 }
89}
90
91pub type ValidationResult<T> = Result<T, ValidationError>;
93
94pub fn validate_protocol_version(version: u16) -> Result<(), validator::ValidationError> {
96 if version != 0 {
97 let mut err = validator::ValidationError::new("invalid_protocol_version");
98 err.message = Some(format!("Protocol version must be 0, got {}", version).into());
99 return Err(err);
100 }
101 Ok(())
102}
103
104pub fn validate_pubkey(pubkey: &str) -> Result<(), validator::ValidationError> {
106 Pubkey::from_str(pubkey).map_err(|_| validator::ValidationError::new("invalid_pubkey"))?;
107 Ok(())
108}
109
110pub fn validate_base64(data: &str) -> Result<(), validator::ValidationError> {
112 use fastcrypto::encoding::{Base64, Encoding};
113 Base64::decode(data).map_err(|_| validator::ValidationError::new("invalid_base64"))?;
114 Ok(())
115}
116
117pub fn validate_base58(data: &str) -> Result<(), validator::ValidationError> {
119 use fastcrypto::encoding::{Base58, Encoding};
120 Base58::decode(data).map_err(|_| validator::ValidationError::new("invalid_base58"))?;
121 Ok(())
122}
123
124pub fn validate_signature(signature: &str) -> Result<(), validator::ValidationError> {
126 if signature.len() > MAX_SIGNATURE_LENGTH {
130 return Err(validator::ValidationError::new("invalid_signature_length"));
131 }
132 validate_base58(signature)
133}
134
135pub fn validate_nonce(nonce: &str) -> Result<(), validator::ValidationError> {
137 if nonce.is_empty() {
138 return Err(validator::ValidationError::new("empty_nonce"));
139 }
140 if nonce.len() > MAX_NONCE_LENGTH {
141 return Err(validator::ValidationError::new("nonce_too_long"));
142 }
143 Ok(())
144}
145
146pub fn validate_kelvins(kelvins: u64) -> Result<(), validator::ValidationError> {
148 if kelvins > MAX_KELVINS {
150 return Err(validator::ValidationError::new("kelvins_too_large"));
151 }
152 Ok(())
153}
154
155pub fn validate_limit(limit: &u64) -> Result<(), validator::ValidationError> {
157 if *limit == 0 {
158 return Err(validator::ValidationError::new("limit_zero"));
159 }
160 if *limit > MAX_PAGINATION_LIMIT {
161 return Err(validator::ValidationError::new("limit_too_large"));
162 }
163 Ok(())
164}
165
166pub fn validate_pubkey_array(pubkeys: &[String]) -> Result<(), validator::ValidationError> {
168 for pubkey in pubkeys {
169 validate_pubkey(pubkey)?;
170 }
171 Ok(())
172}
173
174pub fn validate_signatures_array(signatures: &[String]) -> Result<(), validator::ValidationError> {
176 for signature in signatures {
177 validate_signature(signature)?;
178 }
179 Ok(())
180}
181
182pub fn validate_airdrop_amount(kelvins: u64) -> Result<(), validator::ValidationError> {
184 validate_kelvins(kelvins)?;
185
186 if kelvins > MAX_AIRDROP_AMOUNT {
188 return Err(validator::ValidationError::new("airdrop_amount_too_large"));
189 }
190
191 if kelvins == 0 {
192 return Err(validator::ValidationError::new("airdrop_amount_zero"));
193 }
194
195 Ok(())
196}
197
198pub fn validate_airdrop_amount_i64(kelvins: i64) -> Result<(), validator::ValidationError> {
200 if kelvins < 0 {
202 return Err(validator::ValidationError::new("airdrop_amount_negative"));
203 }
204
205 if kelvins == 0 {
207 return Err(validator::ValidationError::new("airdrop_amount_zero"));
208 }
209
210 let kelvins_u64 = kelvins as u64;
212 validate_kelvins(kelvins_u64)?;
213
214 if kelvins_u64 > MAX_AIRDROP_AMOUNT {
216 return Err(validator::ValidationError::new("airdrop_amount_too_large"));
217 }
218
219 Ok(())
220}
221
222pub fn validate_signature_limit(limit: &u16) -> Result<(), validator::ValidationError> {
224 if *limit == 0 {
225 return Err(validator::ValidationError::new("limit_must_be_positive"));
226 }
227 if *limit > MAX_PAGINATION_LIMIT as u16 {
228 return Err(validator::ValidationError::new("limit_exceeds_maximum"));
229 }
230 Ok(())
231}
232
233pub fn validate_transaction_data(transaction: &str) -> Result<(), validator::ValidationError> {
235 if transaction.is_empty() {
237 return Ok(());
238 }
239
240 if validate_base64(transaction).is_ok() {
242 return validate_transaction_structure_base64(transaction);
244 }
245
246 if validate_base58(transaction).is_ok() {
248 return validate_transaction_structure_base58(transaction);
249 }
250
251 Err(validator::ValidationError::new(
253 "invalid_transaction_encoding",
254 ))
255}
256
257fn validate_transaction_structure_base64(
259 transaction: &str,
260) -> Result<(), validator::ValidationError> {
261 use fastcrypto::encoding::{Base64, Encoding};
262
263 let decoded = Base64::decode(transaction)
265 .map_err(|_| validator::ValidationError::new("invalid_base64_transaction"))?;
266
267 validate_transaction_bytes(&decoded)
268}
269
270fn validate_transaction_structure_base58(
272 transaction: &str,
273) -> Result<(), validator::ValidationError> {
274 use fastcrypto::encoding::{Base58, Encoding};
275
276 let decoded = Base58::decode(transaction)
278 .map_err(|_| validator::ValidationError::new("invalid_base58_transaction"))?;
279
280 validate_transaction_bytes(&decoded)
281}
282
283fn validate_transaction_bytes(transaction_bytes: &[u8]) -> Result<(), validator::ValidationError> {
285 if transaction_bytes.len() < MIN_TRANSACTION_SIZE {
287 return Err(validator::ValidationError::new("transaction_too_small"));
288 }
289
290 if transaction_bytes.len() > MAX_TRANSACTION_SIZE {
292 return Err(validator::ValidationError::new("transaction_too_large"));
293 }
294
295 match bincode::deserialize::<rialo_s_sdk::transaction::VersionedTransaction>(transaction_bytes)
297 {
298 Ok(_) => Ok(()),
299 Err(_) => {
300 match bincode::deserialize::<rialo_s_sdk::transaction::Transaction>(transaction_bytes) {
302 Ok(_) => Ok(()),
303 Err(_) => Err(validator::ValidationError::new(
304 "invalid_transaction_structure",
305 )),
306 }
307 }
308 }
309}
310
311pub fn validate_limit_string(limit: &str) -> Result<(), validator::ValidationError> {
313 let limit_val: u64 = limit
314 .parse()
315 .map_err(|_| validator::ValidationError::new("invalid_limit_format"))?;
316 validate_limit(&limit_val)?;
317 Ok(())
318}
319
320pub fn validate_blockhash(blockhash: &str) -> Result<(), validator::ValidationError> {
322 if blockhash.len() < MIN_BLOCKHASH_LENGTH || blockhash.len() > MAX_BLOCKHASH_LENGTH {
324 return Err(validator::ValidationError::new("invalid_blockhash_length"));
325 }
326 validate_base58(blockhash)
327}
328
329pub fn validate_addresses(addresses: &[String]) -> Result<(), validator::ValidationError> {
331 for address in addresses {
332 validate_pubkey(address)?;
333 }
334 Ok(())
335}
336
337pub fn validate_signatures(signatures: &[String]) -> Result<(), validator::ValidationError> {
339 for signature in signatures {
340 validate_signature(signature)?;
341 }
342 Ok(())
343}
344
345pub fn validate_encoding(encoding: &str) -> Result<(), validator::ValidationError> {
347 match encoding {
348 "json" | "jsonParsed" | "base58" | "base64" => Ok(()),
349 _ => Err(validator::ValidationError::new("invalid_encoding_format")),
350 }
351}
352
353pub fn validate_max_transaction_version(version: &u8) -> Result<(), validator::ValidationError> {
355 if *version <= 1 {
356 Ok(())
357 } else {
358 Err(validator::ValidationError::new(
359 "invalid_max_transaction_version",
360 ))
361 }
362}
363
364pub fn validate_request<T>(request: T) -> ValidationResult<T>
366where
367 T: validator::Validate,
368{
369 request.validate().map_err(ValidationError::from)?;
370 Ok(request)
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 fn test_validate_limit() {
379 assert!(validate_limit(&1).is_ok());
380 assert!(validate_limit(&MAX_PAGINATION_LIMIT).is_ok());
381 assert!(validate_limit(&0).is_err());
382 assert!(validate_limit(&(MAX_PAGINATION_LIMIT + 1)).is_err());
383 }
384
385 #[test]
386 fn test_validate_nonce() {
387 assert!(validate_nonce("valid_nonce").is_ok());
388 assert!(validate_nonce("").is_err());
389 let long_nonce = "x".repeat(65);
390 assert!(validate_nonce(&long_nonce).is_err());
391 }
392
393 #[test]
394 fn test_validate_encoding() {
395 assert!(validate_encoding("json").is_ok());
396 assert!(validate_encoding("jsonParsed").is_ok());
397 assert!(validate_encoding("base58").is_ok());
398 assert!(validate_encoding("base64").is_ok());
399 assert!(validate_encoding("invalid").is_err());
400 }
401
402 #[test]
403 fn test_validate_signature() {
404 use fastcrypto::encoding::{Base58, Encoding};
405
406 let valid_sig =
408 "5VERv8NMvzbJMEkV8xnrLkEaWRtSz9CosKDYjCJjBRnbJLgp8uirBgmQpjKhoR4tjF3ZpRzrFmBV6UjKdiSZkQUW";
409 assert!(validate_signature(valid_sig).is_ok());
410
411 let mut sig_with_zeros = [0u8; 64];
414 sig_with_zeros[63] = 1; let short_sig = Base58::encode(sig_with_zeros);
416 assert!(
418 short_sig.len() < 87,
419 "Expected short signature due to leading zeros, got len {}",
420 short_sig.len()
421 );
422 assert!(
423 validate_signature(&short_sig).is_ok(),
424 "Signature with leading zeros should be valid"
425 );
426
427 let all_zeros = [0u8; 64];
429 let all_zeros_sig = Base58::encode(all_zeros);
430 assert_eq!(all_zeros_sig.len(), 64);
431 assert!(
432 validate_signature(&all_zeros_sig).is_ok(),
433 "All-zeros signature should be valid"
434 );
435
436 assert!(validate_signature("invalid!signature").is_err());
438
439 let too_long = "1".repeat(100);
441 assert!(validate_signature(&too_long).is_err());
442 }
443
444 #[test]
445 fn test_validate_max_transaction_version() {
446 assert!(validate_max_transaction_version(&0).is_ok());
447 assert!(validate_max_transaction_version(&1).is_ok());
448 assert!(validate_max_transaction_version(&2).is_err());
449 }
450
451 #[test]
452 fn test_validation_error_from_field_errors() {
453 use validator::Validate;
454
455 #[derive(Validate)]
457 struct TestStruct {
458 #[validate(length(min = 1, message = "Field cannot be empty"))]
459 field1: String,
460 #[validate(range(min = 0, max = 100, message = "Must be between 0 and 100"))]
461 field2: usize,
462 }
463
464 let test = TestStruct {
465 field1: "".to_string(), field2: 150, };
468
469 let validation_result = test.validate();
470 assert!(validation_result.is_err());
471
472 let validation_errors = validation_result.unwrap_err();
473 let error: ValidationError = validation_errors.into();
474
475 match error {
477 ValidationError::Multiple(msg) => {
478 assert!(msg.contains("field1"));
479 assert!(msg.contains("Field cannot be empty"));
480 assert!(msg.contains("field2"));
481 assert!(msg.contains("Must be between 0 and 100"));
482 }
483 other => panic!("Expected ValidationError::Multiple, got {:?}", other),
484 }
485 }
486
487 #[test]
488 fn test_validation_error_from_nested_struct_errors() {
489 use validator::Validate;
490
491 #[derive(Validate)]
492 struct NestedConfig {
493 #[validate(range(
494 min = 0,
495 max = 100,
496 message = "Max retries must be between 0 and 100"
497 ))]
498 max_retries: usize,
499 #[validate(range(min = 0, message = "Min slot must be non-negative"))]
500 min_slot: u64,
501 }
502
503 #[derive(Validate)]
504 struct ParentStruct {
505 #[validate(length(min = 1, message = "Name cannot be empty"))]
506 name: String,
507 #[validate(nested)]
508 config: NestedConfig,
509 }
510
511 let test = ParentStruct {
512 name: "valid".to_string(),
513 config: NestedConfig {
514 max_retries: 150, min_slot: 0,
516 },
517 };
518
519 let validation_result = test.validate();
520 assert!(validation_result.is_err());
521
522 let validation_errors = validation_result.unwrap_err();
523 let error: ValidationError = validation_errors.into();
524
525 let error_msg = error.to_string();
527 assert!(
528 error_msg.contains("config.max_retries"),
529 "Expected 'config.max_retries' in error message, got: {}",
530 error_msg
531 );
532 assert!(
533 error_msg.contains("Max retries must be between 0 and 100"),
534 "Expected validation message in error, got: {}",
535 error_msg
536 );
537 }
538
539 #[test]
540 fn test_validation_error_from_mixed_errors() {
541 use validator::Validate;
542
543 #[derive(Validate)]
544 struct NestedConfig {
545 #[validate(range(min = 0, max = 100, message = "Nested field must be 0-100"))]
546 nested_field: usize,
547 }
548
549 #[derive(Validate)]
550 struct ParentStruct {
551 #[validate(length(min = 1, message = "Parent field cannot be empty"))]
552 parent_field: String,
553 #[validate(nested)]
554 config: NestedConfig,
555 }
556
557 let test = ParentStruct {
558 parent_field: "".to_string(), config: NestedConfig {
560 nested_field: 150, },
562 };
563
564 let validation_result = test.validate();
565 assert!(validation_result.is_err());
566
567 let validation_errors = validation_result.unwrap_err();
568 let error: ValidationError = validation_errors.into();
569
570 let error_msg = error.to_string();
571 assert!(
573 error_msg.contains("parent_field")
574 && error_msg.contains("Parent field cannot be empty"),
575 "Expected parent field error in message, got: {}",
576 error_msg
577 );
578 assert!(
579 error_msg.contains("config.nested_field")
580 && error_msg.contains("Nested field must be 0-100"),
581 "Expected nested field error in message, got: {}",
582 error_msg
583 );
584 }
585}