Skip to main content

oxirs_physics/
thermal_analysis.rs

1//! Thermal finite-element analysis.
2//!
3//! Implements triangular element heat-conduction FEM with Dirichlet, Neumann,
4//! and Robin (convection) boundary conditions. Uses Gaussian elimination to
5//! solve K·T = F without any external linear-algebra dependencies.
6
7// ──────────────────────────────────────────────────────────────────────────────
8// Types
9// ──────────────────────────────────────────────────────────────────────────────
10
11/// A node in the thermal mesh.
12#[derive(Debug, Clone)]
13pub struct ThermalNode {
14    /// Node index.
15    pub id: usize,
16    /// X coordinate (meters).
17    pub x: f64,
18    /// Y coordinate (meters).
19    pub y: f64,
20    /// Initial or prescribed temperature (Kelvin or °C).
21    pub temperature: f64,
22    /// `true` when this node has a Dirichlet (prescribed temperature) BC.
23    pub is_boundary: bool,
24}
25
26impl ThermalNode {
27    /// Create an interior node.
28    pub fn interior(id: usize, x: f64, y: f64) -> Self {
29        Self {
30            id,
31            x,
32            y,
33            temperature: 0.0,
34            is_boundary: false,
35        }
36    }
37
38    /// Create a boundary node with a prescribed temperature.
39    pub fn boundary(id: usize, x: f64, y: f64, temperature: f64) -> Self {
40        Self {
41            id,
42            x,
43            y,
44            temperature,
45            is_boundary: true,
46        }
47    }
48}
49
50/// Triangular thermal element (CST – Constant Strain Triangle).
51#[derive(Debug, Clone)]
52pub struct ThermalElement {
53    /// Node indices (counter-clockwise ordering preferred).
54    pub nodes: [usize; 3],
55    /// Isotropic thermal conductivity (W/m·K).
56    pub conductivity: f64,
57    /// Element thickness (m). For 2-D plane problems.
58    pub thickness: f64,
59}
60
61/// Thermal boundary conditions.
62#[derive(Debug, Clone)]
63pub enum ThermalBc {
64    /// Prescribed temperature at a node (Dirichlet).
65    DirichletTemp { node_id: usize, temp: f64 },
66    /// Prescribed heat flux at a node (Neumann, positive = into domain).
67    HeatFlux {
68        node_id: usize,
69        /// Heat flux (W/m²).
70        flux: f64,
71    },
72    /// Convection BC at a node (Robin).
73    Convection {
74        node_id: usize,
75        /// Convection coefficient (W/m²·K).
76        h: f64,
77        /// Ambient temperature (same unit as nodal temperatures).
78        t_inf: f64,
79    },
80}
81
82/// The thermal mesh: nodes + elements.
83#[derive(Debug, Clone)]
84pub struct ThermalMesh {
85    pub nodes: Vec<ThermalNode>,
86    pub elements: Vec<ThermalElement>,
87}
88
89impl ThermalMesh {
90    /// Create a mesh from node and element lists.
91    pub fn new(nodes: Vec<ThermalNode>, elements: Vec<ThermalElement>) -> Self {
92        Self { nodes, elements }
93    }
94
95    /// Number of nodes.
96    pub fn n_nodes(&self) -> usize {
97        self.nodes.len()
98    }
99}
100
101/// Solver error.
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum SolverError {
104    /// The global stiffness matrix is singular.
105    SingularMatrix,
106    /// A mesh integrity issue was detected.
107    InvalidMesh(String),
108    /// An element referenced a node index that does not exist.
109    InvalidNodeIndex(usize),
110}
111
112impl std::fmt::Display for SolverError {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            SolverError::SingularMatrix => write!(f, "global conductivity matrix is singular"),
116            SolverError::InvalidMesh(m) => write!(f, "invalid mesh: {m}"),
117            SolverError::InvalidNodeIndex(i) => write!(f, "invalid node index: {i}"),
118        }
119    }
120}
121
122impl std::error::Error for SolverError {}
123
124/// Thermal analysis result.
125#[derive(Debug, Clone)]
126pub struct ThermalResult {
127    /// Solved nodal temperatures.
128    pub temperatures: Vec<f64>,
129    /// Maximum nodal temperature.
130    pub max_temp: f64,
131    /// Minimum nodal temperature.
132    pub min_temp: f64,
133    /// Heat flux (qx, qy) per element.
134    pub heat_fluxes: Vec<(f64, f64)>,
135}
136
137// ──────────────────────────────────────────────────────────────────────────────
138// ThermalSolver
139// ──────────────────────────────────────────────────────────────────────────────
140
141/// Thermal FEM solver.
142pub struct ThermalSolver;
143
144impl ThermalSolver {
145    /// Create a new solver instance.
146    pub fn new() -> Self {
147        Self
148    }
149
150    /// Assemble the global conductivity matrix K (n×n dense).
151    pub fn assemble_conductivity_matrix(
152        &self,
153        mesh: &ThermalMesh,
154    ) -> Result<Vec<Vec<f64>>, SolverError> {
155        let n = mesh.n_nodes();
156        let mut k = vec![vec![0.0; n]; n];
157
158        for elem in &mesh.elements {
159            let ke = self.element_conductivity_matrix(mesh, elem)?;
160            for (local_i, &global_i) in elem.nodes.iter().enumerate() {
161                for (local_j, &global_j) in elem.nodes.iter().enumerate() {
162                    k[global_i][global_j] += ke[local_i][local_j];
163                }
164            }
165        }
166        Ok(k)
167    }
168
169    /// Assemble the global load vector F (length n).
170    pub fn assemble_load_vector(&self, mesh: &ThermalMesh, bcs: &[ThermalBc]) -> Vec<f64> {
171        let n = mesh.n_nodes();
172        let mut f = vec![0.0; n];
173
174        for bc in bcs {
175            match bc {
176                ThermalBc::HeatFlux { node_id, flux } => {
177                    if *node_id < n {
178                        f[*node_id] += flux;
179                    }
180                }
181                ThermalBc::Convection { node_id, h, t_inf } => {
182                    if *node_id < n {
183                        f[*node_id] += h * t_inf;
184                    }
185                }
186                ThermalBc::DirichletTemp { .. } => {
187                    // Applied via penalty method in solve()
188                }
189            }
190        }
191        f
192    }
193
194    /// Solve K·T = F for nodal temperatures.
195    ///
196    /// Dirichlet BCs are enforced with the large-number (penalty) method.
197    pub fn solve(&self, mesh: &ThermalMesh, bcs: &[ThermalBc]) -> Result<Vec<f64>, SolverError> {
198        if mesh.nodes.is_empty() {
199            return Err(SolverError::InvalidMesh("empty mesh".to_string()));
200        }
201
202        let n = mesh.n_nodes();
203        let mut k = self.assemble_conductivity_matrix(mesh)?;
204        let mut f = self.assemble_load_vector(mesh, bcs);
205
206        // Apply Dirichlet BCs via penalty method (large number α)
207        let alpha = self.penalty_factor(&k);
208
209        // Also apply prescribed temperatures from node.is_boundary
210        for node in &mesh.nodes {
211            if node.is_boundary {
212                k[node.id][node.id] += alpha;
213                f[node.id] += alpha * node.temperature;
214            }
215        }
216
217        // Dirichlet BCs from BC list override node.temperature
218        for bc in bcs {
219            if let ThermalBc::DirichletTemp { node_id, temp } = bc {
220                if *node_id < n {
221                    k[*node_id][*node_id] += alpha;
222                    f[*node_id] += alpha * temp;
223                }
224            }
225        }
226
227        // Convection BCs modify the diagonal of K
228        for bc in bcs {
229            if let ThermalBc::Convection { node_id, h, .. } = bc {
230                if *node_id < n {
231                    k[*node_id][*node_id] += h;
232                }
233            }
234        }
235
236        gaussian_elimination(&mut k, &mut f)
237    }
238
239    /// Compute the element heat flux (qx, qy) from solved temperatures.
240    ///
241    /// For a CST element: q = −k · B · T_e
242    pub fn heat_flux_at_element(
243        &self,
244        mesh: &ThermalMesh,
245        temps: &[f64],
246        elem_id: usize,
247    ) -> (f64, f64) {
248        let elem = &mesh.elements[elem_id];
249        let [i, j, m] = elem.nodes;
250
251        let (xi, yi) = (mesh.nodes[i].x, mesh.nodes[i].y);
252        let (xj, yj) = (mesh.nodes[j].x, mesh.nodes[j].y);
253        let (xm, ym) = (mesh.nodes[m].x, mesh.nodes[m].y);
254
255        // B matrix rows (1 / 2A)
256        let two_area = (xj - xi) * (ym - yi) - (xm - xi) * (yj - yi);
257        if two_area.abs() < 1e-15 {
258            return (0.0, 0.0);
259        }
260
261        // dN/dx and dN/dy for the three shape functions
262        let b_x = [(yj - ym), (ym - yi), (yi - yj)];
263        let b_y = [(xm - xj), (xi - xm), (xj - xi)];
264
265        let ti = temps[i];
266        let tj = temps[j];
267        let tm = temps[m];
268        let t_e = [ti, tj, tm];
269
270        let grad_t_x: f64 = (b_x[0] * t_e[0] + b_x[1] * t_e[1] + b_x[2] * t_e[2]) / two_area;
271        let grad_t_y: f64 = (b_y[0] * t_e[0] + b_y[1] * t_e[1] + b_y[2] * t_e[2]) / two_area;
272
273        // q = −k ∇T
274        let qx = -elem.conductivity * grad_t_x;
275        let qy = -elem.conductivity * grad_t_y;
276        (qx, qy)
277    }
278
279    /// Solve and return a full `ThermalResult`.
280    pub fn analyze(
281        &self,
282        mesh: &ThermalMesh,
283        bcs: &[ThermalBc],
284    ) -> Result<ThermalResult, SolverError> {
285        let temperatures = self.solve(mesh, bcs)?;
286        let max_temp = temperatures
287            .iter()
288            .cloned()
289            .fold(f64::NEG_INFINITY, f64::max);
290        let min_temp = temperatures.iter().cloned().fold(f64::INFINITY, f64::min);
291        let heat_fluxes: Vec<(f64, f64)> = (0..mesh.elements.len())
292            .map(|i| self.heat_flux_at_element(mesh, &temperatures, i))
293            .collect();
294        Ok(ThermalResult {
295            temperatures,
296            max_temp,
297            min_temp,
298            heat_fluxes,
299        })
300    }
301
302    // ── Private helpers ───────────────────────────────────────────────────────
303
304    /// 3×3 element conductivity matrix for a CST element.
305    fn element_conductivity_matrix(
306        &self,
307        mesh: &ThermalMesh,
308        elem: &ThermalElement,
309    ) -> Result<[[f64; 3]; 3], SolverError> {
310        let [i, j, m] = elem.nodes;
311        let n = mesh.n_nodes();
312        for &idx in &[i, j, m] {
313            if idx >= n {
314                return Err(SolverError::InvalidNodeIndex(idx));
315            }
316        }
317
318        let (xi, yi) = (mesh.nodes[i].x, mesh.nodes[i].y);
319        let (xj, yj) = (mesh.nodes[j].x, mesh.nodes[j].y);
320        let (xm, ym) = (mesh.nodes[m].x, mesh.nodes[m].y);
321
322        let two_area = (xj - xi) * (ym - yi) - (xm - xi) * (yj - yi);
323        if two_area.abs() < 1e-15 {
324            // Degenerate (zero-area) element: contribute nothing
325            return Ok([[0.0; 3]; 3]);
326        }
327        let area = two_area.abs() / 2.0;
328
329        // Shape-function gradient components (constant over CST element)
330        let b_x = [(yj - ym), (ym - yi), (yi - yj)];
331        let b_y = [(xm - xj), (xi - xm), (xj - xi)];
332
333        let k_factor = elem.conductivity * elem.thickness * area / (two_area * two_area);
334
335        let mut ke = [[0.0; 3]; 3];
336        for r in 0..3 {
337            for c in 0..3 {
338                ke[r][c] = k_factor * (b_x[r] * b_x[c] + b_y[r] * b_y[c]);
339            }
340        }
341        Ok(ke)
342    }
343
344    /// Choose a penalty factor α ≈ 10⁶ × max(|K_{ii}|).
345    fn penalty_factor(&self, k: &[Vec<f64>]) -> f64 {
346        let max_diag = k
347            .iter()
348            .enumerate()
349            .map(|(i, row)| row[i].abs())
350            .fold(0.0_f64, f64::max);
351        if max_diag < 1e-15 {
352            1e6_f64
353        } else {
354            max_diag * 1e6
355        }
356    }
357}
358
359impl Default for ThermalSolver {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365// ──────────────────────────────────────────────────────────────────────────────
366// Gaussian elimination
367// ──────────────────────────────────────────────────────────────────────────────
368
369/// Solve A·x = b via partial-pivot Gaussian elimination.
370/// Modifies `a` and `b` in place; returns the solution vector.
371fn gaussian_elimination(a: &mut [Vec<f64>], b: &mut [f64]) -> Result<Vec<f64>, SolverError> {
372    let n = b.len();
373
374    for col in 0..n {
375        // Partial pivot
376        let (pivot_row, _) = (col..n)
377            .map(|r| (r, a[r][col].abs()))
378            .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap_or(std::cmp::Ordering::Equal))
379            .unwrap_or((col, 0.0));
380
381        if a[pivot_row][col].abs() < 1e-15 {
382            return Err(SolverError::SingularMatrix);
383        }
384
385        a.swap(col, pivot_row);
386        b.swap(col, pivot_row);
387
388        let pivot = a[col][col];
389        for row in col + 1..n {
390            let factor = a[row][col] / pivot;
391            b[row] -= factor * b[col];
392            // row > col always here; split to satisfy borrow checker
393            let (upper, lower) = a.split_at_mut(row);
394            for (rv, cv) in lower[0][col..n].iter_mut().zip(upper[col][col..n].iter()) {
395                *rv -= factor * cv;
396            }
397        }
398    }
399
400    // Back-substitution
401    let mut x = vec![0.0; n];
402    for i in (0..n).rev() {
403        let mut s = b[i];
404        for j in i + 1..n {
405            s -= a[i][j] * x[j];
406        }
407        x[i] = s / a[i][i];
408    }
409    Ok(x)
410}
411
412// ──────────────────────────────────────────────────────────────────────────────
413// Test meshes helpers
414// ──────────────────────────────────────────────────────────────────────────────
415
416/// Build a simple two-triangle unit-square mesh.
417///
418/// Nodes: (0,0), (1,0), (1,1), (0,1)
419/// Elements: [(0,1,2), (0,2,3)]
420#[cfg(test)]
421fn square_mesh(k: f64) -> ThermalMesh {
422    let nodes = vec![
423        ThermalNode::interior(0, 0.0, 0.0),
424        ThermalNode::interior(1, 1.0, 0.0),
425        ThermalNode::interior(2, 1.0, 1.0),
426        ThermalNode::interior(3, 0.0, 1.0),
427    ];
428    let elements = vec![
429        ThermalElement {
430            nodes: [0, 1, 2],
431            conductivity: k,
432            thickness: 1.0,
433        },
434        ThermalElement {
435            nodes: [0, 2, 3],
436            conductivity: k,
437            thickness: 1.0,
438        },
439    ];
440    ThermalMesh::new(nodes, elements)
441}
442
443// ──────────────────────────────────────────────────────────────────────────────
444// Tests
445// ──────────────────────────────────────────────────────────────────────────────
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    // ── ThermalNode ───────────────────────────────────────────────────────────
452
453    #[test]
454    fn test_node_interior() {
455        let n = ThermalNode::interior(0, 1.0, 2.0);
456        assert_eq!(n.id, 0);
457        assert!(!n.is_boundary);
458        assert_eq!(n.temperature, 0.0);
459    }
460
461    #[test]
462    fn test_node_boundary() {
463        let n = ThermalNode::boundary(1, 0.0, 0.0, 300.0);
464        assert!(n.is_boundary);
465        assert_eq!(n.temperature, 300.0);
466    }
467
468    // ── ThermalMesh ───────────────────────────────────────────────────────────
469
470    #[test]
471    fn test_mesh_n_nodes() {
472        let mesh = square_mesh(10.0);
473        assert_eq!(mesh.n_nodes(), 4);
474    }
475
476    #[test]
477    fn test_mesh_n_elements() {
478        let mesh = square_mesh(10.0);
479        assert_eq!(mesh.elements.len(), 2);
480    }
481
482    // ── SolverError ───────────────────────────────────────────────────────────
483
484    #[test]
485    fn test_error_display_singular() {
486        let e = SolverError::SingularMatrix;
487        assert!(format!("{e}").contains("singular"));
488    }
489
490    #[test]
491    fn test_error_display_invalid_mesh() {
492        let e = SolverError::InvalidMesh("empty".to_string());
493        assert!(format!("{e}").contains("empty"));
494    }
495
496    #[test]
497    fn test_error_is_std_error() {
498        let e: Box<dyn std::error::Error> = Box::new(SolverError::SingularMatrix);
499        assert!(!e.to_string().is_empty());
500    }
501
502    // ── assemble_conductivity_matrix ──────────────────────────────────────────
503
504    #[test]
505    fn test_conductivity_matrix_size() {
506        let mesh = square_mesh(10.0);
507        let solver = ThermalSolver::new();
508        let k = solver
509            .assemble_conductivity_matrix(&mesh)
510            .expect("should succeed");
511        assert_eq!(k.len(), 4);
512        assert_eq!(k[0].len(), 4);
513    }
514
515    #[test]
516    fn test_conductivity_matrix_symmetric() {
517        let mesh = square_mesh(5.0);
518        let solver = ThermalSolver::new();
519        let k = solver
520            .assemble_conductivity_matrix(&mesh)
521            .expect("should succeed");
522        for (i, row_i) in k.iter().enumerate() {
523            for (j, &kij) in row_i.iter().enumerate() {
524                let kji = k[j][i];
525                assert!(
526                    (kij - kji).abs() < 1e-12,
527                    "K[{i}][{j}] = {kij} ≠ K[{j}][{i}] = {kji}",
528                );
529            }
530        }
531    }
532
533    #[test]
534    fn test_conductivity_matrix_positive_diagonal() {
535        let mesh = square_mesh(1.0);
536        let solver = ThermalSolver::new();
537        let k = solver
538            .assemble_conductivity_matrix(&mesh)
539            .expect("should succeed");
540        for (i, row_i) in k.iter().enumerate() {
541            assert!(row_i[i] >= 0.0, "K[{i}][{i}] should be non-negative");
542        }
543    }
544
545    #[test]
546    fn test_conductivity_matrix_row_sum_near_zero() {
547        // For a pure conductivity matrix (no BCs) the row sums should be 0.
548        let mesh = square_mesh(1.0);
549        let solver = ThermalSolver::new();
550        let k = solver
551            .assemble_conductivity_matrix(&mesh)
552            .expect("should succeed");
553        for (i, row) in k.iter().enumerate() {
554            let sum: f64 = row.iter().sum();
555            assert!(sum.abs() < 1e-10, "Row {i} sum = {sum} should be ~0");
556        }
557    }
558
559    // ── assemble_load_vector ──────────────────────────────────────────────────
560
561    #[test]
562    fn test_load_vector_zeros_without_bcs() {
563        let mesh = square_mesh(1.0);
564        let solver = ThermalSolver::new();
565        let f = solver.assemble_load_vector(&mesh, &[]);
566        assert!(f.iter().all(|&v| v == 0.0));
567    }
568
569    #[test]
570    fn test_load_vector_heat_flux() {
571        let mesh = square_mesh(1.0);
572        let solver = ThermalSolver::new();
573        let bcs = vec![ThermalBc::HeatFlux {
574            node_id: 0,
575            flux: 100.0,
576        }];
577        let f = solver.assemble_load_vector(&mesh, &bcs);
578        assert!((f[0] - 100.0).abs() < 1e-10);
579    }
580
581    #[test]
582    fn test_load_vector_convection() {
583        let mesh = square_mesh(1.0);
584        let solver = ThermalSolver::new();
585        let bcs = vec![ThermalBc::Convection {
586            node_id: 1,
587            h: 20.0,
588            t_inf: 300.0,
589        }];
590        let f = solver.assemble_load_vector(&mesh, &bcs);
591        assert!((f[1] - 20.0 * 300.0).abs() < 1e-10);
592    }
593
594    #[test]
595    fn test_load_vector_dirichlet_not_in_f() {
596        let mesh = square_mesh(1.0);
597        let solver = ThermalSolver::new();
598        let bcs = vec![ThermalBc::DirichletTemp {
599            node_id: 0,
600            temp: 500.0,
601        }];
602        let f = solver.assemble_load_vector(&mesh, &bcs);
603        // Dirichlet is handled in solve(), not assemble_load_vector
604        assert_eq!(f[0], 0.0);
605    }
606
607    // ── solve ─────────────────────────────────────────────────────────────────
608
609    #[test]
610    fn test_solve_uniform_temperature() {
611        // If all boundary nodes are at 100°C the interior should also be 100°C.
612        let nodes = vec![
613            ThermalNode::boundary(0, 0.0, 0.0, 100.0),
614            ThermalNode::boundary(1, 1.0, 0.0, 100.0),
615            ThermalNode::interior(2, 0.5, 0.5),
616            ThermalNode::boundary(3, 0.0, 1.0, 100.0),
617        ];
618        let elements = vec![
619            ThermalElement {
620                nodes: [0, 1, 2],
621                conductivity: 1.0,
622                thickness: 1.0,
623            },
624            ThermalElement {
625                nodes: [0, 2, 3],
626                conductivity: 1.0,
627                thickness: 1.0,
628            },
629        ];
630        let mesh = ThermalMesh::new(nodes, vec![]);
631        let mesh = ThermalMesh::new(mesh.nodes, elements);
632        let solver = ThermalSolver::new();
633        let bcs = vec![
634            ThermalBc::DirichletTemp {
635                node_id: 0,
636                temp: 100.0,
637            },
638            ThermalBc::DirichletTemp {
639                node_id: 1,
640                temp: 100.0,
641            },
642            ThermalBc::DirichletTemp {
643                node_id: 3,
644                temp: 100.0,
645            },
646        ];
647        let temps = solver.solve(&mesh, &bcs).expect("should succeed");
648        // Interior node 2 should be ≈ 100°C
649        assert!(
650            (temps[2] - 100.0).abs() < 1.0,
651            "Interior temp = {}",
652            temps[2]
653        );
654    }
655
656    #[test]
657    fn test_solve_returns_n_temps() {
658        let mesh = square_mesh(1.0);
659        let solver = ThermalSolver::new();
660        let bcs = vec![
661            ThermalBc::DirichletTemp {
662                node_id: 0,
663                temp: 0.0,
664            },
665            ThermalBc::DirichletTemp {
666                node_id: 1,
667                temp: 0.0,
668            },
669            ThermalBc::DirichletTemp {
670                node_id: 2,
671                temp: 100.0,
672            },
673            ThermalBc::DirichletTemp {
674                node_id: 3,
675                temp: 100.0,
676            },
677        ];
678        let temps = solver.solve(&mesh, &bcs).expect("should succeed");
679        assert_eq!(temps.len(), 4);
680    }
681
682    #[test]
683    fn test_solve_dirichlet_nodes_match_prescribed() {
684        let mesh = square_mesh(10.0);
685        let solver = ThermalSolver::new();
686        let bcs = vec![
687            ThermalBc::DirichletTemp {
688                node_id: 0,
689                temp: 0.0,
690            },
691            ThermalBc::DirichletTemp {
692                node_id: 1,
693                temp: 0.0,
694            },
695            ThermalBc::DirichletTemp {
696                node_id: 2,
697                temp: 200.0,
698            },
699            ThermalBc::DirichletTemp {
700                node_id: 3,
701                temp: 200.0,
702            },
703        ];
704        let temps = solver.solve(&mesh, &bcs).expect("should succeed");
705        // Penalty method → very close to prescribed values
706        assert!(temps[0].abs() < 1.0, "node 0 ≈ 0: {}", temps[0]);
707        assert!(temps[1].abs() < 1.0, "node 1 ≈ 0: {}", temps[1]);
708        assert!((temps[2] - 200.0).abs() < 1.0, "node 2 ≈ 200: {}", temps[2]);
709        assert!((temps[3] - 200.0).abs() < 1.0, "node 3 ≈ 200: {}", temps[3]);
710    }
711
712    #[test]
713    fn test_solve_empty_mesh_error() {
714        let mesh = ThermalMesh::new(vec![], vec![]);
715        let solver = ThermalSolver::new();
716        assert!(matches!(
717            solver.solve(&mesh, &[]),
718            Err(SolverError::InvalidMesh(_))
719        ));
720    }
721
722    #[test]
723    fn test_solve_gradient_monotone() {
724        // Linear gradient: nodes 0,1 at 0°C and nodes 2,3 at 100°C.
725        let mesh = square_mesh(1.0);
726        let solver = ThermalSolver::new();
727        let bcs = vec![
728            ThermalBc::DirichletTemp {
729                node_id: 0,
730                temp: 0.0,
731            },
732            ThermalBc::DirichletTemp {
733                node_id: 1,
734                temp: 0.0,
735            },
736            ThermalBc::DirichletTemp {
737                node_id: 2,
738                temp: 100.0,
739            },
740            ThermalBc::DirichletTemp {
741                node_id: 3,
742                temp: 100.0,
743            },
744        ];
745        let temps = solver.solve(&mesh, &bcs).expect("should succeed");
746        // Nodes 2,3 are hotter than 0,1
747        assert!(temps[2] > temps[0]);
748        assert!(temps[3] > temps[1]);
749    }
750
751    // ── heat_flux_at_element ──────────────────────────────────────────────────
752
753    #[test]
754    fn test_heat_flux_uniform_temperature_is_zero() {
755        let mesh = square_mesh(1.0);
756        let solver = ThermalSolver::new();
757        let temps = vec![100.0; 4];
758        let (qx, qy) = solver.heat_flux_at_element(&mesh, &temps, 0);
759        assert!(qx.abs() < 1e-10, "qx should be ~0, got {qx}");
760        assert!(qy.abs() < 1e-10, "qy should be ~0, got {qy}");
761    }
762
763    #[test]
764    fn test_heat_flux_direction() {
765        // Temperature increases in the y direction → expect nonzero qy.
766        let mesh = square_mesh(1.0);
767        let solver = ThermalSolver::new();
768        let bcs = vec![
769            ThermalBc::DirichletTemp {
770                node_id: 0,
771                temp: 0.0,
772            },
773            ThermalBc::DirichletTemp {
774                node_id: 1,
775                temp: 0.0,
776            },
777            ThermalBc::DirichletTemp {
778                node_id: 2,
779                temp: 100.0,
780            },
781            ThermalBc::DirichletTemp {
782                node_id: 3,
783                temp: 100.0,
784            },
785        ];
786        let temps = solver.solve(&mesh, &bcs).expect("should succeed");
787        let (_, qy) = solver.heat_flux_at_element(&mesh, &temps, 0);
788        // qy should be negative (heat flows from high T to low T, i.e. downward)
789        assert!(qy < 0.0, "qy = {qy} should be < 0");
790    }
791
792    #[test]
793    fn test_heat_flux_all_elements() {
794        let mesh = square_mesh(1.0);
795        let solver = ThermalSolver::new();
796        let temps = vec![0.0, 0.0, 100.0, 100.0];
797        for elem_id in 0..mesh.elements.len() {
798            let (qx, qy) = solver.heat_flux_at_element(&mesh, &temps, elem_id);
799            assert!(qx.is_finite());
800            assert!(qy.is_finite());
801        }
802    }
803
804    // ── analyze ───────────────────────────────────────────────────────────────
805
806    #[test]
807    fn test_analyze_result_fields() {
808        let mesh = square_mesh(1.0);
809        let solver = ThermalSolver::new();
810        let bcs = vec![
811            ThermalBc::DirichletTemp {
812                node_id: 0,
813                temp: 20.0,
814            },
815            ThermalBc::DirichletTemp {
816                node_id: 1,
817                temp: 20.0,
818            },
819            ThermalBc::DirichletTemp {
820                node_id: 2,
821                temp: 80.0,
822            },
823            ThermalBc::DirichletTemp {
824                node_id: 3,
825                temp: 80.0,
826            },
827        ];
828        let result = solver.analyze(&mesh, &bcs).expect("should succeed");
829        assert_eq!(result.temperatures.len(), 4);
830        assert!(result.max_temp >= result.min_temp);
831        assert_eq!(result.heat_fluxes.len(), 2);
832    }
833
834    #[test]
835    fn test_analyze_max_min_temp() {
836        let mesh = square_mesh(1.0);
837        let solver = ThermalSolver::new();
838        let bcs = vec![
839            ThermalBc::DirichletTemp {
840                node_id: 0,
841                temp: 0.0,
842            },
843            ThermalBc::DirichletTemp {
844                node_id: 1,
845                temp: 0.0,
846            },
847            ThermalBc::DirichletTemp {
848                node_id: 2,
849                temp: 100.0,
850            },
851            ThermalBc::DirichletTemp {
852                node_id: 3,
853                temp: 100.0,
854            },
855        ];
856        let result = solver.analyze(&mesh, &bcs).expect("should succeed");
857        assert!(result.max_temp >= 90.0);
858        assert!(result.min_temp <= 10.0);
859    }
860
861    // ── Gaussian elimination ───────────────────────────────────────────────────
862
863    #[test]
864    fn test_gaussian_elimination_2x2() {
865        // 2x - y = 3 ; x + y = 3 → x=2, y=1
866        let mut a = vec![vec![2.0_f64, -1.0], vec![1.0, 1.0]];
867        let mut b = vec![3.0_f64, 3.0];
868        let x = gaussian_elimination(&mut a, &mut b).expect("should succeed");
869        assert!((x[0] - 2.0).abs() < 1e-10);
870        assert!((x[1] - 1.0).abs() < 1e-10);
871    }
872
873    #[test]
874    fn test_gaussian_elimination_3x3() {
875        // x + y + z = 6 ; x - y + z = 2 ; 2x + y - z = 1 → x=1, y=2, z=3
876        let mut a = vec![
877            vec![1.0, 1.0, 1.0],
878            vec![1.0, -1.0, 1.0],
879            vec![2.0, 1.0, -1.0],
880        ];
881        let mut b = vec![6.0, 2.0, 1.0];
882        let x = gaussian_elimination(&mut a, &mut b).expect("should succeed");
883        assert!((x[0] - 1.0).abs() < 1e-9);
884        assert!((x[1] - 2.0).abs() < 1e-9);
885        assert!((x[2] - 3.0).abs() < 1e-9);
886    }
887
888    #[test]
889    fn test_solver_default() {
890        let _ = ThermalSolver::new();
891    }
892
893    #[test]
894    fn test_boundary_node_via_is_boundary_flag() {
895        let nodes = vec![
896            ThermalNode::boundary(0, 0.0, 0.0, 50.0),
897            ThermalNode::interior(1, 1.0, 0.0),
898            ThermalNode::boundary(2, 0.5, 1.0, 50.0),
899        ];
900        let elements = vec![ThermalElement {
901            nodes: [0, 1, 2],
902            conductivity: 1.0,
903            thickness: 1.0,
904        }];
905        let mesh = ThermalMesh::new(nodes, elements);
906        let solver = ThermalSolver::new();
907        let temps = solver.solve(&mesh, &[]).expect("should succeed");
908        // Node 1 should converge close to 50°C
909        assert!(
910            (temps[1] - 50.0).abs() < 5.0,
911            "Interior node temp = {}",
912            temps[1]
913        );
914    }
915}