1use std::fs;
19use std::path::{Path, PathBuf};
20
21use memmap2::Mmap;
22
23use crate::error::{Error, Result};
24use crate::io;
25use crate::types::{ObservationTable, StringInterner};
26
27pub struct MmapObservations {
32 len: usize,
33 _id_mmap: Mmap,
34 _time_mmap: Mmap,
35 _ra_mmap: Mmap,
36 _dec_mmap: Mmap,
37 _obs_code_mmap: Mmap,
38 _object_id_mmap: Mmap,
39 _night_mmap: Mmap,
40 id_ptr: *const u64,
43 time_ptr: *const f64,
44 ra_ptr: *const f64,
45 dec_ptr: *const f64,
46 obs_code_ptr: *const u32,
47 object_id_ptr: *const u64,
48 night_ptr: *const i64,
49}
50
51unsafe impl Send for MmapObservations {}
53unsafe impl Sync for MmapObservations {}
54
55impl MmapObservations {
56 pub fn from_cache(cache_dir: &Path) -> Result<(Self, StringInterner)> {
60 let id_mmap = mmap_file(&cache_dir.join("id.bin"))?;
61 let time_mmap = mmap_file(&cache_dir.join("time_mjd.bin"))?;
62 let ra_mmap = mmap_file(&cache_dir.join("ra.bin"))?;
63 let dec_mmap = mmap_file(&cache_dir.join("dec.bin"))?;
64 let obs_code_mmap = mmap_file(&cache_dir.join("observatory_code.bin"))?;
65 let object_id_mmap = mmap_file(&cache_dir.join("object_id.bin"))?;
66 let night_mmap = mmap_file(&cache_dir.join("night.bin"))?;
67
68 let len = id_mmap.len() / std::mem::size_of::<u64>();
69
70 let expected_f64 = len * std::mem::size_of::<f64>();
72 let expected_i64 = len * std::mem::size_of::<i64>();
73 let expected_u32 = len * std::mem::size_of::<u32>();
74 if time_mmap.len() != expected_f64
75 || ra_mmap.len() != expected_f64
76 || dec_mmap.len() != expected_f64
77 || object_id_mmap.len() != expected_f64
78 || night_mmap.len() != expected_i64
79 || obs_code_mmap.len() != expected_u32
80 {
81 return Err(Error::InvalidInput(
82 "Cache files have inconsistent lengths".to_string(),
83 ));
84 }
85
86 let id_ptr = id_mmap.as_ptr() as *const u64;
87 let time_ptr = time_mmap.as_ptr() as *const f64;
88 let ra_ptr = ra_mmap.as_ptr() as *const f64;
89 let dec_ptr = dec_mmap.as_ptr() as *const f64;
90 let obs_code_ptr = obs_code_mmap.as_ptr() as *const u32;
91 let object_id_ptr = object_id_mmap.as_ptr() as *const u64;
92 let night_ptr = night_mmap.as_ptr() as *const i64;
93
94 let interner_json = fs::read_to_string(cache_dir.join("interner.json"))?;
96 let interner: StringInterner = serde_json::from_str(&interner_json)
97 .map_err(|e| Error::InvalidInput(format!("Failed to parse interner: {e}")))?;
98
99 Ok((
100 MmapObservations {
101 len,
102 _id_mmap: id_mmap,
103 _time_mmap: time_mmap,
104 _ra_mmap: ra_mmap,
105 _dec_mmap: dec_mmap,
106 _obs_code_mmap: obs_code_mmap,
107 _object_id_mmap: object_id_mmap,
108 _night_mmap: night_mmap,
109 id_ptr,
110 time_ptr,
111 ra_ptr,
112 dec_ptr,
113 obs_code_ptr,
114 object_id_ptr,
115 night_ptr,
116 },
117 interner,
118 ))
119 }
120}
121
122impl ObservationTable for MmapObservations {
123 fn len(&self) -> usize {
124 self.len
125 }
126
127 fn ids(&self) -> &[u64] {
128 unsafe { std::slice::from_raw_parts(self.id_ptr, self.len) }
131 }
132
133 fn times_mjd(&self) -> &[f64] {
134 unsafe { std::slice::from_raw_parts(self.time_ptr, self.len) }
135 }
136
137 fn ra(&self) -> &[f64] {
138 unsafe { std::slice::from_raw_parts(self.ra_ptr, self.len) }
139 }
140
141 fn dec(&self) -> &[f64] {
142 unsafe { std::slice::from_raw_parts(self.dec_ptr, self.len) }
143 }
144
145 fn nights(&self) -> &[i64] {
146 unsafe { std::slice::from_raw_parts(self.night_ptr, self.len) }
147 }
148
149 fn object_ids(&self) -> &[u64] {
150 unsafe { std::slice::from_raw_parts(self.object_id_ptr, self.len) }
151 }
152
153 fn observatory_codes(&self) -> &[u32] {
154 unsafe { std::slice::from_raw_parts(self.obs_code_ptr, self.len) }
155 }
156}
157
158pub fn write_cache(
160 cache_dir: &Path,
161 obs: &impl ObservationTable,
162 interner: &StringInterner,
163) -> Result<()> {
164 fs::create_dir_all(cache_dir)?;
165
166 write_slice(&cache_dir.join("id.bin"), obs.ids())?;
167 write_slice(&cache_dir.join("time_mjd.bin"), obs.times_mjd())?;
168 write_slice(&cache_dir.join("ra.bin"), obs.ra())?;
169 write_slice(&cache_dir.join("dec.bin"), obs.dec())?;
170 write_slice(
171 &cache_dir.join("observatory_code.bin"),
172 obs.observatory_codes(),
173 )?;
174 write_slice(&cache_dir.join("object_id.bin"), obs.object_ids())?;
175 write_slice(&cache_dir.join("night.bin"), obs.nights())?;
176
177 let interner_json = serde_json::to_string(interner)
178 .map_err(|e| Error::InvalidInput(format!("Failed to serialize interner: {e}")))?;
179 fs::write(cache_dir.join("interner.json"), interner_json)?;
180
181 Ok(())
182}
183
184pub fn load_observations_cached(parquet_path: &Path) -> Result<(MmapObservations, StringInterner)> {
191 let cache_dir = cache_dir_for(parquet_path);
192
193 if is_cache_valid(parquet_path, &cache_dir) {
194 return MmapObservations::from_cache(&cache_dir);
195 }
196
197 let (obs, interner, _obs_code_interner) = io::read_observations(parquet_path)?;
199 write_cache(&cache_dir, &obs, &interner)?;
200
201 MmapObservations::from_cache(&cache_dir)
203}
204
205fn cache_dir_for(parquet_path: &Path) -> PathBuf {
210 let mut cache = parquet_path.as_os_str().to_owned();
211 cache.push(".difi_cache");
212 PathBuf::from(cache)
213}
214
215fn is_cache_valid(parquet_path: &Path, cache_dir: &Path) -> bool {
216 let marker = cache_dir.join("id.bin");
217 let Ok(cache_meta) = fs::metadata(&marker) else {
218 return false;
219 };
220 let Ok(parquet_meta) = fs::metadata(parquet_path) else {
221 return false;
222 };
223 let Ok(cache_time) = cache_meta.modified() else {
224 return false;
225 };
226 let Ok(parquet_time) = parquet_meta.modified() else {
227 return false;
228 };
229 cache_time >= parquet_time
230}
231
232fn mmap_file(path: &Path) -> Result<Mmap> {
233 let file = fs::File::open(path)?;
234 unsafe { Mmap::map(&file).map_err(Error::Io) }
237}
238
239fn write_slice<T: Copy>(path: &Path, data: &[T]) -> Result<()> {
240 let bytes = unsafe {
241 std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
242 };
243 fs::write(path, bytes)?;
244 Ok(())
245}