Skip to main content

oxiphysics_python/world_api/
fem.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! FEM (Finite Element Method) Assembly and Solver.
5
6#![allow(missing_docs)]
7
8// ===========================================================================
9// FEM (Finite Element Method) Assembly and Solver
10// ===========================================================================
11
12/// A single 2-node bar (truss) element for FEM assembly.
13#[derive(Debug, Clone)]
14#[allow(dead_code)]
15pub struct FemBarElement {
16    /// Global indices of the two end nodes.
17    pub nodes: [usize; 2],
18    /// Young's modulus times cross-sectional area (EA).
19    pub ea: f64,
20    /// Undeformed length of the element.
21    pub length: f64,
22}
23
24/// FEM assembly for a truss structure.
25///
26/// Assembles a global stiffness matrix from bar elements, applies boundary
27/// conditions, and solves with a direct (dense) solver suitable for
28/// demonstration/testing with small meshes.
29#[derive(Debug, Clone)]
30#[allow(dead_code)]
31pub struct PyFemAssembly {
32    /// Number of degrees of freedom (nodes × 3 for 3-D).
33    pub n_dofs: usize,
34    /// Bar elements.
35    elements: Vec<FemBarElement>,
36    /// External force vector.
37    forces: Vec<f64>,
38    /// Fixed DOF indices (Dirichlet BCs: displacement = 0).
39    fixed_dofs: Vec<usize>,
40    /// Global stiffness matrix (dense, row-major, n_dofs × n_dofs).
41    k_global: Vec<f64>,
42    /// Displacement solution vector.
43    displacements: Vec<f64>,
44    /// Whether the stiffness matrix has been assembled.
45    assembled: bool,
46}
47
48impl PyFemAssembly {
49    /// Create a new FEM assembly with `n_nodes` 3-D nodes.
50    pub fn new(n_nodes: usize) -> Self {
51        let n_dofs = n_nodes * 3;
52        Self {
53            n_dofs,
54            elements: Vec::new(),
55            forces: vec![0.0; n_dofs],
56            fixed_dofs: Vec::new(),
57            k_global: vec![0.0; n_dofs * n_dofs],
58            displacements: vec![0.0; n_dofs],
59            assembled: false,
60        }
61    }
62
63    /// Add a bar element between two nodes with given EA and undeformed length.
64    pub fn add_bar_element(&mut self, node_a: usize, node_b: usize, ea: f64, length: f64) {
65        self.elements.push(FemBarElement {
66            nodes: [node_a, node_b],
67            ea,
68            length,
69        });
70        self.assembled = false;
71    }
72
73    /// Apply an external force to a DOF (node_index * 3 + direction).
74    pub fn apply_force(&mut self, dof: usize, force: f64) {
75        if dof < self.n_dofs {
76            self.forces[dof] += force;
77        }
78    }
79
80    /// Fix a DOF (zero displacement Dirichlet BC).
81    pub fn fix_dof(&mut self, dof: usize) {
82        if dof < self.n_dofs && !self.fixed_dofs.contains(&dof) {
83            self.fixed_dofs.push(dof);
84        }
85    }
86
87    /// Assemble the global stiffness matrix from all bar elements.
88    ///
89    /// For each bar element the 2×2 (1-D axial) stiffness matrix contribution
90    /// `k_e = EA/L * [[1, -1\], [-1, 1]]` is added to the corresponding DOF
91    /// rows/columns along the element's axis direction.
92    pub fn assemble(&mut self) {
93        // Zero the matrix
94        for v in &mut self.k_global {
95            *v = 0.0;
96        }
97
98        let n = self.n_dofs;
99        for elem in &self.elements {
100            let ea_l = if elem.length > 1e-15 {
101                elem.ea / elem.length
102            } else {
103                0.0
104            };
105            // Axial DOFs: use DOF 0 of each node (x-direction simplified)
106            let dof_a = elem.nodes[0] * 3;
107            let dof_b = elem.nodes[1] * 3;
108            if dof_a < n && dof_b < n {
109                self.k_global[dof_a * n + dof_a] += ea_l;
110                self.k_global[dof_a * n + dof_b] -= ea_l;
111                self.k_global[dof_b * n + dof_a] -= ea_l;
112                self.k_global[dof_b * n + dof_b] += ea_l;
113            }
114        }
115
116        // Zero-out empty diagonal entries (unconstrained DOFs with no element
117        // contribution) and apply Dirichlet BCs via the large-number method.
118        let big = 1.0e20;
119        for i in 0..n {
120            if self.k_global[i * n + i].abs() < 1e-30 {
121                // This DOF has no stiffness contribution — pin it to zero.
122                self.k_global[i * n + i] = big;
123                self.forces[i] = 0.0;
124            }
125        }
126        for &dof in &self.fixed_dofs {
127            if dof < n {
128                for col in 0..n {
129                    self.k_global[dof * n + col] = 0.0;
130                }
131                self.k_global[dof * n + dof] = big;
132                self.forces[dof] = 0.0;
133            }
134        }
135
136        self.assembled = true;
137    }
138
139    /// Solve the linear system K·u = f using Gaussian elimination.
140    ///
141    /// The system is solved in-place. Returns `true` on success,
142    /// `false` if the matrix is singular or not yet assembled.
143    pub fn solve(&mut self) -> bool {
144        if !self.assembled {
145            self.assemble();
146        }
147        let n = self.n_dofs;
148        if n == 0 {
149            return false;
150        }
151
152        // Build an augmented matrix [K | f]
153        let mut aug: Vec<f64> = vec![0.0; n * (n + 1)];
154        for i in 0..n {
155            for j in 0..n {
156                aug[i * (n + 1) + j] = self.k_global[i * n + j];
157            }
158            aug[i * (n + 1) + n] = self.forces[i];
159        }
160
161        // Forward elimination
162        for col in 0..n {
163            // Pivot
164            let mut max_row = col;
165            let mut max_val = aug[col * (n + 1) + col].abs();
166            for row in (col + 1)..n {
167                let v = aug[row * (n + 1) + col].abs();
168                if v > max_val {
169                    max_val = v;
170                    max_row = row;
171                }
172            }
173            if max_val < 1e-15 {
174                return false; // singular
175            }
176            if max_row != col {
177                for j in 0..=(n) {
178                    aug.swap(col * (n + 1) + j, max_row * (n + 1) + j);
179                }
180            }
181            let pivot = aug[col * (n + 1) + col];
182            for row in (col + 1)..n {
183                let factor = aug[row * (n + 1) + col] / pivot;
184                for j in col..=(n) {
185                    let sub = factor * aug[col * (n + 1) + j];
186                    aug[row * (n + 1) + j] -= sub;
187                }
188            }
189        }
190
191        // Back substitution
192        let mut u = vec![0.0f64; n];
193        for i in (0..n).rev() {
194            let mut sum = aug[i * (n + 1) + n];
195            for j in (i + 1)..n {
196                sum -= aug[i * (n + 1) + j] * u[j];
197            }
198            u[i] = sum / aug[i * (n + 1) + i];
199        }
200
201        self.displacements = u;
202        true
203    }
204
205    /// Get the displacement at DOF `dof`, or `0.0` if out of bounds.
206    pub fn displacement(&self, dof: usize) -> f64 {
207        self.displacements.get(dof).copied().unwrap_or(0.0)
208    }
209
210    /// Get all displacements as a slice.
211    pub fn displacements(&self) -> &[f64] {
212        &self.displacements
213    }
214
215    /// Compute the axial force in element `elem_idx`.
216    ///
217    /// Returns `None` if the index is out of bounds or the system is unsolved.
218    pub fn element_force(&self, elem_idx: usize) -> Option<f64> {
219        let elem = self.elements.get(elem_idx)?;
220        if elem.length < 1e-15 {
221            return Some(0.0);
222        }
223        let dof_a = elem.nodes[0] * 3;
224        let dof_b = elem.nodes[1] * 3;
225        let ua = self.displacements.get(dof_a).copied().unwrap_or(0.0);
226        let ub = self.displacements.get(dof_b).copied().unwrap_or(0.0);
227        Some(elem.ea / elem.length * (ub - ua))
228    }
229}