moshi_db/
conditioner.rs

1use crate::nn::{
2    linear, MaybeQuantizedEmbedding as Embedding, MaybeQuantizedLinear as Linear,
3    MaybeQuantizedVarBuilder as VarBuilder,
4};
5use candle::{DType, Result, Tensor};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct LutConfig {
10    pub n_bins: usize,
11    pub dim: usize,
12    pub possible_values: Vec<String>,
13}
14
15#[derive(Debug, Clone, serde::Deserialize)]
16pub struct ContinuousAttributeConfig {
17    pub dim: usize,
18    pub scale_factor: f32,
19    pub max_period: f32,
20}
21
22#[derive(Debug, Clone, serde::Deserialize)]
23#[serde(tag = "type")]
24pub enum ConditionerConfig {
25    Lut(LutConfig),
26    ContinuousAttribute(ContinuousAttributeConfig),
27}
28
29pub type Config = HashMap<String, ConditionerConfig>;
30
31#[derive(Debug, Clone)]
32pub struct LutConditioner {
33    embed: Embedding,
34    output_proj: Linear,
35    #[allow(unused)]
36    learnt_padding: Tensor,
37    possible_values: HashMap<String, usize>,
38}
39
40impl LutConditioner {
41    pub fn new(output_dim: usize, cfg: &LutConfig, vb: VarBuilder) -> Result<Self> {
42        let embed = Embedding::new(cfg.n_bins + 1, cfg.dim, vb.pp("embed"))?;
43        let output_proj = linear(cfg.dim, output_dim, false, vb.pp("output_proj"))?;
44        let learnt_padding = vb.get_as_tensor((1, 1, output_dim), "learnt_padding")?;
45        let possible_values: HashMap<String, usize> =
46            cfg.possible_values.iter().enumerate().map(|(i, v)| (v.to_string(), i)).collect();
47        Ok(Self { embed, output_proj, learnt_padding, possible_values })
48    }
49
50    pub fn condition(&self, value: &str) -> Result<Condition> {
51        let idx = match self.possible_values.get(value) {
52            None => candle::bail!("unknown value for lut conditioner '{value}'"),
53            Some(idx) => *idx,
54        };
55        let cond = Tensor::from_vec(vec![idx as u32], (1, 1), self.embed.embeddings().device())?
56            .apply(&self.embed)?
57            .apply(&self.output_proj)?;
58        Ok(Condition::AddToInput(cond))
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct ContinuousAttributeConditioner {
64    scale_factor: f32,
65    max_period: f32,
66    dim: usize,
67    output_proj: Linear,
68    #[allow(unused)]
69    learnt_padding: Tensor,
70    device: candle::Device,
71}
72
73impl ContinuousAttributeConditioner {
74    pub fn new(output_dim: usize, cfg: &ContinuousAttributeConfig, vb: VarBuilder) -> Result<Self> {
75        let output_proj = linear(cfg.dim, output_dim, false, vb.pp("output_proj"))?;
76        let learnt_padding = vb.get_as_tensor((1, 1, output_dim), "learnt_padding")?;
77        Ok(Self {
78            scale_factor: cfg.scale_factor,
79            max_period: cfg.max_period,
80            dim: cfg.dim,
81            output_proj,
82            learnt_padding,
83            device: vb.device().clone(),
84        })
85    }
86
87    // `positions` should have shape (b, t, 1), the output will be (b, t, dim)
88    pub fn create_sin_embeddings(&self, positions: &Tensor, dtype: DType) -> Result<Tensor> {
89        let dev = positions.device();
90        let half_dim = self.dim / 2;
91        let positions = positions.to_dtype(dtype)?;
92        let adim: Vec<_> = (0..half_dim)
93            .map(|i| 1f32 / self.max_period.powf(i as f32 / (half_dim - 1) as f32))
94            .collect();
95        let adim = Tensor::from_vec(adim, (1, 1, ()), dev)?;
96        let freqs = positions.broadcast_mul(&adim)?;
97        let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], candle::D::Minus1)?;
98        Ok(pos_emb)
99    }
100
101    // TODO(laurent): should we support different values per batch element?
102    pub fn condition(&self, value: f32) -> Result<Condition> {
103        let value = value * self.scale_factor;
104        let positions = Tensor::full(value, (1, 1, 1), &self.device)?;
105        let cond = self
106            .create_sin_embeddings(&positions, DType::F32)?
107            .to_dtype(self.output_proj.dtype())?
108            .apply(&self.output_proj)?;
109        Ok(Condition::AddToInput(cond))
110    }
111}
112
113#[derive(Debug, Clone)]
114pub enum Conditioner {
115    Lut(LutConditioner),
116    ContinuousAttribute(ContinuousAttributeConditioner),
117}
118
119#[derive(Debug, Clone)]
120pub struct ConditionProvider {
121    conditioners: HashMap<String, Conditioner>,
122}
123
124#[derive(Debug, Clone)]
125pub enum Condition {
126    AddToInput(Tensor),
127}
128
129impl ConditionProvider {
130    pub fn new(output_dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
131        let vb = vb.pp("conditioners");
132        let mut conditioners = HashMap::new();
133        for (conditioner_name, conditioner_cfg) in cfg.iter() {
134            let vb = vb.pp(conditioner_name);
135            let conditioner = match conditioner_cfg {
136                ConditionerConfig::Lut(cfg) => {
137                    Conditioner::Lut(LutConditioner::new(output_dim, cfg, vb)?)
138                }
139                ConditionerConfig::ContinuousAttribute(cfg) => Conditioner::ContinuousAttribute(
140                    ContinuousAttributeConditioner::new(output_dim, cfg, vb)?,
141                ),
142            };
143            conditioners.insert(conditioner_name.to_string(), conditioner);
144        }
145        Ok(Self { conditioners })
146    }
147
148    pub fn condition_lut(&self, name: &str, value: &str) -> Result<Condition> {
149        let lut = match self.conditioners.get(name) {
150            None => candle::bail!("unknown conditioner {name}"),
151            Some(Conditioner::Lut(l)) => l,
152            Some(_) => candle::bail!("cannot use conditioner with a str value {name}"),
153        };
154        let cond = lut.condition(value)?;
155        Ok(cond)
156    }
157
158    pub fn condition_cont(&self, name: &str, value: f32) -> Result<Condition> {
159        let c = match self.conditioners.get(name) {
160            None => candle::bail!("unknown conditioner {name}"),
161            Some(Conditioner::ContinuousAttribute(c)) => c,
162            Some(_) => candle::bail!("cannot use conditioner with a str value {name}"),
163        };
164        let cond = c.condition(value)?;
165        Ok(cond)
166    }
167
168    pub fn learnt_padding(&self, name: &str) -> Result<Condition> {
169        let c = match self.conditioners.get(name) {
170            None => candle::bail!("unknown conditioner {name}"),
171            Some(Conditioner::ContinuousAttribute(c)) => c.learnt_padding.clone(),
172            Some(Conditioner::Lut(c)) => c.learnt_padding.clone(),
173        };
174        Ok(Condition::AddToInput(c))
175    }
176}