pyref_core/
io.rs

1use astrors_fork::io::hdus::{
2    image::{imagehdu::ImageHDU, ImageData},
3    primaryhdu::PrimaryHDU,
4};
5use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr};
6use polars::prelude::*;
7use std::ops::Mul;
8
9use crate::errors::FitsLoaderError;
10
11pub fn q(lam: f64, theta: f64) -> f64 {
12    let theta = theta;
13    match 4.0 * std::f64::consts::PI * theta.to_radians().sin() / lam {
14        q if q < 0.0 => 0.0,
15        q => q,
16    }
17}
18
19pub fn col_from_array(
20    name: PlSmallStr,
21    array: ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
22) -> Result<Column, PolarsError> {
23    let rows = array.len_of(Axis(0));
24    let cols = array.len_of(Axis(1));
25
26    let mut list_builder =
27        ListPrimitiveChunkedBuilder::<Int64Type>::new(name, rows, rows * cols, DataType::Int64);
28
29    for row in array.axis_iter(Axis(0)) {
30        match row.as_slice() {
31            Some(s) => list_builder.append_slice(s),
32            None => list_builder.append_slice(&row.to_owned().into_raw_vec()),
33        }
34    }
35
36    let series_of_lists = list_builder.finish().into_series();
37
38    // Implode into a single row containing a list of lists.
39    let series_of_list_of_lists = series_of_lists.implode()?;
40
41    // Cast to a 2D Array type.
42    let array_series = series_of_list_of_lists.cast(&DataType::Array(
43        Box::new(DataType::Array(Box::new(DataType::Int64), cols)),
44        rows,
45    ))?;
46
47    Ok(array_series.into_column())
48}
49// ================== CCD Raw Data Processing ============
50pub fn add_calculated_domains(lzf: LazyFrame) -> DataFrame {
51    let h = physical_constants::PLANCK_CONSTANT_IN_EV_PER_HZ;
52    let c = physical_constants::SPEED_OF_LIGHT_IN_VACUUM * 1e10;
53
54    // Collect schema once to check for column existence
55    let schema = lzf.clone().collect_schema().unwrap_or_default();
56    let has_column = |name: &str| schema.iter().any(|(col_name, _)| col_name == name);
57
58    // Start with basic sorting that won't fail even if columns don't exist
59    let mut lz = lzf;
60
61    // Apply optional sorting only if the columns exist
62    if has_column("DATE") {
63        lz = lz.sort(["DATE"], Default::default());
64    }
65
66    if has_column("file_name") {
67        lz = lz.sort(["file_name"], Default::default());
68    }
69
70    // Conditionally apply column transformations only if the columns exist
71    if has_column("EXPOSURE") {
72        lz = lz.with_column(col("EXPOSURE").round(3).alias("EXPOSURE"));
73    }
74
75    if has_column("Higher Order Suppressor") {
76        lz = lz.with_column(
77            col("Higher Order Suppressor")
78                .round(2)
79                .alias("Higher Order Suppressor"),
80        );
81    }
82
83    if has_column("Horizontal Exit Slit Size") {
84        lz = lz.with_column(
85            col("Horizontal Exit Slit Size")
86                .round(1)
87                .alias("Horizontal Exit Slit Size"),
88        );
89    }
90
91    if has_column("Beamline Energy") {
92        lz = lz.with_column(col("Beamline Energy").round(1).alias("Beamline Energy"));
93
94        // Only calculate Lambda if Beamline Energy exists
95        lz = lz.with_column(
96            col("Beamline Energy")
97                .pow(-1)
98                .mul(lit(h * c))
99                .alias("Lambda"),
100        );
101    }
102
103    // Add Q column if required columns exist
104    lz = lz.with_column(
105        when(
106            col("Sample Theta")
107                .is_not_null()
108                .and(col("Lambda").is_not_null()),
109        )
110        .then(as_struct(vec![col("Sample Theta"), col("Lambda")]).map(
111            move |s| {
112                let struc = s.struct_()?;
113                let th_series = struc.field_by_name("Sample Theta")?;
114                let theta = th_series.f64()?;
115                let lam_series = struc.field_by_name("Lambda")?;
116                let lam = lam_series.f64()?;
117
118                let out: Float64Chunked = theta
119                    .into_iter()
120                    .zip(lam.iter())
121                    .map(|(theta, lam)| match (theta, lam) {
122                        (Some(theta), Some(lam)) => Some(q(lam, theta)),
123                        _ => None,
124                    })
125                    .collect();
126
127                Ok(Some(out.into_column()))
128            },
129            GetOutput::from_type(DataType::Float64),
130        ))
131        .otherwise(lit(NULL))
132        .alias("Q"),
133    );
134
135    // Collect the final DataFrame only once at the end
136    lz.collect().unwrap_or_else(|_| DataFrame::empty())
137}
138
139/// Reads a single FITS file and converts it to a Polars DataFrame.
140///
141/// # Arguments
142///
143/// * `file_path` - Path to the FITS file to read
144/// * `header_items` - List of header values to extract
145///
146/// # Returns
147///
148/// A `Result` containing either the DataFrame or a `FitsLoaderError`.
149pub fn process_image(img: &ImageHDU) -> Result<Vec<Column>, FitsLoaderError> {
150    let bzero = img
151        .header
152        .get_card("BZERO")
153        .ok_or_else(|| FitsLoaderError::MissingHeaderKey("BZERO".into()))?
154        .value
155        .as_int()
156        .ok_or_else(|| FitsLoaderError::FitsError("BZERO not an integer".into()))?;
157
158    match &img.data {
159        ImageData::I16(image) => {
160            let data = image.map(|&x| i64::from(x as i64 + bzero));
161            // Implement row-by-row background subtraction
162            let subtracted = subtract_background(&data);
163            // Locate the index tuple with the maximum value
164            // Find the coordinates of the maximum value in the 2D array
165            let max_coords = {
166                let mut max_coords = (0, 0);
167                let mut max_val = i64::MIN;
168
169                for (idx, &val) in subtracted.indexed_iter() {
170                    if val > max_val {
171                        max_val = val;
172                        max_coords = (idx[0], idx[1]);
173                    }
174                }
175
176                max_coords
177            };
178
179            // Check if the beam is too close to any edge (top, bottom, left, right)
180            let msg = if max_coords.0 < 20
181                || max_coords.0 > (subtracted.len_of(Axis(0)) - 20)
182                || max_coords.1 < 20
183                || max_coords.1 > (subtracted.len_of(Axis(1)) - 20)
184            {
185                "Simple Detection Error: Beam is too close to the edge"
186            } else {
187                ""
188            };
189            // Calculate a simple reflectivity result from the subtracted data
190            let (db_sum, scaled_bg) = { simple_reflectivity(&subtracted, max_coords) };
191
192            Ok(vec![
193                col_from_array("RAW".into(), data.clone()).unwrap(),
194                col_from_array("SUBTRACTED".into(), subtracted.clone()).unwrap(),
195                Column::new("Simple Spot X".into(), vec![max_coords.0 as u64]),
196                Column::new("Simple Spot Y".into(), vec![max_coords.1 as u64]),
197                Column::new(
198                    "Simple Reflectivity".into(),
199                    vec![(db_sum - scaled_bg) as f64],
200                ),
201                Series::new("status".into(), vec![msg.to_string()]).into_column(),
202            ])
203        }
204        _ => Err(FitsLoaderError::UnsupportedImageData),
205    }
206}
207
208fn simple_reflectivity(
209    subtracted: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
210    max_index: (usize, usize),
211) -> (i64, i64) {
212    // Convert max_index to 2D coordinates
213    let beam_y = max_index.0;
214    // Row
215    let beam_x = max_index.1;
216    // Column
217
218    // Define ROI size
219    let roi = 5;
220    // Region of interest size
221
222    // Define ROI boundaries
223    let roi_start_y = beam_y.saturating_sub(roi);
224    let roi_end_y = (beam_y + roi + 1).min(subtracted.len_of(Axis(0)));
225    let roi_start_x = beam_x.saturating_sub(roi);
226    let roi_end_x = (beam_x + roi + 1).min(subtracted.len_of(Axis(1)));
227
228    // Initialize sums and counts
229    let mut db_sum = 0i64;
230    let mut db_count = 0i64;
231    let mut bg_sum = 0i64;
232    let mut bg_count = 0i64;
233
234    // Iterate over all rows and columns
235    for y in 0..subtracted.len_of(Axis(0)) {
236        for x in 0..subtracted.len_of(Axis(1)) {
237            let value = subtracted[[y, x]];
238            if value == 0 {
239                continue;
240            }
241
242            if (roi_start_y <= y && y < roi_end_y) && (roi_start_x <= x && x < roi_end_x) {
243                db_sum += value;
244                db_count += 1;
245            } else {
246                bg_sum += value;
247                bg_count += 1;
248            }
249        }
250    }
251
252    // Handle edge cases
253    if bg_count == 0 || db_sum == 0 {
254        (0, 0)
255    } else {
256        // Scale background sum based on ratio of counts
257        let scaled_bg = (bg_sum * db_count) / bg_count;
258        (db_sum, scaled_bg)
259    }
260}
261
262fn subtract_background(
263    data: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
264) -> ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>> {
265    // Get a view of the data with 5 pixels sliced from each side
266    let view = data.slice(ndarray::s![5..-5, 5..-5]);
267    let rows = view.len_of(Axis(0));
268    let cols = view.len_of(Axis(1));
269
270    // Extract the left and right columns (first and last 20 columns)
271    let left = view.slice(ndarray::s![.., ..20]);
272    let right = view.slice(ndarray::s![.., (cols - 20)..]);
273
274    // Calculate the sum of left and right regions
275    let left_sum: i64 = left.iter().copied().sum();
276    let right_sum: i64 = right.iter().copied().sum();
277
278    // Create background array to store row means
279    let mut background = ndarray::Array1::zeros(rows);
280
281    // Determine which side to use for background
282    if left_sum < right_sum {
283        // Use right side as background
284        for (i, row) in right.axis_iter(Axis(0)).enumerate() {
285            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
286        }
287    } else {
288        // Use left side as background
289        for (i, row) in left.axis_iter(Axis(0)).enumerate() {
290            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
291        }
292    }
293
294    // Create a new owned array from the view
295    let mut result = view.to_owned();
296
297    // Subtract background from each row
298    for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
299        let bg = background[i];
300        for val in row.iter_mut() {
301            *val -= bg;
302        }
303    }
304    result.into_dyn()
305}
306
307pub fn process_metadata(
308    hdu: &PrimaryHDU,
309    keys: &Vec<String>,
310) -> Result<Vec<Column>, FitsLoaderError> {
311    if keys.is_empty() {
312        // If no specific keys are requested, return all header values
313        Ok(hdu
314            .header
315            .iter()
316            .filter(|card| !card.keyword.as_str().to_lowercase().contains("comment"))
317            .map(|card| {
318                let name = card.keyword.as_str();
319                let value = card.value.as_float().unwrap_or(0.0);
320                Column::new(name.into(), &[value])
321            })
322            .collect())
323    } else {
324        // Process each requested header key
325        let mut columns = Vec::new();
326
327        for key in keys {
328            // Special handling for Beamline Energy
329            if key == "Beamline Energy" {
330                // First try to get "Beamline Energy"
331                if let Some(card) = hdu.header.get_card(key) {
332                    if let Some(val) = card.value.as_float() {
333                        columns.push(Column::new(key.into(), &[val]));
334                        continue;
335                    }
336                }
337
338                // Then fall back to "Beamline Energy Goal" if "Beamline Energy" is not present
339                if let Some(card) = hdu.header.get_card("Beamline Energy Goal") {
340                    if let Some(val) = card.value.as_float() {
341                        columns.push(Column::new(key.into(), &[val]));
342                        continue;
343                    }
344                }
345
346                // If neither value is available, use a default
347                columns.push(Column::new(key.into(), &[0.0]));
348                continue;
349            }
350
351            // Special handling for Date header (it's a string value, not a float)
352            if key == "DATE" {
353                if let Some(card) = hdu.header.get_card(key) {
354                    let val = card.value.to_string();
355                    columns.push(Column::new(key.into(), &[val]));
356                    continue;
357                }
358                // If DATE is not present, use a default empty string
359                columns.push(Column::new(key.into(), &["".to_string()]));
360                continue;
361            }
362
363            // For other headers, don't fail if they're missing
364            let val = match hdu.header.get_card(key) {
365                Some(card) => card.value.as_float().unwrap_or(1.0),
366                None => 0.0, // Default value for missing headers
367            };
368
369            // Use the snake_case name from the enum variant
370            columns.push(Column::new(key.into(), &[val]));
371        }
372
373        Ok(columns)
374    }
375}
376
377pub fn process_file_name(path: std::path::PathBuf) -> Vec<Column> {
378    // Extract just the file name without extension
379    let file_name = path.file_stem().unwrap().to_str().unwrap_or("");
380
381    // Just return the file name directly, without extracting frame numbers or scan IDs
382    vec![Column::new("file_name".into(), vec![file_name])]
383}