candle_transformers/models/stable_diffusion/
utils.rs1use candle::{Device, Result, Tensor};
2
3pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
4 if steps == 0 {
5 Tensor::from_vec(Vec::<f64>::new(), steps, &Device::Cpu)
6 } else if steps == 1 {
7 Tensor::from_vec(vec![start], steps, &Device::Cpu)
8 } else {
9 let delta = (stop - start) / (steps - 1) as f64;
10 let vs = (0..steps)
11 .map(|step| start + step as f64 * delta)
12 .collect::<Vec<_>>();
13 Tensor::from_vec(vs, steps, &Device::Cpu)
14 }
15}
16
17struct LinearInterpolator<'x, 'y> {
19 xp: &'x [f64],
20 fp: &'y [f64],
21 cache: usize,
22}
23
24impl LinearInterpolator<'_, '_> {
25 fn accel_find(&mut self, x: f64) -> usize {
26 let xidx = self.cache;
27 if x < self.xp[xidx] {
28 self.cache = self.xp[0..xidx].partition_point(|o| *o < x);
29 self.cache = self.cache.saturating_sub(1);
30 } else if x >= self.xp[xidx + 1] {
31 self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx;
32 self.cache = self.cache.saturating_sub(1);
33 }
34
35 self.cache
36 }
37
38 fn eval(&mut self, x: f64) -> f64 {
39 if x < self.xp[0] || x > self.xp[self.xp.len() - 1] {
40 return f64::NAN;
41 }
42
43 let idx = self.accel_find(x);
44
45 let x_l = self.xp[idx];
46 let x_h = self.xp[idx + 1];
47 let y_l = self.fp[idx];
48 let y_h = self.fp[idx + 1];
49 let dx = x_h - x_l;
50 if dx > 0.0 {
51 y_l + (x - x_l) / dx * (y_h - y_l)
52 } else {
53 f64::NAN
54 }
55 }
56}
57
58pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> {
59 let mut interpolator = LinearInterpolator { xp, fp, cache: 0 };
60 x.iter().map(|&x| interpolator.eval(x)).collect()
61}