#![allow(dead_code)]
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
use std::f64::consts::PI;
pub struct CrystalStructurePredictor {
composition: ChemicalComposition,
method: PredictionMethod,
constraints: StructureConstraints,
energy_model: EnergyModel,
search_strategy: SearchStrategy,
}
#[derive(Debug, Clone)]
pub struct ChemicalComposition {
pub elements: HashMap<String, usize>,
pub total_atoms: usize,
pub stoichiometry: Option<Vec<f64>>,
pub oxidation_states: Option<HashMap<String, i32>>,
}
#[derive(Debug, Clone)]
pub enum PredictionMethod {
GlobalOptimization {
algorithm: GlobalOptAlgorithm,
max_iterations: usize,
},
DataMining {
database: StructureDatabase,
similarity_threshold: f64,
},
Evolutionary {
population_size: usize,
generations: usize,
mutation_rate: f64,
},
MachineLearning {
model: MLModel,
confidence_threshold: f64,
},
AIRSS {
num_searches: usize,
symmetry_constraints: bool,
},
}
#[derive(Debug, Clone)]
pub enum GlobalOptAlgorithm {
SimulatedAnnealing,
BasinHopping,
ParticleSwarm,
GeneticAlgorithm,
MinimumHopping,
}
#[derive(Debug, Clone)]
pub struct StructureDatabase {
pub source: DatabaseSource,
pub size: usize,
pub filters: Vec<DatabaseFilter>,
}
#[derive(Debug, Clone)]
pub enum DatabaseSource {
MaterialsProject,
ICSD,
COD,
Custom { path: String },
}
#[derive(Debug, Clone)]
pub enum DatabaseFilter {
Elements {
required: Vec<String>,
forbidden: Vec<String>,
},
SpaceGroup { allowed: Vec<u32> },
Property { name: String, min: f64, max: f64 },
Stability { max_above_hull: f64 },
}
#[derive(Debug, Clone)]
pub struct MLModel {
pub model_type: MLModelType,
pub features: FeatureRepresentation,
pub training_size: usize,
pub accuracy: f64,
}
#[derive(Debug, Clone)]
pub enum MLModelType {
GraphNN,
CGCNN,
SchNet,
MEGNet,
GAP,
}
#[derive(Debug, Clone)]
pub enum FeatureRepresentation {
CoulombMatrix,
SineMatrix,
ManyBodyTensor { order: usize },
SOAP {
cutoff: f64,
n_max: usize,
l_max: usize,
},
CrystalGraph,
}
#[derive(Debug, Clone)]
pub struct StructureConstraints {
pub lattice: LatticeConstraints,
pub symmetry: SymmetryConstraints,
pub distances: DistanceConstraints,
pub coordination: CoordinationConstraints,
pub pressure: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct LatticeConstraints {
pub min_lengths: Option<Vec3D>,
pub max_lengths: Option<Vec3D>,
pub angle_ranges: Option<Vec<(f64, f64)>>,
pub volume_range: Option<(f64, f64)>,
pub lattice_type: Option<LatticeType>,
}
#[derive(Debug, Clone)]
pub enum LatticeType {
Cubic,
Tetragonal,
Orthorhombic,
Hexagonal,
Rhombohedral,
Monoclinic,
Triclinic,
}
#[derive(Debug, Clone)]
pub struct Vec3D {
pub x: f64,
pub y: f64,
pub z: f64,
}
#[derive(Debug, Clone)]
pub struct SymmetryConstraints {
pub space_groups: Option<Vec<u32>>,
pub point_groups: Option<Vec<String>>,
pub min_symmetry: Option<usize>,
pub wyckoff_positions: Option<Vec<WyckoffPosition>>,
}
#[derive(Debug, Clone)]
pub struct WyckoffPosition {
pub letter: char,
pub multiplicity: usize,
pub site_symmetry: String,
pub coordinates: Vec<Vec3D>,
}
#[derive(Debug, Clone)]
pub struct DistanceConstraints {
pub min_distances: HashMap<(String, String), f64>,
pub max_distances: HashMap<(String, String), f64>,
pub bond_lengths: Vec<BondConstraint>,
}
#[derive(Debug, Clone)]
pub struct BondConstraint {
pub atoms: (String, String),
pub target_length: f64,
pub tolerance: f64,
pub bond_order: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct CoordinationConstraints {
pub coordination_numbers: HashMap<String, (usize, usize)>,
pub geometries: HashMap<String, CoordinationGeometry>,
pub allowed_ligands: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone)]
pub enum CoordinationGeometry {
Linear,
Trigonal,
Tetrahedral,
SquarePlanar,
TrigonalBipyramidal,
Octahedral,
PentagonalBipyramidal,
CubicCoordination,
Custom { angles: Vec<f64> },
}
#[derive(Debug, Clone)]
pub enum EnergyModel {
Empirical {
potential: EmpiricalPotential,
parameters: HashMap<String, f64>,
},
DFT {
functional: String,
basis_set: String,
k_points: Vec<usize>,
},
MLPotential {
model: MLPotentialModel,
uncertainty_quantification: bool,
},
TightBinding {
parameterization: String,
k_points: Vec<usize>,
},
}
#[derive(Debug, Clone)]
pub enum EmpiricalPotential {
LennardJones,
Buckingham,
Morse,
EAM,
Tersoff,
StillingerWeber,
}
#[derive(Debug, Clone)]
pub struct MLPotentialModel {
pub architecture: String,
pub training_rmse: f64,
pub validation_rmse: f64,
pub elements: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum SearchStrategy {
Random { num_trials: usize },
Grid { resolution: Vec<usize> },
Bayesian {
acquisition_function: AcquisitionFunction,
num_initial: usize,
},
Metadynamics {
collective_variables: Vec<CollectiveVariable>,
bias_factor: f64,
},
}
#[derive(Debug, Clone)]
pub enum AcquisitionFunction {
ExpectedImprovement,
UCB { kappa: f64 },
ProbabilityOfImprovement,
ThompsonSampling,
}
#[derive(Debug, Clone)]
pub enum CollectiveVariable {
LatticeParameter { index: usize },
CoordinationNumber { element: String },
OrderParameter { definition: String },
Density,
}
impl CrystalStructurePredictor {
pub fn new(composition: ChemicalComposition, energy_model: EnergyModel) -> Self {
Self {
composition,
method: PredictionMethod::GlobalOptimization {
algorithm: GlobalOptAlgorithm::SimulatedAnnealing,
max_iterations: 1000,
},
constraints: StructureConstraints::default(),
energy_model,
search_strategy: SearchStrategy::Random { num_trials: 100 },
}
}
pub fn with_method(mut self, method: PredictionMethod) -> Self {
self.method = method;
self
}
pub fn with_constraints(mut self, constraints: StructureConstraints) -> Self {
self.constraints = constraints;
self
}
pub fn build_qubo(&self) -> Result<(Array2<f64>, HashMap<String, usize>), String> {
match &self.method {
PredictionMethod::GlobalOptimization { .. } => self.build_global_optimization_qubo(),
PredictionMethod::Evolutionary { .. } => self.build_evolutionary_qubo(),
_ => Err("QUBO formulation not available for this method".to_string()),
}
}
fn build_global_optimization_qubo(
&self,
) -> Result<(Array2<f64>, HashMap<String, usize>), String> {
let lattice_resolution = 10; let position_resolution = 20;
let n_lattice_vars = 6 * lattice_resolution;
let n_atoms = self.composition.total_atoms;
let n_position_vars = n_atoms * 3 * position_resolution;
let n_vars = n_lattice_vars + n_position_vars;
let mut qubo = Array2::zeros((n_vars, n_vars));
let mut var_map = HashMap::new();
self.create_lattice_variables(&mut var_map, lattice_resolution)?;
self.create_position_variables(&mut var_map, n_atoms, position_resolution, n_lattice_vars)?;
self.add_energy_objective(&mut qubo, &var_map)?;
self.add_lattice_constraints(&mut qubo, &var_map, lattice_resolution)?;
self.add_distance_constraints(&mut qubo, &var_map)?;
self.add_symmetry_constraints(&mut qubo, &var_map)?;
Ok((qubo, var_map))
}
fn create_lattice_variables(
&self,
var_map: &mut HashMap<String, usize>,
resolution: usize,
) -> Result<(), String> {
let params = ["a", "b", "c", "alpha", "beta", "gamma"];
let mut var_idx = 0;
for param in ¶ms {
for i in 0..resolution {
let var_name = format!("lattice_{param}_{i}");
var_map.insert(var_name, var_idx);
var_idx += 1;
}
}
Ok(())
}
fn create_position_variables(
&self,
var_map: &mut HashMap<String, usize>,
n_atoms: usize,
resolution: usize,
offset: usize,
) -> Result<(), String> {
let mut var_idx = offset;
for atom in 0..n_atoms {
for coord in ["x", "y", "z"] {
for i in 0..resolution {
let var_name = format!("pos_{atom}_{coord}_{i}");
var_map.insert(var_name, var_idx);
var_idx += 1;
}
}
}
Ok(())
}
fn add_energy_objective(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
match &self.energy_model {
EnergyModel::Empirical {
potential,
parameters,
} => self.add_empirical_energy(qubo, var_map, potential, parameters),
_ => {
self.add_surrogate_energy(qubo, var_map)
}
}
}
fn add_empirical_energy(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
potential: &EmpiricalPotential,
parameters: &HashMap<String, f64>,
) -> Result<(), String> {
if matches!(potential, EmpiricalPotential::LennardJones) {
let epsilon = parameters.get("epsilon").unwrap_or(&1.0);
let sigma = parameters.get("sigma").unwrap_or(&3.4);
for i in 0..self.composition.total_atoms {
for j in i + 1..self.composition.total_atoms {
self.add_pairwise_energy(qubo, var_map, i, j, *epsilon, *sigma)?;
}
}
}
Ok(())
}
fn add_pairwise_energy(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
atom1: usize,
atom2: usize,
epsilon: f64,
_sigma: f64,
) -> Result<(), String> {
for coord in ["x", "y", "z"] {
for i in 0..20 {
let var1 = format!("pos_{atom1}_{coord}_{i}");
let var2 = format!("pos_{atom2}_{coord}_{i}");
if let (Some(&idx1), Some(&idx2)) = (var_map.get(&var1), var_map.get(&var2)) {
qubo[[idx1, idx2]] -= epsilon;
}
}
}
Ok(())
}
fn add_surrogate_energy(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
self.add_coordination_energy(qubo, var_map)?;
self.add_electrostatic_energy(qubo, var_map)?;
Ok(())
}
fn add_coordination_energy(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
if !self
.constraints
.coordination
.coordination_numbers
.is_empty()
{
let _coord_numbers = &self.constraints.coordination.coordination_numbers;
let penalty = 10.0;
for i in 0..self.composition.total_atoms {
for coord in ["x", "y", "z"] {
for pos in 0..20 {
let var_name = format!("pos_{i}_{coord}_{pos}");
if let Some(&idx) = var_map.get(&var_name) {
let deviation = (pos as f64 - 10.0).abs();
qubo[[idx, idx]] += penalty * deviation / 10.0;
}
}
}
}
}
Ok(())
}
fn add_electrostatic_energy(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
if let Some(_oxidation_states) = &self.composition.oxidation_states {
let _elements: Vec<_> = self.composition.elements.keys().collect();
for i in 0..self.composition.total_atoms {
for j in i + 1..self.composition.total_atoms {
let charge1 = 1.0; let charge2 = -1.0;
let interaction = charge1 * charge2;
for coord in ["x", "y", "z"] {
for pos in 0..20 {
let var1 = format!("pos_{i}_{coord}_{pos}");
let var2 = format!("pos_{j}_{coord}_{pos}");
if let (Some(&idx1), Some(&idx2)) =
(var_map.get(&var1), var_map.get(&var2))
{
if idx1 != idx2 {
qubo[[idx1, idx2]] += interaction * 0.1;
}
}
}
}
}
}
}
Ok(())
}
fn add_lattice_constraints(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
resolution: usize,
) -> Result<(), String> {
let penalty = 100.0;
for param in ["a", "b", "c", "alpha", "beta", "gamma"] {
for i in 0..resolution {
for j in i + 1..resolution {
let var1 = format!("lattice_{param}_{i}");
let var2 = format!("lattice_{param}_{j}");
if let (Some(&idx1), Some(&idx2)) = (var_map.get(&var1), var_map.get(&var2)) {
qubo[[idx1, idx2]] += penalty;
}
}
}
}
if let Some(lattice_type) = &self.constraints.lattice.lattice_type {
self.add_lattice_type_constraints(qubo, var_map, lattice_type, resolution)?;
}
Ok(())
}
fn add_lattice_type_constraints(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
lattice_type: &LatticeType,
resolution: usize,
) -> Result<(), String> {
let penalty = 200.0;
match lattice_type {
LatticeType::Cubic => {
for i in 0..resolution {
let var_a = format!("lattice_a_{i}");
let var_b = format!("lattice_b_{i}");
let var_c = format!("lattice_c_{i}");
if let (Some(&idx_a), Some(&idx_b), Some(&idx_c)) = (
var_map.get(&var_a),
var_map.get(&var_b),
var_map.get(&var_c),
) {
qubo[[idx_a, idx_b]] -= penalty;
qubo[[idx_b, idx_c]] -= penalty;
qubo[[idx_a, idx_c]] -= penalty;
}
}
let angle_90_idx = resolution / 2; for angle in ["alpha", "beta", "gamma"] {
let var_name = format!("lattice_{angle}_{angle_90_idx}");
if let Some(&idx) = var_map.get(&var_name) {
qubo[[idx, idx]] -= penalty * 2.0;
}
}
}
LatticeType::Hexagonal => {
}
_ => {}
}
Ok(())
}
fn add_distance_constraints(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
if !self.constraints.distances.min_distances.is_empty() {
let min_distances = &self.constraints.distances.min_distances;
let penalty = 50.0;
for ((_elem1, _elem2), &_min_dist) in min_distances {
for i in 0..self.composition.total_atoms {
for j in i + 1..self.composition.total_atoms {
for coord in ["x", "y", "z"] {
for pos in 0..20 {
let var1 = format!("pos_{i}_{coord}_{pos}");
let var2 = format!("pos_{j}_{coord}_{pos}");
if let (Some(&idx1), Some(&idx2)) =
(var_map.get(&var1), var_map.get(&var2))
{
if idx1 == idx2 {
qubo[[idx1, idx2]] += penalty;
}
}
}
}
}
}
}
}
Ok(())
}
fn add_symmetry_constraints(
&self,
qubo: &mut Array2<f64>,
var_map: &HashMap<String, usize>,
) -> Result<(), String> {
if let Some(_space_groups) = &self.constraints.symmetry.space_groups {
let symmetry_bonus = -10.0;
for i in 0..self.composition.total_atoms {
for coord in ["x", "y", "z"] {
for special_pos in [0, 10, 19] {
let var_name = format!("pos_{i}_{coord}_{special_pos}");
if let Some(&idx) = var_map.get(&var_name) {
qubo[[idx, idx]] += symmetry_bonus;
}
}
}
}
}
Ok(())
}
fn build_evolutionary_qubo(&self) -> Result<(Array2<f64>, HashMap<String, usize>), String> {
let genome_length = 100; let mut qubo = Array2::zeros((genome_length, genome_length));
let mut var_map = HashMap::new();
for i in 0..genome_length {
var_map.insert(format!("gene_{i}"), i);
}
self.add_fitness_function(&mut qubo, &var_map)?;
Ok((qubo, var_map))
}
fn add_fitness_function(
&self,
qubo: &mut Array2<f64>,
_var_map: &HashMap<String, usize>,
) -> Result<(), String> {
for i in 0..qubo.shape()[0] {
qubo[[i, i]] = -1.0; }
Ok(())
}
pub fn decode_solution(
&self,
solution: &HashMap<String, bool>,
) -> Result<CrystalStructure, String> {
let lattice = self.decode_lattice(solution)?;
let positions = self.decode_positions(solution)?;
let space_group = self.determine_space_group(&lattice, &positions)?;
Ok(CrystalStructure {
composition: self.composition.clone(),
lattice,
positions,
space_group,
energy: None,
properties: HashMap::new(),
})
}
fn decode_lattice(&self, solution: &HashMap<String, bool>) -> Result<Lattice, String> {
let mut params = HashMap::new();
for param in ["a", "b", "c", "alpha", "beta", "gamma"] {
for i in 0..10 {
let var_name = format!("lattice_{param}_{i}");
if *solution.get(&var_name).unwrap_or(&false) {
let value = match param {
"a" | "b" | "c" => (i as f64).mul_add(0.5, 3.0), "alpha" | "beta" | "gamma" => (i as f64).mul_add(6.0, 60.0), _ => 0.0,
};
params.insert(param.to_string(), value);
break;
}
}
}
Ok(Lattice {
a: params.get("a").copied().unwrap_or(5.0),
b: params.get("b").copied().unwrap_or(5.0),
c: params.get("c").copied().unwrap_or(5.0),
alpha: params.get("alpha").copied().unwrap_or(90.0),
beta: params.get("beta").copied().unwrap_or(90.0),
gamma: params.get("gamma").copied().unwrap_or(90.0),
})
}
fn decode_positions(
&self,
solution: &HashMap<String, bool>,
) -> Result<Vec<AtomicPosition>, String> {
let mut positions = Vec::new();
let elements: Vec<_> = self.composition.elements.keys().cloned().collect();
for atom in 0..self.composition.total_atoms {
let mut coords = [0.0, 0.0, 0.0];
for (i, coord) in ["x", "y", "z"].iter().enumerate() {
for pos in 0..20 {
let var_name = format!("pos_{atom}_{coord}_{pos}");
if *solution.get(&var_name).unwrap_or(&false) {
coords[i] = pos as f64 / 19.0; break;
}
}
}
positions.push(AtomicPosition {
element: elements[atom % elements.len()].clone(),
x: coords[0],
y: coords[1],
z: coords[2],
occupancy: 1.0,
});
}
Ok(positions)
}
fn determine_space_group(
&self,
lattice: &Lattice,
_positions: &[AtomicPosition],
) -> Result<SpaceGroup, String> {
let lattice_type = lattice.determine_type();
Ok(SpaceGroup {
number: 1, symbol: "P1".to_string(),
lattice_type,
point_group: "1".to_string(),
})
}
}
impl Default for StructureConstraints {
fn default() -> Self {
Self {
lattice: LatticeConstraints {
min_lengths: None,
max_lengths: None,
angle_ranges: None,
volume_range: None,
lattice_type: None,
},
symmetry: SymmetryConstraints {
space_groups: None,
point_groups: None,
min_symmetry: None,
wyckoff_positions: None,
},
distances: DistanceConstraints {
min_distances: HashMap::new(),
max_distances: HashMap::new(),
bond_lengths: Vec::new(),
},
coordination: CoordinationConstraints {
coordination_numbers: HashMap::new(),
geometries: HashMap::new(),
allowed_ligands: HashMap::new(),
},
pressure: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Lattice {
pub a: f64,
pub b: f64,
pub c: f64,
pub alpha: f64,
pub beta: f64,
pub gamma: f64,
}
impl Lattice {
pub fn volume(&self) -> f64 {
let alpha_rad = self.alpha.to_radians();
let beta_rad = self.beta.to_radians();
let gamma_rad = self.gamma.to_radians();
self.a
* self.b
* self.c
* (2.0 * alpha_rad.cos() * beta_rad.cos())
.mul_add(
gamma_rad.cos(),
gamma_rad.cos().mul_add(
-gamma_rad.cos(),
beta_rad.cos().mul_add(
-beta_rad.cos(),
alpha_rad.cos().mul_add(-alpha_rad.cos(), 1.0),
),
),
)
.sqrt()
}
pub fn determine_type(&self) -> LatticeType {
let tol = 0.01;
if (self.a - self.b).abs() < tol && (self.b - self.c).abs() < tol {
if (self.alpha - 90.0).abs() < tol
&& (self.beta - 90.0).abs() < tol
&& (self.gamma - 90.0).abs() < tol
{
LatticeType::Cubic
} else if (self.alpha - self.beta).abs() < tol && (self.beta - self.gamma).abs() < tol {
LatticeType::Rhombohedral
} else {
LatticeType::Triclinic
}
} else if (self.a - self.b).abs() < tol {
if (self.alpha - 90.0).abs() < tol
&& (self.beta - 90.0).abs() < tol
&& (self.gamma - 120.0).abs() < tol
{
LatticeType::Hexagonal
} else if (self.alpha - 90.0).abs() < tol
&& (self.beta - 90.0).abs() < tol
&& (self.gamma - 90.0).abs() < tol
{
LatticeType::Tetragonal
} else {
LatticeType::Monoclinic
}
} else if (self.alpha - 90.0).abs() < tol
&& (self.beta - 90.0).abs() < tol
&& (self.gamma - 90.0).abs() < tol
{
LatticeType::Orthorhombic
} else if (self.alpha - 90.0).abs() < tol && (self.gamma - 90.0).abs() < tol {
LatticeType::Monoclinic
} else {
LatticeType::Triclinic
}
}
pub fn transformation_matrix(&self) -> Array2<f64> {
let alpha_rad = self.alpha.to_radians();
let beta_rad = self.beta.to_radians();
let gamma_rad = self.gamma.to_radians();
let mut matrix = Array2::zeros((3, 3));
matrix[[0, 0]] = self.a;
matrix[[0, 1]] = self.b * gamma_rad.cos();
matrix[[0, 2]] = self.c * beta_rad.cos();
matrix[[1, 0]] = 0.0;
matrix[[1, 1]] = self.b * gamma_rad.sin();
matrix[[1, 2]] =
self.c * beta_rad.cos().mul_add(-gamma_rad.cos(), alpha_rad.cos()) / gamma_rad.sin();
matrix[[2, 0]] = 0.0;
matrix[[2, 1]] = 0.0;
matrix[[2, 2]] = self.c
* (2.0 * alpha_rad.cos() * beta_rad.cos())
.mul_add(
gamma_rad.cos(),
gamma_rad.cos().mul_add(
-gamma_rad.cos(),
beta_rad.cos().mul_add(
-beta_rad.cos(),
alpha_rad.cos().mul_add(-alpha_rad.cos(), 1.0),
),
),
)
.sqrt()
/ gamma_rad.sin();
matrix
}
}
#[derive(Debug, Clone)]
pub struct AtomicPosition {
pub element: String,
pub x: f64,
pub y: f64,
pub z: f64,
pub occupancy: f64,
}
#[derive(Debug, Clone)]
pub struct SpaceGroup {
pub number: u32,
pub symbol: String,
pub lattice_type: LatticeType,
pub point_group: String,
}
#[derive(Debug, Clone)]
pub struct CrystalStructure {
pub composition: ChemicalComposition,
pub lattice: Lattice,
pub positions: Vec<AtomicPosition>,
pub space_group: SpaceGroup,
pub energy: Option<f64>,
pub properties: HashMap<String, f64>,
}
impl CrystalStructure {
pub fn density(&self) -> f64 {
let volume = self.lattice.volume();
let mass = self.calculate_mass();
mass / volume * 1.66054
}
fn calculate_mass(&self) -> f64 {
let masses: HashMap<&str, f64> = [
("H", 1.008),
("C", 12.011),
("N", 14.007),
("O", 15.999),
("Na", 22.990),
("Mg", 24.305),
("Al", 26.982),
("Si", 28.085),
("Fe", 55.845),
]
.iter()
.copied()
.collect();
self.composition
.elements
.iter()
.map(|(elem, count)| masses.get(elem.as_str()).unwrap_or(&1.0) * *count as f64)
.sum()
}
pub fn supercell(&self, nx: usize, ny: usize, nz: usize) -> Self {
let mut new_positions = Vec::new();
for i in 0..nx {
for j in 0..ny {
for k in 0..nz {
for pos in &self.positions {
new_positions.push(AtomicPosition {
element: pos.element.clone(),
x: (pos.x + i as f64) / nx as f64,
y: (pos.y + j as f64) / ny as f64,
z: (pos.z + k as f64) / nz as f64,
occupancy: pos.occupancy,
});
}
}
}
}
let mut new_composition = self.composition.clone();
for count in new_composition.elements.values_mut() {
*count *= nx * ny * nz;
}
new_composition.total_atoms *= nx * ny * nz;
Self {
composition: new_composition,
lattice: Lattice {
a: self.lattice.a * nx as f64,
b: self.lattice.b * ny as f64,
c: self.lattice.c * nz as f64,
..self.lattice.clone()
},
positions: new_positions,
space_group: self.space_group.clone(),
energy: None,
properties: HashMap::new(),
}
}
}
pub struct PhaseTransitionAnalyzer {
structures: Vec<CrystalStructure>,
method: TransitionMethod,
order_parameters: Vec<OrderParameter>,
}
#[derive(Debug, Clone)]
pub enum TransitionMethod {
NEB { images: usize, spring_constant: f64 },
Metadynamics {
bias_factor: f64,
gaussian_width: f64,
},
TPS { shooting_moves: usize },
ML { model: String },
}
#[derive(Debug, Clone)]
pub struct OrderParameter {
pub name: String,
pub definition: OrderParameterDef,
pub range: (f64, f64),
}
#[derive(Debug, Clone)]
pub enum OrderParameterDef {
Structural { description: String },
Electronic { property: String },
Magnetic { moment_type: String },
Custom,
}
pub struct DefectModeler {
host: CrystalStructure,
defect_types: Vec<DefectType>,
interactions: DefectInteractions,
}
#[derive(Debug, Clone)]
pub enum DefectType {
Vacancy { site: usize },
Interstitial { element: String, position: Vec3D },
Substitution { site: usize, new_element: String },
Frenkel {
vacancy_site: usize,
interstitial_pos: Vec3D,
},
Schottky { sites: Vec<usize> },
GrainBoundary { plane: (i32, i32, i32), angle: f64 },
}
#[derive(Debug, Clone)]
pub struct DefectInteractions {
pub cutoff: f64,
pub model: InteractionModel,
pub clustering: bool,
}
#[derive(Debug, Clone)]
pub enum InteractionModel {
Coulombic,
Elastic { elastic_constants: Array2<f64> },
Combined,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crystal_structure_predictor() {
let composition = ChemicalComposition {
elements: {
let mut elements = HashMap::new();
elements.insert("Na".to_string(), 1);
elements.insert("Cl".to_string(), 1);
elements
},
total_atoms: 2,
stoichiometry: Some(vec![1.0, 1.0]),
oxidation_states: Some({
let mut states = HashMap::new();
states.insert("Na".to_string(), 1);
states.insert("Cl".to_string(), -1);
states
}),
};
let energy_model = EnergyModel::Empirical {
potential: EmpiricalPotential::LennardJones,
parameters: {
let mut params = HashMap::new();
params.insert("epsilon".to_string(), 1.0);
params.insert("sigma".to_string(), 3.4);
params
},
};
let predictor = CrystalStructurePredictor::new(composition, energy_model);
let mut result = predictor.build_qubo();
assert!(result.is_ok());
}
#[test]
fn test_lattice() {
let lattice = Lattice {
a: 5.0,
b: 5.0,
c: 5.0,
alpha: 90.0,
beta: 90.0,
gamma: 90.0,
};
assert_eq!(lattice.determine_type() as u8, LatticeType::Cubic as u8);
assert!((lattice.volume() - 125.0).abs() < 0.01);
}
#[test]
fn test_supercell() {
let structure = CrystalStructure {
composition: ChemicalComposition {
elements: {
let mut elements = HashMap::new();
elements.insert("Si".to_string(), 1);
elements
},
total_atoms: 1,
stoichiometry: None,
oxidation_states: None,
},
lattice: Lattice {
a: 5.0,
b: 5.0,
c: 5.0,
alpha: 90.0,
beta: 90.0,
gamma: 90.0,
},
positions: vec![AtomicPosition {
element: "Si".to_string(),
x: 0.0,
y: 0.0,
z: 0.0,
occupancy: 1.0,
}],
space_group: SpaceGroup {
number: 225,
symbol: "Fm-3m".to_string(),
lattice_type: LatticeType::Cubic,
point_group: "m-3m".to_string(),
},
energy: None,
properties: HashMap::new(),
};
let supercell = structure.supercell(2, 2, 2);
assert_eq!(supercell.positions.len(), 8);
assert_eq!(supercell.composition.total_atoms, 8);
}
}