use crate::atom_struct::{AtomType, OrbProj};
use crate::error::{Result, TbError};
use crate::solve_ham::*;
use crate::{Model, SpinDirection};
use ndarray::{Array1, Array2, Axis, arr1, arr2};
use ndarray_linalg::Norm;
use num_complex::Complex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeSet, HashMap};
use std::fs;
use std::io::{self, Write};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkAtom {
pub position: Array1<f64>,
pub atom_type: AtomType,
pub projections: Vec<OrbProj>,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct SkParams {
pub v_ss_sigma: Option<f64>,
pub v_sp_sigma: Option<f64>,
pub v_pp_sigma: Option<f64>,
pub v_pp_pi: Option<f64>,
pub v_sd_sigma: Option<f64>,
pub v_pd_sigma: Option<f64>,
pub v_pd_pi: Option<f64>,
pub v_dd_sigma: Option<f64>,
pub v_dd_pi: Option<f64>,
pub v_dd_delta: Option<f64>,
pub v_sf_sigma: Option<f64>,
pub v_pf_sigma: Option<f64>,
pub v_pf_pi: Option<f64>,
pub v_df_sigma: Option<f64>,
pub v_df_pi: Option<f64>,
pub v_df_delta: Option<f64>,
pub v_ff_sigma: Option<f64>,
pub v_ff_pi: Option<f64>,
pub v_ff_delta: Option<f64>,
pub v_ff_phi: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct SlaterKosterModel {
pub dim_r: usize,
pub lat: Array2<f64>,
pub atoms: Vec<SkAtom>,
pub spin: bool,
pub neighbor_search_range: i32,
}
impl Default for SlaterKosterModel {
fn default() -> Self {
Self {
dim_r: 3,
lat: Array2::zeros((3, 3)),
atoms: Vec::new(),
spin: false,
neighbor_search_range: 3,
}
}
}
impl SlaterKosterModel {
pub fn new(dim_r: usize, lat: Array2<f64>, atoms: Vec<SkAtom>, spin: bool) -> Self {
Self {
dim_r,
lat,
atoms,
spin,
neighbor_search_range: 3,
}
}
pub fn with_search_range(mut self, range: i32) -> Result<Self> {
if range <= 0 {
return Err(TbError::InvalidSearchRange(range));
}
self.neighbor_search_range = range;
Ok(self)
}
fn generate_R_vectors(&self) -> Vec<Array1<isize>> {
let rng = -self.neighbor_search_range..=self.neighbor_search_range;
match self.dim_r {
1 => rng.map(|i| arr1(&[i as isize])).collect(),
2 => rng
.clone()
.flat_map(|i| rng.clone().map(move |j| arr1(&[i as isize, j as isize])))
.collect(),
3 => rng
.clone()
.flat_map(|i| {
rng.clone().flat_map({
let value = rng.clone();
move |j| {
value
.clone()
.map(move |k| arr1(&[i as isize, j as isize, k as isize]))
}
})
})
.collect(),
_ => vec![],
}
}
pub fn find_shell_distances(&self, n_neighbors: usize) -> Result<Vec<f64>> {
let mut distances = BTreeSet::new();
let R_vectors = self.generate_R_vectors();
let real_pos: Vec<Array1<f64>> = self
.atoms
.iter()
.map(|a| self.lat.dot(&a.position))
.collect();
for (i, pos_i) in real_pos.iter().enumerate() {
for (j, pos_j) in real_pos.iter().enumerate() {
for R in &R_vectors {
if i == j && R.iter().all(|&x| x == 0) {
continue;
}
let R_real = self.lat.dot(&R.mapv(|x| x as f64));
let dvec = pos_j + &R_real - pos_i;
let dist = dvec.norm_l2();
if dist > 1e-8 {
let key = (dist * 1e6).round() as i64;
distances.insert(key);
}
}
}
}
if distances.is_empty() {
return Err(TbError::NoShellsFound);
}
let mut shells: Vec<f64> = distances.into_iter().map(|x| x as f64 / 1e6).collect();
shells.sort_by(|a, b| a.partial_cmp(b).unwrap());
shells.truncate(n_neighbors);
Ok(shells)
}
}
#[allow(clippy::too_many_arguments)]
fn sk_element(
oi: OrbProj,
oj: OrbProj,
l: f64,
m: f64,
n: f64,
param: &SkParams,
pair: (AtomType, AtomType),
shell: usize,
) -> Result<f64> {
use OrbProj::*;
macro_rules! get {
($field:ident, $name:expr) => {
param.$field.ok_or_else(|| TbError::SkParameterMissing {
param: $name.to_string(),
atom1: pair.0,
atom2: pair.1,
shell,
})?
};
}
let result = match (oi, oj) {
(s, s) => get!(v_ss_sigma, "Vssσ"),
(s, px) | (px, s) => l * get!(v_sp_sigma, "Vspσ"),
(s, py) | (py, s) => m * get!(v_sp_sigma, "Vspσ"),
(s, pz) | (pz, s) => n * get!(v_sp_sigma, "Vspσ"),
(px, px) => l * l * get!(v_pp_sigma, "Vppσ") + (1.0 - l * l) * get!(v_pp_pi, "Vppπ"),
(py, py) => m * m * get!(v_pp_sigma, "Vppσ") + (1.0 - m * m) * get!(v_pp_pi, "Vppπ"),
(pz, pz) => n * n * get!(v_pp_sigma, "Vppσ") + (1.0 - n * n) * get!(v_pp_pi, "Vppπ"),
(px, py) | (py, px) => l * m * (get!(v_pp_sigma, "Vppσ") - get!(v_pp_pi, "Vppπ")),
(px, pz) | (pz, px) => l * n * (get!(v_pp_sigma, "Vppσ") - get!(v_pp_pi, "Vppπ")),
(py, pz) | (pz, py) => m * n * (get!(v_pp_sigma, "Vppσ") - get!(v_pp_pi, "Vppπ")),
(s, dxy) | (dxy, s) => (3.0_f64.sqrt()) * l * m * get!(v_sd_sigma, "Vsdσ"),
(s, dyz) | (dyz, s) => (3.0_f64.sqrt()) * m * n * get!(v_sd_sigma, "Vsdσ"),
(s, dxz) | (dxz, s) => (3.0_f64.sqrt()) * l * n * get!(v_sd_sigma, "Vsdσ"),
(s, dz2) | (dz2, s) => ((3.0 * n * n - 1.0) / 2.0) * get!(v_sd_sigma, "Vsdσ"),
(s, dx2y2) | (dx2y2, s) => {
(3.0_f64.sqrt() / 2.0) * (l * l - m * m) * get!(v_sd_sigma, "Vsdσ")
}
(px, dxy) | (dxy, px) => {
3_f64.sqrt() * l * l * m * get!(v_pd_sigma, "Vpdσ")
+ m * (1.0 - 2.0 * l * l) * get!(v_pd_pi, "Vpdπ")
}
(s, fz3) | (fz3, s) => {
n * (5.0 * n * n - 3.0) / 2.0 * get!(v_sf_sigma, "Vsfσ")
}
(s, fxz2) | (fxz2, s) => {
(3.0_f64.sqrt()) * l * (5.0 * n * n - 1.0) / 2.0 * get!(v_sf_sigma, "Vsfσ")
}
(s, fyz2) | (fyz2, s) => {
(3.0_f64.sqrt()) * m * (5.0 * n * n - 1.0) / 2.0 * get!(v_sf_sigma, "Vsfσ")
}
(s, fxyz) | (fxyz, s) => {
(15.0_f64.sqrt()) * l * m * n * get!(v_sf_sigma, "Vsfσ")
}
(s, fx3) | (fx3, s) => {
(15.0_f64.sqrt()) / 2.0 * l * (l * l - 3.0 * m * m) * get!(v_sf_sigma, "Vsfσ")
}
(s, fy3) | (fy3, s) => {
(15.0_f64.sqrt()) / 2.0 * m * (3.0 * l * l - m * m) * get!(v_sf_sigma, "Vsfσ")
}
(s, fz2x2y2) | (fz2x2y2, s) => {
(15.0_f64.sqrt()) / 2.0 * n * (l * l - m * m) * get!(v_sf_sigma, "Vsfσ")
}
_ => return Err(TbError::UnsupportedOrbitalCombination(oi, oj)),
};
Ok(result)
}
pub trait ToTbModel {
fn build_model(
&self,
n_neighbors: usize,
params: &HashMap<(AtomType, AtomType, usize), SkParams>,
) -> Result<Model>;
}
impl ToTbModel for SlaterKosterModel {
fn build_model(
&self,
n_neighbors: usize,
params: &HashMap<(AtomType, AtomType, usize), SkParams>,
) -> Result<Model> {
let mut orb_positions = Vec::new();
let mut orb_projections = Vec::new();
for atom in &self.atoms {
for proj in &atom.projections {
orb_positions.push(atom.position.clone());
orb_projections.push(*proj);
}
}
let norb = orb_positions.len();
let orb_positions_flat: Vec<f64> = orb_positions
.into_iter()
.flat_map(|v| v.into_raw_vec())
.collect();
let orb_array = Array2::from_shape_vec((norb, self.dim_r), orb_positions_flat)
.map_err(|e| TbError::Linalg(ndarray_linalg::error::LinalgError::Shape(e)))?;
let mut model = Model::tb_model(self.dim_r, self.lat.clone(), orb_array, self.spin, None)?;
model.set_projection(&orb_projections);
if n_neighbors == 0 {
return Ok(model);
}
let shells = self.find_shell_distances(n_neighbors)?;
let real_pos: Vec<Array1<f64>> = self
.atoms
.iter()
.map(|a| self.lat.dot(&a.position))
.collect();
let mut orb_offset = vec![0; self.atoms.len()];
let mut total = 0;
for (i, atom) in self.atoms.iter().enumerate() {
orb_offset[i] = total;
total += atom.projections.len();
}
let R_vectors = self.generate_R_vectors();
for (iatom, atom_i) in self.atoms.iter().enumerate() {
let pos_i = &real_pos[iatom];
let offset_i = orb_offset[iatom];
let type_i = atom_i.atom_type;
let projs_i = &atom_i.projections;
for (jatom, atom_j) in self.atoms.iter().enumerate() {
let pos_j = &real_pos[jatom];
let offset_j = orb_offset[jatom];
let type_j = atom_j.atom_type;
let projs_j = &atom_j.projections;
for R in &R_vectors {
if iatom == jatom && R.iter().all(|&x| x == 0) {
continue;
}
let R_real = self.lat.dot(&R.mapv(|x| x as f64));
let dvec = pos_j + &R_real - pos_i;
let dist = dvec.norm_l2();
if dist < 1e-8 {
continue;
}
let shell_idx = shells.iter().position(|&d| (d - dist).abs() < 1e-6);
if let Some(s_idx) = shell_idx {
let key = (type_i, type_j, s_idx);
let key_rev = (type_j, type_i, s_idx);
let param = params
.get(&key)
.or_else(|| params.get(&key_rev))
.ok_or_else(|| TbError::SkParameterMissing {
param: "any".to_string(),
atom1: type_i,
atom2: type_j,
shell: s_idx,
})?;
let l = dvec[0] / dist;
let m = if self.dim_r > 1 { dvec[1] / dist } else { 0.0 };
let n = if self.dim_r > 2 { dvec[2] / dist } else { 0.0 };
for (pi_idx, &proj_i) in projs_i.iter().enumerate() {
let io = offset_i + pi_idx;
for (pj_idx, &proj_j) in projs_j.iter().enumerate() {
let jo = offset_j + pj_idx;
if iatom == jatom && R.iter().all(|&x| x == 0) && io == jo {
continue;
}
let hop = sk_element(
proj_i,
proj_j,
l,
m,
n,
param,
(type_i, type_j),
s_idx,
)?;
if hop.abs() > 1e-12 {
model.add_hop(hop, io, jo, R, SpinDirection::None);
}
}
}
}
}
}
}
Ok(model)
}
}
fn required_sk_parameters(projs_a: &[OrbProj], projs_b: &[OrbProj]) -> Vec<String> {
let mut needed = Vec::new();
let has_s = |projs: &[OrbProj]| projs.contains(&OrbProj::s);
let has_p = |projs: &[OrbProj]| {
projs.contains(&OrbProj::px) || projs.contains(&OrbProj::py) || projs.contains(&OrbProj::pz)
};
let has_d = |projs: &[OrbProj]| {
projs.contains(&OrbProj::dxy)
|| projs.contains(&OrbProj::dyz)
|| projs.contains(&OrbProj::dxz)
|| projs.contains(&OrbProj::dz2)
|| projs.contains(&OrbProj::dx2y2)
};
let has_f = |projs: &[OrbProj]| {
projs.contains(&OrbProj::fz3)
|| projs.contains(&OrbProj::fxz2)
|| projs.contains(&OrbProj::fyz2)
|| projs.contains(&OrbProj::fzx2y2)
|| projs.contains(&OrbProj::fxyz)
|| projs.contains(&OrbProj::fxx23y2)
|| projs.contains(&OrbProj::fy3x2y2)
};
if has_s(projs_a) && has_s(projs_b) {
needed.push("v_ss_sigma".to_string());
}
if (has_s(projs_a) && has_p(projs_b)) || (has_p(projs_a) && has_s(projs_b)) {
needed.push("v_sp_sigma".to_string());
}
if has_p(projs_a) && has_p(projs_b) {
needed.push("v_pp_sigma".to_string());
needed.push("v_pp_pi".to_string());
}
if has_s(projs_a) && has_d(projs_b) || has_d(projs_a) && has_s(projs_b) {
needed.push("v_sd_sigma".to_string());
}
if has_p(projs_a) && has_d(projs_b) || has_d(projs_a) && has_p(projs_b) {
needed.push("v_pd_sigma".to_string());
needed.push("v_pd_pi".to_string());
}
if has_d(projs_a) && has_d(projs_b) {
needed.push("v_dd_sigma".to_string());
needed.push("v_dd_pi".to_string());
needed.push("v_dd_delta".to_string());
}
if has_s(projs_a) && has_f(projs_b) || has_f(projs_a) && has_s(projs_b) {
needed.push("v_sf_sigma".to_string());
}
if has_p(projs_a) && has_f(projs_b) || has_f(projs_a) && has_p(projs_b) {
needed.push("v_pf_sigma".to_string());
needed.push("v_pf_pi".to_string());
}
if has_d(projs_a) && has_f(projs_b) || has_f(projs_a) && has_d(projs_b) {
needed.push("v_df_sigma".to_string());
needed.push("v_df_pi".to_string());
needed.push("v_df_delta".to_string());
}
if has_f(projs_a) && has_f(projs_b) {
needed.push("v_ff_sigma".to_string());
needed.push("v_ff_pi".to_string());
needed.push("v_ff_delta".to_string());
needed.push("v_ff_phi".to_string());
}
needed
}
fn prompt_float(prompt: &str) -> Result<f64> {
print!("{}", prompt);
io::stdout().flush().unwrap();
let mut input = String::new();
io::stdin().read_line(&mut input).unwrap();
let trimmed = input.trim();
if trimmed.is_empty() {
return Ok(0.0);
}
trimmed
.parse()
.map_err(|_| TbError::Other("Invalid number".into()))
}
pub fn collect_sk_parameters(
sk_model: &SlaterKosterModel,
n_shells: usize,
cache_path: Option<&str>,
) -> Result<HashMap<(AtomType, AtomType, usize), SkParams>> {
if let Some(path) = cache_path {
if Path::new(path).exists() {
println!("Loading SK parameters from cache file: {}", path);
let content = fs::read_to_string(path).map_err(|e| TbError::Io(e))?;
let params: HashMap<(AtomType, AtomType, usize), SkParams> =
toml::from_str(&content)
.map_err(|e| TbError::Other(format!("TOML parse error: {}", e)))?;
return Ok(params);
}
}
let shells = sk_model.find_shell_distances(n_shells)?;
println!("\nFound {} neighbor shells:", shells.len());
for (idx, dist) in shells.iter().enumerate() {
println!(" Shell {}: distance = {:.6} Ã…", idx, dist);
}
let mut atom_type_info: HashMap<AtomType, Vec<OrbProj>> = HashMap::new();
for atom in &sk_model.atoms {
let entry = atom_type_info
.entry(atom.atom_type)
.or_insert_with(Vec::new);
for proj in &atom.projections {
if !entry.contains(proj) {
entry.push(*proj);
}
}
}
let atom_types: Vec<AtomType> = atom_type_info.keys().cloned().collect();
let mut pairs = Vec::new();
for i in 0..atom_types.len() {
for j in i..atom_types.len() {
pairs.push((atom_types[i], atom_types[j]));
}
}
let mut all_params = HashMap::new();
for shell_idx in 0..shells.len() {
println!(
"\n=== Shell {} (distance = {:.6} Ã…) ===",
shell_idx, shells[shell_idx]
);
for (a, b) in &pairs {
let projs_a = atom_type_info.get(a).unwrap();
let projs_b = atom_type_info.get(b).unwrap();
let needed = required_sk_parameters(projs_a, projs_b);
if needed.is_empty() {
continue;
}
println!(
"\nAtom pair {} - {} (shell {})",
a.to_str(),
b.to_str(),
shell_idx
);
let mut skp = SkParams::default();
for param_name in needed {
let prompt_str = format!(" {} = ", param_name);
let value = prompt_float(&prompt_str)?;
match param_name.as_str() {
"v_ss_sigma" => skp.v_ss_sigma = Some(value),
"v_sp_sigma" => skp.v_sp_sigma = Some(value),
"v_pp_sigma" => skp.v_pp_sigma = Some(value),
"v_pp_pi" => skp.v_pp_pi = Some(value),
"v_sd_sigma" => skp.v_sd_sigma = Some(value),
"v_pd_sigma" => skp.v_pd_sigma = Some(value),
"v_pd_pi" => skp.v_pd_pi = Some(value),
"v_dd_sigma" => skp.v_dd_sigma = Some(value),
"v_dd_pi" => skp.v_dd_pi = Some(value),
"v_dd_delta" => skp.v_dd_delta = Some(value),
"v_sf_sigma" => skp.v_sf_sigma = Some(value),
"v_pf_sigma" => skp.v_pf_sigma = Some(value),
"v_pf_pi" => skp.v_pf_pi = Some(value),
"v_df_sigma" => skp.v_df_sigma = Some(value),
"v_df_pi" => skp.v_df_pi = Some(value),
"v_df_delta" => skp.v_df_delta = Some(value),
"v_ff_sigma" => skp.v_ff_sigma = Some(value),
"v_ff_pi" => skp.v_ff_pi = Some(value),
"v_ff_delta" => skp.v_ff_delta = Some(value),
"v_ff_phi" => skp.v_ff_phi = Some(value),
_ => {}
}
}
all_params.insert((*a, *b, shell_idx), skp);
if a != b {
all_params.insert((*b, *a, shell_idx), skp);
}
}
}
if let Some(path) = cache_path {
let toml_string = toml::to_string(&all_params)
.map_err(|e| TbError::Other(format!("TOML serialize error: {}", e)))?;
fs::write(path, toml_string).map_err(|e| TbError::Io(e))?;
println!("SK parameters saved to {}", path);
}
Ok(all_params)
}
pub fn build_model_interactive(
sk_model: &SlaterKosterModel,
n_shells: usize,
cache_path: Option<&str>,
) -> Result<Model> {
let params = collect_sk_parameters(sk_model, n_shells, cache_path)?;
sk_model.build_model(n_shells, ¶ms)
}
pub fn write_sk_params_to_file(
path: &str,
params: &HashMap<(AtomType, AtomType, usize), SkParams>,
) -> Result<()> {
let toml_string = toml::to_string(params)
.map_err(|e| TbError::Other(format!("TOML serialize error: {}", e)))?;
fs::write(path, toml_string).map_err(|e| TbError::Io(e))?;
Ok(())
}
pub fn read_sk_params_from_file(
path: &str,
) -> Result<HashMap<(AtomType, AtomType, usize), SkParams>> {
let content = fs::read_to_string(path).map_err(|e| TbError::Io(e))?;
toml::from_str(&content).map_err(|e| TbError::Other(format!("TOML parse error: {}", e)))
}
pub fn example_graphene_interactive(cache_path: Option<&str>) -> Result<Model> {
let lat = arr2(&[[1.0, 0.0], [-0.5, 3.0f64.sqrt() / 2.0]]);
let atoms = vec![
SkAtom {
position: arr1(&[1.0 / 3.0, 2.0 / 3.0]),
atom_type: AtomType::C,
projections: vec![OrbProj::pz],
},
SkAtom {
position: arr1(&[2.0 / 3.0, 1.0 / 3.0]),
atom_type: AtomType::C,
projections: vec![OrbProj::pz],
},
];
let sk_model = SlaterKosterModel::new(2, lat, atoms, false);
build_model_interactive(&sk_model, 1, cache_path)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_f_orbital_detection() {
let projs = vec![OrbProj::fz3, OrbProj::fxz2];
let needed = required_sk_parameters(&projs, &projs);
assert!(needed.contains(&"v_ff_sigma".to_string()));
assert!(needed.contains(&"v_ff_pi".to_string()));
assert!(needed.contains(&"v_ff_delta".to_string()));
assert!(needed.contains(&"v_ff_phi".to_string()));
}
#[test]
fn test_sk_element_s_f() {
let param = SkParams {
v_sf_sigma: Some(1.0),
..Default::default()
};
let result = sk_element(
OrbProj::s,
OrbProj::fz3,
0.0,
0.0,
1.0,
¶m,
(AtomType::Ce, AtomType::Ce),
0,
);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1.0);
}
#[test]
fn test_build_model_with_f_orbital() {
let lat = Array2::eye(3);
let atoms = vec![SkAtom {
position: arr1(&[0.0, 0.0, 0.0]),
atom_type: AtomType::Ce,
projections: vec![OrbProj::fz3],
}];
let sk_model = SlaterKosterModel::new(3, lat, atoms, false);
let params = HashMap::new();
let model = sk_model.build_model(0, ¶ms).unwrap();
assert_eq!(model.norb(), 1);
assert_eq!(model.nsta(), 1);
let band = model.solve_band_onek(&arr1(&[0.0, 0.0, 0.0]));
assert_eq!(band.len(), 1);
}
}