torsh-functional 0.1.2

Functional programming utilities for ToRSh tensors
Documentation
//! Linear system solving operations
//!
//! This module provides functions for solving linear systems of equations,
//! including general linear systems, triangular systems, and least squares problems.

use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{creation::ones, Tensor};

/// Solve system of linear equations Ax = b
///
/// ## Mathematical Background
///
/// Solves the linear system:
/// ```text
/// A x = b
/// ```text
/// where A is an n×n matrix, b is an n×m matrix (multiple right-hand sides),
/// and x is the n×m solution matrix.
///
/// ## Solution Methods
///
/// For different matrix types, optimal algorithms exist:
/// - **General matrices**: LU decomposition with partial pivoting
/// - **Symmetric positive definite**: Cholesky decomposition
/// - **Triangular**: Forward/backward substitution
/// - **Overdetermined**: Least squares via QR or SVD
///
/// ## Stability Considerations
///
/// The condition number κ(A) = ||A|| ||A⁻¹|| determines numerical stability:
/// - κ(A) ≈ 1: Well-conditioned, stable solution
/// - κ(A) >> 1: Ill-conditioned, solution may be unreliable
/// - κ(A) = ∞: Singular, no unique solution exists
///
/// ## Parameters
/// * `a` - Coefficient matrix A (n×n, must be non-singular)
/// * `b` - Right-hand side matrix b (n×m)
///
/// ## Returns
/// * Solution matrix x such that Ax = b
///
/// ## Applications
/// - **Regression**: Normal equations (X^T X) β = X^T y
/// - **Interpolation**: Vandermonde systems for polynomial fitting
/// - **Finite differences**: Discretized differential equations
/// - **Optimization**: Newton's method Hx = -g
///
/// ## Note
/// This is currently a placeholder implementation. Production code should
/// use LU decomposition with partial pivoting for general matrices.
pub fn solve(a: &Tensor, b: &Tensor) -> TorshResult<Tensor> {
    if a.shape().ndim() != 2 {
        return Err(TorshError::invalid_argument_with_context(
            "Matrix A must be 2D",
            "solve",
        ));
    }

    let a_binding = a.shape();
    let a_dims = a_binding.dims();
    let b_binding = b.shape();
    let b_dims = b_binding.dims();

    if a_dims[0] != a_dims[1] {
        return Err(TorshError::invalid_argument_with_context(
            "Matrix A must be square",
            "solve",
        ));
    }

    if b_dims[0] != a_dims[0] {
        return Err(TorshError::invalid_argument_with_context(
            "Dimensions of A and b don't match",
            "solve",
        ));
    }

    // Placeholder implementation - should use LU decomposition or similar
    // For now, return a tensor of the right shape filled with ones
    ones(&b_dims)
}

/// Solve triangular system of linear equations
///
/// ## Mathematical Background
///
/// Solves triangular systems efficiently using substitution:
///
/// ### Forward Substitution (Lower Triangular)
/// For Lx = b where L is lower triangular:
/// ```text
/// x₁ = b₁ / L₁₁
/// x₂ = (b₂ - L₂₁x₁) / L₂₂
/// xᵢ = (bᵢ - Σⱼ₌₁ⁱ⁻¹ Lᵢⱼxⱼ) / Lᵢᵢ
/// ```text
///
/// ### Backward Substitution (Upper Triangular)
/// For Ux = b where U is upper triangular:
/// ```text
/// xₙ = bₙ / Uₙₙ
/// xₙ₋₁ = (bₙ₋₁ - Uₙ₋₁,ₙxₙ) / Uₙ₋₁,ₙ₋₁
/// xᵢ = (bᵢ - Σⱼ₌ᵢ₊₁ⁿ Uᵢⱼxⱼ) / Uᵢᵢ
/// ```text
///
/// ## Computational Complexity
/// - Time: O(n²) for n×n triangular matrix
/// - Space: O(1) additional space (in-place possible)
/// - Highly efficient compared to O(n³) general solve
///
/// ## Parameters
/// * `b` - Right-hand side matrix (n×m)
/// * `a` - Triangular coefficient matrix (n×n)
/// * `upper` - If true, A is upper triangular; if false, lower triangular
/// * `transpose` - If true, solve A^T x = b instead of Ax = b
/// * `unitriangular` - If true, assume diagonal entries are 1
///
/// ## Returns
/// * Solution matrix x such that Ax = b (or A^T x = b if transpose=true)
///
/// ## Applications
/// - **LU solving**: Forward substitution Ly = b, backward substitution Ux = y
/// - **Cholesky solving**: Forward Ly = b, backward L^T x = y
/// - **Iterative methods**: Preconditioning with triangular matrices
///
/// ## Note
/// This is currently a placeholder implementation. Production code should
/// implement actual forward/backward substitution algorithms.
pub fn triangular_solve(
    b: &Tensor,
    a: &Tensor,
    _upper: bool,
    _transpose: bool,
    _unitriangular: bool,
) -> TorshResult<Tensor> {
    if a.shape().ndim() != 2 || b.shape().ndim() != 2 {
        return Err(TorshError::invalid_argument_with_context(
            "Both tensors must be 2D",
            "triangular_solve",
        ));
    }

    let a_binding = a.shape();
    let a_dims = a_binding.dims();
    let b_binding = b.shape();
    let b_dims = b_binding.dims();

    if a_dims[0] != a_dims[1] {
        return Err(TorshError::invalid_argument_with_context(
            "Matrix A must be square",
            "triangular_solve",
        ));
    }

    if b_dims[0] != a_dims[0] {
        return Err(TorshError::invalid_argument_with_context(
            "Dimensions of A and b don't match",
            "triangular_solve",
        ));
    }

    // Placeholder implementation
    Ok(b.clone())
}

/// Least squares solution to overdetermined systems
///
/// ## Mathematical Background
///
/// Solves the least squares problem:
/// ```text
/// minimize ||Ax - b||₂²
/// ```text
///
/// For an m×n matrix A with m ≥ n (overdetermined system), the least squares
/// solution minimizes the sum of squared residuals.
///
/// ## Solution Methods
///
/// ### Normal Equations
/// ```text
/// A^T A x = A^T b
/// x = (A^T A)⁻¹ A^T b
/// ```text
/// - Fast for well-conditioned problems
/// - Can be unstable for ill-conditioned A
///
/// ### QR Decomposition
/// ```text
/// A = QR  →  ||QRx - b||₂ = ||Rx - Q^T b||₂
/// ```text
/// - More stable than normal equations
/// - Standard choice for most problems
///
/// ### SVD Decomposition
/// ```text
/// A = UΣV^T  →  x = VΣ⁺U^T b
/// ```text
/// - Most stable, handles rank-deficient cases
/// - Slowest but most robust method
///
/// ## Mathematical Properties
///
/// 1. **Optimality**: x minimizes ||Ax - b||₂
/// 2. **Normal equations**: A^T(Ax - b) = 0 (residual orthogonal to column space)
/// 3. **Uniqueness**: Unique solution if A has full column rank
/// 4. **Residual**: r = b - Ax is orthogonal to column space of A
///
/// ## Parameters
/// * `a` - Coefficient matrix A (m×n, typically m ≥ n)
/// * `b` - Right-hand side matrix b (m×k for k different problems)
///
/// ## Returns
/// * Tuple (solution, residuals, rank, singular_values) where:
///   - solution: Least squares solution x (n×k)
///   - residuals: Sum of squared residuals for each column
///   - rank: Effective rank of matrix A
///   - singular_values: Singular values of A
///
/// ## Applications
/// - **Linear regression**: Fit y = Xβ + ε
/// - **Curve fitting**: Polynomial/function approximation
/// - **Data interpolation**: Overdetermined interpolation problems
/// - **Signal processing**: Parameter estimation from noisy measurements
///
/// ## Note
/// This is currently a placeholder implementation. Production code should
/// use QR decomposition or SVD for robust least squares solving.
pub fn lstsq(a: &Tensor, b: &Tensor) -> TorshResult<(Tensor, Tensor, Tensor, Tensor)> {
    if a.shape().ndim() != 2 || b.shape().ndim() != 2 {
        return Err(TorshError::invalid_argument_with_context(
            "Both tensors must be 2D",
            "lstsq",
        ));
    }

    let a_shape = a.shape();
    let b_shape = b.shape();
    let a_dims = a_shape.dims();
    let b_dims = b_shape.dims();

    if a_dims[0] != b_dims[0] {
        return Err(TorshError::invalid_argument_with_context(
            "Number of rows in A and b must match",
            "lstsq",
        ));
    }

    // Placeholder implementation
    let solution = ones(&[a_dims[1], b_dims[1]])?;
    let residuals = ones(&[1])?;
    let rank = Tensor::from_data(vec![a_dims[1] as f32], vec![], a.device())?;
    let s = ones(&[a_dims[1].min(a_dims[0])])?;

    Ok((solution, residuals, rank, s))
}