pyref_core/
io.rs

1use astrors_fork::io::hdus::{
2    image::{imagehdu::ImageHDU, ImageData},
3    primaryhdu::PrimaryHDU,
4};
5use ndarray::{ArrayBase, Axis, Dim, IxDynImpl, OwnedRepr};
6use polars::prelude::*;
7use std::ops::Mul;
8
9use crate::errors::FitsLoaderError;
10
11pub fn q(lam: f64, theta: f64) -> f64 {
12    let theta = theta;
13    match 4.0 * std::f64::consts::PI * theta.to_radians().sin() / lam {
14        q if q < 0.0 => 0.0,
15        q => q,
16    }
17}
18
19pub fn col_from_array(
20    name: PlSmallStr,
21    array: ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
22) -> Result<Column, PolarsError> {
23    let size0 = array.len_of(Axis(0));
24    let size1 = array.len_of(Axis(1));
25
26    let mut s = Column::new_empty(
27        name,
28        &DataType::List(Box::new(DataType::List(Box::new(DataType::Int64)))),
29    );
30
31    let mut chunked_builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(
32        PlSmallStr::EMPTY,
33        array.len_of(Axis(1)),
34        array.len_of(Axis(1)) * array.len_of(Axis(0)),
35        DataType::Int64,
36    );
37    for col in array.axis_iter(Axis(1)) {
38        match col.as_slice() {
39            Some(col) => chunked_builder.append_slice(col),
40            None => chunked_builder.append_slice(&col.to_owned().into_raw_vec()),
41        }
42    }
43    let new_series = chunked_builder
44        .finish()
45        .into_series()
46        .implode()
47        .unwrap()
48        .into_column();
49    let _ = s.extend(&new_series);
50    let s = s.cast(&DataType::Array(
51        Box::new(DataType::Array(Box::new(DataType::Int32), size0)),
52        size1,
53    ));
54    s
55}
56// ================== CCD Raw Data Processing ============
57pub fn add_calculated_domains(lzf: LazyFrame) -> DataFrame {
58    let h = physical_constants::PLANCK_CONSTANT_IN_EV_PER_HZ;
59    let c = physical_constants::SPEED_OF_LIGHT_IN_VACUUM * 1e10;
60
61    // Collect schema once to check for column existence
62    let schema = lzf.clone().collect_schema().unwrap_or_default();
63    let has_column = |name: &str| schema.iter().any(|(col_name, _)| col_name == name);
64
65    // Start with basic sorting that won't fail even if columns don't exist
66    let mut lz = lzf;
67
68    // Apply optional sorting only if the columns exist
69    if has_column("DATE") {
70        lz = lz.sort(["DATE"], Default::default());
71    }
72
73    if has_column("file_name") {
74        lz = lz.sort(["file_name"], Default::default());
75    }
76
77    // Conditionally apply column transformations only if the columns exist
78    if has_column("EXPOSURE") {
79        lz = lz.with_column(col("EXPOSURE").round(3).alias("EXPOSURE"));
80    }
81
82    if has_column("Higher Order Suppressor") {
83        lz = lz.with_column(
84            col("Higher Order Suppressor")
85                .round(2)
86                .alias("Higher Order Suppressor"),
87        );
88    }
89
90    if has_column("Horizontal Exit Slit Size") {
91        lz = lz.with_column(
92            col("Horizontal Exit Slit Size")
93                .round(1)
94                .alias("Horizontal Exit Slit Size"),
95        );
96    }
97
98    if has_column("Beamline Energy") {
99        lz = lz.with_column(col("Beamline Energy").round(1).alias("Beamline Energy"));
100
101        // Only calculate Lambda if Beamline Energy exists
102        lz = lz.with_column(
103            col("Beamline Energy")
104                .pow(-1)
105                .mul(lit(h * c))
106                .alias("Lambda"),
107        );
108    }
109
110    // Add Q column if required columns exist
111    lz = lz.with_column(
112        when(
113            col("Sample Theta")
114                .is_not_null()
115                .and(col("Lambda").is_not_null()),
116        )
117        .then(as_struct(vec![col("Sample Theta"), col("Lambda")]).map(
118            move |s| {
119                let struc = s.struct_()?;
120                let th_series = struc.field_by_name("Sample Theta")?;
121                let theta = th_series.f64()?;
122                let lam_series = struc.field_by_name("Lambda")?;
123                let lam = lam_series.f64()?;
124
125                let out: Float64Chunked = theta
126                    .into_iter()
127                    .zip(lam.iter())
128                    .map(|(theta, lam)| match (theta, lam) {
129                        (Some(theta), Some(lam)) => Some(q(lam, theta)),
130                        _ => None,
131                    })
132                    .collect();
133
134                Ok(Some(out.into_column()))
135            },
136            GetOutput::from_type(DataType::Float64),
137        ))
138        .otherwise(lit(NULL))
139        .alias("Q"),
140    );
141
142    // Collect the final DataFrame only once at the end
143    lz.collect().unwrap_or_else(|_| DataFrame::empty())
144}
145
146/// Reads a single FITS file and converts it to a Polars DataFrame.
147///
148/// # Arguments
149///
150/// * `file_path` - Path to the FITS file to read
151/// * `header_items` - List of header values to extract
152///
153/// # Returns
154///
155/// A `Result` containing either the DataFrame or a `FitsLoaderError`.
156pub fn process_image(img: &ImageHDU) -> Result<Vec<Column>, FitsLoaderError> {
157    let bzero = img
158        .header
159        .get_card("BZERO")
160        .ok_or_else(|| FitsLoaderError::MissingHeaderKey("BZERO".into()))?
161        .value
162        .as_int()
163        .ok_or_else(|| FitsLoaderError::FitsError("BZERO not an integer".into()))?;
164
165    match &img.data {
166        ImageData::I16(image) => {
167            let data = image.map(|&x| i64::from(x as i64 + bzero));
168            // Implement row-by-row background subtraction
169            let subtracted = subtract_background(&data);
170            // Locate the index tuple with the maximum value
171            // Find the coordinates of the maximum value in the 2D array
172            let max_coords = {
173                let mut max_coords = (0, 0);
174                let mut max_val = i64::MIN;
175
176                for (idx, &val) in subtracted.indexed_iter() {
177                    if val > max_val {
178                        max_val = val;
179                        max_coords = (idx[0], idx[1]);
180                    }
181                }
182
183                max_coords
184            };
185
186            // Check if the beam is too close to any edge (top, bottom, left, right)
187            let msg = if max_coords.0 < 20
188                || max_coords.0 > (subtracted.len_of(Axis(0)) - 20)
189                || max_coords.1 < 20
190                || max_coords.1 > (subtracted.len_of(Axis(1)) - 20)
191            {
192                "Simple Detection Error: Beam is too close to the edge"
193            } else {
194                ""
195            };
196            // Calculate a simple reflectivity result from the subtracted data
197            let (db_sum, scaled_bg) = { simple_reflectivity(&subtracted, max_coords) };
198
199            Ok(vec![
200                col_from_array("RAW".into(), data.clone()).unwrap(),
201                col_from_array("SUBTRACTED".into(), subtracted.clone()).unwrap(),
202                Column::new("Simple Spot X".into(), vec![max_coords.0 as u64]),
203                Column::new("Simple Spot Y".into(), vec![max_coords.1 as u64]),
204                Column::new(
205                    "Simple Reflectivity".into(),
206                    vec![(db_sum - scaled_bg) as f64],
207                ),
208                Series::new("status".into(), vec![msg.to_string()]).into_column(),
209            ])
210        }
211        _ => Err(FitsLoaderError::UnsupportedImageData),
212    }
213}
214
215fn simple_reflectivity(
216    subtracted: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
217    max_index: (usize, usize),
218) -> (i64, i64) {
219    // Convert max_index to 2D coordinates
220    let beam_y = max_index.0;
221    // Row
222    let beam_x = max_index.1;
223    // Column
224
225    // Define ROI size
226    let roi = 5;
227    // Region of interest size
228
229    // Define ROI boundaries
230    let roi_start_y = beam_y.saturating_sub(roi);
231    let roi_end_y = (beam_y + roi + 1).min(subtracted.len_of(Axis(0)));
232    let roi_start_x = beam_x.saturating_sub(roi);
233    let roi_end_x = (beam_x + roi + 1).min(subtracted.len_of(Axis(1)));
234
235    // Initialize sums and counts
236    let mut db_sum = 0i64;
237    let mut db_count = 0i64;
238    let mut bg_sum = 0i64;
239    let mut bg_count = 0i64;
240
241    // Iterate over all rows and columns
242    for y in 0..subtracted.len_of(Axis(0)) {
243        for x in 0..subtracted.len_of(Axis(1)) {
244            let value = subtracted[[y, x]];
245            if value == 0 {
246                continue;
247            }
248
249            if (roi_start_y <= y && y < roi_end_y) && (roi_start_x <= x && x < roi_end_x) {
250                db_sum += value;
251                db_count += 1;
252            } else {
253                bg_sum += value;
254                bg_count += 1;
255            }
256        }
257    }
258
259    // Handle edge cases
260    if bg_count == 0 || db_sum == 0 {
261        (0, 0)
262    } else {
263        // Scale background sum based on ratio of counts
264        let scaled_bg = (bg_sum * db_count) / bg_count;
265        (db_sum, scaled_bg)
266    }
267}
268
269fn subtract_background(
270    data: &ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>>,
271) -> ArrayBase<OwnedRepr<i64>, Dim<IxDynImpl>> {
272    // Get a view of the data with 5 pixels sliced from each side
273    let view = data.slice(ndarray::s![5..-5, 5..-5]);
274    let rows = view.len_of(Axis(0));
275    let cols = view.len_of(Axis(1));
276
277    // Extract the left and right columns (first and last 20 columns)
278    let left = view.slice(ndarray::s![.., ..20]);
279    let right = view.slice(ndarray::s![.., (cols - 20)..]);
280
281    // Calculate the sum of left and right regions
282    let left_sum: i64 = left.iter().copied().sum();
283    let right_sum: i64 = right.iter().copied().sum();
284
285    // Create background array to store row means
286    let mut background = ndarray::Array1::zeros(rows);
287
288    // Determine which side to use for background
289    if left_sum < right_sum {
290        // Use right side as background
291        for (i, row) in right.axis_iter(Axis(0)).enumerate() {
292            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
293        }
294    } else {
295        // Use left side as background
296        for (i, row) in left.axis_iter(Axis(0)).enumerate() {
297            background[i] = row.iter().copied().sum::<i64>() / row.len() as i64;
298        }
299    }
300
301    // Create a new owned array from the view
302    let mut result = view.to_owned();
303
304    // Subtract background from each row
305    for (i, mut row) in result.axis_iter_mut(Axis(0)).enumerate() {
306        let bg = background[i];
307        for val in row.iter_mut() {
308            *val -= bg;
309        }
310    }
311    result.into_dyn()
312}
313
314pub fn process_metadata(
315    hdu: &PrimaryHDU,
316    keys: &Vec<String>,
317) -> Result<Vec<Column>, FitsLoaderError> {
318    if keys.is_empty() {
319        // If no specific keys are requested, return all header values
320        Ok(hdu
321            .header
322            .iter()
323            .filter(|card| !card.keyword.as_str().to_lowercase().contains("comment"))
324            .map(|card| {
325                let name = card.keyword.as_str();
326                let value = card.value.as_float().unwrap_or(0.0);
327                Column::new(name.into(), &[value])
328            })
329            .collect())
330    } else {
331        // Process each requested header key
332        let mut columns = Vec::new();
333
334        for key in keys {
335            // Special handling for Beamline Energy
336            if key == "Beamline Energy" {
337                // First try to get "Beamline Energy"
338                if let Some(card) = hdu.header.get_card(key) {
339                    if let Some(val) = card.value.as_float() {
340                        columns.push(Column::new(key.into(), &[val]));
341                        continue;
342                    }
343                }
344
345                // Then fall back to "Beamline Energy Goal" if "Beamline Energy" is not present
346                if let Some(card) = hdu.header.get_card("Beamline Energy Goal") {
347                    if let Some(val) = card.value.as_float() {
348                        columns.push(Column::new(key.into(), &[val]));
349                        continue;
350                    }
351                }
352
353                // If neither value is available, use a default
354                columns.push(Column::new(key.into(), &[0.0]));
355                continue;
356            }
357
358            // Special handling for Date header (it's a string value, not a float)
359            if key == "DATE" {
360                if let Some(card) = hdu.header.get_card(key) {
361                    let val = card.value.to_string();
362                    columns.push(Column::new(key.into(), &[val]));
363                    continue;
364                }
365                // If DATE is not present, use a default empty string
366                columns.push(Column::new(key.into(), &["".to_string()]));
367                continue;
368            }
369
370            // For other headers, don't fail if they're missing
371            let val = match hdu.header.get_card(key) {
372                Some(card) => card.value.as_float().unwrap_or(1.0),
373                None => 0.0, // Default value for missing headers
374            };
375
376            // Use the snake_case name from the enum variant
377            columns.push(Column::new(key.into(), &[val]));
378        }
379
380        Ok(columns)
381    }
382}
383
384pub fn process_file_name(path: std::path::PathBuf) -> Vec<Column> {
385    // Extract just the file name without extension
386    let file_name = path.file_stem().unwrap().to_str().unwrap_or("");
387
388    // Just return the file name directly, without extracting frame numbers or scan IDs
389    vec![Column::new("file_name".into(), vec![file_name])]
390}