use std::io::Read;
use std::path::Path;
use std::str::FromStr;
use csv::StringRecord;
use num_complex::Complex64;
use serde::{Deserialize, Serialize};
use crate::error::{Result, VecfitError};
use crate::fit::{Options, Report, SampleMatrix};
use crate::model::Model;
use crate::shape::{Layout, ResponseSample, Shape};
const REAL_COEFFICIENT_TOLERANCE: f64 = 1e-10;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealKernelPoleJson {
#[serde(rename = "q")]
pub pole: f64,
#[serde(rename = "r")]
pub residues: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RealKernelJsonModel {
pub description: Option<String>,
#[serde(rename = "nfuncs")]
pub n_functions: usize,
#[serde(rename = "npoles")]
pub n_poles: usize,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shape: Option<Shape>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub layout: Option<Layout>,
#[serde(rename = "d")]
pub direct_terms: Vec<f64>,
#[serde(rename = "e")]
pub proportional_terms: Option<Vec<f64>>,
pub poles: Vec<RealKernelPoleJson>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplexModelJson {
pub poles: Vec<[f64; 2]>,
pub residues: Vec<Vec<[f64; 2]>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub shape: Option<Shape>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub layout: Option<Layout>,
#[serde(rename = "d")]
pub direct_terms: Vec<[f64; 2]>,
#[serde(rename = "e")]
pub proportional_terms: Vec<[f64; 2]>,
#[serde(rename = "rmse")]
pub abs_rmse: f64,
#[serde(rename = "iters")]
pub iterations: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CsvFormat {
MagnitudePhase,
RealImag,
}
#[derive(Debug, Clone)]
pub struct ParsedSamples {
frequency_hz: Vec<f64>,
axis: Vec<Complex64>,
samples: SampleMatrix,
shape: Shape,
layout: Layout,
}
impl ParsedSamples {
pub fn new(
frequency_hz: Vec<f64>,
axis: Vec<Complex64>,
samples: SampleMatrix,
shape: Shape,
layout: Layout,
) -> Result<Self> {
if frequency_hz.len() != axis.len() {
return Err(VecfitError::Dimension(format!(
"frequency_hz length {} does not match axis length {}",
frequency_hz.len(),
axis.len()
)));
}
if samples.samples != axis.len() {
return Err(VecfitError::Dimension(format!(
"sample matrix rows {} do not match axis length {}",
samples.samples,
axis.len()
)));
}
Ok(Self {
frequency_hz,
axis,
samples,
shape,
layout,
})
}
pub fn with_shape(mut self, shape: Shape) -> Result<Self> {
if shape.channels() != self.samples.channels {
return Err(VecfitError::Dimension(format!(
"shape {:?} expects {} channels but data has {}",
shape.dims(),
shape.channels(),
self.samples.channels
)));
}
self.shape = shape;
Ok(self)
}
pub fn frequency_hz(&self) -> &[f64] {
&self.frequency_hz
}
pub fn axis(&self) -> &[Complex64] {
&self.axis
}
pub fn samples(&self) -> &SampleMatrix {
&self.samples
}
pub fn len(&self) -> usize {
self.samples.samples
}
pub fn is_empty(&self) -> bool {
self.samples.samples == 0
}
pub fn channels(&self) -> usize {
self.samples.channels
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn layout(&self) -> Layout {
self.layout
}
pub fn responses(&self) -> Result<Vec<ResponseSample<Complex64>>> {
(0..self.samples.samples)
.map(|row| {
ResponseSample::new(
self.samples.row(row).to_vec(),
self.shape.clone(),
self.layout,
)
})
.collect()
}
pub fn scalars(&self) -> Result<Vec<Complex64>> {
self.responses()?
.into_iter()
.map(ResponseSample::into_scalar)
.collect()
}
pub fn vectors(&self) -> Result<Vec<Vec<Complex64>>> {
self.responses()?
.into_iter()
.map(ResponseSample::into_vector)
.collect()
}
pub fn matrices(&self) -> Result<Vec<Vec<Vec<Complex64>>>> {
self.responses()?
.into_iter()
.map(ResponseSample::into_matrix)
.collect()
}
pub fn scalar(self) -> Result<Self> {
self.with_shape(Shape::scalar())
}
pub fn vector(self, len: usize) -> Result<Self> {
self.with_shape(Shape::vector(len)?)
}
pub fn matrix(self, rows: usize, cols: usize) -> Result<Self> {
self.with_shape(Shape::matrix(rows, cols)?)
}
pub fn tensor<I>(self, dims: I) -> Result<Self>
where
I: IntoIterator<Item = usize>,
{
self.with_shape(Shape::tensor(dims)?)
}
pub fn with_layout(mut self, layout: Layout) -> Self {
self.layout = layout;
self
}
pub fn compare(&self, model: &Model) -> Result<crate::model::ChannelErrors> {
model.channel_errors(&self.axis as &[Complex64], &self.samples.values)
}
pub fn fit(&self, options: Options) -> Result<Model> {
Model::fit_samples(
crate::axis::complex(&self.axis),
&self.samples.values,
self.shape.clone(),
options,
)
}
}
#[derive(Debug, Clone)]
pub struct Csv {
inner: ParsedSamples,
}
impl std::ops::Deref for Csv {
type Target = ParsedSamples;
fn deref(&self) -> &ParsedSamples {
&self.inner
}
}
impl Csv {
pub fn into_parsed(self) -> ParsedSamples {
self.inner
}
pub fn from_csv(csv_text: &str) -> Result<Self> {
Self::from_reader_inner(csv_text.as_bytes(), b',')
}
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let csv_text = std::fs::read_to_string(path)?;
Self::from_csv(&csv_text)
}
pub fn from_reader<R: Read>(reader: R) -> Result<Self> {
Self::from_reader_inner(reader, b',')
}
pub fn from_tsv(text: &str) -> Result<Self> {
Self::from_delimited(text, b'\t')
}
pub fn from_ssv(text: &str) -> Result<Self> {
Self::from_delimited(text, b';')
}
pub fn from_delimited(text: &str, delimiter: u8) -> Result<Self> {
Self::from_reader_inner(text.as_bytes(), delimiter)
}
pub fn from_path_delimited<P: AsRef<Path>>(path: P, delimiter: u8) -> Result<Self> {
let text = std::fs::read_to_string(path)?;
Self::from_delimited(&text, delimiter)
}
pub fn fit_csv(csv_text: &str, options: Options) -> Result<Model> {
Self::from_csv(csv_text)?.fit(options)
}
pub fn fit_path<P: AsRef<Path>>(path: P, options: Options) -> Result<Model> {
Self::from_path(path)?.fit(options)
}
fn from_reader_inner<R: Read>(reader: R, delimiter: u8) -> Result<Self> {
let mut reader = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(delimiter)
.from_reader(reader);
let headers = reader
.headers()
.map_err(|err| VecfitError::Csv(err.to_string()))?
.clone();
let (format, channels) = detect_csv_format(&headers)?;
let mut frequency_hz = Vec::new();
let mut axis = Vec::new();
let mut sample_values = Vec::new();
for record in reader.records() {
let record = record?;
let frequency = parse_frequency_hz_column(&record)?;
frequency_hz.push(frequency);
axis.push(Complex64::new(0.0, 2.0 * std::f64::consts::PI * frequency));
for channel_idx in 0..channels {
let value = match format {
CsvFormat::MagnitudePhase => parse_mag_phase_column(&record, channel_idx)?,
CsvFormat::RealImag => parse_real_imag_column(&record, channel_idx)?,
};
sample_values.push(value);
}
}
let shape = if channels == 1 {
Shape::scalar()
} else {
Shape::vector(channels)?
};
Ok(Self {
inner: ParsedSamples {
frequency_hz,
samples: SampleMatrix::new(sample_values, axis.len(), channels)?,
axis,
shape,
layout: Layout::RowMajor,
},
})
}
pub fn with_shape(mut self, shape: Shape) -> Result<Self> {
self.inner = self.inner.with_shape(shape)?;
Ok(self)
}
pub fn scalar(self) -> Result<Self> {
self.with_shape(Shape::scalar())
}
pub fn vector(self, len: usize) -> Result<Self> {
self.with_shape(Shape::vector(len)?)
}
pub fn matrix(self, rows: usize, cols: usize) -> Result<Self> {
self.with_shape(Shape::matrix(rows, cols)?)
}
pub fn tensor<I>(self, dims: I) -> Result<Self>
where
I: IntoIterator<Item = usize>,
{
self.with_shape(Shape::tensor(dims)?)
}
pub fn with_layout(mut self, layout: Layout) -> Self {
self.inner.layout = layout;
self
}
}
impl Model {
pub fn from_json(json_text: &str) -> Result<Self> {
let json: ComplexModelJson = serde_json::from_str(json_text)?;
Self::try_from(json)
}
pub fn from_json_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let json_text = std::fs::read_to_string(path)?;
Self::from_json(&json_text)
}
pub fn to_json(&self) -> Result<String> {
let json = ComplexModelJson::try_from(self)?;
serde_json::to_string_pretty(&json).map_err(Into::into)
}
pub fn from_real_json(json_text: &str) -> Result<Self> {
let json: RealKernelJsonModel = serde_json::from_str(json_text)?;
Self::try_from(json)
}
pub fn from_real_json_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let json_text = std::fs::read_to_string(path)?;
Self::from_real_json(&json_text)
}
pub fn to_real_json(&self, description: Option<String>) -> Result<String> {
let json = RealKernelJsonModel::try_from((self, description))?;
serde_json::to_string_pretty(&json).map_err(Into::into)
}
}
impl FromStr for Csv {
type Err = VecfitError;
fn from_str(s: &str) -> Result<Self> {
Self::from_csv(s)
}
}
impl TryFrom<ComplexModelJson> for Model {
type Error = VecfitError;
fn try_from(json: ComplexModelJson) -> Result<Self> {
let pole_count = json.poles.len();
let channels = json.direct_terms.len();
if json.residues.len() != pole_count {
return Err(VecfitError::Serialization(
"complex JSON residue rows must match pole count".to_string(),
));
}
if json.proportional_terms.len() != channels {
return Err(VecfitError::Serialization(
"complex JSON proportional term count must match channel count".to_string(),
));
}
if json.residues.iter().any(|row| row.len() != channels) {
return Err(VecfitError::Serialization(
"complex JSON residue row length must match channel count".to_string(),
));
}
let residues = json
.residues
.iter()
.flat_map(|row| row.iter().copied().map(complex_from_pair))
.collect::<Vec<_>>();
let model = Model {
poles: json.poles.into_iter().map(complex_from_pair).collect(),
residues,
channels,
constant_terms: json
.direct_terms
.into_iter()
.map(complex_from_pair)
.collect(),
proportional_terms: json
.proportional_terms
.into_iter()
.map(complex_from_pair)
.collect(),
shape: resolve_json_shape(json.shape, channels)?,
layout: json.layout.unwrap_or(Layout::RowMajor),
report: Report {
abs_rmse: json.abs_rmse,
iterations: json.iterations,
..Report::default()
},
};
model.validate()?;
Ok(model)
}
}
impl TryFrom<&Model> for ComplexModelJson {
type Error = VecfitError;
fn try_from(model: &Model) -> Result<Self> {
model.validate()?;
let residues = (0..model.poles.len())
.map(|pole_idx| {
(0..model.channels)
.map(|channel_idx| {
let value = model.residues[pole_idx * model.channels + channel_idx];
[value.re, value.im]
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
Ok(Self {
poles: model
.poles
.iter()
.map(|value| [value.re, value.im])
.collect(),
residues,
shape: Some(model.shape.clone()),
layout: Some(model.layout),
direct_terms: model
.constant_terms
.iter()
.map(|value| [value.re, value.im])
.collect(),
proportional_terms: model
.proportional_terms
.iter()
.map(|value| [value.re, value.im])
.collect(),
abs_rmse: model.report.abs_rmse,
iterations: model.report.iterations,
})
}
}
impl TryFrom<RealKernelJsonModel> for Model {
type Error = VecfitError;
fn try_from(json: RealKernelJsonModel) -> Result<Self> {
if json.poles.len() != json.n_poles {
return Err(VecfitError::Serialization(
"real-kernel JSON pole list length does not match n_poles".to_string(),
));
}
let channels = json.n_functions;
if json.direct_terms.len() != channels {
return Err(VecfitError::Serialization(
"real-kernel JSON direct term count must match n_functions".to_string(),
));
}
if let Some(proportional_terms) = &json.proportional_terms {
if proportional_terms.len() != channels {
return Err(VecfitError::Serialization(
"real-kernel JSON proportional term count must match n_functions".to_string(),
));
}
}
let mut residues = Vec::with_capacity(json.n_poles * channels);
for pole in &json.poles {
if pole.residues.len() != channels {
return Err(VecfitError::Serialization(
"real-kernel JSON residue row length must match n_functions".to_string(),
));
}
residues.extend(
pole.residues
.iter()
.copied()
.map(|value| Complex64::new(value, 0.0)),
);
}
let model = Model {
poles: json
.poles
.iter()
.map(|pole| Complex64::new(pole.pole, 0.0))
.collect(),
residues,
channels,
constant_terms: json
.direct_terms
.iter()
.copied()
.map(|value| Complex64::new(value, 0.0))
.collect(),
proportional_terms: json
.proportional_terms
.unwrap_or_else(|| vec![0.0; channels])
.into_iter()
.map(|value| Complex64::new(value, 0.0))
.collect(),
shape: resolve_json_shape(json.shape, channels)?,
layout: json.layout.unwrap_or(Layout::RowMajor),
report: Report::default(),
};
model.validate()?;
Ok(model)
}
}
impl TryFrom<(&Model, Option<String>)> for RealKernelJsonModel {
type Error = VecfitError;
fn try_from((model, description): (&Model, Option<String>)) -> Result<Self> {
model.validate()?;
if model
.poles
.iter()
.any(|pole| pole.im.abs() > REAL_COEFFICIENT_TOLERANCE)
{
return Err(VecfitError::InvalidInput(
"real-kernel export requires real poles".to_string(),
));
}
if model
.residues
.iter()
.chain(model.constant_terms.iter())
.chain(model.proportional_terms.iter())
.any(|value| value.im.abs() > REAL_COEFFICIENT_TOLERANCE)
{
return Err(VecfitError::InvalidInput(
"real-kernel export requires real coefficients".to_string(),
));
}
let poles = model
.poles
.iter()
.enumerate()
.map(|(pole_idx, pole)| RealKernelPoleJson {
pole: pole.re,
residues: (0..model.channels)
.map(|channel_idx| model.residues[pole_idx * model.channels + channel_idx].re)
.collect(),
})
.collect::<Vec<_>>();
Ok(Self {
description,
n_functions: model.channels,
n_poles: model.poles.len(),
shape: Some(model.shape.clone()),
layout: Some(model.layout),
direct_terms: model.constant_terms.iter().map(|value| value.re).collect(),
proportional_terms: Some(
model
.proportional_terms
.iter()
.map(|value| value.re)
.collect(),
),
poles,
})
}
}
fn detect_csv_format(headers: &StringRecord) -> Result<(CsvFormat, usize)> {
if headers.len() < 3 {
return Err(VecfitError::Csv(
"expected at least a frequency column and one column pair".to_string(),
));
}
if (headers.len() - 1) % 2 != 0 {
return Err(VecfitError::Csv(
"data columns must appear in complete pairs".to_string(),
));
}
let channels = (headers.len() - 1) / 2;
let first_col = headers.get(1).unwrap_or("").trim().to_lowercase();
let format = if first_col.starts_with("re") {
CsvFormat::RealImag
} else {
CsvFormat::MagnitudePhase
};
Ok((format, channels))
}
fn parse_frequency_hz_column(record: &StringRecord) -> Result<f64> {
record
.get(0)
.ok_or_else(|| VecfitError::Csv("missing frequency column".to_string()))?
.trim()
.parse::<f64>()
.map_err(Into::into)
}
fn parse_mag_phase_column(record: &StringRecord, channel_idx: usize) -> Result<Complex64> {
let magnitude = record
.get(1 + channel_idx * 2)
.ok_or_else(|| VecfitError::Csv("missing magnitude column".to_string()))?
.trim()
.parse::<f64>()?;
let phase_deg = record
.get(2 + channel_idx * 2)
.ok_or_else(|| VecfitError::Csv("missing phase column".to_string()))?
.trim()
.parse::<f64>()?;
let phase_rad = phase_deg.to_radians();
Ok(Complex64::new(
magnitude * phase_rad.cos(),
magnitude * phase_rad.sin(),
))
}
fn parse_real_imag_column(record: &StringRecord, channel_idx: usize) -> Result<Complex64> {
let re = record
.get(1 + channel_idx * 2)
.ok_or_else(|| VecfitError::Csv("missing real column".to_string()))?
.trim()
.parse::<f64>()?;
let im = record
.get(2 + channel_idx * 2)
.ok_or_else(|| VecfitError::Csv("missing imaginary column".to_string()))?
.trim()
.parse::<f64>()?;
Ok(Complex64::new(re, im))
}
fn complex_from_pair(value: [f64; 2]) -> Complex64 {
Complex64::new(value[0], value[1])
}
fn resolve_json_shape(shape: Option<Shape>, channels: usize) -> Result<Shape> {
let shape = match shape {
Some(shape) => shape,
None if channels == 1 => Shape::scalar(),
None => Shape::vector(channels)?,
};
if shape.channels() != channels {
return Err(VecfitError::Serialization(format!(
"JSON shape {:?} expects {} channels but model has {}",
shape.dims(),
shape.channels(),
channels
)));
}
Ok(shape)
}