Skip to main content

entrenar/hf_pipeline/loader/
teacher.rs

1//! Teacher model trait for distillation
2
3use crate::hf_pipeline::error::Result;
4use ndarray::Array2;
5
6use super::MemoryEstimate;
7
8/// Teacher model trait for distillation
9///
10/// Provides interface for frozen teacher models used in knowledge distillation.
11pub trait TeacherModel: Send + Sync {
12    /// Run forward pass, returning output logits
13    ///
14    /// # Arguments
15    ///
16    /// * `input` - Input tensor [batch_size, seq_len, hidden_size]
17    ///
18    /// # Returns
19    ///
20    /// Output logits [batch_size, seq_len, vocab_size]
21    fn forward(&self, input: &Array2<f32>) -> Result<Array2<f32>>;
22
23    /// Get intermediate hidden states for progressive distillation
24    ///
25    /// # Arguments
26    ///
27    /// * `input` - Input tensor
28    ///
29    /// # Returns
30    ///
31    /// Hidden states for each layer
32    fn hidden_states(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>>;
33
34    /// Get attention weights for attention transfer
35    ///
36    /// # Arguments
37    ///
38    /// * `input` - Input tensor
39    ///
40    /// # Returns
41    ///
42    /// Attention weights [batch, heads, seq, seq] for each layer
43    fn attention_weights(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>>;
44
45    /// Estimate memory requirements
46    fn estimate_memory(&self, batch_size: usize, seq_len: usize) -> MemoryEstimate;
47
48    /// Get number of parameters
49    fn param_count(&self) -> u64;
50
51    /// Get number of layers
52    fn num_layers(&self) -> usize;
53
54    /// Get hidden size
55    fn hidden_size(&self) -> usize;
56}