entrenar/hf_pipeline/loader/teacher.rs
1//! Teacher model trait for distillation
2
3use crate::hf_pipeline::error::Result;
4use ndarray::Array2;
5
6use super::MemoryEstimate;
7
8/// Teacher model trait for distillation
9///
10/// Provides interface for frozen teacher models used in knowledge distillation.
11pub trait TeacherModel: Send + Sync {
12 /// Run forward pass, returning output logits
13 ///
14 /// # Arguments
15 ///
16 /// * `input` - Input tensor [batch_size, seq_len, hidden_size]
17 ///
18 /// # Returns
19 ///
20 /// Output logits [batch_size, seq_len, vocab_size]
21 fn forward(&self, input: &Array2<f32>) -> Result<Array2<f32>>;
22
23 /// Get intermediate hidden states for progressive distillation
24 ///
25 /// # Arguments
26 ///
27 /// * `input` - Input tensor
28 ///
29 /// # Returns
30 ///
31 /// Hidden states for each layer
32 fn hidden_states(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>>;
33
34 /// Get attention weights for attention transfer
35 ///
36 /// # Arguments
37 ///
38 /// * `input` - Input tensor
39 ///
40 /// # Returns
41 ///
42 /// Attention weights [batch, heads, seq, seq] for each layer
43 fn attention_weights(&self, input: &Array2<f32>) -> Result<Vec<Array2<f32>>>;
44
45 /// Estimate memory requirements
46 fn estimate_memory(&self, batch_size: usize, seq_len: usize) -> MemoryEstimate;
47
48 /// Get number of parameters
49 fn param_count(&self) -> u64;
50
51 /// Get number of layers
52 fn num_layers(&self) -> usize;
53
54 /// Get hidden size
55 fn hidden_size(&self) -> usize;
56}