use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, Array2};
pub type MatrixFn = Box<dyn Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync>;
pub type ScalarFn = Box<dyn Fn(&[f64]) -> IntegrateResult<f64> + Send + Sync>;
pub type VectorFn = Box<dyn Fn(&[f64]) -> IntegrateResult<Array1<f64>> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct PortHamiltonianConfig {
pub skew_sym_tol: f64,
pub psd_tol: f64,
pub grad_epsilon: f64,
}
impl Default for PortHamiltonianConfig {
fn default() -> Self {
Self {
skew_sym_tol: 1e-10,
psd_tol: -1e-10,
grad_epsilon: 1e-7,
}
}
}
pub struct PortHamiltonianSystem {
pub n_states: usize,
pub n_ports: usize,
j_matrix: MatrixFn,
r_matrix: MatrixFn,
hamiltonian: ScalarFn,
grad_hamiltonian: Option<VectorFn>,
b_matrix: MatrixFn,
config: PortHamiltonianConfig,
}
impl PortHamiltonianSystem {
pub fn new(
n_states: usize,
n_ports: usize,
j_fn: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
r_fn: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
hamiltonian: impl Fn(&[f64]) -> IntegrateResult<f64> + Send + Sync + 'static,
b_matrix: Array2<f64>,
) -> Self {
let b = b_matrix.clone();
Self {
n_states,
n_ports,
j_matrix: Box::new(j_fn),
r_matrix: Box::new(r_fn),
hamiltonian: Box::new(hamiltonian),
grad_hamiltonian: None,
b_matrix: Box::new(move |_x| Ok(b.clone())),
config: PortHamiltonianConfig::default(),
}
}
pub fn with_grad_hamiltonian(
mut self,
grad_fn: impl Fn(&[f64]) -> IntegrateResult<Array1<f64>> + Send + Sync + 'static,
) -> Self {
self.grad_hamiltonian = Some(Box::new(grad_fn));
self
}
pub fn with_state_dependent_b(
mut self,
b_fn: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
) -> Self {
self.b_matrix = Box::new(b_fn);
self
}
pub fn with_config(mut self, config: PortHamiltonianConfig) -> Self {
self.config = config;
self
}
pub fn hamiltonian(&self, x: &[f64]) -> IntegrateResult<f64> {
(self.hamiltonian)(x)
}
pub fn grad_hamiltonian(&self, x: &[f64]) -> IntegrateResult<Array1<f64>> {
if let Some(ref grad_fn) = self.grad_hamiltonian {
return grad_fn(x);
}
let n = x.len();
let eps = self.config.grad_epsilon;
let mut grad = Array1::zeros(n);
let mut x_plus = x.to_vec();
let mut x_minus = x.to_vec();
for i in 0..n {
x_plus[i] = x[i] + eps;
x_minus[i] = x[i] - eps;
let h_plus = (self.hamiltonian)(&x_plus)?;
let h_minus = (self.hamiltonian)(&x_minus)?;
grad[i] = (h_plus - h_minus) / (2.0 * eps);
x_plus[i] = x[i];
x_minus[i] = x[i];
}
Ok(grad)
}
pub fn j_matrix(&self, x: &[f64]) -> IntegrateResult<Array2<f64>> {
(self.j_matrix)(x)
}
pub fn r_matrix(&self, x: &[f64]) -> IntegrateResult<Array2<f64>> {
(self.r_matrix)(x)
}
pub fn b_matrix(&self, x: &[f64]) -> IntegrateResult<Array2<f64>> {
(self.b_matrix)(x)
}
pub fn rhs(&self, x: &[f64], u: &[f64]) -> IntegrateResult<Array1<f64>> {
let j = self.j_matrix(x)?;
let r = self.r_matrix(x)?;
let grad_h = self.grad_hamiltonian(x)?;
let b = self.b_matrix(x)?;
let jr = &j - &r;
let jr_grad = jr.dot(&grad_h);
let u_arr = Array1::from_vec(u.to_vec());
let b_u = b.dot(&u_arr);
Ok(jr_grad + b_u)
}
pub fn output(&self, x: &[f64]) -> IntegrateResult<Array1<f64>> {
let grad_h = self.grad_hamiltonian(x)?;
let b = self.b_matrix(x)?;
Ok(b.t().dot(&grad_h))
}
pub fn power_balance(&self, x: &[f64], u: &[f64]) -> IntegrateResult<(f64, f64)> {
let r = self.r_matrix(x)?;
let grad_h = self.grad_hamiltonian(x)?;
let y = self.output(x)?;
let dissipation = grad_h.dot(&r.dot(&grad_h));
let u_arr = Array1::from_vec(u.to_vec());
let supply_rate = y.dot(&u_arr);
Ok((dissipation, supply_rate))
}
pub fn validate_skew_symmetry(&self, x: &[f64]) -> IntegrateResult<bool> {
let j = self.j_matrix(x)?;
let jt = j.t().to_owned();
let sum = &j + &jt;
let max_err = sum
.iter()
.map(|v| v.abs())
.fold(0.0_f64, f64::max);
Ok(max_err <= self.config.skew_sym_tol)
}
pub fn validate_psd(&self, x: &[f64]) -> IntegrateResult<bool> {
let r = self.r_matrix(x)?;
let n = r.nrows();
for i in 0..n {
for j in 0..n {
if (r[[i, j]] - r[[j, i]]).abs() > 1e-10 {
return Ok(false);
}
}
}
for i in 0..n {
let diag = r[[i, i]];
let off_sum: f64 = (0..n)
.filter(|&j| j != i)
.map(|j| r[[i, j]].abs())
.sum();
if diag - off_sum < self.config.psd_tol {
return Ok(false);
}
}
Ok(true)
}
}
impl std::fmt::Debug for PortHamiltonianSystem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PortHamiltonianSystem")
.field("n_states", &self.n_states)
.field("n_ports", &self.n_ports)
.field("config", &self.config)
.finish()
}
}
pub struct PortHamiltonianBuilder {
n_states: usize,
n_ports: usize,
j_fn: Option<MatrixFn>,
r_fn: Option<MatrixFn>,
hamiltonian: Option<ScalarFn>,
grad_hamiltonian: Option<VectorFn>,
b_matrix: Option<Array2<f64>>,
b_fn: Option<MatrixFn>,
config: PortHamiltonianConfig,
}
impl PortHamiltonianBuilder {
pub fn new(n_states: usize, n_ports: usize) -> Self {
Self {
n_states,
n_ports,
j_fn: None,
r_fn: None,
hamiltonian: None,
grad_hamiltonian: None,
b_matrix: None,
b_fn: None,
config: PortHamiltonianConfig::default(),
}
}
pub fn with_j(
mut self,
j: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
) -> Self {
self.j_fn = Some(Box::new(j));
self
}
pub fn with_r(
mut self,
r: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
) -> Self {
self.r_fn = Some(Box::new(r));
self
}
pub fn with_hamiltonian(
mut self,
h: impl Fn(&[f64]) -> IntegrateResult<f64> + Send + Sync + 'static,
) -> Self {
self.hamiltonian = Some(Box::new(h));
self
}
pub fn with_grad_hamiltonian(
mut self,
gh: impl Fn(&[f64]) -> IntegrateResult<Array1<f64>> + Send + Sync + 'static,
) -> Self {
self.grad_hamiltonian = Some(Box::new(gh));
self
}
pub fn with_b(mut self, b: Array2<f64>) -> Self {
self.b_matrix = Some(b);
self
}
pub fn with_b_fn(
mut self,
b: impl Fn(&[f64]) -> IntegrateResult<Array2<f64>> + Send + Sync + 'static,
) -> Self {
self.b_fn = Some(Box::new(b));
self
}
pub fn with_config(mut self, config: PortHamiltonianConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> IntegrateResult<PortHamiltonianSystem> {
let j_fn = self.j_fn.ok_or_else(|| {
IntegrateError::ValueError("J matrix function is required".into())
})?;
let r_fn = self.r_fn.ok_or_else(|| {
IntegrateError::ValueError("R matrix function is required".into())
})?;
let hamiltonian = self.hamiltonian.ok_or_else(|| {
IntegrateError::ValueError("Hamiltonian function is required".into())
})?;
let b_matrix_fn: MatrixFn = if let Some(b_fn) = self.b_fn {
b_fn
} else {
let b = self.b_matrix.ok_or_else(|| {
IntegrateError::ValueError("B matrix (constant or function) is required".into())
})?;
Box::new(move |_x| Ok(b.clone()))
};
let mut system = PortHamiltonianSystem {
n_states: self.n_states,
n_ports: self.n_ports,
j_matrix: j_fn,
r_matrix: r_fn,
hamiltonian,
grad_hamiltonian: self.grad_hamiltonian,
b_matrix: b_matrix_fn,
config: self.config,
};
if let Some(gh) = system.grad_hamiltonian.take() {
system.grad_hamiltonian = Some(gh);
}
Ok(system)
}
}