use crate::ModelState;
use crate::models::transformer::StreamingTransformer;
use crate::modules::mlp::{LayerNorm, ModulationParams, SimpleMLPAdaLN};
use candle_core::{Result, Tensor};
use candle_nn::{Linear, Module, VarBuilder};
pub fn lsd_decode(
flow_net: &SimpleMLPAdaLN,
modulations: &[Vec<ModulationParams>],
x_0: &Tensor,
) -> Result<Tensor> {
let mut current = x_0.clone();
let num_steps = modulations.len();
let step_factor = 1.0 / num_steps as f64;
for step_mod in modulations {
let flow_dir = flow_net.forward_step_cached(¤t, step_mod)?;
current = (current + flow_dir.affine(step_factor, 0.0)?)?;
}
Ok(current)
}
#[derive(Clone)]
pub struct FlowLMModel {
pub flow_net: SimpleMLPAdaLN,
pub transformer: StreamingTransformer,
pub input_linear: Linear,
pub out_norm: LayerNorm,
pub out_eos: Linear,
pub bos_emb: Tensor,
pub emb_mean: Tensor,
pub emb_std: Tensor,
pub ldim: usize,
pub dim: usize,
pub noise_clamp: Option<f32>,
}
fn sample_noise(
device: &candle_core::Device,
shape: (usize, usize),
temp: f32,
clamp: Option<f32>,
) -> Result<Tensor> {
let std = temp.sqrt();
match clamp {
None => Tensor::randn(0.0f32, std, shape, device),
Some(limit) => {
let count = shape.0 * shape.1;
let mut data = Vec::with_capacity(count);
let mut rng = rand::thread_rng();
let dist = rand_distr::Normal::new(0.0f32, std)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
while data.len() < count {
let v = rand_distr::Distribution::sample(&dist, &mut rng);
if v.abs() <= limit {
data.push(v);
}
}
Tensor::from_vec(data, shape, device)
}
}
}
impl FlowLMModel {
pub fn new(
flow_net: SimpleMLPAdaLN,
transformer: StreamingTransformer,
ldim: usize,
dim: usize,
vb: VarBuilder,
) -> Result<Self> {
let input_linear = candle_nn::linear_no_bias(ldim, dim, vb.pp("input_linear"))?;
let out_norm = LayerNorm::new(dim, 1e-5, true, vb.pp("out_norm"))?;
let out_eos = candle_nn::linear(dim, 1, vb.pp("out_eos"))?;
let bos_emb = vb.get(ldim, "bos_emb")?;
let emb_mean = vb.get(ldim, "emb_mean")?;
let emb_std = vb.get(ldim, "emb_std")?;
Ok(Self {
flow_net,
transformer,
input_linear,
out_norm,
out_eos,
bos_emb,
emb_mean,
emb_std,
ldim,
dim,
noise_clamp: None, })
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
sequence: &Tensor,
text_embeddings: &Tensor,
model_state: &mut ModelState,
time_embeddings: &Tensor,
temp: f32,
eos_threshold: f32,
step: usize,
) -> Result<(Tensor, bool)> {
let x = self.input_linear.forward(sequence)?;
let s_len = text_embeddings.dims()[1];
let transformer_out_pre_norm = if s_len > 0 {
let input = Tensor::cat(&[text_embeddings, &x], 1)?;
let mut out = self.transformer.forward(&input, model_state, step)?;
out = out.narrow(1, s_len, out.dims()[1] - s_len)?;
out
} else {
self.transformer.forward(&x, model_state, step)?
};
let transformer_out = self.out_norm.forward(&transformer_out_pre_norm)?;
let last_frame = transformer_out
.narrow(1, transformer_out.dims()[1] - 1, 1)?
.squeeze(1)?;
let eos_score = self
.out_eos
.forward(&last_frame)?
.squeeze(0)?
.squeeze(0)?
.to_scalar::<f32>()?;
let is_eos = eos_score > eos_threshold;
let noise = sample_noise(
last_frame.device(),
(last_frame.dims()[0], self.ldim),
temp,
self.noise_clamp,
)?;
let c_emb = self.flow_net.embed_condition(&last_frame)?;
let modulations = self
.flow_net
.precompute_modulations(&c_emb, time_embeddings)?;
let next_latent = lsd_decode(&self.flow_net, &modulations, &noise)?;
Ok((next_latent, is_eos))
}
}