import os
import fastatomstruct as fs
from black import format_str, FileMode
from pyo3_stubgen import genentry
out = [
"from typing import List, Tuple, Union, Optional\n",
"import ase",
"import fastatomstruct",
"import numpy as np\n",
]
out.append(
"""class FilterTag:
\"\"\"Filter atoms based on tags.
On creation of a :code:`Filter`, you have to specify which atoms
should be regarded as "center" atoms and as "other" atoms, respectively.
Atoms that have a tag other than :code:`center` or :code:`other`
will be disregarded. The last argument (:code:`center_is_other`, a boolean)
specifies whether "center" atoms should also be regarded as "other" atoms.
Examples
--------
Suppose that we have a NaCl system and want to calculate the **partial
Na-Na, Na-Cl and Cl-Cl pair correlation functions**. This can be achieved
by first tagging all Cl atoms with tag 1:
>>> from ase.build import bulk
>>> a = 5.64
>>> nacl = bulk("NaCl", "rocksalt", a=a) * (5, 5, 5)
>>> nacl.rattle()
>>> tags = nacl.get_tags()
>>> tags[nacl.numbers == 17] = 1
>>> nacl.set_tags(tags)
For the partial Na-Cl correlation function, we can then use
:code:`Filter(0, 1, False)`:
>>> import fastatomstruct as fs
>>> r_na_cl, rdf_na_cl = fs.radial_distribution_function(
>>> nacl, 10, 200, fs.FilterTag(0, 1, False)
>>> )
Analogously, the Na-Na pair correlation function is
>>> import fastatomstruct as fs
>>> r_na_na, rdf_na_na = fs.radial_distribution_function(
>>> nacl, 10, 200, fs.FilterTag(0, 0, False)
>>> )
The :code:`center_is_other` argument will not matter in this case.
Now suppose you want to calculate the **partial three-body correlation**
around the Na atoms (including atoms of any kind around those atoms).
This can be achieved as follows:
>>> tbc = fs.tbc(nacl, 3, 10, 250, fs.Filter(0, 1, True)))
\"\"\"
...
"""
)
out.append(
"""class FilterElement:
\"\"\"Filter atoms based on elements.
On creation of a `FilterElement`, you have to specify which atoms
should be regarded as "center" atoms and as "other" atoms, respectively.
Atoms that have an element other than `center` or `other`
will be disregarded. The last argument (`center_is_other`, a boolean)
specifies whether "center" atoms should also be regarded as "other" atoms.
Examples
--------
Suppose that we have a NaCl system and want to calculate the **partial
Na-Na, Na-Cl and Cl-Cl pair correlation functions**. This can be achieved
as follows:
>>> from ase.build import bulk
>>> a = 5.64
>>> nacl = bulk("NaCl", "rocksalt", a=a) * (5, 5, 5)
>>> nacl.rattle()
>>> filter = fs.FilterElement(fs.Element.Na, fs.Element.Cl, False)
>>> r_na_cl, rdf_na_cl = fs.radial_distribution_function(
>>> nacl, 10, 200, filter
>>> )
>>> filter = fs.FilterElement(fs.Element.Na, fs.Element.Na, False)
>>> r_na_na, rdf_na_na = fs.radial_distribution_function(
>>> nacl, 10, 200, filter
>>> )
The `center_is_other` argument will not matter in this case.
Now suppose you want to calculate the **partial three-body correlation**
around the Na atoms (including atoms of any kind around those atoms).
This can be achieved as follows:
>>> tbc = fs.tbc(nacl, 3, 10, 250, fs.FilterElement(fs.Element.Na, fs.Element.Cl, True))
\"\"\"
...
"""
)
out.append(
"""from enum import Enum
class Element(Enum):
\"\"\"Enum representing chemical elements.
The :code:`Element` enum represents chemical elements. It can be used
to filter atoms based on their element. All chemical elements can be accessed
(e.g. :code:`Element.H`, :code:`Element.He`, etc.). There is also a special element
:code:`Element.Undefined` that can be used to filter atoms with an undefined element.
\"\"\"
...
"""
)
out.append(
"""class Atom:
\"\"\"Represents an atom with its position, element, velocity, and tag.
An Atom object contains all the information about a single atom, including its position,
element type, velocity (if available), and tag.
Arguments:
position (np.ndarray): Array of shape (3,) representing the position of the atom
element (Element): Chemical element of the atom
velocity (np.ndarray, optional): Array of shape (3,) representing the velocity of the atom
tag (int, optional): Tag for the atom, can be used for filtering or grouping atoms
Examples
--------
>>> import fastatomstruct as fs
>>> import numpy as np
>>> position = np.array([0.0, 0.0, 0.0])
>>> element = fs.Element.H
>>> velocity = np.array([1.0, 0.0, 0.0])
>>> tag = 1
>>> atom = fs.Atom(position, element, velocity, tag)
\"\"\"
...
"""
)
out.append(
"""class Atoms:
\"\"\"A collection of atoms with associated cell and properties.
The `Atoms` class represents a collection of atoms with associated cell dimensions
and optional properties like velocities and stress tensor. It provides methods for
accessing and manipulating atomic positions, velocities, elements, and cell parameters.
Examples
--------
Create a simple hydrogen molecule:
>>> import fastatomstruct as fs
>>> import numpy as np
>>> positions = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]])
>>> cell = np.eye(3) * 10.0
>>> elements = [fs.Element.H, fs.Element.H]
>>> atoms = fs.Atoms(positions, cell, elements)
Get and set positions:
>>> pos = atoms.get_positions()
>>> pos[0] = [1.0, 0.0, 0.0]
>>> atoms.set_positions(pos)
Set velocities:
>>> vel = np.zeros_like(positions)
>>> vel[0] = [0.1, 0.0, 0.0]
>>> atoms.set_velocities(vel)
Get elements and cell:
>>> elements = atoms.get_elements()
>>> cell = atoms.get_cell()
Calculate properties:
>>> from fastatomstruct import FilterNone
>>> filter_none = FilterNone()
>>> temperature = atoms.temperature(filter_none)
>>> volume = atoms.volume()
>>> density = atoms.density(filter_none)
\"\"\"
...
"""
)
out.append(
"""class Element:
\"\"\"Element class representing chemical elements with their properties.
This class provides methods to access the properties of elements such as their atomic number, mass, covalent radius, and van der Waals radius.
It also includes methods for converting elements to strings and retrieving their Jmol colors.
Examples
--------
.. plot::
:include-source: True
:context: reset
>>> import fastatomstruct as fs
>>> hydrogen = fs.Element.H
>>> print(str(hydrogen))
>>> print(hydrogen.get_covalent_radius())
>>> print(hydrogen.get_vdw_radius())
>>> print(hydrogen.get_jmol_color())
\"\"\""""
)
out.append(
"""class GenericParticle:
\"\"\"A generic particle implementation with custom name and mass.
This struct provides a flexible way to represent particles that are not
included in the predefined `Element` enum. It's useful for custom atom types,
virtual sites, or other simulation entities.
Examples
--------
>>> from fastatomstruct import Atoms, GenericParticle
>>> custom_particle = GenericParticle("Dummy", 12.0)
\"\"\""""
)
out.append(
"""class TimeAxis:
\"\"\"An enum representing a time axis for iterating over items.
The :code:`TimeAxis` enum represents a time axis for iterating over a slice of items.
It can be either :code:`Linear` or :code:`Logarithmic`, with the latter increasing the
time step exponentially with each iteration.
Variants
--------
- `Linear(float)` - A linear time axis with a constant time step.
- `Logarithmic(float)` - A logarithmic time axis with an initial time step.
Examples
--------
>>> from fastatomstruct import TimeAxis
>>> linear = TimeAxis.Linear(1.0)
>>> log = TimeAxis.Logarithmic(1.0)
\"\"\"
...
"""
)
out.append(
"""class CutoffMode:
\"\"\"Represents the cutoff mode for distance calculations.
This enum allows for different strategies to determine the cutoff distance
between atoms based on their elements or tags.
Attributes:
Fixed (float): Single default cutoff for all atom pairs.
Elements (dict): One cutoff per element, with a fallback.
Tags (dict): One cutoff per tag, with a fallback.
PairElements (dict): One cutoff per pair of elements, with a fallback.
PairTags (dict): One cutoff per pair of tags, with a fallback.
CovalentRadiiScaled (float): Covalent radii scaled by a factor.
VdWRadiiScaled (float): Van der Waals radii scaled by a factor.
Examples
--------
.. plot::
>>> from fastatomstruct import Element
>>> from fastatomstruct import CutoffMode
>>> element_map = {Element.H: 1.0, Element.O: 1.5}
>>> cutoff_mode = CutoffMode.Elements(map=element_map, fallback=2.0)
>>> cutoff_mode = CutoffMode.Fixed(2.0)
>>> cutoff_mode = CutoffMode.PairElements(map={(Element.H, Element.O): 1.0}, fallback=2.0)
\"\"\"
...
""")
for f in [
fs.convert_ase,
fs.get_num_threads,
fs.set_num_threads,
fs.q_l_global,
fs.q_l,
fs.q_l_dot,
fs.q_tetrahedral,
fs.bond_length_ratio,
fs.bond_length_ratio_list,
fs.altbc,
fs.tbc,
fs.angular_three_body_correlation,
fs.temporal_altbc,
fs.temporal_tbc,
fs.temporal_angular_three_body_correlation,
fs.distances,
fs.all_distances,
fs.r_theta_phi,
fs.distance_vectors,
fs.all_distance_vectors,
fs.neighbour_lists,
fs.find_bonds,
fs.find_bonds_with_vec,
fs.coordination_numbers,
fs.bond_angle_distribution,
fs.radial_distribution_function,
fs.mean_squared_displacement_single,
fs.squared_displacement_single,
fs.non_gaussian_alpha2,
fs.non_gaussian_alpha2_single,
fs.incoherent_intermediate_scattering,
fs.coherent_intermediate_scattering,
fs.overlap_q,
fs.overlap_q_self,
fs.overlap_q_self_atomic,
fs.overlap_q_distinct,
fs.overlap_q_single,
fs.overlap_q_single_self,
fs.overlap_q_single_distinct,
fs.fourpoint_susceptibility,
fs.fourpoint_susceptibility_self,
fs.vacf,
fs.viscosity,
fs.viscosity_average,
]:
pyi = genentry(f)
pyi = pyi.splitlines()
pyi = "\n".join(["\ndef" + pyi[2][3:] + ":", ' """' + pyi[4], *pyi[5:]])
out.append(pyi)
out.append(" ...\n")
pyi = "\n".join(out)
with open("python/fastatomstruct/fastatomstruct.pyi", "w") as f:
f.write(pyi)
os.system("pyright --createstub fastatomstruct")
os.rename("typings/fastatomstruct/__init__.pyi", "python/fastatomstruct/__init__.pyi")
os.remove("typings/fastatomstruct/fastatomstruct.pyi")
os.removedirs("typings/fastatomstruct")
os.system("black python")