1use crate::phantom::{OptimizationProblem, Riemannian, Statistical};
28use crate::{OptimizationError, OptimizationResult};
29
30use num_traits::Float;
33use std::marker::PhantomData;
34
35#[derive(Clone, Debug)]
39pub struct NaturalGradientConfig<T: Float> {
40 pub learning_rate: T,
42 pub max_iterations: usize,
44 pub gradient_tolerance: T,
46 pub parameter_tolerance: T,
48 pub fisher_regularization: T,
50 pub use_line_search: bool,
52 pub line_search_beta: T,
54 pub line_search_alpha: T,
56}
57
58impl<T: Float> Default for NaturalGradientConfig<T> {
59 fn default() -> Self {
60 Self {
61 learning_rate: T::from(0.1).unwrap(), max_iterations: 1000,
63 gradient_tolerance: T::from(1e-4).unwrap(), parameter_tolerance: T::from(1e-6).unwrap(), fisher_regularization: T::from(1e-6).unwrap(),
66 use_line_search: true, line_search_beta: T::from(0.5).unwrap(),
68 line_search_alpha: T::from(1.0).unwrap(),
69 }
70 }
71}
72
73#[derive(Clone, Debug)]
75pub struct NaturalGradientResult<T: Float> {
76 pub parameters: Vec<T>,
78 pub objective_value: T,
80 pub gradient_norm: T,
82 pub iterations: usize,
84 pub converged: bool,
86 pub trajectory: Option<Vec<Vec<T>>>,
88}
89
90pub trait ObjectiveWithFisher<T: Float> {
92 fn evaluate(&self, parameters: &[T]) -> T;
94
95 fn gradient(&self, parameters: &[T]) -> Vec<T>;
97
98 fn fisher_information(&self, parameters: &[T]) -> Vec<Vec<T>>;
100
101 fn hessian(&self, _parameters: &[T]) -> Option<Vec<Vec<T>>> {
103 None
104 }
105}
106
107#[derive(Clone, Debug)]
109pub struct NaturalGradientOptimizer<T: Float> {
110 config: NaturalGradientConfig<T>,
111 _phantom: PhantomData<T>,
112}
113
114impl<T: Float> NaturalGradientOptimizer<T> {
115 pub fn new(config: NaturalGradientConfig<T>) -> Self {
117 Self {
118 config,
119 _phantom: PhantomData,
120 }
121 }
122
123 pub fn with_default_config() -> Self {
125 Self::new(NaturalGradientConfig::default())
126 }
127
128 pub fn optimize_statistical<
130 const DIM: usize,
131 C: crate::phantom::ConstraintState,
132 O: crate::phantom::ObjectiveState,
133 V: crate::phantom::ConvexityState,
134 >(
135 &self,
136 _problem: &OptimizationProblem<DIM, C, O, V, Statistical>,
137 objective: &impl ObjectiveWithFisher<T>,
138 initial_parameters: Vec<T>,
139 ) -> OptimizationResult<NaturalGradientResult<T>> {
140 self.optimize_with_fisher(objective, initial_parameters)
141 }
142
143 pub fn optimize_riemannian<
145 const DIM: usize,
146 C: crate::phantom::ConstraintState,
147 O: crate::phantom::ObjectiveState,
148 V: crate::phantom::ConvexityState,
149 >(
150 &self,
151 _problem: &OptimizationProblem<DIM, C, O, V, Riemannian>,
152 objective: &impl ObjectiveWithFisher<T>,
153 initial_parameters: Vec<T>,
154 ) -> OptimizationResult<NaturalGradientResult<T>> {
155 self.optimize_with_fisher(objective, initial_parameters)
156 }
157
158 fn optimize_with_fisher(
160 &self,
161 objective: &impl ObjectiveWithFisher<T>,
162 mut parameters: Vec<T>,
163 ) -> OptimizationResult<NaturalGradientResult<T>> {
164 let mut trajectory = if self.config.max_iterations < 1000 {
165 Some(Vec::with_capacity(self.config.max_iterations))
166 } else {
167 None
168 };
169
170 let mut best_parameters = parameters.clone();
171 let mut best_objective = objective.evaluate(¶meters);
172
173 for iteration in 0..self.config.max_iterations {
174 let gradient = objective.gradient(¶meters);
176 let gradient_norm = self.compute_norm(&gradient);
177
178 if gradient_norm < self.config.gradient_tolerance {
180 return Ok(NaturalGradientResult {
181 parameters: best_parameters,
182 objective_value: best_objective,
183 gradient_norm,
184 iterations: iteration,
185 converged: true,
186 trajectory,
187 });
188 }
189
190 let fisher = objective.fisher_information(¶meters);
192
193 let natural_gradient = self.solve_fisher_system(&fisher, &gradient)?;
195
196 let step_size = if self.config.use_line_search {
198 self.line_search(objective, ¶meters, &natural_gradient)?
199 } else {
200 self.config.learning_rate
201 };
202
203 let old_parameters = parameters.clone();
205 let param_updates: Vec<T> = parameters
206 .iter()
207 .zip(natural_gradient.iter())
208 .map(|(p, ng)| *p - step_size * *ng)
209 .collect();
210
211 parameters = param_updates;
212
213 let param_change = self.compute_parameter_change(&old_parameters, ¶meters);
215 if param_change < self.config.parameter_tolerance {
216 return Ok(NaturalGradientResult {
217 parameters: best_parameters,
218 objective_value: best_objective,
219 gradient_norm,
220 iterations: iteration + 1,
221 converged: true,
222 trajectory,
223 });
224 }
225
226 let current_objective = objective.evaluate(¶meters);
228 if current_objective < best_objective {
229 best_parameters = parameters.clone();
230 best_objective = current_objective;
231 }
232
233 if let Some(ref mut traj) = trajectory {
235 traj.push(parameters.clone());
236 }
237 }
238
239 let _final_gradient = objective.gradient(&best_parameters);
241 let _final_gradient_norm = self.compute_norm(&_final_gradient);
242
243 Err(OptimizationError::ConvergenceFailure {
244 iterations: self.config.max_iterations,
245 })
246 }
247
248 fn solve_fisher_system(&self, fisher: &[Vec<T>], gradient: &[T]) -> OptimizationResult<Vec<T>> {
250 let n = fisher.len();
251 if n == 0 || gradient.len() != n {
252 return Err(OptimizationError::InvalidProblem {
253 message: "Fisher matrix and gradient dimension mismatch".to_string(),
254 });
255 }
256
257 let mut regularized_fisher = fisher.to_vec();
259 for (i, row) in regularized_fisher.iter_mut().enumerate().take(n) {
260 row[i] = row[i] + self.config.fisher_regularization;
261 }
262
263 self.solve_linear_system(®ularized_fisher, gradient)
265 }
266
267 fn solve_linear_system(&self, matrix: &[Vec<T>], rhs: &[T]) -> OptimizationResult<Vec<T>> {
269 let n = matrix.len();
270 let mut a = matrix.to_vec();
271 let b = rhs.to_vec();
272
273 let mut pivot: Vec<usize> = (0..n).collect();
275
276 for k in 0..n - 1 {
278 let mut max_idx = k;
280 for i in k + 1..n {
281 if a[i][k].abs() > a[max_idx][k].abs() {
282 max_idx = i;
283 }
284 }
285
286 if max_idx != k {
288 a.swap(k, max_idx);
289 pivot.swap(k, max_idx);
290 }
291
292 if a[k][k].abs() < T::from(1e-14).unwrap() {
294 return Err(OptimizationError::NumericalError {
295 message: "Singular Fisher information matrix".to_string(),
296 });
297 }
298
299 for i in k + 1..n {
301 let factor = a[i][k] / a[k][k];
302 #[allow(clippy::needless_range_loop)]
303 for j in k + 1..n {
304 a[i][j] = a[i][j] - factor * a[k][j];
305 }
306 a[i][k] = factor;
307 }
308 }
309
310 let mut perm_b = vec![T::zero(); n];
312 for i in 0..n {
313 perm_b[i] = b[pivot[i]];
314 }
315
316 for i in 1..n {
318 for j in 0..i {
319 perm_b[i] = perm_b[i] - a[i][j] * perm_b[j];
320 }
321 }
322
323 let mut x = vec![T::zero(); n];
325 for i in (0..n).rev() {
326 x[i] = perm_b[i];
327 for j in i + 1..n {
328 x[i] = x[i] - a[i][j] * x[j];
329 }
330 x[i] = x[i] / a[i][i];
331 }
332
333 Ok(x)
334 }
335
336 fn line_search(
338 &self,
339 objective: &impl ObjectiveWithFisher<T>,
340 parameters: &[T],
341 direction: &[T],
342 ) -> OptimizationResult<T> {
343 let mut alpha = self.config.line_search_alpha;
344 let current_objective = objective.evaluate(parameters);
345
346 for _ in 0..20 {
347 let trial_params: Vec<T> = parameters
350 .iter()
351 .zip(direction.iter())
352 .map(|(p, d)| *p - alpha * *d)
353 .collect();
354
355 let trial_objective = objective.evaluate(&trial_params);
356
357 if trial_objective <= current_objective {
359 return Ok(alpha);
360 }
361
362 alpha = alpha * self.config.line_search_beta;
363 }
364
365 Ok(self.config.learning_rate * T::from(0.1).unwrap())
367 }
368
369 fn compute_norm(&self, vector: &[T]) -> T {
371 vector
372 .iter()
373 .map(|x| *x * *x)
374 .fold(T::zero(), |acc, x| acc + x)
375 .sqrt()
376 }
377
378 fn compute_parameter_change(&self, old_params: &[T], new_params: &[T]) -> T {
380 let change: T = old_params
381 .iter()
382 .zip(new_params.iter())
383 .map(|(old, new)| (*new - *old) * (*new - *old))
384 .fold(T::zero(), |acc, x| acc + x)
385 .sqrt();
386
387 let norm: T = old_params
388 .iter()
389 .map(|x| *x * *x)
390 .fold(T::zero(), |acc, x| acc + x)
391 .sqrt();
392
393 if norm > T::zero() {
394 change / norm
395 } else {
396 change
397 }
398 }
399}
400
401pub mod info_geom {
403 use super::*;
404
405 pub fn exponential_family_fisher<T: Float>(
407 natural_parameters: &[T],
408 _sufficient_statistics: &impl Fn(&[T]) -> Vec<T>,
409 log_partition: &impl Fn(&[T]) -> T,
410 ) -> Vec<Vec<T>> {
411 let dim = natural_parameters.len();
412 let eps = T::from(1e-8).unwrap();
413
414 let mut fisher = vec![vec![T::zero(); dim]; dim];
416
417 for i in 0..dim {
418 for j in 0..dim {
419 let mut params_ij = natural_parameters.to_vec();
421 let mut params_i = natural_parameters.to_vec();
422 let mut params_j = natural_parameters.to_vec();
423 let params_base = natural_parameters.to_vec();
424
425 params_ij[i] = params_ij[i] + eps;
426 params_ij[j] = params_ij[j] + eps;
427
428 params_i[i] = params_i[i] + eps;
429 params_j[j] = params_j[j] + eps;
430
431 let hessian_ij = (log_partition(¶ms_ij)
432 - log_partition(¶ms_i)
433 - log_partition(¶ms_j)
434 + log_partition(¶ms_base))
435 / (eps * eps);
436
437 fisher[i][j] = hessian_ij;
438 }
439 }
440
441 fisher
442 }
443
444 pub fn statistical_distance<T: Float>(
446 params1: &[T],
447 params2: &[T],
448 fisher_info: &impl Fn(&[T]) -> Vec<Vec<T>>,
449 ) -> T {
450 let midpoint: Vec<T> = params1
452 .iter()
453 .zip(params2.iter())
454 .map(|(p1, p2)| (*p1 + *p2) / T::from(2.0).unwrap())
455 .collect();
456
457 let fisher = fisher_info(&midpoint);
458 let diff: Vec<T> = params1
459 .iter()
460 .zip(params2.iter())
461 .map(|(p1, p2)| *p1 - *p2)
462 .collect();
463
464 let mut distance_squared = T::zero();
466 for i in 0..diff.len() {
467 for j in 0..diff.len() {
468 distance_squared = distance_squared + diff[i] * fisher[i][j] * diff[j];
469 }
470 }
471
472 distance_squared.sqrt()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479 use approx::assert_relative_eq;
480
481 struct QuadraticObjective {
483 dim: usize,
484 }
485
486 impl ObjectiveWithFisher<f64> for QuadraticObjective {
487 fn evaluate(&self, parameters: &[f64]) -> f64 {
488 parameters.iter().map(|x| x * x).sum::<f64>() / 2.0
489 }
490
491 fn gradient(&self, parameters: &[f64]) -> Vec<f64> {
492 parameters.to_vec()
493 }
494
495 fn fisher_information(&self, _parameters: &[f64]) -> Vec<Vec<f64>> {
496 let mut fisher = vec![vec![0.0; self.dim]; self.dim];
498 for (i, row) in fisher.iter_mut().enumerate().take(self.dim) {
499 row[i] = 1.0;
500 }
501 fisher
502 }
503 }
504
505 #[test]
506 fn test_natural_gradient_quadratic() {
507 let objective = QuadraticObjective { dim: 2 };
508 let config = NaturalGradientConfig {
509 learning_rate: 0.5, max_iterations: 100,
511 gradient_tolerance: 1e-4, parameter_tolerance: 1e-6, fisher_regularization: 1e-6,
514 use_line_search: false,
515 line_search_beta: 0.5,
516 line_search_alpha: 1.0,
517 };
518
519 let optimizer = NaturalGradientOptimizer::new(config);
520 let initial_params = vec![0.5, 0.5]; let result = optimizer
523 .optimize_with_fisher(&objective, initial_params)
524 .unwrap();
525
526 assert!(result.converged);
527 assert_relative_eq!(result.parameters[0], 0.0, epsilon = 1e-3);
528 assert_relative_eq!(result.parameters[1], 0.0, epsilon = 1e-3);
529 assert!(result.objective_value < 1e-4);
530 }
531
532 #[test]
533 fn test_fisher_system_solve() {
534 let optimizer = NaturalGradientOptimizer::<f64>::with_default_config();
535
536 let fisher = vec![vec![2.0, 1.0], vec![1.0, 2.0]];
538 let gradient = vec![3.0, 4.0];
539
540 let solution = optimizer.solve_fisher_system(&fisher, &gradient).unwrap();
541
542 assert_relative_eq!(solution[0], 2.0 / 3.0, epsilon = 1e-6);
544 assert_relative_eq!(solution[1], 5.0 / 3.0, epsilon = 1e-6);
545 }
546
547 #[test]
548 fn test_exponential_family_fisher() {
549 use crate::natural_gradient::info_geom::exponential_family_fisher;
550
551 let natural_params = vec![1.0, 2.0];
553
554 let sufficient_stats = |_params: &[f64]| vec![1.0, 1.0]; let log_partition = |_params: &[f64]| 1.0; let fisher = exponential_family_fisher(&natural_params, &sufficient_stats, &log_partition);
558
559 assert_eq!(fisher.len(), 2, "Fisher matrix should be 2x2");
561 assert_eq!(
562 fisher[0].len(),
563 2,
564 "Fisher matrix rows should have length 2"
565 );
566 assert_eq!(
567 fisher[1].len(),
568 2,
569 "Fisher matrix rows should have length 2"
570 );
571
572 assert!(fisher[0][0].abs() < 1e-6);
574 assert!(fisher[1][1].abs() < 1e-6);
575 }
576}