pub struct ResponseDistiller {
pub soft_loss: DistilLoss,
pub hard_label_weight: f32,
pub soft_label_weight: f32,
}Expand description
Response-based knowledge distillation.
Combines a hard-label cross-entropy loss with a soft-label distillation loss.
Fields§
§soft_loss: DistilLossDistillation loss applied to soft targets.
hard_label_weight: f32Weight for the hard-label (cross-entropy) term.
soft_label_weight: f32Weight for the soft-label (distillation) term.
Implementations§
Source§impl ResponseDistiller
impl ResponseDistiller
Sourcepub fn new(
soft_loss: DistilLoss,
hard_label_weight: f32,
soft_label_weight: f32,
) -> Self
pub fn new( soft_loss: DistilLoss, hard_label_weight: f32, soft_label_weight: f32, ) -> Self
Create a response distiller.
§Parameters
soft_loss— distillation loss (e.g., KL divergence with temperature).hard_label_weight— weight α for cross-entropy term.soft_label_weight— weight β for distillation term.
Sourcepub fn pure_kl(temperature: f32) -> Self
pub fn pure_kl(temperature: f32) -> Self
Pure distillation (no hard labels): KL divergence at temperature tau.
Sourcepub fn compute_loss(
&self,
student_logits: &[f32],
teacher_logits: &[f32],
hard_label: usize,
) -> QuantResult<f32>
pub fn compute_loss( &self, student_logits: &[f32], teacher_logits: &[f32], hard_label: usize, ) -> QuantResult<f32>
Compute the combined distillation loss.
§Parameters
student_logits— unnormalised student output (length = n_classes).teacher_logits— unnormalised teacher output (same length).hard_label— integer ground-truth class index.
§Errors
QuantError::EmptyInput— either logit slice is empty.QuantError::TeacherStudentMismatch— logit slices differ in length.QuantError::DimensionMismatch—hard_label≥n_classes.
Sourcepub fn compute_batch_loss(
&self,
student_batch: &[f32],
teacher_batch: &[f32],
hard_labels: &[usize],
n_classes: usize,
) -> QuantResult<f32>
pub fn compute_batch_loss( &self, student_batch: &[f32], teacher_batch: &[f32], hard_labels: &[usize], n_classes: usize, ) -> QuantResult<f32>
Compute distillation loss over a batch of examples.
Returns the average loss over all examples in the batch.
§Parameters
student_batch—[batch_size, n_classes]row-major student logits.teacher_batch—[batch_size, n_classes]row-major teacher logits.hard_labels—[batch_size]integer class labels.n_classes— number of output classes.
§Errors
Propagates dimension and empty-input errors.
Trait Implementations§
Source§impl Clone for ResponseDistiller
impl Clone for ResponseDistiller
Source§fn clone(&self) -> ResponseDistiller
fn clone(&self) -> ResponseDistiller
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreAuto Trait Implementations§
impl Freeze for ResponseDistiller
impl RefUnwindSafe for ResponseDistiller
impl Send for ResponseDistiller
impl Sync for ResponseDistiller
impl Unpin for ResponseDistiller
impl UnsafeUnpin for ResponseDistiller
impl UnwindSafe for ResponseDistiller
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more