use scirs2_core::ndarray::{Array1, Array2};
use crate::{NeuralError, Result};
#[derive(Debug, Clone)]
pub struct Ia3Config {
pub scale_keys: bool,
pub scale_values: bool,
pub scale_ffn: bool,
}
impl Default for Ia3Config {
fn default() -> Self {
Self {
scale_keys: true,
scale_values: true,
scale_ffn: true,
}
}
}
pub struct Ia3Adapter {
pub scale: Array1<f64>,
config: Ia3Config,
dim: usize,
}
impl Ia3Adapter {
pub fn new(dim: usize, config: Ia3Config) -> Self {
Self {
scale: Array1::ones(dim),
config,
dim,
}
}
pub fn forward(&self, input: &Array1<f64>) -> Result<Array1<f64>> {
if input.len() != self.dim {
return Err(NeuralError::DimensionMismatch(format!(
"IA³ adapter expects dimension {}, got {}",
self.dim,
input.len()
)));
}
Ok(&self.scale * input)
}
pub fn forward_batch(&self, input: &Array2<f64>) -> Result<Array2<f64>> {
if input.ncols() != self.dim {
return Err(NeuralError::DimensionMismatch(format!(
"IA³ adapter expects input width {}, got {}",
self.dim,
input.ncols()
)));
}
let mut out = input.clone();
for mut row in out.rows_mut() {
for (x, &s) in row.iter_mut().zip(self.scale.iter()) {
*x *= s;
}
}
Ok(out)
}
pub fn merge_into_weight_rows(&self, weight: &Array2<f64>) -> Result<Array2<f64>> {
if weight.nrows() != self.dim {
return Err(NeuralError::DimensionMismatch(format!(
"IA³ merge_into_weight_rows: weight has {} rows, scale has {} elements",
weight.nrows(),
self.dim
)));
}
let mut out = weight.clone();
for (i, mut row) in out.rows_mut().into_iter().enumerate() {
let s = self.scale[i];
row.mapv_inplace(|v| v * s);
}
Ok(out)
}
pub fn merge_into_weight_cols(&self, weight: &Array2<f64>) -> Result<Array2<f64>> {
if weight.ncols() != self.dim {
return Err(NeuralError::DimensionMismatch(format!(
"IA³ merge_into_weight_cols: weight has {} cols, scale has {} elements",
weight.ncols(),
self.dim
)));
}
let mut out = weight.clone();
for mut row in out.rows_mut() {
for (x, &s) in row.iter_mut().zip(self.scale.iter()) {
*x *= s;
}
}
Ok(out)
}
pub fn n_params(&self) -> usize {
self.dim
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn config(&self) -> &Ia3Config {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
#[test]
fn ia3_forward_shape() {
let adapter = Ia3Adapter::new(8, Ia3Config::default());
let x = Array1::from_elem(8, 1.0_f64);
let y = adapter.forward(&x).expect("forward");
assert_eq!(y.len(), 8);
}
#[test]
fn ia3_ones_identity() {
let adapter = Ia3Adapter::new(6, Ia3Config::default());
let x = Array1::from_shape_fn(6, |i| (i + 1) as f64);
let y = adapter.forward(&x).expect("forward");
for (a, b) in y.iter().zip(x.iter()) {
assert!((a - b).abs() < 1e-14, "identity broken: {a} != {b}");
}
}
#[test]
fn ia3_scaling_correct() {
let mut adapter = Ia3Adapter::new(4, Ia3Config::default());
for i in 0..4 {
adapter.scale[i] = (i + 1) as f64;
}
let x = Array1::ones(4);
let y = adapter.forward(&x).expect("forward");
for i in 0..4 {
let expected = (i + 1) as f64;
assert!(
(y[i] - expected).abs() < 1e-14,
"scale mismatch at {i}: expected {expected}, got {}",
y[i]
);
}
}
#[test]
fn ia3_wrong_dim_returns_error() {
let adapter = Ia3Adapter::new(4, Ia3Config::default());
let x = Array1::ones(5); assert!(adapter.forward(&x).is_err());
}
#[test]
fn ia3_merge_into_weight_rows() {
let mut adapter = Ia3Adapter::new(3, Ia3Config::default());
adapter.scale = Array1::from_vec(vec![2.0, 3.0, 4.0]);
let w = Array2::from_shape_fn((3, 4), |(_, j)| (j + 1) as f64);
let merged = adapter.merge_into_weight_rows(&w).expect("merge");
for j in 0..4 {
assert!((merged[[0, j]] - 2.0 * w[[0, j]]).abs() < 1e-14);
assert!((merged[[1, j]] - 3.0 * w[[1, j]]).abs() < 1e-14);
assert!((merged[[2, j]] - 4.0 * w[[2, j]]).abs() < 1e-14);
}
}
#[test]
fn ia3_batch_forward() {
let mut adapter = Ia3Adapter::new(4, Ia3Config::default());
adapter.scale = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let input = Array2::ones((3, 4));
let out = adapter.forward_batch(&input).expect("batch forward");
assert_eq!(out.shape(), &[3, 4]);
for b in 0..3 {
for i in 0..4 {
let expected = (i + 1) as f64;
assert!(
(out[[b, i]] - expected).abs() < 1e-14,
"batch [{b},{i}]: expected {expected}, got {}",
out[[b, i]]
);
}
}
}
#[test]
fn ia3_n_params() {
let adapter = Ia3Adapter::new(64, Ia3Config::default());
assert_eq!(adapter.n_params(), 64);
}
}