use anyhow::Result;
use burn::{
module::Module,
nn::{Initializer, Linear, LinearConfig},
prelude::Backend,
tensor::{
activation::log_softmax,
backend::AutodiffBackend,
Int, Tensor, TensorData,
},
};
use crate::{
config::{MAX_ANSWER_TOKENS, PATCH_SIZE},
data::batch::Sample,
model::{
encoder::{TransformerCnnEncoder, TransformerCnnEncoderConfig},
llm::llama_cpp::LlamaCppBackend,
},
};
#[derive(Module, Debug)]
pub struct TrainableComponents<B: Backend> {
pub encoder: TransformerCnnEncoder<B>,
pub logit_head: Linear<B>,
}
pub struct OpenTslmSp<B: Backend> {
pub trainable: TrainableComponents<B>,
pub patch_size: usize,
pub n_vocab: usize,
pub enc_out_dim: usize,
}
impl<B: Backend> OpenTslmSp<B> {
pub fn new(llm: &LlamaCppBackend, device: &B::Device) -> Self {
let enc_cfg = TransformerCnnEncoderConfig::default();
let enc_out = enc_cfg.output_dim;
let encoder = enc_cfg.init(device);
let logit_head = LinearConfig::new(enc_out, llm.n_vocab)
.with_bias(false)
.with_initializer(Initializer::Normal { mean: 0.0, std: 0.01 })
.init::<B>(device);
Self {
trainable: TrainableComponents { encoder, logit_head },
patch_size: PATCH_SIZE,
n_vocab: llm.n_vocab,
enc_out_dim: enc_out,
}
}
pub fn encode_to_logit_bias(
&self,
batch: &[Sample],
device: &B::Device,
) -> Tensor<B, 2> {
let b = batch.len();
let n_vocab = self.n_vocab;
let mut all_ts: Vec<&Vec<f32>> = Vec::new();
let mut counts: Vec<usize> = Vec::with_capacity(b);
for sample in batch {
counts.push(sample.time_series.len());
for ts in &sample.time_series {
all_ts.push(ts);
}
}
let total_ts: usize = counts.iter().sum();
if total_ts == 0 {
return Tensor::zeros([b, n_vocab], device);
}
let ts_tensor = pad_ts_batch_refs(&all_ts, self.patch_size, device);
let encoded = self.trainable.encoder.forward(ts_tensor);
let [_, _n_patches, enc_dim] = encoded.dims();
let pooled = encoded.mean_dim(1).reshape([total_ts, enc_dim]);
let all_biases = self.trainable.logit_head.forward(pooled);
let mut rows: Vec<Tensor<B, 2>> = Vec::with_capacity(b);
let mut offset = 0usize;
for &count in &counts {
if count == 0 {
rows.push(Tensor::zeros([1, n_vocab], device));
} else {
let slice = all_biases
.clone()
.slice([offset..(offset + count), 0..n_vocab]); let avg = slice.mean_dim(0).reshape([1, n_vocab]); rows.push(avg);
offset += count;
}
}
Tensor::cat(rows, 0) }
pub fn generate(
&self,
batch: &[Sample],
llm: &LlamaCppBackend,
max_tokens: usize,
device: &B::Device,
) -> Vec<String> {
let biases = self.encode_to_logit_bias(batch, device);
batch
.iter()
.enumerate()
.map(|(i, sample)| {
let bias_vec: Vec<f32> = biases
.clone()
.slice([i..(i + 1), 0..self.n_vocab])
.reshape([self.n_vocab])
.to_data()
.to_vec::<f32>()
.unwrap_or_default();
let prompt_text = format_prompt(sample);
let prompt_tokens = llm.tokenize(&prompt_text, true).unwrap_or_default();
let generated = llm
.generate(&prompt_tokens, max_tokens, Some(&bias_vec))
.unwrap_or_default();
llm.detokenize(&generated)
})
.collect()
}
pub fn compute_loss(
&self,
batch: &[Sample],
llm: &LlamaCppBackend,
device: &B::Device,
) -> Tensor<B, 1>
where
B: AutodiffBackend,
{
self.loss_and_acc_inner(batch, llm, device, false).0
}
pub fn compute_loss_and_metrics(
&self,
batch: &[Sample],
llm: &LlamaCppBackend,
device: &B::Device,
) -> (Tensor<B, 1>, f64, f64)
where
B: AutodiffBackend,
{
self.loss_and_acc_inner(batch, llm, device, true)
}
fn loss_and_acc_inner(
&self,
batch: &[Sample],
llm: &LlamaCppBackend,
device: &B::Device,
compute_acc: bool,
) -> (Tensor<B, 1>, f64, f64)
where
B: AutodiffBackend,
{
let biases = self.encode_to_logit_bias(batch, device);
let n_vocab = self.n_vocab;
let mut total_loss = Tensor::<B, 1>::zeros([1], device);
let mut n_counted = 0usize;
let mut correct: usize = 0;
let mut total_tok: usize = 0;
let mut class_tp: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
let mut class_sup: std::collections::HashMap<usize, usize> =
std::collections::HashMap::new();
for (i, sample) in batch.iter().enumerate() {
let bias_i = biases
.clone()
.slice([i..(i + 1), 0..n_vocab])
.reshape([n_vocab]);
let prompt_text = format_prompt(sample);
let prompt_tokens = match llm.tokenize(&prompt_text, true) {
Ok(t) if !t.is_empty() => t,
_ => continue,
};
let answer_tokens = match llm.tokenize(&sample.answer, false) {
Ok(t) if !t.is_empty() => t,
_ => continue,
};
let answer_tokens: Vec<_> = answer_tokens.into_iter()
.take(MAX_ANSWER_TOKENS)
.collect();
let answer_len = answer_tokens.len();
let base_logits = match llm.answer_logits(&prompt_tokens, &answer_tokens) {
Ok(v) => v,
Err(e) => { tracing::warn!("answer_logits: {e}"); continue; }
};
if base_logits.len() != answer_len { continue; }
let flat_base: Vec<f32> = base_logits.into_iter().flatten().collect();
let base_t = Tensor::<B, 2>::from_data(
TensorData::new(flat_base, [answer_len, n_vocab]),
device,
);
let adjusted = base_t + bias_i.reshape([1, n_vocab]).expand([answer_len, n_vocab]);
if compute_acc {
let pred_ids: Vec<i32> = adjusted
.clone()
.argmax(1)
.to_data()
.to_vec::<i32>()
.unwrap_or_else(|e| {
tracing::warn!("argmax to_vec failed: {e}");
vec![]
});
for (pos, tok) in answer_tokens.iter().enumerate() {
let target = tok.0; let pred = pred_ids.get(pos).copied().unwrap_or(-1);
*class_sup.entry(target as usize).or_insert(0) += 1;
if pred == target {
correct += 1;
*class_tp.entry(target as usize).or_insert(0) += 1;
}
total_tok += 1;
}
}
let log_probs = log_softmax(adjusted, 1);
let target_ids: Vec<i64> = answer_tokens.iter().map(|t| t.0 as i64).collect();
let target_t = Tensor::<B, 2, Int>::from_data(
TensorData::new(target_ids, [answer_len, 1]),
device,
);
let selected_lp = log_probs.gather(1, target_t).reshape([answer_len]);
let n_t = Tensor::<B, 1>::from_data(
TensorData::new(vec![answer_len as f32], [1]),
device,
);
total_loss = total_loss + selected_lp.sum().neg() / n_t;
n_counted += 1;
}
let loss = if n_counted == 0 {
Tensor::<B, 1>::zeros([1], device)
} else {
let n_t = Tensor::<B, 1>::from_data(
TensorData::new(vec![n_counted as f32], [1]),
device,
);
total_loss / n_t
};
let accuracy = if total_tok == 0 { 0.0 } else { correct as f64 / total_tok as f64 };
let macro_recall = if class_sup.is_empty() {
0.0
} else {
let sum: f64 = class_sup.keys().map(|c| {
let tp = *class_tp.get(c).unwrap_or(&0) as f64;
let sup = *class_sup.get(c).unwrap_or(&1) as f64;
tp / sup
}).sum();
sum / class_sup.len() as f64
};
(loss, accuracy, macro_recall)
}
}
pub fn format_prompt(sample: &Sample) -> String {
let mut parts = vec![sample.pre_prompt.clone()];
for (text, series) in sample
.time_series_text
.iter()
.zip(sample.time_series.iter())
{
parts.push(text.clone());
let n_show = 20.min(series.len());
let nums: String = series[..n_show]
.iter()
.map(|v| format!("{v:.3}"))
.collect::<Vec<_>>()
.join(", ");
let ellipsis = if series.len() > n_show { ", ..." } else { "" };
parts.push(format!("[{nums}{ellipsis}]"));
}
parts.push(sample.post_prompt.clone());
parts.join("\n")
}
#[allow(dead_code)]
fn pad_ts_batch<B: Backend>(
ts_list: &[Vec<f32>],
patch_size: usize,
device: &B::Device,
) -> Tensor<B, 2> {
pad_ts_batch_refs(&ts_list.iter().collect::<Vec<_>>(), patch_size, device)
}
fn pad_ts_batch_refs<B: Backend>(
ts_list: &[&Vec<f32>],
patch_size: usize,
device: &B::Device,
) -> Tensor<B, 2> {
let max_len = ts_list.iter().map(|v| v.len()).max().unwrap_or(0);
let padded_len = ((max_len + patch_size - 1) / patch_size) * patch_size;
let n = ts_list.len();
let mut flat = vec![0.0f32; n * padded_len];
for (i, ts) in ts_list.iter().enumerate() {
let copy = ts.len().min(padded_len);
flat[i * padded_len..i * padded_len + copy].copy_from_slice(&ts[..copy]);
}
Tensor::<B, 2>::from_data(TensorData::new(flat, [n, padded_len]), device)
}