Skip to main content

train_classifier

Function train_classifier 

Source
pub fn train_classifier(
    samples: &[TrainingSample],
    config: &Config,
    epochs: usize,
    lr: f64,
    device: &Device,
) -> Result<(NodeClassifier, f32)>
Expand description

Train a NodeClassifier on labelled samples for epochs full-batch steps.

Returns the trained classifier and the final BCE loss. Errors if no sample carried usable ground truth.