use std::sync::Mutex;
use std::{marker::PhantomData, path::PathBuf};
use log::debug;
use measure_time_macro::log_time;
use tch::{
nn::{self, OptimizerConfig},
vision::dataset::Dataset,
Device, Tensor,
};
use crate::{
clustering::{self},
errors::{DliError, DliResult},
model::{ModelDevice, ModelLayer, RetrainStrategy, TchBackend, TrainParams},
sampling,
structs::{FloatElement, LabelMethod},
types::ArraySlice,
ModelConfig,
};
fn to_tch_device(device: ModelDevice) -> Device {
match device {
ModelDevice::Cpu => Device::Cpu,
ModelDevice::Gpu(gpu_no) => Device::Cuda(gpu_no),
}
}
impl<F: FloatElement> crate::model::BaseModelBuilder<TchBackend, F> {
pub fn build(&self) -> DliResult<Model<F>> {
if self.quantize {
return DliResult::Err(DliError::ModelCreation(
"Quantization is not supported for TchBackend",
));
}
let device_mdl = self
.device
.as_ref()
.ok_or(DliError::MissingAttribute("device"))?;
let device = to_tch_device(*device_mdl);
let label_method = self
.label_method
.ok_or(DliError::MissingAttribute("label_method"))?;
let input_nodes = self
.input_nodes
.ok_or(DliError::MissingAttribute("input_nodes"))?;
let labels = self.labels.ok_or(DliError::MissingAttribute("labels"))?;
assert!(labels > 0, "labels must be greater than 0");
let train_params = self.train_params.unwrap_or_default();
let (vs, model) = Self::build_varstore_and_model(
device,
input_nodes,
labels,
&self.layers,
&self.weights_path,
)?;
let model = Model {
model: Mutex::new(Box::new(model)),
vs,
labels,
device,
train_params,
input_shape: input_nodes as usize,
label_method,
layers: self.layers.clone(),
seed: self.seed,
_marker: PhantomData,
};
Ok(model)
}
fn build_varstore_and_model(
device: Device,
input_nodes: i64,
labels: usize,
layers: &[ModelLayer],
weights_path: &Option<PathBuf>,
) -> DliResult<(nn::VarStore, nn::Sequential)> {
let mut vs = nn::VarStore::new(device);
let vs_root = vs.root();
let mut i = 0;
let (mut model, output_nodes) =
layers
.iter()
.fold((nn::seq(), input_nodes), |(model, input_nodes), layer| {
let (model, output_nodes) = match layer {
ModelLayer::Linear(nodes) => {
let nodes = *nodes as i64;
let r = (
model.add(nn::linear(
&vs_root / format!("{i}", i = 2 * i),
input_nodes,
nodes,
Default::default(),
)),
nodes,
);
i += 1;
r
}
ModelLayer::ReLU => (model.add_fn(|xs| xs.relu()), input_nodes),
};
(model, output_nodes)
});
model = model.add(nn::linear(
&vs_root / format!("{i}", i = 2 * i),
output_nodes,
labels as i64,
Default::default(),
));
if let Some(path) = weights_path {
vs.load(path)?;
}
Ok((vs, model))
}
}
#[derive(Debug)]
pub struct Model<F: FloatElement> {
model: Mutex<Box<dyn nn::Module>>,
vs: nn::VarStore,
labels: usize,
device: Device,
pub input_shape: usize,
pub train_params: TrainParams,
label_method: LabelMethod,
layers: Vec<ModelLayer>,
seed: u64,
_marker: PhantomData<F>,
}
impl<F: FloatElement> crate::model::ModelInterface<F> for Model<F> {
type TensorType = Tensor;
fn predict(&self, xs: &Tensor) -> DliResult<Vec<(usize, f32)>> {
let xs = match self.device {
Device::Cpu => xs.shallow_clone(),
_ => xs.to_device(self.device),
};
let predictions = tensor2vec(
&self
.model
.lock()
.unwrap()
.forward(&xs)
.softmax(-1, tch::Kind::Float),
);
let predictions = predictions.into_iter().enumerate().collect::<Vec<_>>();
assert!(predictions.len() <= self.labels);
Ok(predictions)
}
#[log_time]
fn predict_many(&self, xs: &[F]) -> DliResult<Vec<Vec<(usize, f32)>>> {
let xs_tensor = Tensor::from_slice(xs);
let xs_tensor = xs_tensor.view((
(xs.len() / self.input_shape) as i64,
self.input_shape as i64,
));
let res = self
.model
.lock()
.unwrap()
.forward(&xs_tensor)
.softmax(-1, tch::Kind::Float);
let predictions = tensor2vec(&res);
Ok(predictions
.chunks_exact(self.labels)
.map(|chunk| chunk.iter().enumerate().map(|(i, v)| (i, *v)).collect())
.collect())
}
fn train(&mut self, xs: &ArraySlice) -> DliResult<()> {
let sample_size = sampling::select_sample_size(
self.labels,
xs.len() / self.input_shape,
self.train_params.threshold_samples,
);
debug!(sample_size = sample_size, total = xs.len() / self.input_shape ; "model:train");
let xs = sampling::sample(xs, sample_size, self.input_shape, self.seed);
let ys = clustering::compute_labels(
&xs,
&self.label_method,
self.labels,
self.input_shape,
self.train_params.max_iters,
);
assert_eq!(ys.len(), sample_size);
let dataset = self.dataset(&xs, &ys);
let mut opt = nn::Adam::default().build(&self.vs, 1e-3).unwrap();
for _ in 0..self.train_params.epochs {
for (xs, ys) in dataset
.train_iter(self.train_params.batch_size as i64)
.shuffle()
{
let loss = self
.model
.lock()
.unwrap()
.forward(&xs)
.cross_entropy_for_logits(&ys);
opt.backward_step(&loss);
}
}
Ok(())
}
fn retrain(&mut self, _xs: &ArraySlice) -> DliResult<()> {
match self.train_params.retrain_strategy {
RetrainStrategy::NoRetrain => {
debug!("No retraining performed as per strategy.");
}
RetrainStrategy::FromScratch => {
self.reset_model()?;
self.train(_xs)?;
}
};
Ok(())
}
fn dump(&self, weights_filename: PathBuf) -> DliResult<ModelConfig> {
self.vs.save(&weights_filename)?;
Ok(ModelConfig {
train_params: self.train_params,
weights_path: Some(weights_filename),
layers: self.layers.clone(),
quantize: false,
seed: 42,
})
}
fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>()
+ (self
.vs
.variables()
.into_values()
.map(|tensor| {
let numel = tensor.size().iter().product::<i64>() as u64;
let element_size = match tensor.kind() {
tch::Kind::Float => 4,
tch::Kind::Double => 8,
tch::Kind::Int64 => 8,
tch::Kind::Int => 4,
_ => 4, };
numel * element_size
})
.sum::<u64>() as usize)
}
fn vec2tensor(&self, xs: &[f32]) -> DliResult<Tensor> {
Ok(tch::Tensor::from_slice(xs))
}
}
impl<F: FloatElement> Model<F> {
pub fn reset_model(&mut self) -> DliResult<()> {
let (new_vs, new_model) =
crate::model::BaseModelBuilder::<TchBackend, F>::build_varstore_and_model(
self.device,
self.input_shape as i64,
self.labels,
&self.layers,
&None,
)?;
self.vs = new_vs;
*self.model.lock().unwrap() = Box::new(new_model);
Ok(())
}
pub fn dataset(&self, xs: &[f32], ys: &[i64]) -> Dataset {
let total_queries = ys.len();
assert!(xs.len().is_multiple_of(self.input_shape));
assert!(xs.len() / self.input_shape == ys.len());
let xs = Tensor::from_slice(xs);
let xs = xs.view((
xs.size()[0] / self.input_shape as i64,
self.input_shape as i64,
));
let xs = match self.device {
Device::Cpu => xs,
_ => xs.to_device(self.device),
};
assert!(
xs.size()[0] as usize == total_queries,
"{} != {total_queries}, {:?}",
xs.size()[0],
xs.size()
);
let ys = Tensor::from_slice(ys).to_kind(tch::Kind::Int64);
let ys = match self.device {
Device::Cpu => ys,
_ => ys.to_device(self.device),
};
assert!(xs.size()[0] == ys.size()[0]);
assert!(xs.size()[0] == ys.size()[0]);
assert!(ys.kind() == tch::Kind::Int64);
let options = (xs.kind(), xs.device());
Dataset {
train_images: xs,
train_labels: ys,
test_images: Tensor::empty(0, options),
test_labels: Tensor::empty(0, options),
labels: self.labels as i64,
}
}
}
fn tensor2vec(tensor: &tch::Tensor) -> Vec<f32> {
tensor.try_into().unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_device_to_tch_device() {
let cpu_device = ModelDevice::Cpu;
assert!(matches!(to_tch_device(cpu_device), tch::Device::Cpu));
let gpu_device = ModelDevice::Gpu(0);
assert!(matches!(to_tch_device(gpu_device), tch::Device::Cuda(0)));
}
}