trustformers-optim 0.1.1

Optimizers for TrustformeRS
Documentation
//! # Lookahead Optimizer
//!
//! This module implements the Lookahead meta-optimizer that can wrap any base optimizer
//! to improve convergence and reduce variance in the optimization process.
//!
//! ## Algorithm
//!
//! Lookahead maintains two sets of weights:
//! - Fast weights (φ): Updated by the base optimizer
//! - Slow weights (θ): Updated every k steps using interpolation
//!
//! The update rule:
//! ```text
//! φ_t = base_optimizer_update(φ_{t-1}, g_t)  // Fast weight update
//!
//! Every k steps:
//! θ_t = θ_{t-k} + α(φ_t - θ_{t-k})          // Slow weight update
//! φ_t = θ_t                                   // Reset fast weights
//! ```
//!
//! ## Benefits
//!
//! - Reduces variance in optimization trajectories
//! - Improves convergence in noisy settings
//! - Can be combined with any base optimizer
//! - Often leads to better generalization
//!
//! ## Usage Example
//!
//! ```rust,no_run
//! use trustformers_optim::{Lookahead, AdamW};
//! use trustformers_core::traits::Optimizer;
//!
//! // Create base optimizer
//! let base_optimizer = AdamW::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
//!
//! // Wrap with Lookahead
//! let mut lookahead = Lookahead::new(
//!     base_optimizer,
//!     5,        // Update slow weights every 5 steps
//!     0.5,      // Interpolation factor
//! );
//!
//! // Lookahead automatically handles fast/slow weight updates
//! // Use it just like any other optimizer with .zero_grad(), .update(), and .step()
//! ```
//!
//! ## Hyperparameter Guidelines
//!
//! ### k (Update Frequency)
//! - k=5 to k=10 is typical for most problems
//! - Smaller k: More frequent slow updates, more stable but slower
//! - Larger k: Less frequent updates, faster but potentially less stable
//!
//! ### α (Interpolation Factor)
//! - α=0.5 is the recommended default
//! - α=0.8 for more aggressive slow weight updates
//! - α=0.2 for more conservative updates
//!
//! ## Implementation Notes
//!
//! - Stores slow weights alongside the base optimizer
//! - Minimal memory overhead (2x parameter storage)
//! - Works with any base optimizer that implements the Optimizer trait
//! - Thread-safe for data parallel training

use crate::common::OptimizerState;
use std::collections::HashMap;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Optimizer;

/// Lookahead meta-optimizer that wraps any base optimizer.
///
/// Implements the Lookahead algorithm from "Lookahead Optimizer: k steps forward,
/// 1 step back" by Zhang et al. (2019). Lookahead maintains slow weights that are
/// updated every k steps using interpolation with the fast weights.
#[derive(Debug)]
pub struct Lookahead<T: Optimizer> {
    /// Base optimizer for fast weight updates
    base_optimizer: T,
    /// Number of fast weight update steps before slow weight update
    k: usize,
    /// Interpolation factor for slow weight updates (α in the paper)
    alpha: f32,
    /// Optimizer state tracking steps
    state: OptimizerState,
    /// Slow weights (θ in the paper)
    slow_weights: HashMap<String, Vec<f32>>,
    /// Counter for fast weight updates since last slow weight update
    fast_step_count: usize,
}

impl<T: Optimizer> Lookahead<T> {
    /// Creates a new Lookahead meta-optimizer.
    ///
    /// # Arguments
    ///
    /// * `base_optimizer` - The base optimizer to wrap
    /// * `k` - Number of fast steps before slow weight update (typical: 5-10)
    /// * `alpha` - Interpolation factor for slow weights (typical: 0.5)
    ///
    /// # Example
    ///
    /// ```
    /// use trustformers_optim::{Lookahead, AdamW};
    ///
    /// let base = AdamW::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
    /// let optimizer = Lookahead::new(base, 5, 0.5);
    /// ```
    pub fn new(base_optimizer: T, k: usize, alpha: f32) -> Self {
        assert!(k > 0, "k must be positive");
        assert!(alpha > 0.0 && alpha <= 1.0, "alpha must be in (0, 1]");

        Self {
            base_optimizer,
            k,
            alpha,
            state: OptimizerState::new(),
            slow_weights: HashMap::new(),
            fast_step_count: 0,
        }
    }

    /// Get a reference to the base optimizer.
    pub fn base_optimizer(&self) -> &T {
        &self.base_optimizer
    }

    /// Get a mutable reference to the base optimizer.
    pub fn base_optimizer_mut(&mut self) -> &mut T {
        &mut self.base_optimizer
    }

    /// Initialize slow weights from current parameter values.
    fn init_slow_weights(&mut self, parameter: &Tensor) -> Result<()> {
        match parameter {
            Tensor::F32(param) => {
                let param_id = format!("{:p}", param.as_ptr());
                self.slow_weights
                    .entry(param_id)
                    .or_insert_with(|| param.iter().cloned().collect());
                Ok(())
            },
            _ => Err(TrustformersError::tensor_op_error(
                "Unsupported tensor type for Lookahead",
                "init_slow_weights",
            )),
        }
    }

    /// Update slow weights using interpolation with fast weights.
    fn update_slow_weights(&mut self, parameter: &mut Tensor) -> Result<()> {
        match parameter {
            Tensor::F32(param) => {
                let param_id = format!("{:p}", param.as_ptr());

                if let Some(slow_weights) = self.slow_weights.get_mut(&param_id) {
                    if slow_weights.len() != param.len() {
                        return Err(TrustformersError::tensor_op_error(
                            "Lookahead slow weights size mismatch",
                            "slow weights validation",
                        ));
                    }

                    // Update slow weights: θ = θ + α(φ - θ)
                    for (slow_w, fast_w) in slow_weights.iter_mut().zip(param.iter()) {
                        *slow_w += self.alpha * (*fast_w - *slow_w);
                    }

                    // Copy slow weights back to parameters (reset fast weights)
                    for (p, slow_w) in param.iter_mut().zip(slow_weights.iter()) {
                        *p = *slow_w;
                    }
                }

                Ok(())
            },
            _ => Err(TrustformersError::tensor_op_error(
                "Unsupported tensor type for Lookahead",
                "update_slow_weights",
            )),
        }
    }
}

impl<T: Optimizer> Optimizer for Lookahead<T> {
    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
        // Initialize slow weights if this is the first time seeing this parameter
        self.init_slow_weights(parameter)?;

        // Apply base optimizer update (fast weight update)
        self.base_optimizer.update(parameter, grad)?;

        Ok(())
    }

    fn zero_grad(&mut self) {
        self.base_optimizer.zero_grad();
    }

    fn step(&mut self) {
        // Always step the base optimizer
        self.base_optimizer.step();
        self.fast_step_count += 1;

        // Check if it's time for slow weight update
        if self.fast_step_count >= self.k {
            // Note: We need access to all parameters here for slow weight update
            // This will be handled by the calling code that has access to all parameters
            self.fast_step_count = 0;
        }

        self.state.step += 1;
    }

    fn get_lr(&self) -> f32 {
        self.base_optimizer.get_lr()
    }

    fn set_lr(&mut self, lr: f32) {
        self.base_optimizer.set_lr(lr);
    }
}

impl<T: Optimizer> Lookahead<T> {
    /// Perform slow weight update for a specific parameter.
    /// This should be called after every k fast steps for each parameter.
    pub fn slow_step(&mut self, parameter: &mut Tensor) -> Result<()> {
        if self.fast_step_count == 0 {
            // We just completed k fast steps, time for slow weight update
            self.update_slow_weights(parameter)?;
        }
        Ok(())
    }
}

/// Convenience wrapper for Lookahead + Adam combination.
pub type LookaheadAdam = Lookahead<crate::adam::Adam>;

/// Convenience wrapper for Lookahead + AdamW combination.
pub type LookaheadAdamW = Lookahead<crate::adam::AdamW>;

/// Convenience wrapper for Lookahead + RAdam combination.
pub type LookaheadRAdam = Lookahead<crate::adam::RAdam>;

/// Convenience wrapper for Lookahead + NAdam combination.
pub type LookaheadNAdam = Lookahead<crate::adam::NAdam>;

/// Convenience wrapper for Lookahead + SGD combination.
pub type LookaheadSGD = Lookahead<crate::sgd::SGD>;

#[cfg(test)]
#[path = "lookahead_tests.rs"]
mod lookahead_tests;