1use crate::nn::{
2 linear, MaybeQuantizedEmbedding as Embedding, MaybeQuantizedLinear as Linear,
3 MaybeQuantizedVarBuilder as VarBuilder,
4};
5use candle::{DType, Result, Tensor};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct LutConfig {
10 pub n_bins: usize,
11 pub dim: usize,
12 pub possible_values: Vec<String>,
13}
14
15#[derive(Debug, Clone, serde::Deserialize)]
16pub struct ContinuousAttributeConfig {
17 pub dim: usize,
18 pub scale_factor: f32,
19 pub max_period: f32,
20}
21
22#[derive(Debug, Clone, serde::Deserialize)]
23#[serde(tag = "type")]
24pub enum ConditionerConfig {
25 Lut(LutConfig),
26 ContinuousAttribute(ContinuousAttributeConfig),
27}
28
29pub type Config = HashMap<String, ConditionerConfig>;
30
31#[derive(Debug, Clone)]
32pub struct LutConditioner {
33 embed: Embedding,
34 output_proj: Linear,
35 #[allow(unused)]
36 learnt_padding: Tensor,
37 possible_values: HashMap<String, usize>,
38}
39
40impl LutConditioner {
41 pub fn new(output_dim: usize, cfg: &LutConfig, vb: VarBuilder) -> Result<Self> {
42 let embed = Embedding::new(cfg.n_bins + 1, cfg.dim, vb.pp("embed"))?;
43 let output_proj = linear(cfg.dim, output_dim, false, vb.pp("output_proj"))?;
44 let learnt_padding = vb.get_as_tensor((1, 1, output_dim), "learnt_padding")?;
45 let possible_values: HashMap<String, usize> =
46 cfg.possible_values.iter().enumerate().map(|(i, v)| (v.to_string(), i)).collect();
47 Ok(Self { embed, output_proj, learnt_padding, possible_values })
48 }
49
50 pub fn condition(&self, value: &str) -> Result<Condition> {
51 let idx = match self.possible_values.get(value) {
52 None => candle::bail!("unknown value for lut conditioner '{value}'"),
53 Some(idx) => *idx,
54 };
55 let cond = Tensor::from_vec(vec![idx as u32], (1, 1), self.embed.embeddings().device())?
56 .apply(&self.embed)?
57 .apply(&self.output_proj)?;
58 Ok(Condition::AddToInput(cond))
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct ContinuousAttributeConditioner {
64 scale_factor: f32,
65 max_period: f32,
66 dim: usize,
67 output_proj: Linear,
68 #[allow(unused)]
69 learnt_padding: Tensor,
70 device: candle::Device,
71}
72
73impl ContinuousAttributeConditioner {
74 pub fn new(output_dim: usize, cfg: &ContinuousAttributeConfig, vb: VarBuilder) -> Result<Self> {
75 let output_proj = linear(cfg.dim, output_dim, false, vb.pp("output_proj"))?;
76 let learnt_padding = vb.get_as_tensor((1, 1, output_dim), "learnt_padding")?;
77 Ok(Self {
78 scale_factor: cfg.scale_factor,
79 max_period: cfg.max_period,
80 dim: cfg.dim,
81 output_proj,
82 learnt_padding,
83 device: vb.device().clone(),
84 })
85 }
86
87 pub fn create_sin_embeddings(&self, positions: &Tensor, dtype: DType) -> Result<Tensor> {
89 let dev = positions.device();
90 let half_dim = self.dim / 2;
91 let positions = positions.to_dtype(dtype)?;
92 let adim: Vec<_> = (0..half_dim)
93 .map(|i| 1f32 / self.max_period.powf(i as f32 / (half_dim - 1) as f32))
94 .collect();
95 let adim = Tensor::from_vec(adim, (1, 1, ()), dev)?;
96 let freqs = positions.broadcast_mul(&adim)?;
97 let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], candle::D::Minus1)?;
98 Ok(pos_emb)
99 }
100
101 pub fn condition(&self, value: f32) -> Result<Condition> {
103 let value = value * self.scale_factor;
104 let positions = Tensor::full(value, (1, 1, 1), &self.device)?;
105 let cond = self
106 .create_sin_embeddings(&positions, DType::F32)?
107 .to_dtype(self.output_proj.dtype())?
108 .apply(&self.output_proj)?;
109 Ok(Condition::AddToInput(cond))
110 }
111}
112
113#[derive(Debug, Clone)]
114pub enum Conditioner {
115 Lut(LutConditioner),
116 ContinuousAttribute(ContinuousAttributeConditioner),
117}
118
119#[derive(Debug, Clone)]
120pub struct ConditionProvider {
121 conditioners: HashMap<String, Conditioner>,
122}
123
124#[derive(Debug, Clone)]
125pub enum Condition {
126 AddToInput(Tensor),
127}
128
129impl ConditionProvider {
130 pub fn new(output_dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
131 let vb = vb.pp("conditioners");
132 let mut conditioners = HashMap::new();
133 for (conditioner_name, conditioner_cfg) in cfg.iter() {
134 let vb = vb.pp(conditioner_name);
135 let conditioner = match conditioner_cfg {
136 ConditionerConfig::Lut(cfg) => {
137 Conditioner::Lut(LutConditioner::new(output_dim, cfg, vb)?)
138 }
139 ConditionerConfig::ContinuousAttribute(cfg) => Conditioner::ContinuousAttribute(
140 ContinuousAttributeConditioner::new(output_dim, cfg, vb)?,
141 ),
142 };
143 conditioners.insert(conditioner_name.to_string(), conditioner);
144 }
145 Ok(Self { conditioners })
146 }
147
148 pub fn condition_lut(&self, name: &str, value: &str) -> Result<Condition> {
149 let lut = match self.conditioners.get(name) {
150 None => candle::bail!("unknown conditioner {name}"),
151 Some(Conditioner::Lut(l)) => l,
152 Some(_) => candle::bail!("cannot use conditioner with a str value {name}"),
153 };
154 let cond = lut.condition(value)?;
155 Ok(cond)
156 }
157
158 pub fn condition_cont(&self, name: &str, value: f32) -> Result<Condition> {
159 let c = match self.conditioners.get(name) {
160 None => candle::bail!("unknown conditioner {name}"),
161 Some(Conditioner::ContinuousAttribute(c)) => c,
162 Some(_) => candle::bail!("cannot use conditioner with a str value {name}"),
163 };
164 let cond = c.condition(value)?;
165 Ok(cond)
166 }
167
168 pub fn learnt_padding(&self, name: &str) -> Result<Condition> {
169 let c = match self.conditioners.get(name) {
170 None => candle::bail!("unknown conditioner {name}"),
171 Some(Conditioner::ContinuousAttribute(c)) => c.learnt_padding.clone(),
172 Some(Conditioner::Lut(c)) => c.learnt_padding.clone(),
173 };
174 Ok(Condition::AddToInput(c))
175 }
176}