pyref_core/
io.rs

1use astrors_fork::io::hdus::{
2    image::{imagehdu::ImageHDU, ImageData},
3    primaryhdu::PrimaryHDU,
4};
5use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr};
6use polars::prelude::*;
7use std::ops::Mul;
8
9use crate::errors::FitsLoaderError;
10
11pub fn q(lam: f64, theta: f64) -> f64 {
12    let theta = theta;
13    match 4.0 * std::f64::consts::PI * theta.to_radians().sin() / lam {
14        q if q < 0.0 => 0.0,
15        q => q,
16    }
17}
18
19pub fn col_from_array(
20    name: PlSmallStr,
21    array: ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
22) -> Result<Column, PolarsError> {
23    // Transpose the array to ensure the dimensions are swapped, which should
24    // counteract the downstream processing issue.
25    let transposed_array = array.t(); // Get a transposed view
26
27    let rows = transposed_array.len_of(Axis(0));
28    let cols = transposed_array.len_of(Axis(1));
29
30    // Create a C-contiguous version of the transposed array.
31    let c_contiguous_array = transposed_array.as_standard_layout().into_owned();
32    let flat_data = c_contiguous_array.as_slice().unwrap();
33
34    let mut list_builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
35        name.clone(),
36        rows,
37        rows * cols,
38        DataType::Int64,
39    );
40
41    for row_chunk in flat_data.chunks_exact(cols) {
42        list_builder.append_slice(row_chunk);
43    }
44
45    let series_of_lists = list_builder.finish().into_series();
46
47    // Implode the series of lists into a single row containing a list of lists.
48    let series_of_list_of_lists = series_of_lists.implode()?;
49
50    // Finally, cast to a 2D Array type with explicit dimensions.
51    // This preserves the 2D structure of the data.
52    let array_series = series_of_list_of_lists.cast(&DataType::Array(
53        Box::new(DataType::Array(Box::new(DataType::Int64), cols)),
54        rows,
55    ))?;
56
57    Ok(array_series.into_column())
58}
59// ================== CCD Raw Data Processing ============
60pub fn add_calculated_domains(lzf: LazyFrame) -> DataFrame {
61    let h = physical_constants::PLANCK_CONSTANT_IN_EV_PER_HZ;
62    let c = physical_constants::SPEED_OF_LIGHT_IN_VACUUM * 1e10;
63
64    // Collect schema once to check for column existence
65    let schema = lzf.clone().collect_schema().unwrap_or_default();
66    let has_column = |name: &str| schema.iter().any(|(col_name, _)| col_name == name);
67
68    // Start with basic sorting that won't fail even if columns don't exist
69    let mut lz = lzf;
70
71    // Apply optional sorting only if the columns exist
72    if has_column("DATE") {
73        lz = lz.sort(["DATE"], Default::default());
74    }
75
76    if has_column("file_name") {
77        lz = lz.sort(["file_name"], Default::default());
78    }
79
80    // Conditionally apply column transformations only if the columns exist
81    if has_column("EXPOSURE") {
82        lz = lz.with_column(col("EXPOSURE").round(3).alias("EXPOSURE"));
83    }
84
85    if has_column("Higher Order Suppressor") {
86        lz = lz.with_column(
87            col("Higher Order Suppressor")
88                .round(2)
89                .alias("Higher Order Suppressor"),
90        );
91    }
92
93    if has_column("Horizontal Exit Slit Size") {
94        lz = lz.with_column(
95            col("Horizontal Exit Slit Size")
96                .round(1)
97                .alias("Horizontal Exit Slit Size"),
98        );
99    }
100
101    if has_column("Beamline Energy") {
102        lz = lz.with_column(col("Beamline Energy").round(1).alias("Beamline Energy"));
103
104        // Only calculate Lambda if Beamline Energy exists
105        lz = lz.with_column(
106            col("Beamline Energy")
107                .pow(-1)
108                .mul(lit(h * c))
109                .alias("Lambda"),
110        );
111    }
112
113    // Add Q column if required columns exist
114    lz = lz.with_column(
115        when(
116            col("Sample Theta")
117                .is_not_null()
118                .and(col("Lambda").is_not_null()),
119        )
120        .then(as_struct(vec![col("Sample Theta"), col("Lambda")]).map(
121            move |s| {
122                let struc = s.struct_()?;
123                let th_series = struc.field_by_name("Sample Theta")?;
124                let theta = th_series.f64()?;
125                let lam_series = struc.field_by_name("Lambda")?;
126                let lam = lam_series.f64()?;
127
128                let out: Float64Chunked = theta
129                    .into_iter()
130                    .zip(lam.iter())
131                    .map(|(theta, lam)| match (theta, lam) {
132                        (Some(theta), Some(lam)) => Some(q(lam, theta)),
133                        _ => None,
134                    })
135                    .collect();
136
137                Ok(Some(out.into_column()))
138            },
139            GetOutput::from_type(DataType::Float64),
140        ))
141        .otherwise(lit(NULL))
142        .alias("Q"),
143    );
144
145    // Collect the final DataFrame only once at the end
146    lz.collect().unwrap_or_else(|_| DataFrame::empty())
147}
148
149/// Reads a single FITS file and converts it to a Polars DataFrame.
150///
151/// # Arguments
152///
153/// * `file_path` - Path to the FITS file to read
154/// * `header_items` - List of header values to extract
155///
156/// # Returns
157///
158/// A `Result` containing either the DataFrame or a `FitsLoaderError`.
159pub fn process_image(img: &ImageHDU) -> Result<Vec<Column>, FitsLoaderError> {
160    let bzero = img
161        .header
162        .get_card("BZERO")
163        .ok_or_else(|| FitsLoaderError::MissingHeaderKey("BZERO".into()))?
164        .value
165        .as_int()
166        .ok_or_else(|| FitsLoaderError::FitsError("BZERO not an integer".into()))?;
167
168    match &img.data {
169        ImageData::I16(image) => {
170            let data = image.map(|&x| i64::from(x as i64 + bzero));
171            // Implement row-by-row background subtraction
172            let subtracted = subtract_background(&data);
173            // Locate the index tuple with the maximum value
174            // Find the coordinates of the maximum value in the 2D array
175            let max_coords = {
176                let mut max_coords = (0, 0);
177                let mut max_val = i64::MIN;
178
179                for (idx, &val) in subtracted.indexed_iter() {
180                    if val > max_val {
181                        max_val = val;
182                        max_coords = (idx[0], idx[1]);
183                    }
184                }
185
186                max_coords
187            };
188
189            // Check if the beam is too close to any edge (top, bottom, left, right)
190            let msg = if max_coords.0 < 20
191                || max_coords.0 > (subtracted.len_of(Axis(0)) - 20)
192                || max_coords.1 < 20
193                || max_coords.1 > (subtracted.len_of(Axis(1)) - 20)
194            {
195                "Simple Detection Error: Beam is too close to the edge"
196            } else {
197                ""
198            };
199            // Calculate a simple reflectivity result from the subtracted data
200            let (db_sum, scaled_bg) = { simple_reflectivity(&subtracted, max_coords) };
201
202            Ok(vec![
203                col_from_array("RAW".into(), data.clone()).unwrap(),
204                col_from_array("SUBTRACTED".into(), subtracted.clone()).unwrap(),
205                Column::new("Simple Spot X".into(), vec![max_coords.0 as u64]),
206                Column::new("Simple Spot Y".into(), vec![max_coords.1 as u64]),
207                Column::new(
208                    "Simple Reflectivity".into(),
209                    vec![(db_sum - scaled_bg) as f64],
210                ),
211                Series::new("status".into(), vec![msg.to_string()]).into_column(),
212            ])
213        }
214        _ => Err(FitsLoaderError::UnsupportedImageData),
215    }
216}
217
218fn simple_reflectivity(
219    subtracted: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
220    max_index: (usize, usize),
221) -> (i64, i64) {
222    // Convert max_index to 2D coordinates
223    let beam_y = max_index.0;
224    // Row
225    let beam_x = max_index.1;
226    // Column
227
228    // Define ROI size
229    let roi = 5;
230    // Region of interest size
231
232    // Define ROI boundaries
233    let roi_start_y = beam_y.saturating_sub(roi);
234    let roi_end_y = (beam_y + roi + 1).min(subtracted.len_of(Axis(0)));
235    let roi_start_x = beam_x.saturating_sub(roi);
236    let roi_end_x = (beam_x + roi + 1).min(subtracted.len_of(Axis(1)));
237
238    // Initialize sums and counts
239    let mut db_sum = 0i64;
240    let mut db_count = 0i64;
241    let mut bg_sum = 0i64;
242    let mut bg_count = 0i64;
243
244    // Iterate over all rows and columns
245    for y in 0..subtracted.len_of(Axis(0)) {
246        for x in 0..subtracted.len_of(Axis(1)) {
247            let value = subtracted[[y, x]];
248            if value == 0 {
249                continue;
250            }
251
252            if (roi_start_y <= y && y < roi_end_y) && (roi_start_x <= x && x < roi_end_x) {
253                db_sum += value;
254                db_count += 1;
255            } else {
256                bg_sum += value;
257                bg_count += 1;
258            }
259        }
260    }
261
262    // Handle edge cases
263    if bg_count == 0 || db_sum == 0 {
264        (0, 0)
265    } else {
266        // Scale background sum based on ratio of counts
267        let scaled_bg = (bg_sum * db_count) / bg_count;
268        (db_sum, scaled_bg)
269    }
270}
271
272fn subtract_background(
273    data: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
274) -> ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>> {
275    // Get a view of the data with 5 pixels sliced from each side
276    let view = data.slice(ndarray::s![5..-5, 5..-5]);
277    let rows = view.len_of(Axis(0));
278    let cols = view.len_of(Axis(1));
279
280    // Extract the left and right columns (first and last 20 columns)
281    let left = view.slice(ndarray::s![.., ..20]);
282    let right = view.slice(ndarray::s![.., (cols - 20)..]);
283
284    // Calculate the sum of left and right regions
285    let left_sum: i64 = left.iter().copied().sum();
286    let right_sum: i64 = right.iter().copied().sum();
287
288    // Create background array to store row means
289    let mut background = ndarray::Array1::zeros(rows);
290
291    // Determine which side to use for background
292    if left_sum < right_sum {
293        // Use right side as background
294        for (i, row) in right.axis_iter(Axis(0)).enumerate() {
295            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
296        }
297    } else {
298        // Use left side as background
299        for (i, row) in left.axis_iter(Axis(0)).enumerate() {
300            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
301        }
302    }
303
304    // Create a new owned array from the view
305    let mut result = view.to_owned();
306
307    // Subtract background from each row
308    for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
309        let bg = background[i];
310        for val in row.iter_mut() {
311            *val -= bg;
312        }
313    }
314    result.into_dyn()
315}
316
317pub fn process_metadata(
318    hdu: &PrimaryHDU,
319    keys: &Vec<String>,
320) -> Result<Vec<Column>, FitsLoaderError> {
321    if keys.is_empty() {
322        // If no specific keys are requested, return all header values
323        Ok(hdu
324            .header
325            .iter()
326            .filter(|card| !card.keyword.as_str().to_lowercase().contains("comment"))
327            .map(|card| {
328                let name = card.keyword.as_str();
329                let value = card.value.as_float().unwrap_or(0.0);
330                Column::new(name.into(), &[value])
331            })
332            .collect())
333    } else {
334        // Process each requested header key
335        let mut columns = Vec::new();
336
337        for key in keys {
338            // Special handling for Beamline Energy
339            if key == "Beamline Energy" {
340                // First try to get "Beamline Energy"
341                if let Some(card) = hdu.header.get_card(key) {
342                    if let Some(val) = card.value.as_float() {
343                        columns.push(Column::new(key.into(), &[val]));
344                        continue;
345                    }
346                }
347
348                // Then fall back to "Beamline Energy Goal" if "Beamline Energy" is not present
349                if let Some(card) = hdu.header.get_card("Beamline Energy Goal") {
350                    if let Some(val) = card.value.as_float() {
351                        columns.push(Column::new(key.into(), &[val]));
352                        continue;
353                    }
354                }
355
356                // If neither value is available, use a default
357                columns.push(Column::new(key.into(), &[0.0]));
358                continue;
359            }
360
361            // Special handling for Date header (it's a string value, not a float)
362            if key == "DATE" {
363                if let Some(card) = hdu.header.get_card(key) {
364                    let val = card.value.to_string();
365                    columns.push(Column::new(key.into(), &[val]));
366                    continue;
367                }
368                // If DATE is not present, use a default empty string
369                columns.push(Column::new(key.into(), &["".to_string()]));
370                continue;
371            }
372
373            // For other headers, don't fail if they're missing
374            let val = match hdu.header.get_card(key) {
375                Some(card) => card.value.as_float().unwrap_or(1.0),
376                None => 0.0, // Default value for missing headers
377            };
378
379            // Use the snake_case name from the enum variant
380            columns.push(Column::new(key.into(), &[val]));
381        }
382
383        Ok(columns)
384    }
385}
386
387pub fn process_file_name(path: std::path::PathBuf) -> Vec<Column> {
388    // Extract just the file name without extension
389    let file_name = path.file_stem().unwrap().to_str().unwrap_or("");
390
391    // Just return the file name directly, without extracting frame numbers or scan IDs
392    vec![Column::new("file_name".into(), vec![file_name])]
393}