1mod error;
16mod fee;
17#[cfg(test)]
18mod tests;
19
20pub(crate) use error::BeamError;
21pub(crate) use fee::FEEBeam;
22
23use std::{path::Path, str::FromStr};
24
25use itertools::Itertools;
26use log::debug;
27use marlu::{AzEl, Jones};
28use ndarray::prelude::*;
29use strum::IntoEnumIterator;
30
31#[cfg(any(feature = "cuda", feature = "hip"))]
32use crate::gpu::{DevicePointer, GpuFloat};
33
34#[derive(
36 Debug,
37 Clone,
38 Copy,
39 PartialEq,
40 Eq,
41 strum_macros::Display,
42 strum_macros::EnumIter,
43 strum_macros::EnumString,
44)]
45#[allow(clippy::upper_case_acronyms)]
46pub enum BeamType {
47 #[strum(serialize = "fee")]
49 FEE,
50
51 #[strum(serialize = "none")]
53 None,
54}
55
56impl Default for BeamType {
57 fn default() -> Self {
58 Self::FEE
59 }
60}
61
62lazy_static::lazy_static! {
63 pub(crate) static ref BEAM_TYPES_COMMA_SEPARATED: String = BeamType::iter().map(|s| s.to_string().to_lowercase()).join(", ");
64}
65
66pub trait Beam: Sync + Send {
68 fn get_beam_type(&self) -> BeamType;
70
71 fn get_num_tiles(&self) -> usize;
74
75 fn get_dipole_delays(&self) -> Option<ArcArray<u32, Dim<[usize; 2]>>>;
77
78 fn get_ideal_dipole_delays(&self) -> Option<[u32; 16]>;
80
81 fn get_dipole_gains(&self) -> Option<ArcArray<f64, Dim<[usize; 2]>>>;
85
86 fn get_beam_file(&self) -> Option<&Path>;
88
89 fn calc_jones(
94 &self,
95 azel: AzEl,
96 freq_hz: f64,
97 tile_index: Option<usize>,
98 latitude_rad: f64,
99 ) -> Result<Jones<f64>, BeamError>;
100
101 fn calc_jones_array(
107 &self,
108 azels: &[AzEl],
109 freq_hz: f64,
110 tile_index: Option<usize>,
111 latitude_rad: f64,
112 ) -> Result<Vec<Jones<f64>>, BeamError>;
113
114 fn calc_jones_array_inner(
121 &self,
122 azels: &[AzEl],
123 freq_hz: f64,
124 tile_index: Option<usize>,
125 latitude_rad: f64,
126 results: &mut [Jones<f64>],
127 ) -> Result<(), BeamError>;
128
129 fn find_closest_freq(&self, desired_freq_hz: f64) -> f64;
134
135 fn empty_coeff_cache(&self);
137
138 #[cfg(any(feature = "cuda", feature = "hip"))]
139 fn prepare_gpu_beam(&self, freqs_hz: &[u32]) -> Result<Box<dyn BeamGpu>, BeamError>;
143}
144
145#[cfg(any(feature = "cuda", feature = "hip"))]
147pub trait BeamGpu {
148 unsafe fn calc_jones_pair(
157 &self,
158 az_rad: &[GpuFloat],
159 za_rad: &[GpuFloat],
160 latitude_rad: f64,
161 d_jones: *mut std::ffi::c_void,
162 ) -> Result<(), BeamError>;
163
164 fn get_beam_type(&self) -> BeamType;
166
167 fn get_tile_map(&self) -> *const i32;
170
171 fn get_freq_map(&self) -> *const i32;
174
175 fn get_num_unique_tiles(&self) -> i32;
177
178 fn get_num_unique_freqs(&self) -> i32;
181}
182
183#[derive(Debug, Clone)]
185pub enum Delays {
186 Full(Array2<u32>),
188
189 Partial(Vec<u32>),
192}
193
194impl Delays {
195 pub(crate) fn get_ideal_delays(&self) -> [u32; 16] {
200 let mut ideal_delays = [32; 16];
201 match self {
202 Delays::Partial(v) => {
203 v.iter().enumerate().for_each(|(i, &elem)| {
206 ideal_delays[i % 16] = elem;
207 });
208 }
209 Delays::Full(a) => {
210 for row in a.outer_iter() {
212 row.iter().enumerate().for_each(|(i, &col)| {
213 let ideal_delay = ideal_delays.get_mut(i % 16).unwrap();
214
215 *ideal_delay = (*ideal_delay).min(col);
221 });
222 if ideal_delays.iter().all(|&e| e < 32) {
223 break;
224 }
225 }
226 }
227 }
228 ideal_delays
229 }
230
231 pub(crate) fn set_to_ideal_delays(&mut self) {
235 let ideal_delays = self.get_ideal_delays();
236 match self {
237 Delays::Full(a) => {
239 let ideal_delays = ArrayView1::from(&ideal_delays);
240 a.outer_iter_mut().for_each(|mut r| r.assign(&ideal_delays));
241 }
242
243 Delays::Partial { .. } => (),
245 }
246 }
247
248 pub(crate) fn parse(delays: Vec<u32>) -> Result<Delays, BeamError> {
250 if delays.len() != 16 || delays.iter().any(|&v| v > 32) {
251 return Err(BeamError::BadDelays);
252 }
253 Ok(Delays::Partial(delays))
254 }
255}
256
257pub(crate) struct NoBeam {
260 pub(crate) num_tiles: usize,
261}
262
263impl Beam for NoBeam {
264 fn get_beam_type(&self) -> BeamType {
265 BeamType::None
266 }
267
268 fn get_num_tiles(&self) -> usize {
269 self.num_tiles
270 }
271
272 fn get_ideal_dipole_delays(&self) -> Option<[u32; 16]> {
273 None
274 }
275
276 fn get_dipole_delays(&self) -> Option<ArcArray<u32, Dim<[usize; 2]>>> {
277 None
278 }
279
280 fn get_dipole_gains(&self) -> Option<ArcArray<f64, Dim<[usize; 2]>>> {
281 None
282 }
283
284 fn get_beam_file(&self) -> Option<&Path> {
285 None
286 }
287
288 fn calc_jones(
289 &self,
290 _azel: AzEl,
291 _freq_hz: f64,
292 _tile_index: Option<usize>,
293 _latitude_rad: f64,
294 ) -> Result<Jones<f64>, BeamError> {
295 Ok(Jones::identity())
296 }
297
298 fn calc_jones_array(
299 &self,
300 azels: &[AzEl],
301 _freq_hz: f64,
302 _tile_index: Option<usize>,
303 _latitude_rad: f64,
304 ) -> Result<Vec<Jones<f64>>, BeamError> {
305 Ok(vec![Jones::identity(); azels.len()])
306 }
307
308 fn calc_jones_array_inner(
309 &self,
310 _azels: &[AzEl],
311 _freq_hz: f64,
312 _tile_index: Option<usize>,
313 _latitude_rad: f64,
314 results: &mut [Jones<f64>],
315 ) -> Result<(), BeamError> {
316 results.fill(Jones::identity());
317 Ok(())
318 }
319
320 fn find_closest_freq(&self, desired_freq_hz: f64) -> f64 {
321 desired_freq_hz
322 }
323
324 fn empty_coeff_cache(&self) {}
325
326 #[cfg(any(feature = "cuda", feature = "hip"))]
327 fn prepare_gpu_beam(&self, freqs_hz: &[u32]) -> Result<Box<dyn BeamGpu>, BeamError> {
328 let obj = NoBeamGpu {
329 tile_map: DevicePointer::copy_to_device(&vec![0; self.num_tiles])?,
330 freq_map: DevicePointer::copy_to_device(&vec![0; freqs_hz.len()])?,
331 };
332 Ok(Box::new(obj))
333 }
334}
335
336#[cfg(any(feature = "cuda", feature = "hip"))]
339pub(crate) struct NoBeamGpu {
340 tile_map: DevicePointer<i32>,
341 freq_map: DevicePointer<i32>,
342}
343
344#[cfg(any(feature = "cuda", feature = "hip"))]
345impl BeamGpu for NoBeamGpu {
346 unsafe fn calc_jones_pair(
347 &self,
348 az_rad: &[GpuFloat],
349 _za_rad: &[GpuFloat],
350 _latitude_rad: f64,
351 d_jones: *mut std::ffi::c_void,
352 ) -> Result<(), BeamError> {
353 #[cfg(feature = "cuda")]
354 use cuda_runtime_sys::{
355 cudaMemcpy as gpuMemcpy,
356 cudaMemcpyKind::cudaMemcpyHostToDevice as gpuMemcpyHostToDevice,
357 };
358 #[cfg(feature = "hip")]
359 use hip_sys::hiprt::{
360 hipMemcpy as gpuMemcpy, hipMemcpyKind::hipMemcpyHostToDevice as gpuMemcpyHostToDevice,
361 };
362
363 let identities: Vec<Jones<GpuFloat>> = vec![Jones::identity(); az_rad.len()];
364 gpuMemcpy(
365 d_jones,
366 identities.as_ptr().cast(),
367 identities.len() * std::mem::size_of::<Jones<GpuFloat>>(),
368 gpuMemcpyHostToDevice,
369 );
370 Ok(())
371 }
372
373 fn get_beam_type(&self) -> BeamType {
374 BeamType::None
375 }
376
377 fn get_tile_map(&self) -> *const i32 {
378 self.tile_map.get()
379 }
380
381 fn get_freq_map(&self) -> *const i32 {
382 self.freq_map.get()
383 }
384
385 fn get_num_unique_tiles(&self) -> i32 {
386 1
387 }
388
389 fn get_num_unique_freqs(&self) -> i32 {
390 1
391 }
392}
393
394pub fn create_beam_object(
395 beam_type: Option<&str>,
396 num_tiles: usize,
397 dipole_delays: Delays,
398) -> Result<Box<dyn Beam>, BeamError> {
399 let beam_type = match (
400 beam_type,
401 beam_type.and_then(|b| BeamType::from_str(b).ok()),
402 ) {
403 (None, _) => BeamType::default(),
404 (Some(_), Some(b)) => b,
405 (Some(s), None) => return Err(BeamError::Unrecognised(s.to_string())),
406 };
407
408 match beam_type {
409 BeamType::None => {
410 debug!("Setting up a \"NoBeam\" object");
411 Ok(Box::new(NoBeam { num_tiles }))
412 }
413
414 BeamType::FEE => {
415 debug!("Setting up a FEE beam object");
416
417 validate_delays(&dipole_delays, num_tiles)?;
419
420 Ok(Box::new(FEEBeam::new_from_env(
423 num_tiles,
424 dipole_delays,
425 None,
426 )?))
427 }
428 }
429}
430
431fn partial_to_full(delays: Vec<u32>, num_tiles: usize) -> Array2<u32> {
434 let mut out = Array2::zeros((num_tiles, 16));
435 let d = Array1::from(delays);
436 out.outer_iter_mut().for_each(|mut tile_delays| {
437 tile_delays.assign(&d);
438 });
439 out
440}
441
442fn validate_delays(delays: &Delays, num_tiles: usize) -> Result<(), BeamError> {
443 match delays {
444 Delays::Partial(v) => {
445 if v.len() != 16 || v.iter().any(|&v| v > 32) {
446 return Err(BeamError::BadDelays);
447 }
448 }
449
450 Delays::Full(a) => {
451 if a.len_of(Axis(1)) != 16 || a.iter().any(|&v| v > 32) {
452 return Err(BeamError::BadDelays);
453 }
454 if a.len_of(Axis(0)) != num_tiles {
455 return Err(BeamError::InconsistentDelays {
456 num_rows: a.len_of(Axis(0)),
457 num_tiles,
458 });
459 }
460 }
461 }
462
463 Ok(())
464}