use scirs2_core::ndarray::Array1;
use crate::NeuralError;
#[derive(Debug, Clone)]
pub struct RationalConfig {
pub p_degree: usize,
pub q_degree: usize,
}
impl Default for RationalConfig {
fn default() -> Self {
Self {
p_degree: 4,
q_degree: 4,
}
}
}
impl RationalConfig {
pub fn validate(&self) -> Result<(), NeuralError> {
if self.p_degree == 0 && self.q_degree == 0 {
return Err(NeuralError::InvalidArgument(
"At least one of p_degree or q_degree must be > 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RationalActivation {
pub p_coeffs: Array1<f64>,
pub q_coeffs: Array1<f64>,
config: RationalConfig,
}
impl RationalActivation {
pub fn new(config: &RationalConfig) -> Result<Self, NeuralError> {
config.validate()?;
let p = Array1::zeros(config.p_degree + 1);
let q = Array1::zeros(config.q_degree + 1);
Ok(Self {
p_coeffs: p,
q_coeffs: q,
config: config.clone(),
})
}
fn eval_poly(coeffs: &Array1<f64>, x: f64) -> f64 {
let mut result = 0.0f64;
for &c in coeffs.iter().rev() {
result = result * x + c;
}
result
}
pub fn evaluate(&self, x: f64) -> f64 {
let p = Self::eval_poly(&self.p_coeffs, x);
let q_raw = Self::eval_poly(&self.q_coeffs, x);
let q = 1.0 + q_raw.abs();
p / q
}
pub fn evaluate_batch(&self, xs: &Array1<f64>) -> Array1<f64> {
xs.iter().map(|&x| self.evaluate(x)).collect()
}
pub fn n_params(&self) -> usize {
self.p_coeffs.len() + self.q_coeffs.len()
}
pub fn config(&self) -> &RationalConfig {
&self.config
}
pub fn grad_p_coeffs(&self, x: f64) -> Array1<f64> {
let q_raw = Self::eval_poly(&self.q_coeffs, x);
let denom = 1.0 + q_raw.abs();
let n = self.p_coeffs.len();
let mut grad = Array1::zeros(n);
let mut xi = 1.0f64;
for i in 0..n {
grad[i] = xi / denom;
xi *= x;
}
grad
}
pub fn grad_q_coeffs(&self, x: f64) -> Array1<f64> {
let p = Self::eval_poly(&self.p_coeffs, x);
let q_raw = Self::eval_poly(&self.q_coeffs, x);
let denom = 1.0 + q_raw.abs();
let sign_q = if q_raw >= 0.0 { 1.0 } else { -1.0 };
let factor = -p * sign_q / (denom * denom);
let m = self.q_coeffs.len();
let mut grad = Array1::zeros(m);
let mut xi = 1.0f64;
for i in 0..m {
grad[i] = factor * xi;
xi *= x;
}
grad
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_config() -> RationalConfig {
RationalConfig::default()
}
#[test]
fn zero_coefficients_output_zero() {
let act = RationalActivation::new(&default_config()).expect("valid config");
for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
let val = act.evaluate(x);
assert!(val.abs() < 1e-14, "Expected 0 at x={x}, got {val}");
}
}
#[test]
fn constant_numerator() {
let mut act = RationalActivation::new(&default_config()).expect("valid config");
act.p_coeffs[0] = 1.0; for &x in &[-2.0, -1.0, 0.0, 1.0, 2.0] {
let val = act.evaluate(x);
assert!(
(val - 1.0).abs() < 1e-14,
"Expected 1 at x={x}, got {val}"
);
}
}
#[test]
fn denominator_stays_positive() {
let mut act = RationalActivation::new(&default_config()).expect("valid config");
for (i, c) in act.q_coeffs.iter_mut().enumerate() {
*c = (i as f64 + 1.0) * 0.7;
}
act.p_coeffs[0] = 1.0;
for i in -50i64..=50 {
let x = i as f64 * 0.1;
let q_raw = RationalActivation::eval_poly(&act.q_coeffs, x);
let denom = 1.0 + q_raw.abs();
assert!(
denom > 0.0,
"Denominator non-positive at x={x}: {denom}"
);
let val = act.evaluate(x);
assert!(val.is_finite(), "Non-finite output at x={x}: {val}");
}
}
#[test]
fn batch_matches_element_wise() {
let mut act = RationalActivation::new(&default_config()).expect("valid config");
for (i, c) in act.p_coeffs.iter_mut().enumerate() {
*c = (i as f64 * 0.4 + 0.1).sin();
}
for (i, c) in act.q_coeffs.iter_mut().enumerate() {
*c = (i as f64 * 0.2 - 0.3).cos() * 0.5;
}
let xs = Array1::from_vec(vec![-1.5, -0.7, 0.0, 0.3, 1.1, 2.0]);
let batch_out = act.evaluate_batch(&xs);
for (i, &x) in xs.iter().enumerate() {
let single = act.evaluate(x);
assert!(
(single - batch_out[i]).abs() < 1e-14,
"Mismatch at i={i}: single={single}, batch={}",
batch_out[i]
);
}
}
}