oar_ocr/pipeline/
config.rs1use crate::core::OCRError;
7use crate::pipeline::OAROCRConfig;
8use std::path::Path;
9
10#[derive(Debug, Clone, Copy)]
12pub enum ConfigFormat {
13 Toml,
15 Json,
17}
18
19impl ConfigFormat {
20 pub fn from_extension(path: &Path) -> Option<Self> {
22 match path.extension()?.to_str()? {
23 "toml" => Some(Self::Toml),
24 "json" => Some(Self::Json),
25 _ => None,
26 }
27 }
28}
29
30pub struct ConfigLoader;
32
33impl ConfigLoader {
34 pub fn load_from_file(path: &Path) -> Result<OAROCRConfig, OCRError> {
54 let format = ConfigFormat::from_extension(path).ok_or_else(|| OCRError::ConfigError {
55 message: format!("Unsupported config file extension: {:?}", path.extension()),
56 })?;
57
58 let content = std::fs::read_to_string(path).map_err(|e| OCRError::ConfigError {
59 message: format!("Failed to read config file {}: {}", path.display(), e),
60 })?;
61
62 Self::load_from_string(&content, format)
63 }
64
65 pub fn load_from_string(content: &str, format: ConfigFormat) -> Result<OAROCRConfig, OCRError> {
76 match format {
77 ConfigFormat::Toml => Self::load_from_toml(content),
78 ConfigFormat::Json => Self::load_from_json(content),
79 }
80 }
81
82 pub fn load_from_toml(content: &str) -> Result<OAROCRConfig, OCRError> {
84 toml::from_str(content).map_err(|e| OCRError::ConfigError {
85 message: format!("Failed to parse TOML config: {e}"),
86 })
87 }
88
89 pub fn load_from_json(content: &str) -> Result<OAROCRConfig, OCRError> {
91 serde_json::from_str(content).map_err(|e| OCRError::ConfigError {
92 message: format!("Failed to parse JSON config: {e}"),
93 })
94 }
95
96 pub fn save_to_file(config: &OAROCRConfig, path: &Path) -> Result<(), OCRError> {
107 let format = ConfigFormat::from_extension(path).ok_or_else(|| OCRError::ConfigError {
108 message: format!("Unsupported config file extension: {:?}", path.extension()),
109 })?;
110
111 let content = Self::save_to_string(config, format)?;
112
113 std::fs::write(path, content).map_err(|e| OCRError::ConfigError {
114 message: format!("Failed to write config file {}: {}", path.display(), e),
115 })
116 }
117
118 pub fn save_to_string(config: &OAROCRConfig, format: ConfigFormat) -> Result<String, OCRError> {
129 match format {
130 ConfigFormat::Toml => Self::save_to_toml(config),
131 ConfigFormat::Json => Self::save_to_json(config),
132 }
133 }
134
135 pub fn save_to_toml(config: &OAROCRConfig) -> Result<String, OCRError> {
137 toml::to_string_pretty(config).map_err(|e| OCRError::ConfigError {
138 message: format!("Failed to serialize config to TOML: {e}"),
139 })
140 }
141
142 pub fn save_to_json(config: &OAROCRConfig) -> Result<String, OCRError> {
144 serde_json::to_string_pretty(config).map_err(|e| OCRError::ConfigError {
145 message: format!("Failed to serialize config to JSON: {e}"),
146 })
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use std::path::PathBuf;
154
155 #[test]
156 fn test_config_format_detection() {
157 assert!(matches!(
158 ConfigFormat::from_extension(Path::new("config.toml")),
159 Some(ConfigFormat::Toml)
160 ));
161 assert!(matches!(
162 ConfigFormat::from_extension(Path::new("config.json")),
163 Some(ConfigFormat::Json)
164 ));
165 assert!(ConfigFormat::from_extension(Path::new("config.txt")).is_none());
166 }
167
168 #[test]
169 fn test_toml_roundtrip() {
170 let config = OAROCRConfig::new(
171 PathBuf::from("detection.onnx"),
172 PathBuf::from("recognition.onnx"),
173 PathBuf::from("dict.txt"),
174 );
175
176 let toml_str = ConfigLoader::save_to_toml(&config).unwrap();
177 let loaded_config = ConfigLoader::load_from_toml(&toml_str).unwrap();
178
179 assert_eq!(
180 config.character_dict_path,
181 loaded_config.character_dict_path
182 );
183 }
184
185 #[test]
186 fn test_json_roundtrip() {
187 let config = OAROCRConfig::new(
188 PathBuf::from("detection.onnx"),
189 PathBuf::from("recognition.onnx"),
190 PathBuf::from("dict.txt"),
191 );
192
193 let json_str = ConfigLoader::save_to_json(&config).unwrap();
194 let loaded_config = ConfigLoader::load_from_json(&json_str).unwrap();
195
196 assert_eq!(
197 config.character_dict_path,
198 loaded_config.character_dict_path
199 );
200 assert_eq!(
201 config.recognition.model_input_shape,
202 loaded_config.recognition.model_input_shape
203 );
204 }
205}