use ndarray::Array2;
use numpy::{IntoPyArray, PyArray2, PyArrayMethods};
use pyo3::prelude::*;
use crate::zones::{Zone, ZoneType, Zones};
use super::mueller::{Mueller, MuellerMatrix};
use super::results::Results;
#[pymethods]
impl Results {
fn __add__(&self, other: &Results) -> Results {
self.clone() + other.clone()
}
fn __sub__(&self, other: &Results) -> Results {
self.clone() - other.clone()
}
fn __pow__(&self, exponent: f32, _modulo: Option<u32>) -> Results {
use rand_distr::num_traits::Pow;
self.clone().pow(exponent)
}
fn __mul__(&self, other: &Bound<'_, PyAny>) -> PyResult<Results> {
if let Ok(other_results) = other.extract::<Results>() {
return Ok(self.clone() * other_results);
}
if let Ok(scalar) = other.extract::<f32>() {
return Ok(self.clone() * scalar);
}
Err(pyo3::exceptions::PyTypeError::new_err(
"unsupported operand type(s) for *: 'Results' and the provided type",
))
}
fn __truediv__(&self, rhs: f32) -> Results {
self.clone() / rhs
}
#[getter]
pub fn get_zones(&self) -> Zones {
self.zones.clone()
}
pub fn get_zone(&self, label: &str) -> Option<Zone> {
self.zones.get(label).cloned()
}
pub fn get_zone_by_type(&self, zone_type: ZoneType) -> Option<Zone> {
self.zones
.iter()
.find(|z| z.zone_type == zone_type)
.cloned()
}
#[getter]
pub fn get_full_zone(&self) -> Option<Zone> {
self.zones.full_zone().cloned()
}
#[getter]
pub fn get_forward_zone(&self) -> Option<Zone> {
self.zones
.iter()
.find(|z| z.zone_type == ZoneType::Forward)
.cloned()
}
#[getter]
pub fn get_backward_zone(&self) -> Option<Zone> {
self.zones.backward_zone().cloned()
}
#[getter]
pub fn get_bins<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f32>> {
let bins: Vec<f32> = self
.bins()
.iter()
.flat_map(|bin| vec![bin.theta.center, bin.phi.center])
.collect();
Array2::from_shape_vec((bins.len() / 2, 2), bins)
.unwrap()
.into_pyarray(py)
}
#[getter]
pub fn get_bins_1d<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<f32>>> {
self.zones.full_zone().and_then(|zone| {
zone.field_1d.as_ref().map(|field_1d| {
let bins: Vec<f32> = field_1d.iter().map(|result| result.bin.center).collect();
Array2::from_shape_vec((bins.len(), 1), bins)
.unwrap()
.into_pyarray(py)
})
})
}
#[getter]
pub fn get_mueller<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f32>> {
let field_2d = self
.zones
.full_zone()
.map(|z| &z.field_2d[..])
.unwrap_or(&[]);
let muellers: Vec<f32> = field_2d
.iter()
.flat_map(|r| r.mueller_total.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
}
#[setter]
pub fn set_mueller(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
for (i, field) in zone.field_2d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_total = Mueller::from_row_slice(slice);
}
}
}
#[getter]
pub fn get_mueller_beam<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f32>> {
let field_2d = self
.zones
.full_zone()
.map(|z| &z.field_2d[..])
.unwrap_or(&[]);
let muellers: Vec<f32> = field_2d
.iter()
.flat_map(|r| r.mueller_beam.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
}
#[setter]
pub fn set_mueller_beam(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
for (i, field) in zone.field_2d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_beam = Mueller::from_row_slice(slice);
}
}
}
#[getter]
pub fn get_mueller_ext<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray2<f32>> {
let field_2d = self
.zones
.full_zone()
.map(|z| &z.field_2d[..])
.unwrap_or(&[]);
let muellers: Vec<f32> = field_2d
.iter()
.flat_map(|r| r.mueller_ext.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
}
#[setter]
pub fn set_mueller_ext(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
for (i, field) in zone.field_2d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_ext = Mueller::from_row_slice(slice);
}
}
}
#[getter]
pub fn get_mueller_1d<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<f32>>> {
self.zones.full_zone().and_then(|zone| {
zone.field_1d.as_ref().map(|field_1d| {
let muellers: Vec<f32> = field_1d
.iter()
.flat_map(|r| r.mueller_total.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
})
})
}
#[setter]
pub fn set_mueller_1d(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
if let Some(ref mut field_1d) = zone.field_1d {
for (i, field) in field_1d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_total = Mueller::from_row_slice(slice);
}
}
}
}
#[getter]
pub fn get_mueller_1d_beam<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<f32>>> {
self.zones.full_zone().and_then(|zone| {
zone.field_1d.as_ref().map(|field_1d| {
let muellers: Vec<f32> = field_1d
.iter()
.flat_map(|r| r.mueller_beam.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
})
})
}
#[setter]
pub fn set_mueller_1d_beam(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
if let Some(ref mut field_1d) = zone.field_1d {
for (i, field) in field_1d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_beam = Mueller::from_row_slice(slice);
}
}
}
}
#[getter]
pub fn get_mueller_1d_ext<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyArray2<f32>>> {
self.zones.full_zone().and_then(|zone| {
zone.field_1d.as_ref().map(|field_1d| {
let muellers: Vec<f32> = field_1d
.iter()
.flat_map(|r| r.mueller_ext.to_vec())
.collect();
Array2::from_shape_vec((muellers.len() / 16, 16), muellers)
.unwrap()
.into_pyarray(py)
})
})
}
#[setter]
pub fn set_mueller_1d_ext(&mut self, array: &Bound<'_, PyArray2<f32>>) {
let array_view = unsafe { array.as_array() };
if let Some(zone) = self.zones.full_zone_mut() {
if let Some(ref mut field_1d) = zone.field_1d {
for (i, field) in field_1d.iter_mut().enumerate() {
let row = array_view.row(i);
let slice = row.as_slice().unwrap();
field.mueller_ext = Mueller::from_row_slice(slice);
}
}
}
}
#[getter]
pub fn get_asymmetry(&self) -> Option<f32> {
use super::component::GOComponent;
self.params.asymmetry(&GOComponent::Total)
}
#[getter]
pub fn get_scat_cross(&self) -> Option<f32> {
use super::component::GOComponent;
self.params.scatt_cross(&GOComponent::Total)
}
#[getter]
pub fn get_ext_cross(&self) -> Option<f32> {
use super::component::GOComponent;
self.params.ext_cross(&GOComponent::Total)
}
#[getter]
pub fn get_albedo(&self) -> Option<f32> {
use super::component::GOComponent;
self.params.albedo(&GOComponent::Total)
}
#[getter]
pub fn get_powers(&self) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
let dict = pyo3::types::PyDict::new(py);
dict.set_item("input", self.powers.input)?;
dict.set_item("output", self.powers.output)?;
dict.set_item("absorbed", self.powers.absorbed)?;
dict.set_item("trnc_ref", self.powers.trnc_ref)?;
dict.set_item("trnc_rec", self.powers.trnc_rec)?;
dict.set_item("trnc_clip", self.powers.trnc_clip)?;
dict.set_item("trnc_energy", self.powers.trnc_energy)?;
dict.set_item("clip_err", self.powers.clip_err)?;
dict.set_item("trnc_area", self.powers.trnc_area)?;
dict.set_item("trnc_cop", self.powers.trnc_cop)?;
dict.set_item("ext_diff", self.powers.ext_diff)?;
dict.set_item("missing", self.powers.missing())?;
Ok(dict.into())
})
}
#[setter]
pub fn set_powers(&mut self, dict: &Bound<'_, pyo3::types::PyDict>) -> PyResult<()> {
if let Some(val) = dict.get_item("input")? {
self.powers.input = val.extract()?;
}
if let Some(val) = dict.get_item("output")? {
self.powers.output = val.extract()?;
}
if let Some(val) = dict.get_item("absorbed")? {
self.powers.absorbed = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_ref")? {
self.powers.trnc_ref = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_rec")? {
self.powers.trnc_rec = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_clip")? {
self.powers.trnc_clip = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_energy")? {
self.powers.trnc_energy = val.extract()?;
}
if let Some(val) = dict.get_item("clip_err")? {
self.powers.clip_err = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_area")? {
self.powers.trnc_area = val.extract()?;
}
if let Some(val) = dict.get_item("trnc_cop")? {
self.powers.trnc_cop = val.extract()?;
}
if let Some(val) = dict.get_item("ext_diff")? {
self.powers.ext_diff = val.extract()?;
}
Ok(())
}
}