Skip to main content

oxirs_physics/fem/
mod.rs

1//! Finite Element Method (FEM) for Structural and Thermal Analysis
2//!
3//! Provides simplified 1D/2D FEM solvers using direct Gaussian elimination
4//! for the global stiffness system. Supports Bar1D, Beam1D, Triangle2D,
5//! and Quad2D element types.
6//!
7//! # Example — 1D bar under axial load
8//!
9//! ```rust
10//! use oxirs_physics::fem::{
11//!     FemMesh, FemMaterial, ElementType, DofType, FemSolver, NodalLoad,
12//! };
13//!
14//! let mut mesh = FemMesh::new();
15//! let mat = FemMaterial {
16//!     youngs_modulus: 200e9,
17//!     poissons_ratio: 0.3,
18//!     thermal_conductivity: 50.0,
19//!     density: 7850.0,
20//! };
21//!
22//! let n0 = mesh.add_node(0.0, 0.0);
23//! let n1 = mesh.add_node(1.0, 0.0);
24//! mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
25//! mesh.add_element(vec![n0, n1], mat, ElementType::Bar1D);
26//!
27//! let solver = FemSolver::new();
28//! let loads = vec![NodalLoad { node_id: n1, fx: 1000.0, fy: 0.0 }];
29//! let sol = solver.solve_static(&mesh, &loads);
30//! assert!(sol.converged);
31//! ```
32
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35
36// ─────────────────────────────────────────────
37// Public Data Model
38// ─────────────────────────────────────────────
39
40/// Degree-of-freedom type for boundary conditions.
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42pub enum DofType {
43    /// Structural displacement (m).
44    Displacement,
45    /// Temperature (K).
46    Temperature,
47    /// Pressure (Pa).
48    Pressure,
49}
50
51/// Fixed boundary condition applied to a node.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct BoundaryCondition {
54    /// Which DOF is constrained.
55    pub dof: DofType,
56    /// Prescribed value.
57    pub value: f64,
58}
59
60/// FEM node with optional boundary condition.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct FemNode {
63    /// Node identifier (0-based).
64    pub id: usize,
65    /// X coordinate (m).
66    pub x: f64,
67    /// Y coordinate (m).
68    pub y: f64,
69    /// Optional fixed boundary condition.
70    pub boundary_condition: Option<BoundaryCondition>,
71}
72
73/// Isotropic linear-elastic material properties.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct FemMaterial {
76    /// Young's modulus E (Pa).
77    pub youngs_modulus: f64,
78    /// Poisson's ratio ν (dimensionless).
79    pub poissons_ratio: f64,
80    /// Thermal conductivity k (W / m·K).
81    pub thermal_conductivity: f64,
82    /// Mass density ρ (kg / m³).
83    pub density: f64,
84}
85
86impl Default for FemMaterial {
87    fn default() -> Self {
88        Self {
89            youngs_modulus: 200e9,
90            poissons_ratio: 0.3,
91            thermal_conductivity: 50.0,
92            density: 7850.0,
93        }
94    }
95}
96
97/// Supported element types.
98#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
99pub enum ElementType {
100    /// 1-D two-node truss / axial bar.
101    Bar1D,
102    /// 2-D three-node constant-strain triangle (CST).
103    Triangle2D,
104    /// 2-D four-node bilinear quadrilateral.
105    Quad2D,
106    /// 1-D Euler-Bernoulli beam (2 DOF per node: v, θ).
107    Beam1D,
108}
109
110/// Finite element connecting two or more nodes.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct FemElement {
113    /// Element identifier (0-based).
114    pub id: usize,
115    /// Ordered list of participating node IDs.
116    pub node_ids: Vec<usize>,
117    /// Material assigned to this element.
118    pub material: FemMaterial,
119    /// Topology/formulation type.
120    pub element_type: ElementType,
121}
122
123/// Applied nodal force (structural).
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct NodalLoad {
126    /// Target node id.
127    pub node_id: usize,
128    /// Force in x-direction (N).
129    pub fx: f64,
130    /// Force in y-direction (N).
131    pub fy: f64,
132}
133
134/// Uniform heat flux applied to an element (W / m²).
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ElementHeatFlux {
137    /// Target element id.
138    pub element_id: usize,
139    /// Heat flux magnitude (W / m²).
140    pub q: f64,
141}
142
143// ─────────────────────────────────────────────
144// Mesh
145// ─────────────────────────────────────────────
146
147/// Finite element mesh: collection of nodes and elements.
148#[derive(Debug, Clone, Serialize, Deserialize, Default)]
149pub struct FemMesh {
150    /// All mesh nodes.
151    pub nodes: Vec<FemNode>,
152    /// All mesh elements.
153    pub elements: Vec<FemElement>,
154}
155
156impl FemMesh {
157    /// Create an empty mesh.
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Add a node at `(x, y)` and return its id.
163    pub fn add_node(&mut self, x: f64, y: f64) -> usize {
164        let id = self.nodes.len();
165        self.nodes.push(FemNode {
166            id,
167            x,
168            y,
169            boundary_condition: None,
170        });
171        id
172    }
173
174    /// Add an element and return its id.
175    pub fn add_element(
176        &mut self,
177        node_ids: Vec<usize>,
178        material: FemMaterial,
179        element_type: ElementType,
180    ) -> usize {
181        let id = self.elements.len();
182        self.elements.push(FemElement {
183            id,
184            node_ids,
185            material,
186            element_type,
187        });
188        id
189    }
190
191    /// Apply a fixed boundary condition to a node DOF.
192    pub fn set_boundary_condition(&mut self, node_id: usize, dof: DofType, value: f64) {
193        if let Some(node) = self.nodes.get_mut(node_id) {
194            node.boundary_condition = Some(BoundaryCondition { dof, value });
195        }
196    }
197
198    /// Number of nodes.
199    pub fn node_count(&self) -> usize {
200        self.nodes.len()
201    }
202
203    /// Number of elements.
204    pub fn element_count(&self) -> usize {
205        self.elements.len()
206    }
207}
208
209// ─────────────────────────────────────────────
210// Solution types
211// ─────────────────────────────────────────────
212
213/// Structural static FEM solution.
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct FemSolution {
216    /// Displacement (dx, dy) per node (m).
217    pub displacements: Vec<(f64, f64)>,
218    /// Von Mises stress per element (Pa).
219    pub von_mises_stress: Vec<f64>,
220    /// Maximum displacement magnitude (m).
221    pub max_displacement: f64,
222    /// Whether the solver converged.
223    pub converged: bool,
224}
225
226/// Thermal FEM solution.
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct ThermalSolution {
229    /// Nodal temperatures (K).
230    pub temperatures: Vec<f64>,
231    /// Heat flux (qx, qy) per element (W / m²).
232    pub heat_flux: Vec<(f64, f64)>,
233    /// Maximum nodal temperature (K).
234    pub max_temperature: f64,
235    /// Whether the solver converged.
236    pub converged: bool,
237}
238
239// ─────────────────────────────────────────────
240// Gaussian Elimination (dense, in-place)
241// ─────────────────────────────────────────────
242
243/// Solve `A x = b` by Gaussian elimination with partial pivoting.
244/// Returns `None` if the system is singular.
245fn gaussian_elimination(a: &mut [Vec<f64>], b: &mut [f64]) -> Option<Vec<f64>> {
246    let n = b.len();
247    for col in 0..n {
248        // Partial pivot
249        let mut max_row = col;
250        let mut max_val = a[col][col].abs();
251        for (row, a_row) in a.iter().enumerate().skip(col + 1).take(n - col - 1) {
252            if a_row[col].abs() > max_val {
253                max_val = a_row[col].abs();
254                max_row = row;
255            }
256        }
257        if max_val < 1e-30 {
258            return None; // singular
259        }
260        a.swap(col, max_row);
261        b.swap(col, max_row);
262
263        let pivot = a[col][col];
264        for row in (col + 1)..n {
265            let factor = a[row][col] / pivot;
266            // Collect the pivot row values to avoid borrow conflicts
267            let pivot_row_vals: Vec<f64> = a[col][col..n].to_vec();
268            for (a_row_k, &av) in a[row][col..n].iter_mut().zip(pivot_row_vals.iter()) {
269                *a_row_k -= factor * av;
270            }
271            b[row] -= factor * b[col];
272        }
273    }
274
275    // Back-substitution
276    let mut x = vec![0.0f64; n];
277    for i in (0..n).rev() {
278        x[i] = b[i];
279        for j in (i + 1)..n {
280            x[i] -= a[i][j] * x[j];
281        }
282        x[i] /= a[i][i];
283    }
284    Some(x)
285}
286
287// ─────────────────────────────────────────────
288// Element stiffness matrices
289// ─────────────────────────────────────────────
290
291/// Compute Bar1D local stiffness matrix (2×2) and map it to global DOFs.
292/// Global DOFs: node i → DOF 2i (x), node j → DOF 2j (x only for 1D).
293fn bar1d_element_stiffness(
294    n_dofs: usize,
295    elem: &FemElement,
296    nodes: &[FemNode],
297    k_global: &mut [Vec<f64>],
298    cross_section_area: f64,
299) {
300    if elem.node_ids.len() < 2 {
301        return;
302    }
303    let ni = elem.node_ids[0];
304    let nj = elem.node_ids[1];
305    let xi = nodes[ni].x;
306    let xj = nodes[nj].x;
307    let yi = nodes[ni].y;
308    let yj = nodes[nj].y;
309
310    let dx = xj - xi;
311    let dy = yj - yi;
312    let length = (dx * dx + dy * dy).sqrt();
313    if length < 1e-30 {
314        return;
315    }
316
317    let ae_over_l = elem.material.youngs_modulus * cross_section_area / length;
318    let c = dx / length;
319    let s = dy / length;
320
321    // 4×4 local stiffness (in global x-y for 2 nodes × 2 DOFs each)
322    let dofs = [2 * ni, 2 * ni + 1, 2 * nj, 2 * nj + 1];
323    let ke_local = [
324        [c * c, c * s, -c * c, -c * s],
325        [c * s, s * s, -c * s, -s * s],
326        [-c * c, -c * s, c * c, c * s],
327        [-c * s, -s * s, c * s, s * s],
328    ];
329
330    for (a, &ga) in dofs.iter().enumerate() {
331        for (b, &gb) in dofs.iter().enumerate() {
332            if ga < n_dofs && gb < n_dofs {
333                k_global[ga][gb] += ae_over_l * ke_local[a][b];
334            }
335        }
336    }
337}
338
339/// Compute Triangle2D CST element stiffness and assemble into global matrix.
340fn triangle2d_element_stiffness(
341    n_dofs: usize,
342    elem: &FemElement,
343    nodes: &[FemNode],
344    k_global: &mut [Vec<f64>],
345    thickness: f64,
346) {
347    if elem.node_ids.len() < 3 {
348        return;
349    }
350    let ni = elem.node_ids[0];
351    let nj = elem.node_ids[1];
352    let nk = elem.node_ids[2];
353
354    let xi = nodes[ni].x;
355    let yi = nodes[ni].y;
356    let xj = nodes[nj].x;
357    let yj = nodes[nj].y;
358    let xk = nodes[nk].x;
359    let yk = nodes[nk].y;
360
361    let area = 0.5 * ((xj - xi) * (yk - yi) - (xk - xi) * (yj - yi));
362    if area.abs() < 1e-30 {
363        return;
364    }
365
366    let e = elem.material.youngs_modulus;
367    let nu = elem.material.poissons_ratio;
368    let factor = e / (1.0 - nu * nu);
369
370    // Shape function derivatives (constant for CST)
371    let b_mat = [
372        [yj - yk, 0.0, yk - yi, 0.0, yi - yj, 0.0],
373        [0.0, xk - xj, 0.0, xi - xk, 0.0, xj - xi],
374        [xk - xj, yj - yk, xi - xk, yk - yi, xj - xi, yi - yj],
375    ];
376    let scale = 1.0 / (2.0 * area);
377    let b_mat: Vec<Vec<f64>> = b_mat
378        .iter()
379        .map(|row| row.iter().map(|&v| v * scale).collect())
380        .collect();
381
382    // Constitutive matrix D (plane stress)
383    let d_mat = [
384        [factor, factor * nu, 0.0],
385        [factor * nu, factor, 0.0],
386        [0.0, 0.0, factor * (1.0 - nu) / 2.0],
387    ];
388
389    // Ke = t * A * B^T * D * B  (6×6)
390    let mut ke = vec![vec![0.0f64; 6]; 6];
391    for i in 0..6 {
392        for j in 0..6 {
393            let mut sum = 0.0;
394            for p in 0..3 {
395                for q in 0..3 {
396                    sum += b_mat[p][i] * d_mat[p][q] * b_mat[q][j];
397                }
398            }
399            ke[i][j] = thickness * area.abs() * sum;
400        }
401    }
402
403    let dofs = [2 * ni, 2 * ni + 1, 2 * nj, 2 * nj + 1, 2 * nk, 2 * nk + 1];
404    for (a, &ga) in dofs.iter().enumerate() {
405        for (b, &gb) in dofs.iter().enumerate() {
406            if ga < n_dofs && gb < n_dofs {
407                k_global[ga][gb] += ke[a][b];
408            }
409        }
410    }
411}
412
413// ─────────────────────────────────────────────
414// Thermal element stiffness
415// ─────────────────────────────────────────────
416
417/// Bar1D thermal conductivity element (1 DOF per node = temperature).
418fn bar1d_thermal_stiffness(
419    n_nodes: usize,
420    elem: &FemElement,
421    nodes: &[FemNode],
422    k_global: &mut [Vec<f64>],
423    cross_section_area: f64,
424) {
425    if elem.node_ids.len() < 2 {
426        return;
427    }
428    let ni = elem.node_ids[0];
429    let nj = elem.node_ids[1];
430    let dx = nodes[nj].x - nodes[ni].x;
431    let dy = nodes[nj].y - nodes[ni].y;
432    let length = (dx * dx + dy * dy).sqrt();
433    if length < 1e-30 {
434        return;
435    }
436
437    let k_coeff = elem.material.thermal_conductivity * cross_section_area / length;
438    // 2×2 conductivity matrix
439    let pairs = [
440        (ni, ni, k_coeff),
441        (nj, nj, k_coeff),
442        (ni, nj, -k_coeff),
443        (nj, ni, -k_coeff),
444    ];
445    for (r, c, val) in pairs {
446        if r < n_nodes && c < n_nodes {
447            k_global[r][c] += val;
448        }
449    }
450}
451
452// ─────────────────────────────────────────────
453// Solver
454// ─────────────────────────────────────────────
455
456/// FEM solver: assembles global stiffness and solves via Gaussian elimination.
457#[derive(Debug, Clone, Default)]
458pub struct FemSolver {
459    /// Cross-section area used for Bar1D elements (m²). Defaults to 1e-4 m².
460    pub cross_section_area: f64,
461    /// Thickness for 2-D elements (m). Defaults to 0.01 m.
462    pub thickness: f64,
463}
464
465impl FemSolver {
466    /// Create a solver with default parameters.
467    pub fn new() -> Self {
468        Self {
469            cross_section_area: 1e-4,
470            thickness: 0.01,
471        }
472    }
473
474    /// Create a solver with explicit cross-section area and thickness.
475    pub fn with_params(cross_section_area: f64, thickness: f64) -> Self {
476        Self {
477            cross_section_area,
478            thickness,
479        }
480    }
481
482    /// Solve a static structural problem.
483    pub fn solve_static(&self, mesh: &FemMesh, loads: &[NodalLoad]) -> FemSolution {
484        let n_nodes = mesh.node_count();
485        let n_dofs = 2 * n_nodes; // x and y DOF per node
486
487        // Assemble global stiffness
488        let mut k = vec![vec![0.0f64; n_dofs]; n_dofs];
489        for elem in &mesh.elements {
490            match elem.element_type {
491                ElementType::Bar1D | ElementType::Beam1D => {
492                    bar1d_element_stiffness(
493                        n_dofs,
494                        elem,
495                        &mesh.nodes,
496                        &mut k,
497                        self.cross_section_area,
498                    );
499                }
500                ElementType::Triangle2D => {
501                    triangle2d_element_stiffness(n_dofs, elem, &mesh.nodes, &mut k, self.thickness);
502                }
503                ElementType::Quad2D => {
504                    // Approximate as two triangles (CST split)
505                    if elem.node_ids.len() >= 4 {
506                        let tri_a = FemElement {
507                            id: elem.id,
508                            node_ids: vec![elem.node_ids[0], elem.node_ids[1], elem.node_ids[2]],
509                            material: elem.material.clone(),
510                            element_type: ElementType::Triangle2D,
511                        };
512                        let tri_b = FemElement {
513                            id: elem.id,
514                            node_ids: vec![elem.node_ids[0], elem.node_ids[2], elem.node_ids[3]],
515                            material: elem.material.clone(),
516                            element_type: ElementType::Triangle2D,
517                        };
518                        triangle2d_element_stiffness(
519                            n_dofs,
520                            &tri_a,
521                            &mesh.nodes,
522                            &mut k,
523                            self.thickness,
524                        );
525                        triangle2d_element_stiffness(
526                            n_dofs,
527                            &tri_b,
528                            &mesh.nodes,
529                            &mut k,
530                            self.thickness,
531                        );
532                    }
533                }
534            }
535        }
536
537        // Build force vector
538        let mut f = vec![0.0f64; n_dofs];
539        for load in loads {
540            let dof_x = 2 * load.node_id;
541            let dof_y = 2 * load.node_id + 1;
542            if dof_x < n_dofs {
543                f[dof_x] += load.fx;
544            }
545            if dof_y < n_dofs {
546                f[dof_y] += load.fy;
547            }
548        }
549
550        // Collect all boundary conditions (DOF index → prescribed value)
551        let mut bc_map: HashMap<usize, f64> = HashMap::new();
552        for node in &mesh.nodes {
553            if let Some(ref bc) = node.boundary_condition {
554                match bc.dof {
555                    DofType::Displacement => {
556                        // Pin both x and y for a displacement BC
557                        bc_map.insert(2 * node.id, bc.value);
558                        bc_map.insert(2 * node.id + 1, bc.value);
559                    }
560                    DofType::Temperature | DofType::Pressure => {
561                        bc_map.insert(2 * node.id, bc.value);
562                    }
563                }
564            }
565        }
566
567        // Apply boundary conditions (large-number / penalty method)
568        let penalty = 1e30;
569        for (&dof, &val) in &bc_map {
570            if dof < n_dofs {
571                k[dof][dof] += penalty;
572                f[dof] += penalty * val;
573            }
574        }
575
576        // Stabilize unconstrained zero-stiffness DOFs (e.g. transverse DOFs
577        // of pure 1D bar elements) to prevent singularity.  A small spring
578        // stiffness 1e-6 × max_diag is added to any zero diagonal entry that
579        // has no prescribed BC, yielding a near-zero (but well-defined) free
580        // displacement for that DOF.
581        {
582            let max_diag = (0..n_dofs).map(|i| k[i][i].abs()).fold(0.0f64, f64::max);
583            let stab = if max_diag > 0.0 { max_diag * 1e-6 } else { 1.0 };
584            for (i, k_row) in k.iter_mut().enumerate().take(n_dofs) {
585                if k_row[i].abs() < 1e-30 && !bc_map.contains_key(&i) {
586                    k_row[i] += stab;
587                }
588            }
589        }
590
591        // Solve
592        let mut k_copy = k.clone();
593        let mut f_copy = f.clone();
594        let solution = gaussian_elimination(&mut k_copy, &mut f_copy);
595
596        match solution {
597            None => FemSolution {
598                displacements: vec![(0.0, 0.0); n_nodes],
599                von_mises_stress: vec![0.0; mesh.element_count()],
600                max_displacement: 0.0,
601                converged: false,
602            },
603            Some(u) => {
604                let displacements: Vec<(f64, f64)> =
605                    (0..n_nodes).map(|i| (u[2 * i], u[2 * i + 1])).collect();
606
607                let von_mises_stress =
608                    compute_von_mises(&u, mesh, self.cross_section_area, self.thickness);
609
610                let max_displacement = displacements
611                    .iter()
612                    .map(|(dx, dy)| (dx * dx + dy * dy).sqrt())
613                    .fold(0.0f64, f64::max);
614
615                FemSolution {
616                    displacements,
617                    von_mises_stress,
618                    max_displacement,
619                    converged: true,
620                }
621            }
622        }
623    }
624
625    /// Solve a steady-state thermal conduction problem.
626    pub fn solve_thermal(&self, mesh: &FemMesh, heat_flux: &[ElementHeatFlux]) -> ThermalSolution {
627        let n_nodes = mesh.node_count();
628
629        // Assemble global thermal conductivity matrix (1 DOF per node)
630        let mut k = vec![vec![0.0f64; n_nodes]; n_nodes];
631        for elem in &mesh.elements {
632            bar1d_thermal_stiffness(n_nodes, elem, &mesh.nodes, &mut k, self.cross_section_area);
633        }
634
635        // Build heat load vector: distribute element heat flux as nodal load
636        let mut q_vec = vec![0.0f64; n_nodes];
637        for hf in heat_flux {
638            if let Some(elem) = mesh.elements.get(hf.element_id) {
639                let n = elem.node_ids.len() as f64;
640                for &nid in &elem.node_ids {
641                    if nid < n_nodes {
642                        q_vec[nid] += hf.q / n;
643                    }
644                }
645            }
646        }
647
648        // Apply temperature boundary conditions (penalty method)
649        let penalty = 1e30;
650        for node in &mesh.nodes {
651            if let Some(ref bc) = node.boundary_condition {
652                if bc.dof == DofType::Temperature {
653                    let dof = node.id;
654                    if dof < n_nodes {
655                        k[dof][dof] += penalty;
656                        q_vec[dof] += penalty * bc.value;
657                    }
658                }
659            }
660        }
661
662        let mut k_copy = k.clone();
663        let mut q_copy = q_vec.clone();
664        let solution = gaussian_elimination(&mut k_copy, &mut q_copy);
665
666        match solution {
667            None => ThermalSolution {
668                temperatures: vec![0.0; n_nodes],
669                heat_flux: vec![(0.0, 0.0); mesh.element_count()],
670                max_temperature: 0.0,
671                converged: false,
672            },
673            Some(temps) => {
674                let element_heat_flux: Vec<(f64, f64)> = mesh
675                    .elements
676                    .iter()
677                    .map(|elem| compute_element_heat_flux(elem, &temps, &mesh.nodes))
678                    .collect();
679
680                let max_temperature = temps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
681
682                ThermalSolution {
683                    temperatures: temps,
684                    heat_flux: element_heat_flux,
685                    max_temperature,
686                    converged: true,
687                }
688            }
689        }
690    }
691}
692
693// ─────────────────────────────────────────────
694// Post-processing helpers
695// ─────────────────────────────────────────────
696
697/// Compute approximate Von Mises stress for each element.
698fn compute_von_mises(
699    u: &[f64],
700    mesh: &FemMesh,
701    cross_section_area: f64,
702    thickness: f64,
703) -> Vec<f64> {
704    mesh.elements
705        .iter()
706        .map(|elem| element_von_mises(u, elem, &mesh.nodes, cross_section_area, thickness))
707        .collect()
708}
709
710/// Compute Von Mises stress for a single element.
711fn element_von_mises(
712    u: &[f64],
713    elem: &FemElement,
714    nodes: &[FemNode],
715    cross_section_area: f64,
716    _thickness: f64,
717) -> f64 {
718    match elem.element_type {
719        ElementType::Bar1D | ElementType::Beam1D => {
720            if elem.node_ids.len() < 2 {
721                return 0.0;
722            }
723            let ni = elem.node_ids[0];
724            let nj = elem.node_ids[1];
725            let xi = nodes[ni].x;
726            let xj = nodes[nj].x;
727            let yi = nodes[ni].y;
728            let yj = nodes[nj].y;
729            let dx = xj - xi;
730            let dy = yj - yi;
731            let length = (dx * dx + dy * dy).sqrt().max(1e-30);
732            let c = dx / length;
733            let s = dy / length;
734
735            // Axial strain from nodal displacements
736            let u_i = if 2 * ni + 1 < u.len() {
737                c * u[2 * ni] + s * u[2 * ni + 1]
738            } else {
739                0.0
740            };
741            let u_j = if 2 * nj + 1 < u.len() {
742                c * u[2 * nj] + s * u[2 * nj + 1]
743            } else {
744                0.0
745            };
746            let strain = (u_j - u_i) / length;
747            (elem.material.youngs_modulus * strain).abs()
748        }
749        ElementType::Triangle2D | ElementType::Quad2D => {
750            // Simplified: use displacement magnitude divided by element size
751            if elem.node_ids.is_empty() {
752                return 0.0;
753            }
754            let avg_disp: f64 = elem
755                .node_ids
756                .iter()
757                .map(|&nid| {
758                    let dx = u.get(2 * nid).copied().unwrap_or(0.0);
759                    let dy = u.get(2 * nid + 1).copied().unwrap_or(0.0);
760                    (dx * dx + dy * dy).sqrt()
761                })
762                .sum::<f64>()
763                / elem.node_ids.len() as f64;
764
765            // Characteristic element size (average side length)
766            let char_len = cross_section_area.sqrt().max(1e-10);
767            elem.material.youngs_modulus * avg_disp / char_len
768        }
769    }
770}
771
772/// Compute heat flux vector for a bar element.
773fn compute_element_heat_flux(elem: &FemElement, temps: &[f64], nodes: &[FemNode]) -> (f64, f64) {
774    if elem.node_ids.len() < 2 {
775        return (0.0, 0.0);
776    }
777    let ni = elem.node_ids[0];
778    let nj = elem.node_ids[1];
779    if ni >= temps.len() || nj >= temps.len() {
780        return (0.0, 0.0);
781    }
782    let xi = nodes[ni].x;
783    let xj = nodes[nj].x;
784    let yi = nodes[ni].y;
785    let yj = nodes[nj].y;
786    let dx = xj - xi;
787    let dy = yj - yi;
788    let length = (dx * dx + dy * dy).sqrt().max(1e-30);
789    let dt_dl = (temps[nj] - temps[ni]) / length;
790    let k = elem.material.thermal_conductivity;
791    let qx = -k * dt_dl * (dx / length);
792    let qy = -k * dt_dl * (dy / length);
793    (qx, qy)
794}
795
796// ─────────────────────────────────────────────
797// Tests
798// ─────────────────────────────────────────────
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803
804    // ---- Helper ----
805
806    fn steel() -> FemMaterial {
807        FemMaterial {
808            youngs_modulus: 200e9,
809            poissons_ratio: 0.3,
810            thermal_conductivity: 50.0,
811            density: 7850.0,
812        }
813    }
814
815    fn aluminium() -> FemMaterial {
816        FemMaterial {
817            youngs_modulus: 70e9,
818            poissons_ratio: 0.33,
819            thermal_conductivity: 205.0,
820            density: 2700.0,
821        }
822    }
823
824    // ────────────────────────────────────────
825    // Mesh tests
826    // ────────────────────────────────────────
827
828    #[test]
829    fn test_mesh_new_is_empty() {
830        let mesh = FemMesh::new();
831        assert_eq!(mesh.node_count(), 0);
832        assert_eq!(mesh.element_count(), 0);
833    }
834
835    #[test]
836    fn test_mesh_add_nodes() {
837        let mut mesh = FemMesh::new();
838        let id0 = mesh.add_node(0.0, 0.0);
839        let id1 = mesh.add_node(1.0, 0.0);
840        let id2 = mesh.add_node(0.5, 1.0);
841        assert_eq!(id0, 0);
842        assert_eq!(id1, 1);
843        assert_eq!(id2, 2);
844        assert_eq!(mesh.node_count(), 3);
845    }
846
847    #[test]
848    fn test_mesh_add_element_bar() {
849        let mut mesh = FemMesh::new();
850        let n0 = mesh.add_node(0.0, 0.0);
851        let n1 = mesh.add_node(1.0, 0.0);
852        let eid = mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
853        assert_eq!(eid, 0);
854        assert_eq!(mesh.element_count(), 1);
855    }
856
857    #[test]
858    fn test_mesh_add_triangle_element() {
859        let mut mesh = FemMesh::new();
860        let n0 = mesh.add_node(0.0, 0.0);
861        let n1 = mesh.add_node(1.0, 0.0);
862        let n2 = mesh.add_node(0.5, 1.0);
863        mesh.add_element(vec![n0, n1, n2], steel(), ElementType::Triangle2D);
864        assert_eq!(mesh.element_count(), 1);
865    }
866
867    #[test]
868    fn test_mesh_set_boundary_condition() {
869        let mut mesh = FemMesh::new();
870        let n0 = mesh.add_node(0.0, 0.0);
871        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
872        let bc = mesh.nodes[n0]
873            .boundary_condition
874            .as_ref()
875            .expect("BC should be set");
876        assert_eq!(bc.dof, DofType::Displacement);
877        assert_eq!(bc.value, 0.0);
878    }
879
880    #[test]
881    fn test_boundary_condition_temperature() {
882        let mut mesh = FemMesh::new();
883        let n0 = mesh.add_node(0.0, 0.0);
884        mesh.set_boundary_condition(n0, DofType::Temperature, 300.0);
885        let bc = mesh.nodes[n0]
886            .boundary_condition
887            .as_ref()
888            .expect("BC must be set");
889        assert_eq!(bc.dof, DofType::Temperature);
890        assert!((bc.value - 300.0).abs() < 1e-10);
891    }
892
893    #[test]
894    fn test_boundary_condition_pressure() {
895        let mut mesh = FemMesh::new();
896        let n0 = mesh.add_node(0.0, 0.0);
897        mesh.set_boundary_condition(n0, DofType::Pressure, 101325.0);
898        let bc = mesh.nodes[n0]
899            .boundary_condition
900            .as_ref()
901            .expect("BC must be set");
902        assert_eq!(bc.dof, DofType::Pressure);
903        assert!((bc.value - 101325.0).abs() < 1.0);
904    }
905
906    #[test]
907    fn test_boundary_condition_invalid_node_is_noop() {
908        let mut mesh = FemMesh::new();
909        // Node 99 does not exist — should not panic
910        mesh.set_boundary_condition(99, DofType::Displacement, 0.0);
911        assert_eq!(mesh.node_count(), 0);
912    }
913
914    #[test]
915    fn test_mesh_multiple_elements() {
916        let mut mesh = FemMesh::new();
917        let n0 = mesh.add_node(0.0, 0.0);
918        let n1 = mesh.add_node(1.0, 0.0);
919        let n2 = mesh.add_node(2.0, 0.0);
920        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
921        mesh.add_element(vec![n1, n2], aluminium(), ElementType::Bar1D);
922        assert_eq!(mesh.element_count(), 2);
923    }
924
925    // ────────────────────────────────────────
926    // Solver — static structural
927    // ────────────────────────────────────────
928
929    /// 1-D bar fixed at left, force F at right.
930    /// Analytical: u_right = F·L / (E·A)
931    #[test]
932    fn test_bar1d_axial_displacement() {
933        let a = 1e-4; // 1 cm²
934        let e = 200e9; // steel
935        let l = 1.0;
936        let force = 10_000.0; // 10 kN
937
938        let mut mesh = FemMesh::new();
939        let n0 = mesh.add_node(0.0, 0.0);
940        let n1 = mesh.add_node(l, 0.0);
941        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
942        mesh.add_element(
943            vec![n0, n1],
944            FemMaterial {
945                youngs_modulus: e,
946                ..FemMaterial::default()
947            },
948            ElementType::Bar1D,
949        );
950
951        let solver = FemSolver::with_params(a, 0.01);
952        let loads = vec![NodalLoad {
953            node_id: n1,
954            fx: force,
955            fy: 0.0,
956        }];
957        let sol = solver.solve_static(&mesh, &loads);
958
959        assert!(sol.converged, "Solver should converge");
960        let expected = force * l / (e * a);
961        let actual = sol.displacements[n1].0;
962        assert!(
963            (actual - expected).abs() / expected < 0.01,
964            "Expected ~{expected:.6e} m, got {actual:.6e} m"
965        );
966    }
967
968    #[test]
969    fn test_bar1d_zero_force_zero_displacement() {
970        let mut mesh = FemMesh::new();
971        let n0 = mesh.add_node(0.0, 0.0);
972        let n1 = mesh.add_node(1.0, 0.0);
973        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
974        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
975
976        let solver = FemSolver::new();
977        let sol = solver.solve_static(&mesh, &[]);
978        assert!(sol.converged);
979        assert!(sol.max_displacement < 1e-20);
980    }
981
982    #[test]
983    fn test_two_bar_series() {
984        // Two bars in series, all horizontal, fixed at left end, force at right.
985        let a = 1e-4;
986        let e = 200e9;
987        let l = 0.5;
988        let force = 5000.0;
989
990        let mut mesh = FemMesh::new();
991        let n0 = mesh.add_node(0.0, 0.0);
992        let n1 = mesh.add_node(l, 0.0);
993        let n2 = mesh.add_node(2.0 * l, 0.0);
994        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
995        let mat = FemMaterial {
996            youngs_modulus: e,
997            ..FemMaterial::default()
998        };
999        mesh.add_element(vec![n0, n1], mat.clone(), ElementType::Bar1D);
1000        mesh.add_element(vec![n1, n2], mat, ElementType::Bar1D);
1001
1002        let solver = FemSolver::with_params(a, 0.01);
1003        let loads = vec![NodalLoad {
1004            node_id: n2,
1005            fx: force,
1006            fy: 0.0,
1007        }];
1008        let sol = solver.solve_static(&mesh, &loads);
1009        assert!(sol.converged);
1010        // Total deformation = F * (L1 + L2) / (E * A)
1011        let expected = force * (2.0 * l) / (e * a);
1012        let actual = sol.displacements[n2].0;
1013        assert!(
1014            (actual - expected).abs() / expected < 0.01,
1015            "Two-bar series: expected {expected:.6e}, got {actual:.6e}"
1016        );
1017    }
1018
1019    #[test]
1020    fn test_solver_converged_flag() {
1021        let mut mesh = FemMesh::new();
1022        let n0 = mesh.add_node(0.0, 0.0);
1023        let n1 = mesh.add_node(1.0, 0.0);
1024        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1025        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1026        let solver = FemSolver::new();
1027        let sol = solver.solve_static(&mesh, &[]);
1028        assert!(sol.converged);
1029    }
1030
1031    #[test]
1032    fn test_max_displacement_positive() {
1033        let mut mesh = FemMesh::new();
1034        let n0 = mesh.add_node(0.0, 0.0);
1035        let n1 = mesh.add_node(1.0, 0.0);
1036        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1037        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1038        let solver = FemSolver::new();
1039        let loads = vec![NodalLoad {
1040            node_id: n1,
1041            fx: 10_000.0,
1042            fy: 0.0,
1043        }];
1044        let sol = solver.solve_static(&mesh, &loads);
1045        assert!(sol.max_displacement > 0.0);
1046    }
1047
1048    #[test]
1049    fn test_von_mises_stress_non_negative() {
1050        let mut mesh = FemMesh::new();
1051        let n0 = mesh.add_node(0.0, 0.0);
1052        let n1 = mesh.add_node(1.0, 0.0);
1053        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1054        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1055        let solver = FemSolver::new();
1056        let loads = vec![NodalLoad {
1057            node_id: n1,
1058            fx: 10_000.0,
1059            fy: 0.0,
1060        }];
1061        let sol = solver.solve_static(&mesh, &loads);
1062        for &s in &sol.von_mises_stress {
1063            assert!(s >= 0.0, "Von Mises stress must be non-negative");
1064        }
1065    }
1066
1067    #[test]
1068    fn test_triangle2d_mesh_solves() {
1069        let mut mesh = FemMesh::new();
1070        let n0 = mesh.add_node(0.0, 0.0);
1071        let n1 = mesh.add_node(1.0, 0.0);
1072        let n2 = mesh.add_node(0.5, 1.0);
1073        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1074        mesh.set_boundary_condition(n1, DofType::Displacement, 0.0);
1075        mesh.add_element(vec![n0, n1, n2], steel(), ElementType::Triangle2D);
1076        let solver = FemSolver::with_params(1e-4, 0.01);
1077        let loads = vec![NodalLoad {
1078            node_id: n2,
1079            fx: 0.0,
1080            fy: -500.0,
1081        }];
1082        let sol = solver.solve_static(&mesh, &loads);
1083        assert!(sol.converged);
1084        assert_eq!(sol.displacements.len(), 3);
1085    }
1086
1087    #[test]
1088    fn test_quad2d_mesh_solves() {
1089        let mut mesh = FemMesh::new();
1090        let n0 = mesh.add_node(0.0, 0.0);
1091        let n1 = mesh.add_node(1.0, 0.0);
1092        let n2 = mesh.add_node(1.0, 1.0);
1093        let n3 = mesh.add_node(0.0, 1.0);
1094        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1095        mesh.set_boundary_condition(n1, DofType::Displacement, 0.0);
1096        mesh.add_element(vec![n0, n1, n2, n3], steel(), ElementType::Quad2D);
1097        let solver = FemSolver::with_params(1e-4, 0.01);
1098        let loads = vec![
1099            NodalLoad {
1100                node_id: n2,
1101                fx: 0.0,
1102                fy: -1000.0,
1103            },
1104            NodalLoad {
1105                node_id: n3,
1106                fx: 0.0,
1107                fy: -1000.0,
1108            },
1109        ];
1110        let sol = solver.solve_static(&mesh, &loads);
1111        assert!(sol.converged);
1112        assert_eq!(sol.displacements.len(), 4);
1113    }
1114
1115    // ────────────────────────────────────────
1116    // Solver — thermal
1117    // ────────────────────────────────────────
1118
1119    /// 1-D rod: T(0) = 100 K, T(L) = 200 K, linear profile expected.
1120    #[test]
1121    fn test_thermal_1d_linear_temperature() {
1122        let mut mesh = FemMesh::new();
1123        let n0 = mesh.add_node(0.0, 0.0);
1124        let n1 = mesh.add_node(1.0, 0.0);
1125        mesh.set_boundary_condition(n0, DofType::Temperature, 100.0);
1126        mesh.set_boundary_condition(n1, DofType::Temperature, 200.0);
1127        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1128
1129        let solver = FemSolver::with_params(1e-4, 0.01);
1130        let sol = solver.solve_thermal(&mesh, &[]);
1131        assert!(sol.converged);
1132        assert!((sol.temperatures[0] - 100.0).abs() < 1.0);
1133        assert!((sol.temperatures[1] - 200.0).abs() < 1.0);
1134    }
1135
1136    #[test]
1137    fn test_thermal_max_temperature() {
1138        let mut mesh = FemMesh::new();
1139        let n0 = mesh.add_node(0.0, 0.0);
1140        let n1 = mesh.add_node(1.0, 0.0);
1141        mesh.set_boundary_condition(n0, DofType::Temperature, 300.0);
1142        mesh.set_boundary_condition(n1, DofType::Temperature, 500.0);
1143        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1144
1145        let solver = FemSolver::with_params(1e-4, 0.01);
1146        let sol = solver.solve_thermal(&mesh, &[]);
1147        assert!(sol.converged);
1148        assert!((sol.max_temperature - 500.0).abs() < 10.0);
1149    }
1150
1151    #[test]
1152    fn test_thermal_heat_flux_direction() {
1153        let mut mesh = FemMesh::new();
1154        let n0 = mesh.add_node(0.0, 0.0);
1155        let n1 = mesh.add_node(1.0, 0.0);
1156        // Higher temperature at left → heat flows right (positive qx)
1157        mesh.set_boundary_condition(n0, DofType::Temperature, 400.0);
1158        mesh.set_boundary_condition(n1, DofType::Temperature, 300.0);
1159        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1160
1161        let solver = FemSolver::with_params(1e-4, 0.01);
1162        let sol = solver.solve_thermal(&mesh, &[]);
1163        assert!(sol.converged);
1164        // qx = -k * dT/dx; dT/dx < 0 → qx > 0
1165        assert!(
1166            sol.heat_flux[0].0 > 0.0,
1167            "Heat should flow from hot to cold (positive qx)"
1168        );
1169    }
1170
1171    #[test]
1172    fn test_thermal_three_node_rod() {
1173        let mut mesh = FemMesh::new();
1174        let n0 = mesh.add_node(0.0, 0.0);
1175        let n1 = mesh.add_node(0.5, 0.0);
1176        let n2 = mesh.add_node(1.0, 0.0);
1177        mesh.set_boundary_condition(n0, DofType::Temperature, 100.0);
1178        mesh.set_boundary_condition(n2, DofType::Temperature, 200.0);
1179        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1180        mesh.add_element(vec![n1, n2], steel(), ElementType::Bar1D);
1181
1182        let solver = FemSolver::with_params(1e-4, 0.01);
1183        let sol = solver.solve_thermal(&mesh, &[]);
1184        assert!(sol.converged);
1185        // Mid-node should be ≈ 150 K (linear)
1186        assert!((sol.temperatures[1] - 150.0).abs() < 5.0);
1187    }
1188
1189    #[test]
1190    fn test_thermal_heat_flux_applied() {
1191        let mut mesh = FemMesh::new();
1192        let n0 = mesh.add_node(0.0, 0.0);
1193        let n1 = mesh.add_node(1.0, 0.0);
1194        mesh.set_boundary_condition(n0, DofType::Temperature, 300.0);
1195        mesh.add_element(vec![n0, n1], steel(), ElementType::Bar1D);
1196
1197        let solver = FemSolver::with_params(1e-4, 0.01);
1198        let hf = vec![ElementHeatFlux {
1199            element_id: 0,
1200            q: 1000.0,
1201        }];
1202        let sol = solver.solve_thermal(&mesh, &hf);
1203        assert!(sol.converged);
1204        // Right node temperature must be >= boundary temperature
1205        assert!(sol.temperatures[1] >= 300.0 - 1.0);
1206    }
1207
1208    #[test]
1209    fn test_element_heat_flux_struct() {
1210        let hf = ElementHeatFlux {
1211            element_id: 3,
1212            q: 500.0,
1213        };
1214        assert_eq!(hf.element_id, 3);
1215        assert!((hf.q - 500.0).abs() < 1e-10);
1216    }
1217
1218    #[test]
1219    fn test_nodal_load_struct() {
1220        let load = NodalLoad {
1221            node_id: 2,
1222            fx: 100.0,
1223            fy: -50.0,
1224        };
1225        assert_eq!(load.node_id, 2);
1226        assert!((load.fx - 100.0).abs() < 1e-10);
1227        assert!((load.fy + 50.0).abs() < 1e-10);
1228    }
1229
1230    #[test]
1231    fn test_fem_material_default() {
1232        let mat = FemMaterial::default();
1233        assert!(mat.youngs_modulus > 0.0);
1234        assert!(mat.poissons_ratio > 0.0 && mat.poissons_ratio < 0.5);
1235        assert!(mat.thermal_conductivity > 0.0);
1236        assert!(mat.density > 0.0);
1237    }
1238
1239    #[test]
1240    fn test_gaussian_elimination_identity() {
1241        // Solve I x = b → x = b
1242        let mut a = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1243        let mut b = vec![3.0, 7.0];
1244        let x = gaussian_elimination(&mut a, &mut b).expect("Should converge");
1245        assert!((x[0] - 3.0).abs() < 1e-10);
1246        assert!((x[1] - 7.0).abs() < 1e-10);
1247    }
1248
1249    #[test]
1250    fn test_gaussian_elimination_2x2() {
1251        // 2x + y = 5
1252        // x  + 3y = 10
1253        // Solution: x = 1, y = 3
1254        let mut a = vec![vec![2.0, 1.0], vec![1.0, 3.0]];
1255        let mut b = vec![5.0, 10.0];
1256        let x = gaussian_elimination(&mut a, &mut b).expect("Should converge");
1257        assert!((x[0] - 1.0).abs() < 1e-10);
1258        assert!((x[1] - 3.0).abs() < 1e-10);
1259    }
1260
1261    #[test]
1262    fn test_beam1d_solves_like_bar() {
1263        let mut mesh = FemMesh::new();
1264        let n0 = mesh.add_node(0.0, 0.0);
1265        let n1 = mesh.add_node(1.0, 0.0);
1266        mesh.set_boundary_condition(n0, DofType::Displacement, 0.0);
1267        mesh.add_element(vec![n0, n1], steel(), ElementType::Beam1D);
1268        let solver = FemSolver::new();
1269        let loads = vec![NodalLoad {
1270            node_id: n1,
1271            fx: 5000.0,
1272            fy: 0.0,
1273        }];
1274        let sol = solver.solve_static(&mesh, &loads);
1275        assert!(sol.converged);
1276        assert!(sol.max_displacement > 0.0);
1277    }
1278}