Skip to main content

Module functions

Module functions 

Source
Expand description

Auto-generated module

🤖 Generated with SplitRS

Functions§

batch_atomic_energies
Predict atomic energies for a batch of (descriptor, atomic_number) pairs.
batch_forward
Run a forward pass for every sample in batch.
clip_gradients_by_norm
Clip a collection of gradient slices in-place so that their combined L2 norm does not exceed max_norm.
compute_forces_batch
Run the network on every position and return a force vector per atom.
compute_gradient_norm
Compute the L2 norm (Frobenius norm) of a concatenated gradient vector.
cross_entropy_loss
Compute cross-entropy loss between predictions and one-hot target.
huber_loss
Huber loss between prediction and target.
huber_loss_grad
Gradient of Huber loss w.r.t. predictions.
l2_regularisation
Compute L2 regularisation loss contribution.
l2_regularisation_grad
Compute L2 regularisation gradient contribution (adds lambda * w to each element).
load_weights_from_buffer
Load weights from a flat f32 buffer into a network, partitioning by layer sizes.
mean_huber_loss
Mean Huber loss over a batch.
mse_loss
Mean squared error loss between predictions and targets.
neural_potential_energy
Compute a scalar potential energy by summing the first network output over all atoms.
save_weights_to_buffer
Serialize network weights to a flat f32 buffer.
scaled_dot_product_attention
Compute scaled dot-product attention.
softmax
Compute softmax of a vector.