use crate::error::{Result, SklearsError};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use syn::{Attribute, Expr, FnArg, ItemFn, ReturnType, Stmt, Type};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Dual {
pub real: f64,
pub dual: f64,
}
impl Dual {
pub fn new(real: f64, dual: f64) -> Self {
Self { real, dual }
}
pub fn variable(value: f64) -> Self {
Self::new(value, 1.0)
}
pub fn constant(value: f64) -> Self {
Self::new(value, 0.0)
}
pub fn value(&self) -> f64 {
self.real
}
pub fn derivative(&self) -> f64 {
self.dual
}
}
impl std::ops::Add for Dual {
type Output = Self;
fn add(self, other: Self) -> Self {
Self::new(self.real + other.real, self.dual + other.dual)
}
}
impl std::ops::Sub for Dual {
type Output = Self;
fn sub(self, other: Self) -> Self {
Self::new(self.real - other.real, self.dual - other.dual)
}
}
impl std::ops::Mul for Dual {
type Output = Self;
fn mul(self, other: Self) -> Self {
Self::new(
self.real * other.real,
self.real * other.dual + self.dual * other.real,
)
}
}
impl std::ops::Div for Dual {
type Output = Self;
fn div(self, other: Self) -> Self {
let inv_other_real = 1.0 / other.real;
Self::new(
self.real * inv_other_real,
(self.dual * other.real - self.real * other.dual) * inv_other_real * inv_other_real,
)
}
}
#[derive(Debug, Clone)]
pub struct Variable {
pub id: VariableId,
pub value: f64,
pub gradient: f64,
pub node: Option<Arc<ComputationNode>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VariableId(pub u64);
impl Variable {
pub fn new(value: f64) -> Self {
static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let id = VariableId(NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst));
Self {
id,
value,
gradient: 0.0,
node: None,
}
}
pub fn with_graph(value: f64, tape: Arc<Mutex<ComputationTape>>) -> Self {
let mut var = Self::new(value);
let node = ComputationNode {
operation: Operation::Input,
inputs: Vec::new(),
output_id: var.id,
gradient_fn: Box::new(|_inputs, _output_grad| Vec::new()),
};
var.node = Some(Arc::new(node));
if let Ok(mut tape_guard) = tape.lock() {
tape_guard.add_node(var.node.as_ref().expect("value should be present").clone());
}
var
}
pub fn set_gradient(&mut self, gradient: f64) {
self.gradient = gradient;
}
pub fn add_gradient(&mut self, gradient: f64) {
self.gradient += gradient;
}
pub fn zero_gradient(&mut self) {
self.gradient = 0.0;
}
}
pub type GradientFunction = Box<dyn Fn(&[f64], f64) -> Vec<f64> + Send + Sync>;
pub struct ComputationNode {
pub operation: Operation,
pub inputs: Vec<VariableId>,
pub output_id: VariableId,
pub gradient_fn: GradientFunction,
}
impl std::fmt::Debug for ComputationNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComputationNode")
.field("operation", &self.operation)
.field("inputs", &self.inputs)
.field("output_id", &self.output_id)
.field("gradient_fn", &"<function>")
.finish()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum Operation {
Input,
Add,
Sub,
Mul,
Div,
Pow,
Exp,
Ln,
Sin,
Cos,
Tanh,
Sigmoid,
ReLU,
Custom(String),
}
#[derive(Debug)]
pub struct ComputationTape {
pub nodes: Vec<Arc<ComputationNode>>,
pub variables: HashMap<VariableId, Variable>,
pub execution_order: Vec<VariableId>,
}
impl ComputationTape {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
variables: HashMap::new(),
execution_order: Vec::new(),
}
}
pub fn add_node(&mut self, node: Arc<ComputationNode>) {
self.execution_order.push(node.output_id);
self.nodes.push(node);
}
pub fn register_variable(&mut self, var: Variable) {
self.variables.insert(var.id, var);
}
pub fn backward(&mut self, root_gradient: f64) -> Result<()> {
for var in self.variables.values_mut() {
var.zero_gradient();
}
if let Some(root_id) = self.execution_order.last() {
if let Some(root_var) = self.variables.get_mut(root_id) {
root_var.set_gradient(root_gradient);
}
}
for &node_id in self.execution_order.iter().rev() {
if let Some(node) = self.nodes.iter().find(|n| n.output_id == node_id) {
let output_gradient = self
.variables
.get(&node_id)
.map(|v| v.gradient)
.unwrap_or(0.0);
let input_values: Vec<f64> = node
.inputs
.iter()
.filter_map(|&id| self.variables.get(&id).map(|v| v.value))
.collect();
let input_gradients = (node.gradient_fn)(&input_values, output_gradient);
for (&input_id, &gradient) in node.inputs.iter().zip(input_gradients.iter()) {
if let Some(input_var) = self.variables.get_mut(&input_id) {
input_var.add_gradient(gradient);
}
}
}
}
Ok(())
}
pub fn get_gradient(&self, id: VariableId) -> Option<f64> {
self.variables.get(&id).map(|v| v.gradient)
}
pub fn clear(&mut self) {
self.nodes.clear();
self.variables.clear();
self.execution_order.clear();
}
}
impl Default for ComputationTape {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutodiffConfig {
pub mode: ADMode,
pub max_order: u32,
pub simd: bool,
pub gpu: bool,
pub symbolic: bool,
pub optimizations: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ADMode {
Forward,
Reverse,
Mixed,
Symbolic,
}
impl Default for AutodiffConfig {
fn default() -> Self {
Self {
mode: ADMode::Forward,
max_order: 1,
simd: false,
gpu: false,
symbolic: false,
optimizations: Vec::new(),
}
}
}
pub fn parse_autodiff_attributes(attrs: &[Attribute]) -> Result<AutodiffConfig> {
let mut config = AutodiffConfig::default();
for attr in attrs {
if attr.path().is_ident("autodiff") {
config.mode = ADMode::Forward; }
}
Ok(config)
}
pub fn generate_autodiff_impl(func: &ItemFn, config: &AutodiffConfig) -> Result<TokenStream> {
let original_name = &func.sig.ident;
let autodiff_name = syn::Ident::new(&format!("{}_autodiff", original_name), Span::call_site());
match config.mode {
ADMode::Forward => generate_forward_mode(func, &autodiff_name, config),
ADMode::Reverse => generate_reverse_mode(func, &autodiff_name, config),
ADMode::Mixed => generate_mixed_mode(func, &autodiff_name, config),
ADMode::Symbolic => generate_symbolic_mode(func, &autodiff_name, config),
}
}
fn generate_forward_mode(
func: &ItemFn,
autodiff_name: &syn::Ident,
_config: &AutodiffConfig,
) -> Result<TokenStream> {
let original_name = &func.sig.ident;
let inputs = &func.sig.inputs;
let output = &func.sig.output;
let dual_inputs = transform_inputs_to_dual(inputs)?;
let dual_output = transform_output_to_dual(output)?;
let dual_body = transform_body_to_dual(&func.block)?;
let generated = quote! {
pub fn #autodiff_name(#dual_inputs) -> #dual_output {
#dual_body
}
pub fn #original_name _derivative(x: f64) -> (f64, f64) {
let dual_x = Dual::variable(x);
let result = #autodiff_name(dual_x);
(result.value(), result.derivative())
}
};
Ok(generated)
}
fn generate_reverse_mode(
func: &ItemFn,
autodiff_name: &syn::Ident,
_config: &AutodiffConfig,
) -> Result<TokenStream> {
let original_name = &func.sig.ident;
let inputs = &func.sig.inputs;
let var_inputs = transform_inputs_to_variables(inputs)?;
let tape_body = transform_body_to_tape(&func.block)?;
let generated = quote! {
pub fn #autodiff_name(#var_inputs, tape: Arc<Mutex<ComputationTape>>) -> Variable {
#tape_body
}
pub fn #original_name _gradients(inputs: &[f64]) -> Vec<f64> {
let tape = Arc::new(Mutex::new(ComputationTape::new()));
let vars: Vec<Variable> = inputs.iter()
.map(|&x| Variable::with_graph(x, tape.clone()))
.collect();
let output = #autodiff_name(vars, tape.clone());
if let Ok(mut tape_guard) = tape.lock() {
let _ = tape_guard.backward(1.0);
vars.iter()
.map(|v| tape_guard.get_gradient(v.id).unwrap_or(0.0))
.collect()
} else {
vec![0.0; inputs.len()]
}
}
};
Ok(generated)
}
fn generate_mixed_mode(
func: &ItemFn,
autodiff_name: &syn::Ident,
config: &AutodiffConfig,
) -> Result<TokenStream> {
let forward_impl = generate_forward_mode(func, autodiff_name, config)?;
let reverse_name = syn::Ident::new(&format!("{}_reverse", autodiff_name), Span::call_site());
let reverse_impl = generate_reverse_mode(func, &reverse_name, config)?;
let generated = quote! {
#forward_impl
#reverse_impl
pub fn #autodiff_name _mixed(inputs: &[f64], forward_vars: &[usize]) -> (f64, Vec<f64>) {
let gradients = vec![0.0; inputs.len()];
(0.0, gradients)
}
};
Ok(generated)
}
fn generate_symbolic_mode(
func: &ItemFn,
autodiff_name: &syn::Ident,
_config: &AutodiffConfig,
) -> Result<TokenStream> {
let original_name = &func.sig.ident;
let generated = quote! {
pub fn #autodiff_name() -> SymbolicExpression {
SymbolicExpression::new("derivative")
}
pub fn #original_name _latex() -> String {
let expr = #autodiff_name();
expr.to_latex()
}
};
Ok(generated)
}
fn transform_inputs_to_dual(
inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
) -> Result<TokenStream> {
let mut dual_inputs = Vec::new();
for input in inputs {
match input {
FnArg::Typed(pat_type) => {
let pat = &pat_type.pat;
match &*pat_type.ty {
Type::Path(type_path) if type_path.path.is_ident("f64") => {
dual_inputs.push(quote! { #pat: Dual });
}
ty => {
dual_inputs.push(quote! { #pat: #ty });
}
}
}
_ => {
return Err(SklearsError::InvalidOperation(
"Unsupported function parameter type".to_string(),
));
}
}
}
Ok(quote! { #(#dual_inputs),* })
}
fn transform_output_to_dual(output: &ReturnType) -> Result<TokenStream> {
match output {
ReturnType::Type(_, ty) => match &**ty {
Type::Path(type_path) if type_path.path.is_ident("f64") => Ok(quote! { Dual }),
ty => Ok(quote! { #ty }),
},
ReturnType::Default => Ok(quote! { () }),
}
}
fn transform_body_to_dual(block: &syn::Block) -> Result<TokenStream> {
let mut transformed_stmts = Vec::new();
for stmt in &block.stmts {
let transformed = transform_statement_to_dual(stmt)?;
transformed_stmts.push(transformed);
}
Ok(quote! { { #(#transformed_stmts)* } })
}
fn transform_statement_to_dual(stmt: &Stmt) -> Result<TokenStream> {
match stmt {
Stmt::Expr(expr, _) => {
let transformed_expr = transform_expression_to_dual(expr)?;
Ok(quote! { #transformed_expr })
}
Stmt::Local(local) => {
let pat = &local.pat;
if let Some(local_init) = &local.init {
let init = &local_init.expr;
let transformed_init = transform_expression_to_dual(init)?;
Ok(quote! { let #pat = #transformed_init; })
} else {
Ok(quote! { #stmt })
}
}
_ => Ok(quote! { #stmt }),
}
}
fn transform_expression_to_dual(expr: &Expr) -> Result<TokenStream> {
match expr {
Expr::Binary(binary_expr) => {
let left = transform_expression_to_dual(&binary_expr.left)?;
let right = transform_expression_to_dual(&binary_expr.right)?;
let op = &binary_expr.op;
Ok(quote! { (#left) #op (#right) })
}
Expr::Call(call_expr) => {
let func = &call_expr.func;
let args: Vec<TokenStream> = call_expr
.args
.iter()
.map(transform_expression_to_dual)
.collect::<Result<Vec<_>>>()?;
match &**func {
Expr::Path(path) if path.path.is_ident("exp") => {
Ok(quote! { dual_exp(#(#args),*) })
}
Expr::Path(path) if path.path.is_ident("ln") => Ok(quote! { dual_ln(#(#args),*) }),
Expr::Path(path) if path.path.is_ident("sin") => {
Ok(quote! { dual_sin(#(#args),*) })
}
Expr::Path(path) if path.path.is_ident("cos") => {
Ok(quote! { dual_cos(#(#args),*) })
}
_ => Ok(quote! { #func(#(#args),*) }),
}
}
Expr::Lit(lit_expr) => {
match &lit_expr.lit {
syn::Lit::Float(float_lit) => {
let value = &float_lit.base10_digits();
let parsed_value: f64 = value.parse().map_err(|_| {
SklearsError::InvalidOperation("Invalid float literal".to_string())
})?;
Ok(quote! { Dual::constant(#parsed_value) })
}
syn::Lit::Int(int_lit) => {
let value = &int_lit.base10_digits();
let parsed_value: i64 = value.parse().map_err(|_| {
SklearsError::InvalidOperation("Invalid int literal".to_string())
})?;
Ok(quote! { Dual::constant(#parsed_value as f64) })
}
_ => Ok(quote! { #expr }),
}
}
_ => Ok(quote! { #expr }),
}
}
fn transform_inputs_to_variables(
inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
) -> Result<TokenStream> {
let mut var_inputs = Vec::new();
for input in inputs {
match input {
FnArg::Typed(pat_type) => {
let pat = &pat_type.pat;
match &*pat_type.ty {
Type::Path(type_path) if type_path.path.is_ident("f64") => {
var_inputs.push(quote! { #pat: Variable });
}
ty => {
var_inputs.push(quote! { #pat: #ty });
}
}
}
_ => {
return Err(SklearsError::InvalidOperation(
"Unsupported function parameter type".to_string(),
));
}
}
}
Ok(quote! { #(#var_inputs),* })
}
fn transform_body_to_tape(_block: &syn::Block) -> Result<TokenStream> {
Ok(quote! {
{
Variable::with_graph(0.0, tape)
}
})
}
pub fn dual_exp(x: Dual) -> Dual {
let exp_x = x.real.exp();
Dual::new(exp_x, x.dual * exp_x)
}
pub fn dual_ln(x: Dual) -> Dual {
Dual::new(x.real.ln(), x.dual / x.real)
}
pub fn dual_sin(x: Dual) -> Dual {
Dual::new(x.real.sin(), x.dual * x.real.cos())
}
pub fn dual_cos(x: Dual) -> Dual {
Dual::new(x.real.cos(), -x.dual * x.real.sin())
}
pub fn dual_tanh(x: Dual) -> Dual {
let tanh_x = x.real.tanh();
Dual::new(tanh_x, x.dual * (1.0 - tanh_x * tanh_x))
}
pub fn dual_sigmoid(x: Dual) -> Dual {
let sigmoid_x = 1.0 / (1.0 + (-x.real).exp());
Dual::new(sigmoid_x, x.dual * sigmoid_x * (1.0 - sigmoid_x))
}
pub fn dual_pow(base: Dual, exponent: f64) -> Dual {
let pow_result = base.real.powf(exponent);
Dual::new(
pow_result,
base.dual * exponent * base.real.powf(exponent - 1.0),
)
}
#[derive(Debug, Clone, PartialEq)]
pub enum SymbolicExpression {
Variable(String),
Constant(f64),
Add(Box<SymbolicExpression>, Box<SymbolicExpression>),
Sub(Box<SymbolicExpression>, Box<SymbolicExpression>),
Mul(Box<SymbolicExpression>, Box<SymbolicExpression>),
Div(Box<SymbolicExpression>, Box<SymbolicExpression>),
Pow(Box<SymbolicExpression>, Box<SymbolicExpression>),
Function(String, Vec<SymbolicExpression>),
}
impl SymbolicExpression {
pub fn new(name: &str) -> Self {
Self::Variable(name.to_string())
}
pub fn differentiate(&self, var: &str) -> Self {
match self {
SymbolicExpression::Variable(v) if v == var => SymbolicExpression::Constant(1.0),
SymbolicExpression::Variable(_) => SymbolicExpression::Constant(0.0),
SymbolicExpression::Constant(_) => SymbolicExpression::Constant(0.0),
SymbolicExpression::Add(left, right) => SymbolicExpression::Add(
Box::new(left.differentiate(var)),
Box::new(right.differentiate(var)),
),
SymbolicExpression::Sub(left, right) => SymbolicExpression::Sub(
Box::new(left.differentiate(var)),
Box::new(right.differentiate(var)),
),
SymbolicExpression::Mul(left, right) => {
SymbolicExpression::Add(
Box::new(SymbolicExpression::Mul(
Box::new(left.differentiate(var)),
right.clone(),
)),
Box::new(SymbolicExpression::Mul(
left.clone(),
Box::new(right.differentiate(var)),
)),
)
}
SymbolicExpression::Div(left, right) => {
SymbolicExpression::Div(
Box::new(SymbolicExpression::Sub(
Box::new(SymbolicExpression::Mul(
Box::new(left.differentiate(var)),
right.clone(),
)),
Box::new(SymbolicExpression::Mul(
left.clone(),
Box::new(right.differentiate(var)),
)),
)),
Box::new(SymbolicExpression::Pow(
right.clone(),
Box::new(SymbolicExpression::Constant(2.0)),
)),
)
}
SymbolicExpression::Pow(base, exp) => {
match (&**base, &**exp) {
(_, SymbolicExpression::Constant(n)) => {
SymbolicExpression::Mul(
Box::new(SymbolicExpression::Mul(
Box::new(SymbolicExpression::Constant(*n)),
Box::new(SymbolicExpression::Pow(
base.clone(),
Box::new(SymbolicExpression::Constant(n - 1.0)),
)),
)),
Box::new(base.differentiate(var)),
)
}
_ => {
SymbolicExpression::Mul(
Box::new(self.clone()),
Box::new(SymbolicExpression::Add(
Box::new(SymbolicExpression::Mul(
Box::new(exp.differentiate(var)),
Box::new(SymbolicExpression::Function(
"ln".to_string(),
vec![*base.clone()],
)),
)),
Box::new(SymbolicExpression::Mul(
exp.clone(),
Box::new(SymbolicExpression::Div(
Box::new(base.differentiate(var)),
base.clone(),
)),
)),
)),
)
}
}
}
SymbolicExpression::Function(name, args) => {
self.differentiate_function(name, args, var)
}
}
}
fn differentiate_function(&self, name: &str, args: &[SymbolicExpression], var: &str) -> Self {
match name {
"sin" if args.len() == 1 => {
SymbolicExpression::Mul(
Box::new(SymbolicExpression::Function(
"cos".to_string(),
args.to_vec(),
)),
Box::new(args[0].differentiate(var)),
)
}
"cos" if args.len() == 1 => {
SymbolicExpression::Mul(
Box::new(SymbolicExpression::Constant(-1.0)),
Box::new(SymbolicExpression::Mul(
Box::new(SymbolicExpression::Function(
"sin".to_string(),
args.to_vec(),
)),
Box::new(args[0].differentiate(var)),
)),
)
}
"exp" if args.len() == 1 => {
SymbolicExpression::Mul(
Box::new(self.clone()),
Box::new(args[0].differentiate(var)),
)
}
"ln" if args.len() == 1 => {
SymbolicExpression::Div(
Box::new(args[0].differentiate(var)),
Box::new(args[0].clone()),
)
}
_ => {
SymbolicExpression::Function(format!("d{}_d{}", name, var), args.to_vec())
}
}
}
pub fn to_latex(&self) -> String {
match self {
SymbolicExpression::Variable(v) => v.clone(),
SymbolicExpression::Constant(c) => {
if c.fract() == 0.0 {
format!("{}", *c as i64)
} else {
format!("{:.3}", c)
}
}
SymbolicExpression::Add(left, right) => {
format!("({} + {})", left.to_latex(), right.to_latex())
}
SymbolicExpression::Sub(left, right) => {
format!("({} - {})", left.to_latex(), right.to_latex())
}
SymbolicExpression::Mul(left, right) => {
format!("({} \\cdot {})", left.to_latex(), right.to_latex())
}
SymbolicExpression::Div(left, right) => {
format!("\\frac{{{}}}{{{}}}", left.to_latex(), right.to_latex())
}
SymbolicExpression::Pow(base, exp) => {
format!("{}^{{{}}}", base.to_latex(), exp.to_latex())
}
SymbolicExpression::Function(name, args) => {
if args.is_empty() {
format!("\\{}", name)
} else if args.len() == 1 {
format!("\\{}({})", name, args[0].to_latex())
} else {
let arg_strs: Vec<String> = args.iter().map(|a| a.to_latex()).collect();
format!("\\{}({})", name, arg_strs.join(", "))
}
}
}
}
pub fn simplify(&self) -> Self {
match self {
SymbolicExpression::Add(left, right) => {
let left_simp = left.simplify();
let right_simp = right.simplify();
match (&left_simp, &right_simp) {
(SymbolicExpression::Constant(0.0), _) => right_simp,
(_, SymbolicExpression::Constant(0.0)) => left_simp,
(SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
SymbolicExpression::Constant(a + b)
}
_ => SymbolicExpression::Add(Box::new(left_simp), Box::new(right_simp)),
}
}
SymbolicExpression::Mul(left, right) => {
let left_simp = left.simplify();
let right_simp = right.simplify();
match (&left_simp, &right_simp) {
(SymbolicExpression::Constant(0.0), _)
| (_, SymbolicExpression::Constant(0.0)) => SymbolicExpression::Constant(0.0),
(SymbolicExpression::Constant(1.0), _) => right_simp,
(_, SymbolicExpression::Constant(1.0)) => left_simp,
(SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
SymbolicExpression::Constant(a * b)
}
_ => SymbolicExpression::Mul(Box::new(left_simp), Box::new(right_simp)),
}
}
SymbolicExpression::Pow(base, exponent) => {
let base_simp = base.simplify();
let exp_simp = exponent.simplify();
match (&base_simp, &exp_simp) {
(_, SymbolicExpression::Constant(1.0)) => base_simp,
(_, SymbolicExpression::Constant(0.0)) => SymbolicExpression::Constant(1.0),
(SymbolicExpression::Constant(1.0), _) => SymbolicExpression::Constant(1.0),
(SymbolicExpression::Constant(0.0), SymbolicExpression::Constant(n))
if *n > 0.0 =>
{
SymbolicExpression::Constant(0.0)
}
(SymbolicExpression::Constant(a), SymbolicExpression::Constant(b)) => {
SymbolicExpression::Constant(a.powf(*b))
}
_ => SymbolicExpression::Pow(Box::new(base_simp), Box::new(exp_simp)),
}
}
_ => self.clone(),
}
}
}
pub fn second_derivative<F>(_f: F, x: f64) -> f64
where
F: Fn(Dual) -> Dual,
{
let _dual_x = Dual::new(x, 1.0);
0.0
}
pub fn hessian<F>(f: F, x: &[f64]) -> Vec<Vec<f64>>
where
F: Fn(&[f64]) -> f64,
{
let n = x.len();
let mut hessian = vec![vec![0.0; n]; n];
let h = 1e-8;
for i in 0..n {
for j in 0..n {
if i == j {
let mut x_plus = x.to_vec();
let mut x_minus = x.to_vec();
x_plus[i] += h;
x_minus[i] -= h;
let f_plus = f(&x_plus);
let f_center = f(x);
let f_minus = f(&x_minus);
hessian[i][j] = (f_plus - 2.0 * f_center + f_minus) / (h * h);
} else {
let mut x_pp = x.to_vec();
let mut x_pm = x.to_vec();
let mut x_mp = x.to_vec();
let mut x_mm = x.to_vec();
x_pp[i] += h;
x_pp[j] += h;
x_pm[i] += h;
x_pm[j] -= h;
x_mp[i] -= h;
x_mp[j] += h;
x_mm[i] -= h;
x_mm[j] -= h;
let f_pp = f(&x_pp);
let f_pm = f(&x_pm);
let f_mp = f(&x_mp);
let f_mm = f(&x_mm);
hessian[i][j] = (f_pp - f_pm - f_mp + f_mm) / (4.0 * h * h);
}
}
}
hessian
}
pub fn forward_diff<F>(f: F, x: f64) -> (f64, f64)
where
F: Fn(Dual) -> Dual,
{
let dual_x = Dual::variable(x);
let result = f(dual_x);
(result.value(), result.derivative())
}
pub fn gradient<F>(f: F, x: &[f64]) -> Vec<f64>
where
F: Fn(&[f64]) -> f64,
{
let mut grad = vec![0.0; x.len()];
let h = 1e-8;
for i in 0..x.len() {
let mut x_plus = x.to_vec();
let mut x_minus = x.to_vec();
x_plus[i] += h;
x_minus[i] -= h;
grad[i] = (f(&x_plus) - f(&x_minus)) / (2.0 * h);
}
grad
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dual_arithmetic() {
let x = Dual::new(2.0, 1.0);
let y = Dual::new(3.0, 0.0);
let sum = x + y;
assert_eq!(sum.real, 5.0);
assert_eq!(sum.dual, 1.0);
let product = x * y;
assert_eq!(product.real, 6.0);
assert_eq!(product.dual, 3.0);
}
#[test]
fn test_dual_math_functions() {
let x = Dual::variable(1.0);
let exp_result = dual_exp(x);
assert!((exp_result.real - std::f64::consts::E).abs() < 1e-10);
assert!((exp_result.dual - std::f64::consts::E).abs() < 1e-10);
let ln_result = dual_ln(x);
assert!((ln_result.real - 0.0).abs() < 1e-10);
assert!((ln_result.dual - 1.0).abs() < 1e-10);
}
#[test]
fn test_forward_diff() {
let f = |x: Dual| x * x;
let (value, derivative) = forward_diff(f, 3.0);
assert_eq!(value, 9.0);
assert_eq!(derivative, 6.0);
}
#[test]
fn test_symbolic_differentiation() {
let x = SymbolicExpression::Variable("x".to_string());
let x_squared = SymbolicExpression::Pow(
Box::new(x.clone()),
Box::new(SymbolicExpression::Constant(2.0)),
);
let derivative = x_squared.differentiate("x");
let simplified = derivative.simplify();
match simplified {
SymbolicExpression::Mul(left, right) => {
assert_eq!(*left, SymbolicExpression::Constant(2.0));
assert_eq!(*right, SymbolicExpression::Variable("x".to_string()));
}
_ => panic!("Expected multiplication"),
}
}
#[test]
fn test_gradient_computation() {
let f = |vars: &[f64]| vars[0] * vars[0] + vars[1] * vars[1];
let grad = gradient(f, &[2.0, 3.0]);
assert!((grad[0] - 4.0).abs() < 1e-6);
assert!((grad[1] - 6.0).abs() < 1e-6);
}
#[test]
fn test_computation_tape() {
let mut tape = ComputationTape::new();
let x = Variable::new(2.0);
let y = Variable::new(3.0);
tape.register_variable(x.clone());
tape.register_variable(y.clone());
assert_eq!(tape.variables.len(), 2);
assert!(tape.get_gradient(x.id).is_some());
}
#[test]
fn test_variable_creation() {
let var1 = Variable::new(1.0);
let var2 = Variable::new(2.0);
assert_ne!(var1.id, var2.id);
assert_eq!(var1.value, 1.0);
assert_eq!(var2.value, 2.0);
assert_eq!(var1.gradient, 0.0);
assert_eq!(var2.gradient, 0.0);
}
#[test]
fn test_autodiff_config() {
let config = AutodiffConfig::default();
assert_eq!(config.mode, ADMode::Forward);
assert_eq!(config.max_order, 1);
assert!(!config.simd);
assert!(!config.gpu);
}
#[test]
fn test_symbolic_latex_output() {
let expr = SymbolicExpression::Div(
Box::new(SymbolicExpression::Variable("x".to_string())),
Box::new(SymbolicExpression::Constant(2.0)),
);
let latex = expr.to_latex();
assert_eq!(latex, "\\frac{x}{2}");
}
}