use crate::DType;
use numr::ops::{LinalgOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::pde::error::{PdeError, PdeResult};
use crate::pde::types::{BoundarySpec, FdmOptions, FemResult};
use super::boundary::extract_dirichlet_1d_bcs;
pub fn fem_1d_impl<R, C>(
client: &C,
f_rhs: &Tensor<R>,
x_nodes: &Tensor<R>,
boundary: &[BoundarySpec<R>],
_options: &FdmOptions,
) -> PdeResult<FemResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + LinalgOps<R> + RuntimeClient<R>,
{
let n = x_nodes.shape()[0];
let device = client.device();
if n < 3 {
return Err(PdeError::InvalidGrid {
context: format!("Need at least 3 nodes for 1D FEM, got {}", n),
});
}
let x_data: Vec<f64> = x_nodes.to_vec();
let f_data: Vec<f64> = f_rhs.to_vec();
let (bc_left, bc_right) = extract_dirichlet_1d_bcs(boundary, "1D FEM")?;
let n_interior = n - 2;
let mut k_data = vec![0.0; n_interior * n_interior];
let mut f_vec = vec![0.0; n_interior];
for e in 0..(n - 1) {
let h = x_data[e + 1] - x_data[e];
if h <= 0.0 {
return Err(PdeError::InvalidGrid {
context: format!("Non-positive element length at element {}", e),
});
}
let ke = 1.0 / h;
let global = [e as i64 - 1, e as i64];
let fe = [h * f_data[e] / 2.0, h * f_data[e + 1] / 2.0];
let local_k = [[ke, -ke], [-ke, ke]];
for (li, &gi) in global.iter().enumerate() {
if gi < 0 || gi >= n_interior as i64 {
if gi < 0 {
let other = global[1 - li];
if other >= 0 && (other as usize) < n_interior {
f_vec[other as usize] -= local_k[1 - li][li] * bc_left;
f_vec[other as usize] += fe[1 - li];
}
} else {
let other = global[1 - li];
if other >= 0 && (other as usize) < n_interior {
f_vec[other as usize] -= local_k[1 - li][li] * bc_right;
f_vec[other as usize] += fe[1 - li];
}
}
continue;
}
let gi_usize = gi as usize;
f_vec[gi_usize] += fe[li];
for (lj, &gj) in global.iter().enumerate() {
if gj >= 0 && (gj as usize) < n_interior {
k_data[gi_usize * n_interior + gj as usize] += local_k[li][lj];
}
}
}
}
let k_tensor = Tensor::<R>::from_slice(&k_data, &[n_interior, n_interior], device);
let f_tensor = Tensor::<R>::from_slice(&f_vec, &[n_interior, 1], device);
let u_interior = client.solve(&k_tensor, &f_tensor).map_err(PdeError::from)?;
let u_int_data: Vec<f64> = u_interior.to_vec();
let mut full_solution = vec![0.0; n];
full_solution[0] = bc_left;
full_solution[1..(n_interior + 1)].copy_from_slice(&u_int_data[..n_interior]);
full_solution[n - 1] = bc_right;
let solution = Tensor::<R>::from_slice(&full_solution, &[n], device);
Ok(FemResult {
solution,
nodes: x_nodes.clone(),
iterations: 1, residual_norm: 0.0,
})
}