use super::types::{StefanConfig, StefanResult};
use crate::error::{IntegrateError, IntegrateResult};
pub struct StefanSolver;
impl StefanSolver {
pub fn new() -> Self {
Self
}
pub fn solve(&self, config: &StefanConfig) -> IntegrateResult<StefanResult> {
config.validate()?;
let nx = config.nx;
let dt = config.dt;
let alpha = config.diffusivity;
let st = config.stefan_number;
let t_m = config.melting_temp;
let t_wall = config.wall_temp;
let l_max = config.l_max;
let t_max = config.max_time;
let output_every = config.output_every;
let dx = l_max / (nx - 1) as f64;
let r = alpha * dt / (dx * dx);
if r > 0.5 {
return Err(IntegrateError::ConvergenceError(format!(
"CFL condition violated: r = {r:.4} > 0.5. Reduce dt or increase nx."
)));
}
let grid: Vec<f64> = (0..nx).map(|i| i as f64 * dx).collect();
let mut u = vec![t_m; nx];
u[0] = t_wall;
let enthalpy_of = |temp: f64| -> f64 {
if temp > t_m {
temp + st
} else {
temp
}
};
let mut h_field: Vec<f64> = u.iter().map(|&temp| enthalpy_of(temp)).collect();
let mut times = Vec::new();
let mut interface_positions = Vec::new();
let mut temperature_fields = Vec::new();
let n_steps = (t_max / dt).ceil() as usize;
let s0 = interface_position(&u, t_m, &grid);
times.push(0.0);
interface_positions.push(s0);
temperature_fields.push(u.clone());
for step in 1..=n_steps {
let t_current = step as f64 * dt;
if t_current > t_max + dt * 0.5 {
break;
}
let mut h_new = h_field.clone();
for i in 1..(nx - 1) {
let d2u = u[i + 1] - 2.0 * u[i] + u[i - 1];
h_new[i] = h_field[i] + r * d2u;
}
h_new[0] = enthalpy_of(t_wall);
h_new[nx - 1] = h_field[nx - 1];
let mut u_new = vec![0.0_f64; nx];
for i in 0..nx {
u_new[i] = temp_from_enthalpy(h_new[i], t_m, st);
}
u_new[0] = t_wall;
h_new[0] = enthalpy_of(t_wall);
h_field = h_new;
u = u_new;
if step % output_every == 0 || step == n_steps {
let s = interface_position(&u, t_m, &grid);
times.push(t_current.min(t_max));
interface_positions.push(s);
temperature_fields.push(u.clone());
}
}
Ok(StefanResult {
times,
interface_positions,
temperature_fields,
grid,
})
}
}
#[inline]
fn temp_from_enthalpy(h: f64, t_m: f64, st: f64) -> f64 {
if h < t_m {
h
} else if h <= t_m + st {
t_m
} else {
h - st
}
}
fn interface_position(u: &[f64], t_m: f64, grid: &[f64]) -> f64 {
let n = u.len();
for i in 0..(n - 1) {
if u[i] > t_m && u[i + 1] <= t_m {
let frac = (u[i] - t_m) / (u[i] - u[i + 1]);
return grid[i] + frac * (grid[i + 1] - grid[i]);
}
}
if u[n - 1] > t_m {
return grid[n - 1];
}
0.0
}
pub fn analytical_stefan_interface(st: f64, alpha: f64) -> IntegrateResult<impl Fn(f64) -> f64> {
if st <= 0.0 {
return Err(IntegrateError::InvalidInput(
"Stefan number must be positive".to_string(),
));
}
if alpha <= 0.0 {
return Err(IntegrateError::InvalidInput(
"diffusivity must be positive".to_string(),
));
}
let lambda = find_stefan_lambda(st)?;
Ok(move |t: f64| {
if t <= 0.0 {
0.0
} else {
2.0 * lambda * (alpha * t).sqrt()
}
})
}
pub fn find_stefan_lambda(st: f64) -> IntegrateResult<f64> {
let target = st / std::f64::consts::PI.sqrt();
let f = |lam: f64| -> f64 { lam * (lam * lam).exp() * erf_approx(lam) - target };
let mut lo = 1e-12_f64;
let mut hi = 10.0_f64;
if f(hi) < 0.0 {
hi = 100.0;
}
if f(hi) < 0.0 {
return Err(IntegrateError::ConvergenceError(
"Could not bracket Stefan root: Stefan number too large?".to_string(),
));
}
for _ in 0..100 {
let mid = 0.5 * (lo + hi);
if f(mid) < 0.0 {
lo = mid;
} else {
hi = mid;
}
if (hi - lo) < 1e-14 {
break;
}
}
Ok(0.5 * (lo + hi))
}
pub fn erf_approx(x: f64) -> f64 {
libm::erf(x)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol
}
#[test]
fn test_stefan_interface_monotone() {
let cfg = StefanConfig {
nx: 80,
dt: 5e-5,
stefan_number: 1.0,
diffusivity: 1.0,
melting_temp: 0.0,
wall_temp: 1.0,
l_max: 4.0,
max_time: 0.2,
output_every: 50,
};
let solver = StefanSolver::new();
let result = solver.solve(&cfg).expect("solve failed");
for i in 1..result.interface_positions.len() {
assert!(
result.interface_positions[i] >= result.interface_positions[i - 1] - 1e-10,
"interface not monotone at step {i}: {} < {}",
result.interface_positions[i],
result.interface_positions[i - 1]
);
}
}
#[test]
fn test_stefan_result_shape() {
let cfg = StefanConfig::default();
let solver = StefanSolver::new();
let result = solver.solve(&cfg).expect("solve failed");
assert_eq!(result.times.len(), result.interface_positions.len());
assert_eq!(result.times.len(), result.temperature_fields.len());
assert_eq!(result.grid.len(), cfg.nx);
for field in &result.temperature_fields {
assert_eq!(field.len(), cfg.nx);
}
}
#[test]
fn test_stefan_wall_temp() {
let cfg = StefanConfig {
nx: 50,
dt: 2e-5,
max_time: 0.05,
output_every: 20,
..Default::default()
};
let solver = StefanSolver::new();
let result = solver.solve(&cfg).expect("solve failed");
for field in &result.temperature_fields {
assert!(
approx_eq(field[0], cfg.wall_temp, 1e-12),
"wall temp changed: {}",
field[0]
);
}
}
#[test]
fn test_stefan_config_default() {
let cfg = StefanConfig::default();
assert!(cfg.nx > 0);
assert!(cfg.dt > 0.0);
assert!(cfg.wall_temp > cfg.melting_temp);
cfg.validate().expect("default config should be valid");
}
#[test]
fn test_stefan_analytical_small_t() {
let st = 1.0;
let alpha = 1.0;
let s_fn = analytical_stefan_interface(st, alpha).expect("analytical failed");
let lambda = find_stefan_lambda(st).expect("lambda failed");
for &t in &[0.01, 0.05, 0.1, 0.2] {
let s_analytical = 2.0 * lambda * (alpha * t).sqrt();
let s_fn_val = s_fn(t);
assert!(
approx_eq(s_analytical, s_fn_val, 1e-12),
"t={t}: analytical={s_analytical}, fn={s_fn_val}"
);
}
}
#[test]
fn test_stefan_numerical_vs_analytical() {
let st = 1.0;
let alpha = 1.0;
let t_final = 0.1;
let cfg = StefanConfig {
nx: 200,
dt: 1e-5,
stefan_number: st,
diffusivity: alpha,
melting_temp: 0.0,
wall_temp: 1.0,
l_max: 3.0,
max_time: t_final,
output_every: 1000,
};
let solver = StefanSolver::new();
let result = solver.solve(&cfg).expect("solve failed");
let s_fn = analytical_stefan_interface(st, alpha).expect("analytical failed");
let s_exact = s_fn(t_final);
let s_num = *result.interface_positions.last().expect("no output");
let rel_err = (s_num - s_exact).abs() / s_exact;
assert!(
rel_err < 0.15,
"numerical vs analytical: s_num={s_num:.4}, s_exact={s_exact:.4}, rel_err={rel_err:.3}"
);
}
}