use crate::autograd::Tensor;
use crate::nn::Linear;
use crate::nn::Module;
use std::collections::HashMap;
pub trait TransferEncoder: Module {
fn freeze_base(&mut self);
fn unfreeze_base(&mut self);
fn is_frozen(&self) -> bool;
fn get_features(&self, x: &Tensor) -> Tensor;
}
#[derive(Debug)]
pub struct TransferableEncoder<M: Module> {
encoder: M,
frozen: bool,
}
impl<M: Module> TransferableEncoder<M> {
pub fn new(encoder: M) -> Self {
Self {
encoder,
frozen: false,
}
}
pub fn encoder(&self) -> &M {
&self.encoder
}
pub fn encoder_mut(&mut self) -> &mut M {
&mut self.encoder
}
}
impl<M: Module> Module for TransferableEncoder<M> {
fn forward(&self, input: &Tensor) -> Tensor {
self.encoder.forward(input)
}
fn parameters(&self) -> Vec<&Tensor> {
if self.frozen {
vec![] } else {
self.encoder.parameters()
}
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
if self.frozen {
vec![] } else {
self.encoder.parameters_mut()
}
}
fn train(&mut self) {
self.encoder.train();
}
fn eval(&mut self) {
self.encoder.eval();
}
fn training(&self) -> bool {
self.encoder.training()
}
}
impl<M: Module> TransferEncoder for TransferableEncoder<M> {
fn freeze_base(&mut self) {
self.frozen = true;
}
fn unfreeze_base(&mut self) {
self.frozen = false;
}
fn is_frozen(&self) -> bool {
self.frozen
}
fn get_features(&self, x: &Tensor) -> Tensor {
self.encoder.forward(x)
}
}
#[derive(Debug)]
pub struct MultiTaskHead<E: TransferEncoder> {
shared_encoder: E,
task_heads: HashMap<String, Linear>,
feature_dim: usize,
}
impl<E: TransferEncoder> MultiTaskHead<E> {
pub fn new(shared_encoder: E, feature_dim: usize) -> Self {
Self {
shared_encoder,
task_heads: HashMap::new(),
feature_dim,
}
}
pub fn add_task(&mut self, task_name: &str, output_dim: usize) {
let head = Linear::new(self.feature_dim, output_dim);
self.task_heads.insert(task_name.to_string(), head);
}
pub fn remove_task(&mut self, task_name: &str) -> Option<Linear> {
self.task_heads.remove(task_name)
}
pub fn task_names(&self) -> Vec<&String> {
self.task_heads.keys().collect()
}
pub fn forward_shared(&self, input: &Tensor) -> Tensor {
self.shared_encoder.get_features(input)
}
pub fn forward_task(&self, task_name: &str, features: &Tensor) -> Tensor {
let head = self
.task_heads
.get(task_name)
.unwrap_or_else(|| panic!("Unknown task: {task_name}"));
head.forward(features)
}
pub fn forward_full(&self, task_name: &str, input: &Tensor) -> Tensor {
let features = self.forward_shared(input);
self.forward_task(task_name, &features)
}
pub fn encoder(&self) -> &E {
&self.shared_encoder
}
pub fn encoder_mut(&mut self) -> &mut E {
&mut self.shared_encoder
}
pub fn freeze_encoder(&mut self) {
self.shared_encoder.freeze_base();
}
pub fn unfreeze_encoder(&mut self) {
self.shared_encoder.unfreeze_base();
}
}
impl<E: TransferEncoder> Module for MultiTaskHead<E> {
fn forward(&self, _input: &Tensor) -> Tensor {
panic!("MultiTaskHead requires task name. Use forward_full(task_name, input) instead.");
}
fn parameters(&self) -> Vec<&Tensor> {
let mut params = self.shared_encoder.parameters();
for head in self.task_heads.values() {
params.extend(head.parameters());
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params = self.shared_encoder.parameters_mut();
for head in self.task_heads.values_mut() {
params.extend(head.parameters_mut());
}
params
}
}
#[derive(Debug)]
pub struct DomainAdapter<E: TransferEncoder> {
encoder: E,
discriminator: Linear,
reversal_scale: f32,
}
impl<E: TransferEncoder> DomainAdapter<E> {
pub fn new(encoder: E, feature_dim: usize, reversal_scale: f32) -> Self {
Self {
encoder,
discriminator: Linear::new(feature_dim, 1), reversal_scale,
}
}
pub fn encode(&self, input: &Tensor) -> Tensor {
self.encoder.get_features(input)
}
pub fn discriminate(&self, features: &Tensor) -> Tensor {
self.discriminator.forward(features)
}
pub fn reversal_scale(&self) -> f32 {
self.reversal_scale
}
pub fn set_reversal_scale(&mut self, scale: f32) {
self.reversal_scale = scale;
}
pub fn encoder(&self) -> &E {
&self.encoder
}
pub fn encoder_mut(&mut self) -> &mut E {
&mut self.encoder
}
}
mod lora;
pub use lora::*;
mod distillation;
pub use distillation::*;