Skip to main content

hanzo_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use hanzo_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::PerTensorFP8 { .. }
41            | QuantMethodConfig::Afq { .. }
42            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43            QuantMethodConfig::Unquantized(l) => Ok(Self {
44                w: l.weight().clone(),
45                b: l.bias().cloned(),
46                stats: None,
47            }),
48        }
49    }
50
51    fn dequantize_w(&self) -> Result<Tensor> {
52        Ok(self.w.clone())
53    }
54
55    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
56        // Batch matrix multiplication
57        maybe_init_cublas_lt_wrapper(a.device().clone());
58
59        // Try custom GEMV for single-token decode (batch_size=1)
60        #[cfg(feature = "cuda")]
61        if crate::gemv::should_use_gemv(a, &self.w) {
62            return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63        }
64
65        let w = match *a.dims() {
66            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67            [bsize, _, _] => self.w.broadcast_left(bsize)?,
68            _ => self.w.clone(),
69        };
70
71        if let Some(stats) = &self.stats {
72            stats.process(a)?;
73        }
74
75        if let Some(b) = self.b.as_ref() {
76            let mut tgt_shape = a.dims().to_vec();
77            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80            match a.device().location() {
81                DeviceLocation::Cuda { .. } => {
82                    // Try to use cublaslt, otherwise fallback to gemm
83                    if let (Device::Cuda(_), Some(cublaslt)) =
84                        (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85                    {
86                        cublaslt
87                            .batch_matmul(
88                                a,
89                                &w,
90                                Some(&b.t()?.contiguous()?),
91                                None,
92                                Some(1.0),
93                                None,
94                                None,
95                            )?
96                            .t()
97                    } else {
98                        let matmul_result = a.matmul(&w.t()?)?;
99                        matmul_result.broadcast_add(&b)
100                    }
101                }
102                DeviceLocation::Metal { .. } => {
103                    let matmul_result = a.matmul(&w.t()?)?;
104                    matmul_result.broadcast_add(&b)
105                }
106                #[cfg(feature = "rocm")]
107                DeviceLocation::Rocm { .. } => {
108                    let matmul_result = a.matmul(&w.t()?)?;
109                    matmul_result.broadcast_add(&b)
110                }
111                #[cfg(feature = "vulkan")]
112                DeviceLocation::Vulkan { .. } => {
113                    let matmul_result = a.matmul(&w.t()?)?;
114                    matmul_result.broadcast_add(&b)
115                }
116                DeviceLocation::Cpu => {
117                    #[cfg(feature = "accelerate")]
118                    {
119                        let original_dtype = a.dtype();
120                        let a_f32 = a.to_dtype(DType::F32)?;
121                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
122                        let b_f32 = b.to_dtype(DType::F32)?;
123                        let matmul_result = a_f32.matmul(&w_f32)?;
124                        matmul_result
125                            .broadcast_add(&b_f32)?
126                            .to_dtype(original_dtype)
127                    }
128                    #[cfg(not(feature = "accelerate"))]
129                    {
130                        let matmul_result = a.matmul(&w.t()?)?;
131                        matmul_result.broadcast_add(&b)
132                    }
133                }
134            }
135        } else {
136            match a.device().location() {
137                DeviceLocation::Cuda { .. } => {
138                    if let (Device::Cuda(_), Some(cublaslt)) =
139                        (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
140                    {
141                        // cuBLAS batch_matmul requires 3D tensors, fall back to regular matmul for 2D.
142                        if a.rank() >= 3 && w.rank() >= 3 {
143                            cublaslt
144                                .batch_matmul(a, &w, None, None, None, None, None)?
145                                .t()
146                        } else {
147                            a.matmul(&w.t()?)
148                        }
149                    } else {
150                        a.matmul(&w.t()?)
151                    }
152                }
153                DeviceLocation::Metal { .. } => a.matmul(&w.t()?),
154                #[cfg(feature = "rocm")]
155                DeviceLocation::Rocm { .. } => a.matmul(&w.t()?),
156                #[cfg(feature = "vulkan")]
157                DeviceLocation::Vulkan { .. } => a.matmul(&w.t()?),
158                DeviceLocation::Cpu => {
159                    #[cfg(feature = "accelerate")]
160                    {
161                        let original_dtype = a.dtype();
162                        a.to_dtype(DType::F32)?
163                            .matmul(&w.t()?.to_dtype(DType::F32)?)?
164                            .to_dtype(original_dtype)
165                    }
166                    #[cfg(not(feature = "accelerate"))]
167                    {
168                        a.matmul(&w.t()?)
169                    }
170                }
171            }
172        }
173    }
174
175    fn gather_forward_raw(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
176        // Weights are [num_experts, out_features, in_features]
177        // For Metal path:
178        //   - a: (b_size, seq_len, 1, 1, hidden_dim) - 5D
179        //   - indices: (b_size, seq_len, num_experts_per_tok) - 3D
180        // For CUDA path:
181        //   - a: (num_tokens, 1, hidden_dim) - 3D
182        //   - indices: (num_tokens, num_experts_per_tok) - 2D
183
184        let w = &self.w;
185        let (_num_experts, out_features, _in_features) = w.dims3()?;
186
187        match a.dims() {
188            // Metal path: 5D input (b_size, seq_len, 1, 1, hidden_dim)
189            &[b_size, seq_len, 1, 1, hidden_dim] => {
190                let (_b, _s, num_experts_per_tok) = indices.dims3()?;
191                // Flatten indices to select experts
192                let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
193
194                // Select expert weights: [b*s*k, out_features, in_features]
195                let selected_w = w.index_select(&flat_indices, 0)?;
196
197                // Reshape input: [b*s, hidden_dim]
198                let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
199
200                // For each token, we need to compute with each selected expert
201                // Broadcast a to match: [b*s, 1, hidden_dim] -> [b*s, k, hidden_dim]
202                let a_expanded = a_flat
203                    .unsqueeze(1)?
204                    .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
205                    .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
206
207                // Matmul: [b*s*k, hidden_dim] @ [b*s*k, hidden_dim, out_features] -> [b*s*k, out_features]
208                let result = a_expanded
209                    .unsqueeze(1)?
210                    .matmul(&selected_w.transpose(1, 2)?)?
211                    .squeeze(1)?;
212
213                // Reshape back to [b, s, k, out_features]
214                result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
215            }
216            // CUDA path: 3D input (num_tokens, 1, hidden_dim)
217            &[num_tokens, 1, hidden_dim] => {
218                let (_, num_experts_per_tok) = indices.dims2()?;
219
220                // Flatten indices
221                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
222
223                // Select expert weights: [n*k, out_features, in_features]
224                let selected_w = w.index_select(&flat_indices, 0)?;
225
226                // Broadcast input: [n, 1, hidden] -> [n, k, hidden] -> [n*k, hidden]
227                let a_expanded = a
228                    .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
229                    .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
230
231                // Matmul: [n*k, hidden] @ [n*k, hidden, out] -> [n*k, out]
232                let result = a_expanded
233                    .unsqueeze(1)?
234                    .matmul(&selected_w.transpose(1, 2)?)?
235                    .squeeze(1)?;
236
237                // Reshape to [n, k, out]
238                result.reshape((num_tokens, num_experts_per_tok, out_features))
239            }
240            &[num_tokens, num_experts_per_tok, hidden_dim] => {
241                let (indices_num_tokens, indices_num_experts_per_tok) = indices.dims2()?;
242                if num_tokens != indices_num_tokens
243                    || num_experts_per_tok != indices_num_experts_per_tok
244                {
245                    hanzo_ml::bail!(
246                        "UnquantLinear::gather_forward: input shape {:?} does not match indices shape {:?}",
247                        a.dims(),
248                        indices.dims()
249                    );
250                }
251
252                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
253                let selected_w = w.index_select(&flat_indices, 0)?;
254                let a_flat = a.reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
255
256                let result = a_flat
257                    .unsqueeze(1)?
258                    .matmul(&selected_w.transpose(1, 2)?)?
259                    .squeeze(1)?;
260
261                result.reshape((num_tokens, num_experts_per_tok, out_features))
262            }
263            dims => {
264                hanzo_ml::bail!(
265                    "UnquantLinear::gather_forward: unsupported input shape {:?}",
266                    dims
267                );
268            }
269        }
270    }
271
272    fn quantized_act_type(&self) -> Option<DType> {
273        None
274    }
275
276    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
277        Ok(Arc::new(Self {
278            w: (&self.w + delta)?,
279            b: self.b.clone(),
280            stats: self.stats.clone(),
281        }))
282    }
283
284    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
285        (self.w.dtype(), self.w.device().clone())
286    }
287
288    fn apply_isq(
289        self: Arc<Self>,
290        dtype: Option<IsqType>,
291        device: Device,
292        n_quantized: &AtomicUsize,
293        imatrix_weight: Option<Vec<f32>>,
294        guard: QuantizeOntoGuard,
295    ) -> Result<Arc<dyn QuantMethod>> {
296        match dtype {
297            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
298            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
299                let _acquired_quantize_guard = guard.acquire(&device);
300                if imatrix_weight.is_some() {
301                    // TODO just warn?
302                    hanzo_ml::bail!("HQQ does not support imatrix.");
303                }
304
305                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306                let bits = match dtype.unwrap() {
307                    IsqType::HQQ8 => HqqBits::Eight,
308                    IsqType::HQQ4 => HqqBits::Four,
309                    // IsqType::HQQ3 => HqqBits::Three,
310                    // IsqType::HQQ2 => HqqBits::Two,
311                    // IsqType::HQQ1 => HqqBits::One,
312                    _ => unreachable!(),
313                };
314                let cfg = HqqConfig {
315                    bits,
316                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
317                    axis: HqqAxis::Zero,
318                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
319                    round_zeros: false,
320                    channel_wise: true,
321                };
322                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
323                if let Some(bias) = &self.b {
324                    let bias = bias
325                        .to_device(&device)?
326                        .to_dtype(res.dtype_and_device().0)?;
327                    Ok(Arc::new(res.with_bias(bias)))
328                } else {
329                    Ok(Arc::new(res))
330                }
331            }
332            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
333                let _acquired_quantize_guard = guard.acquire(&device);
334                if imatrix_weight.is_some() {
335                    // TODO just warn?
336                    hanzo_ml::bail!("AFQ does not support imatrix.");
337                }
338
339                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
340                let bits = match dtype.unwrap() {
341                    IsqType::AFQ8 => AfqBits::Eight,
342                    IsqType::AFQ6 => AfqBits::Six,
343                    IsqType::AFQ4 => AfqBits::Four,
344                    IsqType::AFQ3 => AfqBits::Three,
345                    IsqType::AFQ2 => AfqBits::Two,
346                    _ => unreachable!(),
347                };
348
349                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
350                    weight: self.w.to_device(&device)?,
351                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
352                    bits,
353                    group_size: AfqGroupSize::default(),
354                })?))
355            }
356            Some(
357                IsqType::Q2K
358                | IsqType::Q3K
359                | IsqType::Q4K
360                | IsqType::Q4_0
361                | IsqType::Q4_1
362                | IsqType::Q5K
363                | IsqType::Q5_0
364                | IsqType::Q5_1
365                | IsqType::Q6K
366                | IsqType::Q8K
367                | IsqType::Q8_0
368                | IsqType::Q8_1,
369            ) => {
370                let dtype: GgmlDType = dtype.unwrap().try_into()?;
371                let res = if let Some(imatrix_weight) = imatrix_weight {
372                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
373                } else {
374                    generate_isq!(self.w, device, dtype, n_quantized, guard)
375                };
376                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
377                    q_weight: res,
378                    b: self
379                        .b
380                        .as_ref()
381                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
382                })?))
383            }
384            Some(IsqType::F8E4M3) => {
385                let _acquired_quantize_guard = guard.acquire(&device);
386                if imatrix_weight.is_some() {
387                    // TODO just warn?
388                    hanzo_ml::bail!("F8E4M3 does not support imatrix.");
389                }
390
391                let w = self.w.to_device(&device)?;
392                let b = if let Some(b) = &self.b {
393                    Some(b.to_device(&device)?)
394                } else {
395                    None
396                };
397                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
398                    lin: Linear::new(w, b),
399                    dtype: DType::F8E4M3,
400                })?))
401            }
402            Some(IsqType::MXFP4) => {
403                let _acquired_quantize_guard = guard.acquire(&device);
404                if imatrix_weight.is_some() {
405                    hanzo_ml::bail!("MXFP4 does not support imatrix.");
406                }
407
408                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
409                let w = self.w.to_device(&device)?;
410                let b = self.b.as_ref().map(|b| b.to_device(&device)).transpose()?;
411                crate::MXFP4Layer::quantize(&w, b, &device)
412            }
413            Some(IsqType::F8Q8) => {
414                let _acquired_quantize_guard = guard.acquire(&device);
415                if imatrix_weight.is_some() {
416                    hanzo_ml::bail!("F8Q8 does not support imatrix.");
417                }
418
419                let w = self.w.to_device(&device)?;
420                let b = if let Some(b) = &self.b {
421                    Some(b.to_device(&device)?)
422                } else {
423                    None
424                };
425                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
426            }
427            None => {
428                let _acquired_quantize_guard = guard.acquire(&device);
429                // Ignore imatrix altogether
430
431                let w = self.w.to_device(&device)?;
432                let b = if let Some(b) = &self.b {
433                    Some(b.to_device(&device)?)
434                } else {
435                    None
436                };
437                Ok(Arc::new(UnquantLinear::new(
438                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
439                )?))
440            }
441        }
442    }
443
444    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
445        Some((self.w.clone(), self.b.clone()))
446    }
447
448    fn begin_track_stats(&mut self) -> Result<()> {
449        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
450        Ok(())
451    }
452
453    fn end_track_stats(&self) -> Result<Tensor> {
454        if let Some(stats) = &self.stats {
455            let imatrix = stats.compute_imatrix()?;
456            stats.clear()?;
457            Ok(imatrix)
458        } else {
459            hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
460        }
461    }
462}
463
464// Serialization structure:
465//
466// -----------------------
467// UQFF version, u32, little endian
468// -----------------------
469// ISQ type (1 for unquantized), u8, little endian
470// -----------------------
471// Whether bias data is included, u8 boolean
472// -----------------------
473// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
474// -----------------------
475// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
476// -----------------------
477
478impl QuantizedSerde for UnquantLinear {
479    fn isq_serde_supported(&self) -> bool {
480        true
481    }
482    fn name(&self) -> &'static str {
483        "unquant-linear"
484    }
485    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
486        self.serialize_with_bias(self.b.clone())
487    }
488    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
489        let mut buffer = Vec::new();
490
491        // Version is always first!
492
493        buffer.extend(&UQFF_VERSION.to_le_bytes());
494
495        // ISQ type for unquant is 1
496        buffer.push(QuantizedSerdeType::Unquant as u8);
497
498        // Has bias
499        buffer.push(bias.is_some() as u8);
500
501        // Weight
502        serialize_tensor(&mut buffer, &self.w)?;
503
504        if let Some(bias) = &bias {
505            // Bias
506            serialize_tensor(&mut buffer, bias)?;
507        }
508
509        Ok(Cow::from(buffer))
510    }
511
512    fn deserialize(
513        data: Cow<[u8]>,
514        device: &Device,
515        _comm: &Arc<crate::Comm>,
516        guard: QuantizeOntoGuard,
517    ) -> Result<Arc<dyn QuantMethod>>
518    where
519        Self: Sized,
520    {
521        let mut buffer = Cursor::new(data);
522
523        let version = buffer.read_u32::<LittleEndian>()?;
524        if let Err(e) = version_is_compatible(version) {
525            return Err(hanzo_ml::Error::wrap(e));
526        }
527
528        let isq_type = buffer.read_u8()? as usize;
529        if isq_type != QuantizedSerdeType::Unquant as usize {
530            hanzo_ml::bail!(
531                "ISQ type ({isq_type}) doesn't match expected type {}",
532                QuantizedSerdeType::Unquant as usize
533            );
534        }
535
536        let has_bias = buffer.read_u8()? != 0;
537
538        let _acquired_load_guard = guard.acquire(device);
539        let w = deserialize_tensor(&mut buffer, device)?;
540
541        let b = if has_bias {
542            Some(deserialize_tensor(&mut buffer, device)?)
543        } else {
544            None
545        };
546
547        Ok(Arc::new(Self { w, b, stats: None }))
548    }
549    fn deserialize_ext_bias(
550        data: Cow<[u8]>,
551        device: &Device,
552        guard: QuantizeOntoGuard,
553    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
554    where
555        Self: Sized,
556    {
557        let mut buffer = Cursor::new(data);
558
559        let version = buffer.read_u32::<LittleEndian>()?;
560        if let Err(e) = version_is_compatible(version) {
561            return Err(hanzo_ml::Error::wrap(e));
562        }
563
564        let isq_type = buffer.read_u8()? as usize;
565        if isq_type != QuantizedSerdeType::Unquant as usize {
566            hanzo_ml::bail!(
567                "ISQ type ({isq_type}) doesn't match expected type {}",
568                QuantizedSerdeType::Unquant as usize
569            );
570        }
571
572        let has_bias = buffer.read_u8()? != 0;
573
574        let _acquired_load_guard = guard.acquire(device);
575        let w = deserialize_tensor(&mut buffer, device)?;
576
577        let b = if has_bias {
578            Some(deserialize_tensor(&mut buffer, device)?)
579        } else {
580            None
581        };
582
583        Ok((
584            Arc::new(Self {
585                w,
586                b: None,
587                stats: None,
588            }),
589            b,
590        ))
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    fn test_layer(device: &Device) -> Result<UnquantLinear> {
599        let weight = Tensor::from_vec(
600            vec![1f32, 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1.],
601            (2, 2, 3),
602            device,
603        )?;
604        <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(Linear::new(
605            weight, None,
606        )))
607    }
608
609    #[test]
610    fn gather_forward_expands_single_route_input() -> Result<()> {
611        let device = Device::Cpu;
612        let layer = test_layer(&device)?;
613        let input = Tensor::from_vec(vec![1f32, 2., 3., 4., 5., 6.], (2, 1, 3), &device)?;
614        let indices = Tensor::from_vec(vec![0u32, 1, 1, 0], (2, 2), &device)?;
615
616        let output = layer.gather_forward(&input, &indices)?;
617
618        assert_eq!(output.dims(), &[2, 2, 2]);
619        assert_eq!(
620            output.flatten_all()?.to_vec1::<f32>()?,
621            &[1., 2., 3., 6., 6., 15., 4., 5.]
622        );
623        Ok(())
624    }
625
626    #[test]
627    fn gather_forward_accepts_per_route_input() -> Result<()> {
628        let device = Device::Cpu;
629        let layer = test_layer(&device)?;
630        let input = Tensor::from_vec(vec![1f32, 2., 3., 4., 5., 6.], (1, 2, 3), &device)?;
631        let indices = Tensor::from_vec(vec![0u32, 1], (1, 2), &device)?;
632
633        let output = layer.gather_forward(&input, &indices)?;
634
635        assert_eq!(output.dims(), &[1, 2, 2]);
636        assert_eq!(output.flatten_all()?.to_vec1::<f32>()?, &[1., 2., 6., 15.]);
637        Ok(())
638    }
639}