use crate::autograd::context::{is_anomaly_detection_enabled, is_grad_enabled};
use crate::autograd::{GradFn, Variable};
use crate::error::RusTorchError;
use crate::tensor::Tensor;
use num_traits::Float;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
pub type GradError = RusTorchError;
pub fn grad<T>(
outputs: &[Variable<T>],
inputs: &[Variable<T>],
grad_outputs: Option<&[Tensor<T>]>,
retain_graph: bool,
create_graph: bool,
) -> Result<Vec<Option<Tensor<T>>>, GradError>
where
T: Float
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive
+ std::fmt::Debug,
{
if !is_grad_enabled() {
return Err(RusTorchError::Autograd {
message: "Gradient computation is disabled. Use enable_grad() context manager."
.to_string(),
});
}
if outputs.is_empty() {
return Err(RusTorchError::InvalidParameters {
operation: "grad".to_string(),
message: "At least one output must be provided".to_string(),
});
}
if inputs.is_empty() {
return Err(RusTorchError::InvalidParameters {
operation: "grad".to_string(),
message: "At least one input must be provided".to_string(),
});
}
if grad_outputs.is_none() {
for (i, output) in outputs.iter().enumerate() {
let output_data_guard = output.data();
let output_data = output_data_guard.read().unwrap();
if output_data.numel() != 1 {
return Err(RusTorchError::InvalidParameters {
operation: "grad".to_string(),
message: format!("Output {} is not scalar and no grad_output provided", i),
});
}
}
}
let initial_grads = if let Some(grad_outputs) = grad_outputs {
if grad_outputs.len() != outputs.len() {
return Err(RusTorchError::ShapeMismatch {
expected: vec![outputs.len()],
actual: vec![grad_outputs.len()],
});
}
grad_outputs.to_vec()
} else {
outputs
.iter()
.map(|output| {
let output_data_guard = output.data();
let output_data = output_data_guard.read().unwrap();
if output_data.numel() == 1 {
Tensor::ones(output_data.shape())
} else {
Tensor::ones(&[]) }
})
.collect()
};
if !retain_graph {
for input in inputs {
input.zero_grad();
}
}
for (output, initial_grad) in outputs.iter().zip(initial_grads.iter()) {
if output.requires_grad() {
output.backward_with_grad(Some(initial_grad.clone()));
}
}
let mut result_gradients = Vec::new();
for input in inputs {
if input.requires_grad() {
let grad_arc = input.grad();
let grad_guard = grad_arc.read().unwrap();
result_gradients.push(grad_guard.clone());
} else {
result_gradients.push(None);
}
}
if is_anomaly_detection_enabled() {
for (i, grad_opt) in result_gradients.iter().enumerate() {
if let Some(grad) = grad_opt {
let grad_data = grad.as_array();
for &val in grad_data.iter() {
if val.is_nan() {
return Err(RusTorchError::Autograd {
message: format!("NaN detected in gradient for input {}", i),
});
}
if val.is_infinite() {
return Err(RusTorchError::Autograd {
message: format!("Infinity detected in gradient for input {}", i),
});
}
}
}
}
}
Ok(result_gradients)
}
pub fn gradient<T, F>(
func: F,
inputs: &[Variable<T>],
create_graph: bool,
) -> Result<Vec<Option<Tensor<T>>>, GradError>
where
T: Float
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive
+ std::fmt::Debug,
F: FnOnce(&[Variable<T>]) -> Variable<T>,
{
let mut grad_inputs = Vec::new();
for input in inputs {
let input_data = input.data().read().unwrap().clone();
let grad_input = Variable::new(input_data, true);
grad_inputs.push(grad_input);
}
let output = func(&grad_inputs);
let output_data_guard = output.data();
let output_data = output_data_guard.read().unwrap();
if output_data.numel() != 1 {
return Err(RusTorchError::InvalidParameters {
operation: "gradient".to_string(),
message: "Function output must be scalar for gradient computation".to_string(),
});
}
grad(&[output], &grad_inputs, None, false, create_graph)
}
pub fn is_variable_in_graph<T>(var: &Variable<T>, visited: &mut HashSet<usize>) -> bool
where
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
let var_id = var.id();
if visited.contains(&var_id) {
return true;
}
visited.insert(var_id);
if let Some(grad_fn) = var.grad_fn() {
true
} else {
false
}
}
pub fn validate_grad_setup<T>(
outputs: &[Variable<T>],
inputs: &[Variable<T>],
) -> Result<(), GradError>
where
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
if !outputs.iter().any(|output| output.requires_grad()) {
return Err(RusTorchError::InvalidParameters {
operation: "validate_grad_setup".to_string(),
message: "At least one output must require gradients".to_string(),
});
}
if !inputs.iter().any(|input| input.requires_grad()) {
return Err(RusTorchError::InvalidParameters {
operation: "validate_grad_setup".to_string(),
message: "At least one input must require gradients".to_string(),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::Variable;
use crate::tensor::Tensor;
#[test]
fn test_grad_simple() {
let x = Variable::new(Tensor::from_vec(vec![3.0f32], vec![1]), true);
let y = &x * &x;
let gradients = grad(&[y], &[x.clone()], None, false, false).unwrap();
assert!(gradients[0].is_some());
let grad_val = gradients[0].as_ref().unwrap().as_array()[0];
assert!((grad_val - 6.0).abs() < 1e-6); }
#[test]
fn test_grad_multiple_inputs() {
let x = Variable::new(Tensor::from_vec(vec![2.0f32], vec![1]), true);
let y = Variable::new(Tensor::from_vec(vec![3.0f32], vec![1]), true);
let z = &x * &y;
let gradients = grad(&[z], &[x.clone(), y.clone()], None, false, false).unwrap();
assert!(gradients[0].is_some());
assert!(gradients[1].is_some());
let grad_x = gradients[0].as_ref().unwrap().as_array()[0];
let grad_y = gradients[1].as_ref().unwrap().as_array()[0];
assert!((grad_x - 3.0).abs() < 1e-6); assert!((grad_y - 2.0).abs() < 1e-6); }
#[test]
fn test_gradient_function() {
let inputs = vec![
Variable::new(Tensor::from_vec(vec![2.0f32], vec![1]), true),
Variable::new(Tensor::from_vec(vec![3.0f32], vec![1]), true),
];
let gradients = gradient(
|vars| &vars[0] * &vars[0] + &vars[1] * &vars[1], &inputs,
false,
)
.unwrap();
assert!(gradients[0].is_some());
assert!(gradients[1].is_some());
let grad_x = gradients[0].as_ref().unwrap().as_array()[0];
let grad_y = gradients[1].as_ref().unwrap().as_array()[0];
assert!((grad_x - 4.0).abs() < 1e-6); assert!((grad_y - 6.0).abs() < 1e-6); }
}