use std::fmt;
use std::sync::Arc;
use crate::autograd::higher_order::grad;
use crate::autograd::no_grad::no_grad;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
pub fn fixed_point<T, F>(
f: F,
x0: &Tensor<T>,
params: &[&Tensor<T>],
max_iter: usize,
tol: f64,
) -> FerrotorchResult<Tensor<T>>
where
T: Float,
F: Fn(&Tensor<T>, &[&Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync + 'static,
{
let x_star = no_grad(|| -> FerrotorchResult<Tensor<T>> {
let mut x = x0.clone();
for _ in 0..max_iter {
let x_next = f(&x, params)?;
let x_data = x.data()?;
let x_next_data = x_next.data()?;
let norm: f64 = x_data
.iter()
.zip(x_next_data.iter())
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs())
.sum();
if norm < tol {
return Ok::<Tensor<T>, FerrotorchError>(x_next);
}
x = x_next;
}
Ok(x) })?;
if params.iter().any(|p| p.requires_grad()) {
let x_star_data = x_star.data()?.to_vec();
let x_star_shape = x_star.shape().to_vec();
let storage = TensorStorage::cpu(x_star_data);
let params_owned: Vec<Tensor<T>> = params.iter().map(|p| (*p).clone()).collect();
Tensor::from_operation(
storage,
x_star_shape,
Arc::new(FixedPointBackward {
f_closure: Arc::new(f),
x_star: x_star.clone(),
params: params_owned,
backward_max_iter: max_iter.min(50), backward_tol: tol,
}),
)
} else {
Ok(x_star)
}
}
struct FixedPointBackward<T: Float> {
f_closure: Arc<dyn Fn(&Tensor<T>, &[&Tensor<T>]) -> FerrotorchResult<Tensor<T>> + Send + Sync>,
x_star: Tensor<T>,
params: Vec<Tensor<T>>,
backward_max_iter: usize,
backward_tol: f64,
}
impl<T: Float> fmt::Debug for FixedPointBackward<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FixedPointBackward")
.field("x_star_shape", &self.x_star.shape())
.field("num_params", &self.params.len())
.field("backward_max_iter", &self.backward_max_iter)
.field("backward_tol", &self.backward_tol)
.finish()
}
}
impl<T: Float> GradFn<T> for FixedPointBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let n = self.x_star.numel();
let num_params = self.params.len();
let go_data = grad_output.data()?.to_vec();
let go_shape = grad_output.shape().to_vec();
let mut v_data = go_data.clone();
for _ in 0..self.backward_max_iter {
let x_fresh = Tensor::from_storage(
TensorStorage::cpu(self.x_star.data()?.to_vec()),
self.x_star.shape().to_vec(),
true,
)?;
let params_detached: Vec<Tensor<T>> = self
.params
.iter()
.map(|p| {
Tensor::from_storage(
TensorStorage::cpu(p.data().unwrap().to_vec()),
p.shape().to_vec(),
false,
)
.unwrap()
})
.collect();
let params_ref: Vec<&Tensor<T>> = params_detached.iter().collect();
let y = (self.f_closure)(&x_fresh, ¶ms_ref)?;
let v_tensor = Tensor::from_storage(
TensorStorage::cpu(v_data.clone()),
go_shape.clone(),
false,
)?;
let yv = elementwise_mul_sum(&y, &v_tensor)?;
let grads = grad(&yv, &[&x_fresh], false, false)?;
let jt_v = match &grads[0] {
Some(g) => g.data()?.to_vec(),
None => vec![<T as num_traits::Zero>::zero(); n],
};
let mut v_new = Vec::with_capacity(n);
let mut diff_norm: f64 = 0.0;
for i in 0..n {
let val = T::from(go_data[i].to_f64().unwrap() + jt_v[i].to_f64().unwrap())
.unwrap();
diff_norm += (val.to_f64().unwrap() - v_data[i].to_f64().unwrap()).abs();
v_new.push(val);
}
v_data = v_new;
if diff_norm < self.backward_tol {
break;
}
}
let x_detached = Tensor::from_storage(
TensorStorage::cpu(self.x_star.data()?.to_vec()),
self.x_star.shape().to_vec(),
false,
)?;
let params_with_grad: Vec<Tensor<T>> = self
.params
.iter()
.map(|p| {
Tensor::from_storage(
TensorStorage::cpu(p.data().unwrap().to_vec()),
p.shape().to_vec(),
p.requires_grad(),
)
.unwrap()
})
.collect();
let params_ref: Vec<&Tensor<T>> = params_with_grad.iter().collect();
let y = (self.f_closure)(&x_detached, ¶ms_ref)?;
let v_tensor = Tensor::from_storage(
TensorStorage::cpu(v_data),
go_shape,
false,
)?;
let loss = elementwise_mul_sum(&y, &v_tensor)?;
let grad_inputs: Vec<&Tensor<T>> = params_with_grad
.iter()
.filter(|p| p.requires_grad())
.collect();
let mut result: Vec<Option<Tensor<T>>> = Vec::with_capacity(num_params);
if grad_inputs.is_empty() {
for _ in 0..num_params {
result.push(None);
}
} else {
let param_grads = grad(&loss, &grad_inputs[..], false, false)?;
let mut grad_idx = 0;
for p in ¶ms_with_grad {
if p.requires_grad() {
result.push(param_grads[grad_idx].clone());
grad_idx += 1;
} else {
result.push(None);
}
}
}
Ok(result)
}
fn inputs(&self) -> Vec<&Tensor<T>> {
self.params.iter().collect()
}
fn name(&self) -> &'static str {
"FixedPointBackward"
}
}
fn elementwise_mul_sum<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let prod = crate::grad_fns::arithmetic::mul(a, b)?;
crate::grad_fns::reduction::sum(&prod)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::graph::backward;
use crate::storage::TensorStorage;
fn leaf_scalar(val: f32, requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], requires_grad).unwrap()
}
fn assert_approx(actual: f32, expected: f32, tol: f32, msg: &str) {
assert!(
(actual - expected).abs() < tol,
"{msg}: expected {expected}, got {actual}"
);
}
#[test]
fn test_fixed_point_affine() {
let x0 = leaf_scalar(0.0, false);
let dummy_param = leaf_scalar(1.0, false);
let x_star = fixed_point(
|x, _params| {
let half = Tensor::from_storage(
TensorStorage::cpu(vec![0.5f32]),
vec![],
false,
)?;
let one = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32]),
vec![],
false,
)?;
let half_x = crate::grad_fns::arithmetic::mul(x, &half)?;
crate::grad_fns::arithmetic::add(&half_x, &one)
},
&x0,
&[&dummy_param],
1000,
1e-8,
)
.unwrap();
assert_approx(x_star.item().unwrap(), 2.0, 1e-4, "fixed point of 0.5x + 1");
}
#[test]
fn test_fixed_point_contractive_to_zero() {
let x0 = leaf_scalar(10.0, false);
let a = leaf_scalar(0.5, false);
let x_star = fixed_point(
|x, params| crate::grad_fns::arithmetic::mul(x, params[0]),
&x0,
&[&a],
1000,
1e-8,
)
.unwrap();
assert_approx(x_star.item().unwrap(), 0.0, 1e-4, "fixed point of 0.5*x");
}
#[test]
fn test_fixed_point_tolerance() {
let x0 = leaf_scalar(0.0, false);
let dummy_param = leaf_scalar(1.0, false);
let x_star = fixed_point(
|x, _params| {
let half = Tensor::from_storage(
TensorStorage::cpu(vec![0.5f32]),
vec![],
false,
)?;
let one = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32]),
vec![],
false,
)?;
let half_x = crate::grad_fns::arithmetic::mul(x, &half)?;
crate::grad_fns::arithmetic::add(&half_x, &one)
},
&x0,
&[&dummy_param],
1000,
0.1, )
.unwrap();
let val = x_star.item().unwrap();
assert!(
(val - 2.0).abs() < 0.2,
"loose tolerance: expected near 2.0, got {val}"
);
}
#[test]
fn test_fixed_point_max_iter_reached() {
let x0 = leaf_scalar(0.0, false);
let dummy_param = leaf_scalar(1.0, false);
let x_star = fixed_point(
|x, _params| {
let scale = Tensor::from_storage(
TensorStorage::cpu(vec![0.99f32]),
vec![],
false,
)?;
let bias = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32]),
vec![],
false,
)?;
let sx = crate::grad_fns::arithmetic::mul(x, &scale)?;
crate::grad_fns::arithmetic::add(&sx, &bias)
},
&x0,
&[&dummy_param],
5, 1e-10,
)
.unwrap();
let val = x_star.item().unwrap();
assert!(val > 0.0, "should have made some progress from x0=0");
assert!(val < 1.0, "should not have reached x*=1 in 5 iterations");
}
#[test]
fn test_fixed_point_gradient_affine() {
let x0 = leaf_scalar(0.0, false);
let b = leaf_scalar(3.0, true);
let x_star = fixed_point(
|x, params| {
let half = Tensor::from_storage(
TensorStorage::cpu(vec![0.5f32]),
vec![],
false,
)?;
let half_x = crate::grad_fns::arithmetic::mul(x, &half)?;
crate::grad_fns::arithmetic::add(&half_x, params[0])
},
&x0,
&[&b],
1000,
1e-8,
)
.unwrap();
assert_approx(x_star.item().unwrap(), 6.0, 1e-3, "x* = 2b = 6");
backward(&x_star).unwrap();
let grad_b = b.grad().unwrap().unwrap();
assert_approx(grad_b.item().unwrap(), 2.0, 0.2, "dx*/db = 2");
}
#[test]
fn test_fixed_point_gradient_scaling() {
let x0 = leaf_scalar(10.0, false);
let a = leaf_scalar(0.5, true);
let x_star = fixed_point(
|x, params| crate::grad_fns::arithmetic::mul(x, params[0]),
&x0,
&[&a],
1000,
1e-8,
)
.unwrap();
assert_approx(x_star.item().unwrap(), 0.0, 1e-3, "x* = 0");
backward(&x_star).unwrap();
let grad_a = a.grad().unwrap().unwrap();
assert_approx(grad_a.item().unwrap(), 0.0, 0.1, "dx*/da = 0");
}
#[test]
fn test_fixed_point_no_grad_params() {
let x0 = leaf_scalar(0.0, false);
let b = leaf_scalar(3.0, false);
let x_star = fixed_point(
|x, params| {
let half = Tensor::from_storage(
TensorStorage::cpu(vec![0.5f32]),
vec![],
false,
)?;
let half_x = crate::grad_fns::arithmetic::mul(x, &half)?;
crate::grad_fns::arithmetic::add(&half_x, params[0])
},
&x0,
&[&b],
1000,
1e-8,
)
.unwrap();
assert_approx(x_star.item().unwrap(), 6.0, 1e-3, "x* = 2b = 6");
assert!(
x_star.grad_fn().is_none(),
"no grad_fn when params don't require grad"
);
}
}