heisenberg_data_processing/
lib.rs1use std::{
9 fmt,
10 path::{Path, PathBuf},
11 str::FromStr,
12 sync::{
13 LazyLock,
14 atomic::{AtomicUsize, Ordering},
15 },
16};
17
18pub use error::{DataError, Result};
19use polars::prelude::*;
20use serde::{Deserialize, Serialize};
21use tracing::{info, warn};
22
23use crate::processed::{generate_processed_data, save_processed_data_to_parquet};
24
25pub mod embedded;
26pub mod error;
27pub mod processed;
28pub mod raw;
29
30pub const DATA_DIR_DEFAULT: &str = "heisenberg_data";
31
32pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
33 std::env::var("DATA_DIR").map_or_else(|_| get_default_data_dir(), PathBuf::from)
34});
35
36#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
37#[serde(rename_all = "snake_case")]
39pub enum DataSource {
40 #[default]
41 Cities15000,
43 Cities5000,
45 Cities1000,
47 Cities500,
49 AllCountries,
51 TestData,
54}
55
56impl fmt::Display for DataSource {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 Self::Cities15000 => write!(f, "cities15000"),
60 Self::Cities5000 => write!(f, "cities5000"),
61 Self::Cities1000 => write!(f, "cities1000"),
62 Self::Cities500 => write!(f, "cities500"),
63 Self::AllCountries => write!(f, "allCountries"),
64 Self::TestData => write!(f, "test_data"),
65 }
66 }
67}
68
69impl DataSource {
70 pub const BASE_URL: &str = "https://download.geonames.org/export/dump/";
71 pub const PROCESSED_DIR: &str = "processed";
72
73 pub fn data_source_dir(&self) -> PathBuf {
74 DATA_DIR.join(self.to_string())
75 }
76
77 fn processed_dir(&self) -> PathBuf {
78 self.data_source_dir().join(Self::PROCESSED_DIR)
79 }
80
81 pub fn places_url(&self) -> Option<String> {
82 match self {
83 Self::TestData => {
84 warn!("Using test data, no download URL available");
85 None
86 }
87 _ => Some(format!("{}{}.zip", Self::BASE_URL, &self)),
88 }
89 }
90
91 #[must_use]
92 pub fn admin_parquet(&self) -> PathBuf {
93 self.processed_dir().join("admin_search.parquet")
94 }
95
96 #[must_use]
97 pub fn place_parquet(&self) -> PathBuf {
98 self.processed_dir().join("place_search.parquet")
99 }
100}
101
102impl FromStr for DataSource {
103 type Err = String;
104
105 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
106 match s.to_lowercase().as_str() {
107 "cities15000" => Ok(Self::Cities15000),
108 "cities5000" => Ok(Self::Cities5000),
109 "cities1000" => Ok(Self::Cities1000),
110 "cities500" => Ok(Self::Cities500),
111 "allcountries" => Ok(Self::AllCountries),
112 "test_data" | "test" => Ok(Self::TestData),
113 _ => Err(format!(
114 "Invalid DataSource: {s}. Valid options are: cities15000, cities5000, cities1000, cities500, allCountries, test_data"
115 )),
116 }
117 }
118}
119
120static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
121
122static TEST_DIR: LazyLock<tempfile::TempDir> =
123 LazyLock::new(|| tempfile::tempdir().expect("Failed to create temporary test directory"));
124
125fn get_default_data_dir() -> PathBuf {
127 if std::env::var("CARGO_TARGET_TMPDIR").is_ok() {
129 let test_id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
131 return TEST_DIR
132 .path()
133 .join(format!("heisenberg_doctest_{test_id}"));
134 }
135 #[cfg(any(test, doctest))]
136 {
137 TEST_DIR.path().to_path_buf().join(format!(
138 "heisenberg_data_test_{}",
139 TEST_COUNTER.fetch_add(1, Ordering::SeqCst)
140 ))
141 }
142 #[cfg(not(any(test, doctest)))]
143 {
144 if let Ok(data_dir) = std::env::var("HEISENBERG_DATA_DIR") {
146 return PathBuf::from(data_dir);
147 }
148
149 if let Ok(workspace_root) = std::env::var("CARGO_WORKSPACE_DIR") {
151 return PathBuf::from(workspace_root).join(DATA_DIR_DEFAULT);
152 }
153
154 if std::env::var("CARGO_PKG_NAME").is_ok() {
156 return PathBuf::from(DATA_DIR_DEFAULT.to_string());
158 }
159
160 #[cfg(feature = "system-dirs")]
162 {
163 if let Some(proj_dirs) =
164 directories::ProjectDirs::from("com", "heisenberg", "heisenberg-data")
165 {
166 return proj_dirs.cache_dir().to_path_buf();
167 }
168 }
169
170 PathBuf::from(format!("./{DATA_DIR_DEFAULT}"))
172 }
173}
174
175fn load_single_parquet_file(path: impl Into<Arc<Path>>) -> Result<LazyFrame> {
176 LazyFrame::scan_parquet(PlPath::Local(path.into()), ScanArgsParquet::default())
177 .map_err(Into::into)
178}
179
180fn load_parquet_files(admin_path: &Path, place_path: &Path) -> Result<(LazyFrame, LazyFrame)> {
181 let admin_df = load_single_parquet_file(admin_path)?;
182 let place_df = load_single_parquet_file(place_path)?;
183 Ok((admin_df, place_df))
184}
185
186fn validate_data_files(data_source: &DataSource) -> Result<(PathBuf, PathBuf)> {
188 let admin_path = data_source.admin_parquet();
189 let place_path = data_source.place_parquet();
190
191 if !admin_path.exists() || !place_path.exists() {
193 return Err(DataError::RequiredFilesNotFound);
194 }
195
196 if let Err(e) = LazyFrame::scan_parquet(
198 PlPath::Local(admin_path.clone().into()),
199 ScanArgsParquet::default(),
200 ) {
201 warn!("Admin file corrupted or unreadable: {}", e);
202 return Err(DataError::RequiredFilesNotFound);
203 }
204
205 if let Err(e) = LazyFrame::scan_parquet(
206 PlPath::Local(place_path.clone().into()),
207 ScanArgsParquet::default(),
208 ) {
209 warn!("Place file corrupted or unreadable: {}", e);
210 return Err(DataError::RequiredFilesNotFound);
211 }
212
213 Ok((admin_path, place_path))
214}
215
216fn clean_data_files(data_source: &DataSource) -> Result<()> {
218 let admin_path = data_source.admin_parquet();
219 let place_path = data_source.place_parquet();
220
221 if admin_path.exists() {
222 std::fs::remove_file(&admin_path)?;
223 info!("Removed corrupted admin file: {:?}", admin_path);
224 }
225
226 if place_path.exists() {
227 std::fs::remove_file(&place_path)?;
228 info!("Removed corrupted place file: {:?}", place_path);
229 }
230
231 Ok(())
232}
233
234fn ensure_data_files(data_source: &DataSource) -> Result<(PathBuf, PathBuf)> {
236 if let Ok(paths) = validate_data_files(data_source) {
238 info!("Using existing processed data for {}", data_source);
239 return Ok(paths);
240 }
241 info!(
242 "Data files missing or corrupted for {}, regenerating...",
243 data_source
244 );
245 clean_data_files(data_source)?;
247
248 #[cfg(feature = "download-data")]
250 {
251 info!("Generating processed data for {}", data_source);
252 std::fs::create_dir_all(data_source.processed_dir())?;
253 let temp_files = raw::fetch::download_data(data_source)?;
254
255 let (admin_df, place_df) = generate_processed_data(temp_files)?;
256
257 let admin_path = data_source.admin_parquet();
258 let place_path = data_source.place_parquet();
259
260 save_processed_data_to_parquet(admin_df, &admin_path)?;
261 save_processed_data_to_parquet(place_df, &place_path)?;
262
263 info!(
264 "Processed data saved to {:?} and {:?}",
265 admin_path, place_path
266 );
267
268 validate_data_files(data_source)
270 }
271 #[cfg(not(feature = "download-data"))]
272 {
273 warn!("download_data feature not enabled, cannot regenerate data");
274 Err(DataError::RequiredFilesNotFound)
275 }
276}
277
278pub fn get_data(data_source: &DataSource) -> Result<(LazyFrame, LazyFrame)> {
280 let (admin_path, place_path) = ensure_data_files(data_source)?;
281 load_parquet_files(&admin_path, &place_path)
282}
283
284pub fn get_admin_data(data_source: &DataSource) -> Result<LazyFrame> {
289 let (admin_path, _place_path) = ensure_data_files(data_source)?;
290 load_single_parquet_file(admin_path)
291}
292
293pub fn get_place_data(data_source: &DataSource) -> Result<LazyFrame> {
298 let (_admin_path, place_path) = ensure_data_files(data_source)?;
299 load_single_parquet_file(place_path)
300}
301
302#[must_use]
304pub fn data_exists(data_source: &DataSource) -> bool {
305 validate_data_files(data_source).is_ok()
306}
307
308pub fn regenerate_data(data_source: &DataSource) -> Result<(LazyFrame, LazyFrame)> {
312 info!("Force regenerating data for {}", data_source);
313
314 clean_data_files(data_source)?;
316
317 get_data(data_source)
319}
320
321#[cfg(test)]
322pub(crate) mod tests_utils {
323 use num_traits::NumCast;
324 use polars::prelude::*;
325
326 pub fn assert_has_columns(df: &DataFrame, expected_columns: &[&str]) {
327 let actual_columns: Vec<_> = df.get_column_names().iter().map(|s| s.as_str()).collect();
328 for expected_col in expected_columns {
329 assert!(
330 actual_columns.contains(expected_col),
331 "Missing column: {expected_col}. Available columns: {actual_columns:?}"
332 );
333 }
334 }
335
336 pub fn assert_column_type(df: &DataFrame, column: &str, expected_type: &DataType) {
337 let actual_type = df
338 .column(column)
339 .unwrap_or_else(|_| panic!("Column '{column}' not found"))
340 .dtype();
341 assert_eq!(
342 actual_type, expected_type,
343 "Column '{column}' has wrong type. Expected: {expected_type:?}, Got: {actual_type:?}"
344 );
345 }
346
347 pub fn assert_no_nulls_in_column(df: &DataFrame, column: &str) {
348 let null_count = df
349 .column(column)
350 .unwrap_or_else(|_| panic!("Column '{column}' not found"))
351 .null_count();
352 assert_eq!(
353 null_count, 0,
354 "Column '{column}' contains {null_count} null values"
355 );
356 }
357
358 pub fn assert_column_range<T>(df: &DataFrame, column: &str, min_val: T, max_val: T)
359 where
360 T: std::fmt::Debug + NumCast + PartialOrd + Clone + Copy + 'static,
361 {
362 let series = df
363 .column(column)
364 .unwrap_or_else(|_| panic!("Column '{column}' not found"));
365
366 if let (Ok(min_actual), Ok(max_actual)) = (
367 series
368 .max_reduce()
369 .and_then(|m| m.as_any_value().try_extract::<T>()),
370 series
371 .max_reduce()
372 .and_then(|m| m.as_any_value().try_extract::<T>()),
373 ) {
374 assert!(
375 min_actual >= min_val,
376 "Column '{column}' min value {min_actual:?} is below expected minimum {min_val:?}"
377 );
378 assert!(
379 max_actual <= max_val,
380 "Column '{column}' max value {max_actual:?} is above expected maximum {max_val:?}"
381 );
382 }
383 }
384}