use burn::backend::{Autodiff, Wgpu};
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::{DataLoaderBuilder, Dataset};
use burn::nn::{
conv::{Conv2d, Conv2dConfig},
loss::CrossEntropyLoss,
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
};
use burn::optim::AdamConfig;
use burn::prelude::*;
use burn::record::BinBytesRecorder;
use burn::record::BinFileRecorder;
use burn::record::CompactRecorder;
use burn::record::FullPrecisionSettings;
use burn::record::Recorder;
use burn::tensor::backend::AutodiffBackend;
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
};
use rand::seq::SliceRandom;
use rand::thread_rng;
pub const DEFAULT_MAX_DATAPOINTS: usize = 4000;
pub const DEFAULT_EPOCHS: usize = 6;
const VALIDATION_SET_PERCENTAGE: usize = 20;
const SAMPLE_TIMESPAN: usize = 250; pub const TEST_DATASET: ([(usize, u8); 31100], [[u8; 14]; 59925]) =
include!("data/test_dataset.rs");
pub const TEST_MODEL: &[u8] = include_bytes!("data/test_model.bin");
#[derive(Clone, Default, Debug)]
pub struct CalibController {
pub dataset: PsyLinkDataset,
}
pub type DefaultBackend = Autodiff<Wgpu>;
pub type DefaultModel = Model<Autodiff<Wgpu>>;
impl CalibController {
pub fn add_packet(&mut self, sample: Vec<u8>) {
self.dataset.all_packets.push(sample);
}
pub fn add_datapoint(&mut self, datapoint: Datapoint) {
self.dataset.datapoints.push(datapoint);
}
pub fn count_datapoints(&self) -> usize {
self.dataset.datapoints.len()
}
pub fn has_datapoints(&self) -> bool {
return !self.dataset.datapoints.is_empty();
}
pub fn reset(&mut self) {
self.dataset.datapoints.clear();
self.dataset.all_packets.clear();
}
pub fn get_current_index(&self) -> usize {
return self.dataset.all_packets.len();
}
fn create_artifact_dir(artifact_dir: &str) {
std::fs::remove_dir_all(artifact_dir).ok();
std::fs::create_dir_all(artifact_dir).ok();
}
pub fn infer_latest(&self, model: DefaultModel) -> Option<i32> {
let item = self.dataset.get_latest()?;
Some(infer_item(model, item))
}
pub fn train(
&self,
action_count: usize,
epochs: usize,
max_datapoints: usize,
) -> Result<DefaultModel, Box<dyn std::error::Error>> {
let device = burn::backend::wgpu::WgpuDevice::default();
let artifact_dir = "/tmp/psylink";
let mut model_config = ModelConfig::new();
model_config.num_classes = action_count + 1;
let mut training_config = TrainingConfig::new(model_config, AdamConfig::new());
training_config.num_epochs = epochs;
self.train2::<DefaultBackend>(
artifact_dir,
training_config,
max_datapoints,
device.clone(),
)
}
fn train2<B: AutodiffBackend>(
&self,
artifact_dir: &str,
config: TrainingConfig,
max_datapoints: usize,
device: B::Device,
) -> Result<Model<B>, Box<dyn std::error::Error>> {
Self::create_artifact_dir(artifact_dir);
config
.save(format!("{artifact_dir}/config.json"))
.expect("Config should be saved successfully");
B::seed(config.seed);
println!("Dataset length: {}", self.dataset.len());
let (dataset_train, dataset_valid) = self.dataset.split_train_validate(max_datapoints);
let batcher_train = TrainingBatcher::<B>::new(device.clone());
let batcher_valid = TrainingBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(dataset_train);
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(dataset_valid);
let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(
config.model.init::<B>(&device),
config.optimizer.init(),
config.learning_rate,
);
let model_trained = learner.fit(dataloader_train, dataloader_test);
model_trained
.clone()
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
.expect("Trained model should be saved successfully");
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model_trained
.clone()
.save_file(format!("{artifact_dir}/model_bin"), &recorder)
.expect("Should be able to save the model");
Ok(model_trained)
}
}
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
pool: AdaptiveAvgPool2d,
dropout: Dropout,
linear1: Linear<B>,
linear2: Linear<B>,
activation: Relu,
}
impl<B: AutodiffBackend> TrainStep<TrainingBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: TrainingBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.features, batch.targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<TrainingBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: TrainingBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.features, batch.targets)
}
}
impl<B: Backend> Model<B> {
pub fn forward(&self, features: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, height, width] = features.dims();
let x = features.reshape([batch_size, 1, height, width]);
let x = self.conv1.forward(x); let x = self.dropout.forward(x);
let x = self.conv2.forward(x); let x = self.dropout.forward(x);
let x = self.activation.forward(x);
let x = self.pool.forward(x); let x = x.reshape([batch_size, 16 * 8 * 8]);
let x = self.linear1.forward(x);
let x = self.dropout.forward(x);
let x = self.activation.forward(x);
self.linear2.forward(x) }
pub fn forward_classification(
&self,
features: Tensor<B, 3>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(features);
let loss =
CrossEntropyLoss::new(None, &output.device()).forward(output.clone(), targets.clone());
ClassificationOutput::new(loss, output, targets)
}
}
#[derive(Config, Debug)]
pub struct ModelConfig {
#[config(default = "2")]
num_classes: usize,
#[config(default = "32")]
hidden_size: usize,
#[config(default = "0.5")]
dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
Model {
conv1: Conv2dConfig::new([1, 16], [5, 5]).init(device),
conv2: Conv2dConfig::new([16, 16], [5, 5]).init(device),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: Relu::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 6)]
pub num_epochs: usize,
#[config(default = 32)]
pub batch_size: usize,
#[config(default = 8)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
#[derive(Clone, Default, Debug)]
pub struct Datapoint {
pub packet_index: usize,
pub label: u8,
}
#[derive(Clone, Default, Debug)]
pub struct TrainingSample {
pub features: Vec<Vec<u8>>,
pub label: u8,
}
#[derive(Clone, Default, Debug)]
pub struct PsyLinkDataset {
pub datapoints: Vec<Datapoint>,
pub all_packets: Vec<Vec<u8>>,
}
impl Dataset<TrainingSample> for PsyLinkDataset {
fn get(&self, index: usize) -> Option<TrainingSample> {
let datapoint = self.datapoints.get(index)?;
self.get_sample_from_packet_index(datapoint.packet_index, datapoint.label)
}
fn len(&self) -> usize {
self.datapoints.len()
}
}
impl PsyLinkDataset {
fn get_sample_from_packet_index(
&self,
packet_index: usize,
label: u8,
) -> Option<TrainingSample> {
if packet_index < SAMPLE_TIMESPAN {
return None;
}
let start = packet_index - (SAMPLE_TIMESPAN - 1);
let end = packet_index;
let packet = self.all_packets.get(start..=end)?;
Some(TrainingSample {
features: (*packet).iter().cloned().collect(),
label,
})
}
fn split_train_validate(&self, max_datapoints: usize) -> (Self, Self) {
let mut datapoints = self.datapoints.clone();
let mut rng = thread_rng();
datapoints.shuffle(&mut rng);
datapoints.truncate(max_datapoints);
let validation_split_index = (datapoints.len() * VALIDATION_SET_PERCENTAGE) / 100;
let training_datapoints = if validation_split_index <= datapoints.len() {
datapoints.split_off(validation_split_index)
} else {
vec![]
};
let train_dataset = PsyLinkDataset {
datapoints: training_datapoints,
all_packets: self.all_packets.clone(),
};
let validation_dataset = PsyLinkDataset {
datapoints,
all_packets: self.all_packets.clone(),
};
(train_dataset, validation_dataset)
}
pub fn to_string(&self) -> String {
let mut string = String::new();
string += "([\n";
for datapoint in &self.datapoints {
string += format!("({},{}),", datapoint.packet_index, datapoint.label).as_str();
}
string += "],\n[\n";
for packet in &self.all_packets {
string += "[";
for byte in packet {
string += byte.to_string().as_str();
string += ",";
}
string += "],\n";
}
string += "])\n";
string
}
pub fn from_arrays(datapoints: &[(usize, u8)], all_packets: &[[u8; 14]]) -> Self {
let datapoints: Vec<Datapoint> = datapoints
.iter()
.map(|d| Datapoint {
packet_index: d.0,
label: d.1,
})
.collect();
let all_packets: Vec<Vec<u8>> = all_packets
.iter()
.map(|packet| packet.iter().map(|&byte| byte).collect())
.collect();
Self {
datapoints,
all_packets,
}
}
pub fn get_latest(&self) -> Option<TrainingSample> {
let last = self.all_packets.len().saturating_sub(1);
self.get_sample_from_packet_index(last, 0)
}
}
#[derive(Clone, Debug)]
pub struct TrainingBatch<B: Backend> {
pub features: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>,
}
#[derive(Clone)]
pub struct TrainingBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> TrainingBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}
impl<B: Backend> Batcher<TrainingSample, TrainingBatch<B>> for TrainingBatcher<B> {
fn batch(&self, items: Vec<TrainingSample>) -> TrainingBatch<B> {
let features = items
.iter()
.map(|item| Data::<u8, 2> {
value: item.features.concat().iter().map(|&n| n).collect(),
shape: Shape::<2> { dims: [250, 14] },
})
.map(|data| {
Tensor::<B, 2>::from_data(data.convert(), &self.device).reshape([1, 250, 14])
})
.collect();
let targets = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data(Data::from([item.label.elem()]), &self.device)
})
.collect();
let features = Tensor::cat(features, 0).to_device(&self.device);
let targets = Tensor::cat(targets, 0).to_device(&self.device);
let batch = TrainingBatch { features, targets };
return batch;
}
}
pub fn load_test_model() -> Model<DefaultBackend> {
let device = burn::backend::wgpu::WgpuDevice::default();
let config = TrainingConfig::load_binary(include_bytes!("data/test_model_config.json"))
.expect("Config should exist for the model");
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
.load(TEST_MODEL.to_vec(), &device)
.expect("Should be able to load model the model weights from bytes");
let model = config
.model
.init::<DefaultBackend>(&device)
.load_record(record);
model
}
pub fn infer_item(model: Model<DefaultBackend>, item: TrainingSample) -> i32 {
let device = burn::backend::wgpu::WgpuDevice::default();
let batcher = TrainingBatcher::<DefaultBackend>::new(device.clone());
let batch = batcher.batch(vec![item]);
let output = model.forward(batch.features);
let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();
return predicted;
}
pub fn train() -> Result<(), Box<dyn std::error::Error>> {
let mut calib = CalibController::default();
calib.dataset = PsyLinkDataset::from_arrays(&TEST_DATASET.0, &TEST_DATASET.1);
calib.train(2, DEFAULT_EPOCHS, DEFAULT_MAX_DATAPOINTS)?;
Ok(())
}
pub fn infer() -> Result<(), Box<dyn std::error::Error>> {
let model = load_test_model();
let dataset = PsyLinkDataset::from_arrays(&TEST_DATASET.0, &TEST_DATASET.1);
for item in dataset.iter() {
let predicted = infer_item(model.clone(), item);
dbg!(predicted);
}
Ok(())
}