1use std::path::Path;
6use std::str::FromStr;
7
8#[derive(Debug, thiserror::Error)]
10pub enum InputError {
11 #[error("File not found: {0}")]
12 FileNotFound(String),
13 #[error("Invalid format: {0}")]
14 InvalidFormat(String),
15 #[error("Invalid value: {0}")]
16 InvalidValue(String),
17 #[error("IO error: {0}")]
18 IoError(#[from] std::io::Error),
19}
20
21pub fn parse_file_path(path: &str) -> Result<String, InputError> {
23 let path = Path::new(path);
24
25 if !path.exists() {
26 return Err(InputError::FileNotFound(path.to_string_lossy().to_string()));
27 }
28
29 Ok(path.to_string_lossy().to_string())
30}
31
32pub fn parse_hex(hex_str: &str) -> Result<Vec<u8>, InputError> {
34 hex::decode(hex_str).map_err(|e| InputError::InvalidFormat(format!("Invalid hex string: {e}")))
35}
36
37pub fn parse_base64(base64_str: &str) -> Result<Vec<u8>, InputError> {
39 use base64::{engine::general_purpose, Engine as _};
40 general_purpose::STANDARD
41 .decode(base64_str)
42 .map_err(|e| InputError::InvalidFormat(format!("Invalid base64 string: {e}")))
43}
44
45pub fn parse_number<T>(value: &str) -> Result<T, InputError>
47where
48 T: FromStr,
49 T::Err: std::fmt::Display,
50{
51 value
52 .parse()
53 .map_err(|e| InputError::InvalidValue(format!("Invalid number: {e}")))
54}
55
56pub fn parse_comma_separated(value: &str) -> Vec<String> {
58 value
59 .split(',')
60 .map(|s| s.trim().to_string())
61 .filter(|s| !s.is_empty())
62 .collect()
63}
64
65pub fn parse_threshold(threshold: &str) -> Result<(usize, usize), InputError> {
67 let parts: Vec<&str> = threshold.split("-of-").collect();
68
69 if parts.len() != 2 {
70 return Err(InputError::InvalidFormat(
71 "Threshold must be in format 'N-of-M'".to_string(),
72 ));
73 }
74
75 let threshold_num = parts[0]
76 .parse::<usize>()
77 .map_err(|e| InputError::InvalidValue(format!("Invalid threshold number: {e}")))?;
78
79 let total_num = parts[1]
80 .parse::<usize>()
81 .map_err(|e| InputError::InvalidValue(format!("Invalid total number: {e}")))?;
82
83 if threshold_num > total_num {
84 return Err(InputError::InvalidValue(
85 "Threshold cannot be greater than total".to_string(),
86 ));
87 }
88
89 Ok((threshold_num, total_num))
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use std::fs::File;
96 use std::io::Write;
97 use tempfile::tempdir;
98
99 #[test]
100 fn test_parse_hex() {
101 let result = parse_hex("deadbeef");
102 assert!(result.is_ok());
103 assert_eq!(result.unwrap(), vec![0xde, 0xad, 0xbe, 0xef]);
104 }
105
106 #[test]
107 fn test_parse_invalid_hex() {
108 let result = parse_hex("invalid");
109 assert!(result.is_err());
110 }
111
112 #[test]
113 fn test_parse_base64() {
114 let result = parse_base64("dGVzdA==");
115 assert!(result.is_ok());
116 assert_eq!(result.unwrap(), b"test");
117 }
118
119 #[test]
120 fn test_parse_number() {
121 let result: Result<u32, _> = parse_number("42");
122 assert_eq!(result.unwrap(), 42);
123 }
124
125 #[test]
126 fn test_parse_comma_separated() {
127 let result = parse_comma_separated("a, b, c");
128 assert_eq!(result, vec!["a", "b", "c"]);
129 }
130
131 #[test]
132 fn test_parse_threshold() {
133 let result = parse_threshold("3-of-5");
134 assert_eq!(result.unwrap(), (3, 5));
135 }
136
137 #[test]
138 fn test_parse_invalid_threshold() {
139 let result = parse_threshold("3-5");
140 assert!(result.is_err());
141 }
142
143 #[test]
144 fn test_parse_file_path() {
145 let dir = tempdir().unwrap();
146 let file_path = dir.path().join("test.txt");
147 let mut file = File::create(&file_path).unwrap();
148 file.write_all(b"test").unwrap();
149
150 let result = parse_file_path(file_path.to_str().unwrap());
151 assert!(result.is_ok());
152 }
153
154 #[test]
155 fn test_parse_nonexistent_file() {
156 let result = parse_file_path("/nonexistent/file.txt");
157 assert!(result.is_err());
158 }
159}