use crate::DType;
use numr::runtime::Runtime;
use crate::pde::error::{PdeError, PdeResult};
use crate::pde::impl_generic::stencil::SideBc;
use crate::pde::types::{BoundaryCondition, BoundarySide, BoundarySpec};
pub fn extract_dirichlet_1d_bcs<R: Runtime<DType = DType>>(
boundary: &[BoundarySpec<R>],
solver_name: &str,
) -> PdeResult<(f64, f64)> {
let mut left = 0.0;
let mut right = 0.0;
for spec in boundary {
match &spec.condition {
BoundaryCondition::Dirichlet(vals) => {
let v: Vec<f64> = vals.to_vec();
let val = if v.is_empty() { 0.0 } else { v[0] };
match spec.side {
BoundarySide::Left => left = val,
BoundarySide::Right => right = val,
BoundarySide::All => {
left = val;
right = val;
}
_ => {}
}
}
_ => {
return Err(PdeError::InvalidBoundary {
context: format!("Only Dirichlet BCs supported for {}", solver_name),
});
}
}
}
Ok((left, right))
}
pub fn extract_dirichlet_scalar<R: Runtime<DType = DType>>(
boundary: &[BoundarySpec<R>],
solver_name: &str,
) -> PdeResult<f64> {
if let Some(spec) = boundary.first() {
match &spec.condition {
BoundaryCondition::Dirichlet(vals) => {
let v: Vec<f64> = vals.to_vec();
Ok(if v.is_empty() { 0.0 } else { v[0] })
}
_ => Err(PdeError::InvalidBoundary {
context: format!("Only Dirichlet BCs supported for {}", solver_name),
}),
}
} else {
Ok(0.0)
}
}
pub fn side_conditions_2d<R: Runtime<DType = DType>>(
boundary: &[BoundarySpec<R>],
) -> PdeResult<[SideBc; 4]> {
let mut sides = [SideBc::Dirichlet; 4];
for spec in boundary {
let bc = match &spec.condition {
BoundaryCondition::Dirichlet(_) => SideBc::Dirichlet,
BoundaryCondition::Neumann(vals) => {
let v: Vec<f64> = vals.to_vec();
SideBc::Neumann(if v.is_empty() { 0.0 } else { v[0] })
}
BoundaryCondition::Periodic => SideBc::Periodic,
};
match spec.side {
BoundarySide::Left => sides[0] = bc,
BoundarySide::Right => sides[1] = bc,
BoundarySide::Bottom => sides[2] = bc,
BoundarySide::Top => sides[3] = bc,
BoundarySide::All => sides = [bc; 4],
_ => {}
}
}
let is_periodic = |s: SideBc| s == SideBc::Periodic;
if is_periodic(sides[0]) != is_periodic(sides[1]) {
return Err(PdeError::InvalidBoundary {
context: "Periodic BC must be applied to both Left and Right sides (paired)"
.to_string(),
});
}
if is_periodic(sides[2]) != is_periodic(sides[3]) {
return Err(PdeError::InvalidBoundary {
context: "Periodic BC must be applied to both Bottom and Top sides (paired)"
.to_string(),
});
}
Ok(sides)
}
pub fn extract_boundary_values_2d<R: Runtime<DType = DType>>(
boundary: &[BoundarySpec<R>],
nx: usize,
ny: usize,
) -> PdeResult<Vec<f64>> {
let n = nx * ny;
let mut values = vec![0.0; n];
for spec in boundary {
match &spec.condition {
BoundaryCondition::Dirichlet(vals) => {
let v: Vec<f64> = vals.to_vec();
match spec.side {
BoundarySide::Left => {
let copy_len = ny.min(v.len());
values[..copy_len].copy_from_slice(&v[..copy_len]);
}
BoundarySide::Right => {
let copy_len = ny.min(v.len());
for (idx, &val) in v[..copy_len].iter().enumerate() {
values[(nx - 1) * ny + idx] = val;
}
}
BoundarySide::Bottom => {
let copy_len = nx.min(v.len());
for (i, &val) in v[..copy_len].iter().enumerate() {
values[i * ny] = val;
}
}
BoundarySide::Top => {
let copy_len = nx.min(v.len());
for (i, &val) in v[..copy_len].iter().enumerate() {
values[i * ny + (ny - 1)] = val;
}
}
BoundarySide::All => {
let val = if v.is_empty() { 0.0 } else { v[0] };
for i in 0..nx {
for j in 0..ny {
if i == 0 || i == nx - 1 || j == 0 || j == ny - 1 {
values[i * ny + j] = val;
}
}
}
}
_ => {}
}
}
BoundaryCondition::Neumann(_) | BoundaryCondition::Periodic => {
}
}
}
Ok(values)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
fn device() -> CpuDevice {
CpuDevice::new()
}
fn scalar_tensor(val: f64) -> Tensor<CpuRuntime> {
let dev = device();
Tensor::<CpuRuntime>::from_slice(&[val], &[1], &dev)
}
#[test]
fn test_side_conditions_default_dirichlet() {
let bcs: Vec<BoundarySpec<CpuRuntime>> = vec![];
let sides = side_conditions_2d(&bcs).unwrap();
assert_eq!(sides, [SideBc::Dirichlet; 4]);
}
#[test]
fn test_side_conditions_per_side_mixed() {
let bcs = vec![
BoundarySpec {
side: BoundarySide::Left,
condition: BoundaryCondition::Dirichlet(scalar_tensor(0.0)),
},
BoundarySpec {
side: BoundarySide::Right,
condition: BoundaryCondition::Neumann(scalar_tensor(2.5)),
},
BoundarySpec {
side: BoundarySide::Bottom,
condition: BoundaryCondition::Dirichlet(scalar_tensor(1.0)),
},
BoundarySpec {
side: BoundarySide::Top,
condition: BoundaryCondition::Neumann(scalar_tensor(-1.0)),
},
];
let sides = side_conditions_2d(&bcs).unwrap();
assert_eq!(sides[0], SideBc::Dirichlet);
assert_eq!(sides[1], SideBc::Neumann(2.5));
assert_eq!(sides[2], SideBc::Dirichlet);
assert_eq!(sides[3], SideBc::Neumann(-1.0));
}
#[test]
fn test_side_conditions_periodic_paired_ok() {
let bcs: Vec<BoundarySpec<CpuRuntime>> = vec![
BoundarySpec {
side: BoundarySide::Left,
condition: BoundaryCondition::Periodic,
},
BoundarySpec {
side: BoundarySide::Right,
condition: BoundaryCondition::Periodic,
},
];
let sides = side_conditions_2d(&bcs).unwrap();
assert_eq!(sides[0], SideBc::Periodic);
assert_eq!(sides[1], SideBc::Periodic);
assert_eq!(sides[2], SideBc::Dirichlet);
assert_eq!(sides[3], SideBc::Dirichlet);
}
#[test]
fn test_side_conditions_periodic_unpaired_errors() {
let bcs: Vec<BoundarySpec<CpuRuntime>> = vec![BoundarySpec {
side: BoundarySide::Left,
condition: BoundaryCondition::Periodic,
}];
assert!(side_conditions_2d(&bcs).is_err());
}
#[test]
fn test_extract_values_skips_non_dirichlet() {
let bcs = vec![BoundarySpec {
side: BoundarySide::All,
condition: BoundaryCondition::Neumann(scalar_tensor(3.0)),
}];
let values = extract_boundary_values_2d(&bcs, 5, 5).unwrap();
assert!(values.iter().all(|&v| v == 0.0));
}
#[test]
fn test_extract_values_dirichlet_side() {
let bcs = vec![BoundarySpec {
side: BoundarySide::Left,
condition: BoundaryCondition::Dirichlet(Tensor::<CpuRuntime>::from_slice(
&[7.0, 7.0, 7.0],
&[3],
&device(),
)),
}];
let (nx, ny) = (4, 3);
let values = extract_boundary_values_2d(&bcs, nx, ny).unwrap();
for (j, &v) in values.iter().take(ny).enumerate() {
assert_eq!(v, 7.0, "left column node j={j}");
}
}
}