use nalgebra::DMatrix;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KPoint {
pub frac: [f64; 3],
pub label: Option<String>,
pub path_distance: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BandStructure {
pub kpoints: Vec<KPoint>,
pub bands: Vec<Vec<f64>>,
pub n_bands: usize,
pub n_kpoints: usize,
pub fermi_energy: f64,
pub direct_gap: Option<f64>,
pub indirect_gap: Option<f64>,
pub high_symmetry_points: Vec<(String, usize)>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BandStructureConfig {
pub n_kpoints_per_segment: usize,
pub path: Vec<([f64; 3], String)>,
}
impl Default for BandStructureConfig {
fn default() -> Self {
Self {
n_kpoints_per_segment: 50,
path: vec![
([0.0, 0.0, 0.0], "Γ".to_string()),
([0.5, 0.0, 0.0], "X".to_string()),
([0.5, 0.5, 0.0], "M".to_string()),
([0.0, 0.0, 0.0], "Γ".to_string()),
],
}
}
}
pub fn compute_band_structure(
elements: &[u8],
positions: &[[f64; 3]],
lattice: &[[f64; 3]; 3],
config: &BandStructureConfig,
n_electrons: usize,
) -> Result<BandStructure, String> {
if elements.is_empty() {
return Err("No atoms provided".to_string());
}
let kpoints = generate_kpath(&config.path, config.n_kpoints_per_segment);
let n_kpts = kpoints.len();
let eht_result = crate::eht::solve_eht(elements, positions, None)?;
let n_basis = eht_result.energies.len();
let mut bands = Vec::with_capacity(n_kpts);
let mut high_sym = Vec::new();
for (k_idx, kpt) in kpoints.iter().enumerate() {
let (h_k, s_k) = build_bloch_matrices(elements, positions, lattice, &kpt.frac, n_basis);
let eigenvalues = solve_generalized_eigen(&h_k, &s_k)?;
bands.push(eigenvalues);
if let Some(ref label) = kpt.label {
high_sym.push((label.clone(), k_idx));
}
}
let n_occupied = n_electrons / 2;
let fermi_energy = estimate_fermi_energy(&bands, n_occupied);
let (direct_gap, indirect_gap) = compute_band_gaps(&bands, n_occupied);
Ok(BandStructure {
kpoints,
bands,
n_bands: n_basis,
n_kpoints: n_kpts,
fermi_energy,
direct_gap,
indirect_gap,
high_symmetry_points: high_sym,
})
}
fn generate_kpath(path: &[([f64; 3], String)], n_per_segment: usize) -> Vec<KPoint> {
let mut kpoints = Vec::new();
let mut path_dist = 0.0;
for i in 0..path.len() {
let (k, label) = &path[i];
if i == 0 {
kpoints.push(KPoint {
frac: *k,
label: Some(label.clone()),
path_distance: 0.0,
});
continue;
}
let (k_prev, _) = &path[i - 1];
let dk = [k[0] - k_prev[0], k[1] - k_prev[1], k[2] - k_prev[2]];
let seg_len = (dk[0] * dk[0] + dk[1] * dk[1] + dk[2] * dk[2]).sqrt();
for j in 1..=n_per_segment {
let t = j as f64 / n_per_segment as f64;
let frac = [
k_prev[0] + t * dk[0],
k_prev[1] + t * dk[1],
k_prev[2] + t * dk[2],
];
let is_endpoint = j == n_per_segment;
path_dist += seg_len / n_per_segment as f64;
kpoints.push(KPoint {
frac,
label: if is_endpoint {
Some(label.clone())
} else {
None
},
path_distance: path_dist,
});
}
}
kpoints
}
fn build_bloch_matrices(
elements: &[u8],
positions: &[[f64; 3]],
lattice: &[[f64; 3]; 3],
k: &[f64; 3],
n_basis: usize,
) -> (DMatrix<f64>, DMatrix<f64>) {
let basis = crate::eht::basis::build_basis(elements, positions);
let s_0 = crate::eht::overlap::build_overlap_matrix(&basis);
let h_0 = crate::eht::hamiltonian::build_hamiltonian(&basis, &s_0, None);
let n = n_basis.min(s_0.nrows());
let mut h_k = DMatrix::zeros(n, n);
let mut s_k = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
h_k[(i, j)] = h_0[(i, j)];
s_k[(i, j)] = s_0[(i, j)];
}
}
let translations: Vec<[i32; 3]> = vec![
[1, 0, 0],
[-1, 0, 0],
[0, 1, 0],
[0, -1, 0],
[0, 0, 1],
[0, 0, -1],
];
for r in &translations {
let phase = 2.0
* std::f64::consts::PI
* (k[0] * r[0] as f64 + k[1] * r[1] as f64 + k[2] * r[2] as f64);
let cos_phase = phase.cos();
let translated: Vec<[f64; 3]> = positions
.iter()
.map(|p| {
[
p[0] + r[0] as f64 * lattice[0][0]
+ r[1] as f64 * lattice[1][0]
+ r[2] as f64 * lattice[2][0],
p[1] + r[0] as f64 * lattice[0][1]
+ r[1] as f64 * lattice[1][1]
+ r[2] as f64 * lattice[2][1],
p[2] + r[0] as f64 * lattice[0][2]
+ r[1] as f64 * lattice[1][2]
+ r[2] as f64 * lattice[2][2],
]
})
.collect();
let basis_r = crate::eht::basis::build_basis(elements, &translated);
let mut combined = basis.clone();
combined.extend_from_slice(&basis_r);
let s_combined = crate::eht::overlap::build_overlap_matrix(&combined);
let s_r = s_combined.view((0, n), (n, basis_r.len())).clone_owned();
let h_r = build_intercell_hamiltonian(&basis, &basis_r, &s_r);
let nr = n.min(s_r.nrows()).min(s_r.ncols());
for i in 0..nr {
for j in 0..nr {
h_k[(i, j)] += cos_phase * h_r[(i, j)];
s_k[(i, j)] += cos_phase * s_r[(i, j)];
}
}
}
(h_k, s_k)
}
fn build_intercell_hamiltonian(
basis_0: &[crate::eht::basis::AtomicOrbital],
basis_r: &[crate::eht::basis::AtomicOrbital],
s_0r: &DMatrix<f64>,
) -> DMatrix<f64> {
let n = basis_0.len().min(s_0r.nrows());
let m = basis_r.len().min(s_0r.ncols());
let mut h = DMatrix::zeros(n, m);
let k_wh = 1.75;
for i in 0..n {
let hii = basis_0[i].vsip;
for j in 0..m {
let hjj = basis_r[j].vsip;
h[(i, j)] = 0.5 * k_wh * (hii + hjj) * s_0r[(i, j)];
}
}
h
}
fn solve_generalized_eigen(h: &DMatrix<f64>, s: &DMatrix<f64>) -> Result<Vec<f64>, String> {
let n = h.nrows();
if n == 0 {
return Ok(vec![]);
}
let s_eigen = nalgebra::SymmetricEigen::new(s.clone());
let mut s_inv_sqrt = DMatrix::zeros(n, n);
for (i, &eval) in s_eigen.eigenvalues.iter().enumerate() {
if eval > 1e-8 {
let inv_sqrt = 1.0 / eval.sqrt();
for j in 0..n {
for k in 0..n {
s_inv_sqrt[(j, k)] +=
inv_sqrt * s_eigen.eigenvectors[(j, i)] * s_eigen.eigenvectors[(k, i)];
}
}
}
}
let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
let eigen = nalgebra::SymmetricEigen::new(h_prime);
let mut eigenvalues: Vec<f64> = eigen.eigenvalues.iter().copied().collect();
eigenvalues.sort_by(|a, b| a.partial_cmp(b).unwrap());
Ok(eigenvalues)
}
fn estimate_fermi_energy(bands: &[Vec<f64>], n_occupied: usize) -> f64 {
if bands.is_empty() || n_occupied == 0 {
return 0.0;
}
let mut all_occupied: Vec<f64> = bands
.iter()
.filter_map(|eigenvals| {
if eigenvals.len() > n_occupied {
Some((eigenvals[n_occupied - 1] + eigenvals[n_occupied]) / 2.0)
} else {
eigenvals.last().copied()
}
})
.collect();
all_occupied.sort_by(|a, b| a.partial_cmp(b).unwrap());
if all_occupied.is_empty() {
return 0.0;
}
all_occupied[all_occupied.len() / 2]
}
fn compute_band_gaps(bands: &[Vec<f64>], n_occupied: usize) -> (Option<f64>, Option<f64>) {
if bands.is_empty() || n_occupied == 0 {
return (None, None);
}
let mut min_direct = f64::MAX;
let mut max_vb = f64::MIN;
let mut min_cb = f64::MAX;
for eigenvals in bands {
if eigenvals.len() <= n_occupied {
continue;
}
let vb_top = eigenvals[n_occupied - 1];
let cb_bottom = eigenvals[n_occupied];
let gap = cb_bottom - vb_top;
if gap < min_direct && gap > 0.0 {
min_direct = gap;
}
if vb_top > max_vb {
max_vb = vb_top;
}
if cb_bottom < min_cb {
min_cb = cb_bottom;
}
}
let direct = if min_direct < f64::MAX {
Some(min_direct)
} else {
None
};
let indirect = if min_cb > max_vb {
Some(min_cb - max_vb)
} else {
None
};
(direct, indirect)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_kpath() {
let config = BandStructureConfig::default();
let kpoints = generate_kpath(&config.path, 10);
assert!(!kpoints.is_empty());
assert_eq!(kpoints[0].label.as_deref(), Some("Γ"));
}
#[test]
fn test_band_gaps() {
let bands = vec![vec![-5.0, -3.0, 1.0, 3.0], vec![-4.5, -2.5, 1.5, 3.5]];
let (direct, indirect) = compute_band_gaps(&bands, 2);
assert!(direct.is_some());
assert!(indirect.is_some());
assert!(indirect.unwrap() > 0.0);
}
}