solvr 0.2.0

Advanced computing library for real-world problem solving - optimization, differential equations, interpolation, statistics, and more
Documentation
//! 2D heat equation solver: du/dt = alpha * nabla^2 u + source.
//!
//! Method of lines: spatial FDM discretization -> ODE system -> time integration.
use crate::DType;

use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::sparse::{SparseOps, SparseTensor};
use numr::tensor::Tensor;

use crate::integrate::{IntegrationAlgorithms, ODEOptions, ODEResultTensor};
use crate::pde::error::{PdeError, PdeResult};
use crate::pde::types::{BoundarySpec, FdmOptions, Grid2D, TimeDependentOptions, TimeResult};

use super::ode_extract::extract_ode_solutions;
use super::stencil::assemble_laplacian_2d;

/// Solve the 2D heat equation using method of lines.
///
/// Discretizes space with 5-point FDM stencil, then integrates the resulting
/// ODE system du/dt = alpha * L * u + source using RK45.
#[allow(clippy::too_many_arguments)]
pub fn heat_2d_impl<R, C>(
    client: &C,
    u0: &Tensor<R>,
    alpha: f64,
    source: Option<&Tensor<R>>,
    grid: &Grid2D,
    _boundary: &[BoundarySpec<R>],
    time_opts: &TimeDependentOptions,
    _options: &FdmOptions,
) -> PdeResult<TimeResult<R>>
where
    R: Runtime<DType = DType>,
    R::Client: SparseOps<R>,
    C: TensorOps<R> + ScalarOps<R> + IntegrationAlgorithms<R> + RuntimeClient<R>,
{
    let nx = grid.nx;
    let ny = grid.ny;
    let n = nx * ny;
    let device = client.device();

    if nx < 3 || ny < 3 {
        return Err(PdeError::InvalidGrid {
            context: format!("Grid must be at least 3x3, got {}x{}", nx, ny),
        });
    }

    let laplacian = assemble_laplacian_2d::<R>(grid, device).map_err(PdeError::from)?;
    let sparse_l = SparseTensor::Csr(laplacian);

    let u0_flat = u0.reshape(&[n])?;

    let source_flat = match source {
        Some(s) => Some(s.reshape(&[n])?),
        None => None,
    };

    let rhs = |_t: &Tensor<R>, u: &Tensor<R>| -> Result<Tensor<R>> {
        let lu = sparse_l.spmv(u)?;
        let alpha_lu = client.mul_scalar(&lu, alpha)?;
        match &source_flat {
            Some(s) => client.add(&alpha_lu, s),
            None => Ok(alpha_lu),
        }
    };

    let ode_opts = ODEOptions {
        rtol: 1e-6,
        atol: 1e-9,
        max_steps: 50000,
        ..Default::default()
    };

    let ode_result: ODEResultTensor<R> = client
        .solve_ivp(rhs, time_opts.t_span, &u0_flat, &ode_opts)
        .map_err(PdeError::from)?;

    if !ode_result.success {
        return Err(PdeError::DidNotConverge {
            iterations: ode_result.nfev,
            tolerance: ode_opts.rtol,
            context: ode_result
                .message
                .unwrap_or_else(|| "Heat equation ODE integration failed".to_string()),
        });
    }

    Ok(extract_ode_solutions(
        ode_result.t,
        &ode_result.y,
        n,
        n,
        &[nx, ny],
        device,
    ))
}