Expand description
Auto-generated module
🤖 Generated with SplitRS
Functions§
- batch_
atomic_ energies - Predict atomic energies for a batch of (descriptor, atomic_number) pairs.
- batch_
forward - Run a forward pass for every sample in
batch. - clip_
gradients_ by_ norm - Clip a collection of gradient slices in-place so that their combined L2
norm does not exceed
max_norm. - compute_
forces_ batch - Run the network on every position and return a force vector per atom.
- compute_
gradient_ norm - Compute the L2 norm (Frobenius norm) of a concatenated gradient vector.
- cross_
entropy_ loss - Compute cross-entropy loss between predictions and one-hot target.
- huber_
loss - Huber loss between prediction and target.
- huber_
loss_ grad - Gradient of Huber loss w.r.t. predictions.
- l2_
regularisation - Compute L2 regularisation loss contribution.
- l2_
regularisation_ grad - Compute L2 regularisation gradient contribution (adds
lambda * wto each element). - load_
weights_ from_ buffer - Load weights from a flat f32 buffer into a network, partitioning by layer sizes.
- mean_
huber_ loss - Mean Huber loss over a batch.
- mse_
loss - Mean squared error loss between predictions and targets.
- neural_
potential_ energy - Compute a scalar potential energy by summing the first network output over all atoms.
- save_
weights_ to_ buffer - Serialize network weights to a flat f32 buffer.
- scaled_
dot_ product_ attention - Compute scaled dot-product attention.
- softmax
- Compute softmax of a vector.