Skip to main content

rlx_oneapi/
backend.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// SPDX-License-Identifier: GPL-3.0-only
5
6//! `OneApiExecutable` — compile an IR graph for the Intel oneAPI Level Zero
7//! backend and execute it.
8//!
9//! Two execution paths share one legalized graph (the rlx-vulkan primitive set,
10//! so the same rewrite/legalize decompositions apply):
11//!
12//! - [`run_host`](OneApiExecutable::run_host) — a value-map interpreter that
13//!   evaluates every node through the `rlx-cpu` reference. This is the path the
14//!   macOS dev box / CI take (no Level Zero device), and it makes the backend
15//!   fully correct without Intel hardware.
16//! - [`run_l0`](OneApiExecutable::run_l0) — the native path: a USM-shared f32
17//!   arena + per-op SPIR-V kernel dispatch (with a CPU host-fallback, against
18//!   the same arena, for ops with no native kernel yet). Selected only when a
19//!   live device *and* embedded kernels are both present — neither is true off
20//!   an Intel build host, so it is compiled-but-dormant here, pending hardware
21//!   validation on Arc / Data Center Max.
22
23use crate::device::oneapi_device;
24use crate::host::{self, HostBuf};
25use crate::kernels::kernels;
26use rlx_compile::memory::{BufferSlot, MemoryPlan};
27use rlx_ir::op::Activation;
28use rlx_ir::{DType, Dim, Graph, NodeId, Op, RngOptions, Shape};
29use std::collections::HashMap;
30use std::ffi::c_void;
31
32/// OpKinds this backend lowers (claim set). Identical to rlx-vulkan's: the
33/// rewrite pass decomposes everything else into this primitive set, and the
34/// CPU reference covers every entry, so a legalized graph always executes.
35pub const SUPPORTED_OPS: &[rlx_ir::OpKind] = {
36    use rlx_ir::OpKind::*;
37    &[
38        Input,
39        Param,
40        Constant,
41        Cast,
42        StopGradient,
43        Reshape, // structural / alias
44        Binary,
45        Compare,
46        Where,
47        Activation, // elementwise
48        MatMul,
49        Reduce,
50        Softmax, // contraction / reduction
51        LayerNorm,
52        RmsNorm,
53        LayerNorm2d, // normalization
54        Rope,
55        Attention, // transformer
56        // Claimed first-class; `compile_rng` runs `unfuse_attention_block`
57        // to lower it to the primitive chain above before legalization.
58        FusedAttentionBlock,
59        Transpose,
60        Narrow,
61        Concat,
62        Expand,
63        Gather,
64        Cumsum,
65        Reverse, // shape / indexing
66        ArgMax,
67        ArgMin,
68        Pool,
69        ResizeNearest2x,
70        Conv,          // reductions / vision
71        GroupedMatMul, // MoE
72        SelectiveScan, // SSM / Mamba
73        Im2Col,
74        ScatterAdd,
75        TopK, // vision / indexing / generation
76        Lstm,
77        Gru,
78        Rnn,
79        Mamba2,
80        GatedDeltaNet,
81        ConvTranspose2d,
82        Fft,
83        DequantMatMul,
84        DequantGroupedMatMul,
85        DequantMoEWeights, // GGUF quant
86        RngNormal,
87        RngUniform,
88        Sample, // RNG / generation
89    ]
90};
91
92/// Ops with a native OpenCL-C SPIR-V kernel under `kernels/`. Everything else
93/// routes to the CPU host-fallback on the native path. The set grows as kernels
94/// land (next: layernorm, rope, gather, reduce, attention, then oneMKL gemm).
95fn native_kernel(op: &Op) -> Option<&'static str> {
96    match op {
97        Op::Binary(_) => Some("binary"),
98        Op::Activation(_) => Some("unary"),
99        Op::MatMul => Some("matmul"),
100        Op::Softmax { .. } => Some("softmax"),
101        Op::RmsNorm { .. } => Some("rmsnorm"),
102        _ => None,
103    }
104}
105
106#[derive(Clone)]
107enum ParamVal {
108    F32(Vec<f32>),
109    Bytes(Vec<u8>),
110}
111
112pub struct OneApiExecutable {
113    /// Post-legalize, f32-uniform graph.
114    graph: Graph,
115    params: HashMap<String, ParamVal>,
116    output_ids: Vec<NodeId>,
117    output_dtypes: Vec<DType>,
118    rng: RngOptions,
119    active_extent: Option<(usize, usize)>,
120}
121
122unsafe impl Send for OneApiExecutable {}
123
124impl OneApiExecutable {
125    pub fn compile(graph: Graph) -> Self {
126        Self::compile_rng(graph, RngOptions::default())
127    }
128
129    /// Legalize the graph to the native primitive set, then capture I/O maps.
130    pub fn compile_rng(graph: Graph, rng: RngOptions) -> Self {
131        use rlx_opt::pass::Pass as _;
132
133        let graph = rlx_opt::LowerControlFlow.run(graph);
134        // Decompose `FusedAttentionBlock` (claimed, but no monolithic
135        // kernel) to primitives before legalization. FAB-only; no-op when
136        // absent.
137        let graph = rlx_opt::unfuse::unfuse_attention_block(graph);
138        let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, SUPPORTED_OPS)
139            .unwrap_or_else(|errs| panic!("{}", rlx_opt::format_legalize_error("oneapi", &errs)));
140        let graph = rlx_opt::LegalizeBroadcast.run(graph);
141
142        let output_ids = graph.outputs.clone();
143        let output_dtypes = output_ids
144            .iter()
145            .map(|&id| graph.node(id).shape.dtype())
146            .collect();
147
148        Self {
149            graph,
150            params: HashMap::new(),
151            output_ids,
152            output_dtypes,
153            rng,
154            active_extent: None,
155        }
156    }
157
158    pub fn set_param(&mut self, name: &str, data: &[f32]) {
159        self.params
160            .insert(name.to_string(), ParamVal::F32(data.to_vec()));
161    }
162
163    pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
164        self.params
165            .insert(name.to_string(), ParamVal::Bytes(data.to_vec()));
166    }
167
168    pub fn output_dtypes(&self) -> Vec<DType> {
169        self.output_dtypes.clone()
170    }
171
172    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
173        self.active_extent = extent;
174    }
175
176    pub fn set_rng(&mut self, rng: RngOptions) {
177        self.rng = rng;
178    }
179
180    pub fn rng(&self) -> RngOptions {
181        self.rng
182    }
183
184    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
185        self.run_read_outputs(inputs, None)
186    }
187
188    pub fn run_read_outputs(
189        &mut self,
190        inputs: &[(&str, &[f32])],
191        read_indices: Option<&[usize]>,
192    ) -> Vec<Vec<f32>> {
193        // Native dispatch only when a live device AND embedded kernels exist;
194        // otherwise the CPU-reference interpreter (the dev-box / CI path).
195        if oneapi_device().is_some() && kernels().is_some() {
196            self.run_l0(inputs, read_indices)
197        } else {
198            self.run_host(inputs, read_indices)
199        }
200    }
201
202    // ── dev-box path: whole-graph CPU reference interpreter ────────────────
203
204    fn run_host(&self, inputs: &[(&str, &[f32])], read_indices: Option<&[usize]>) -> Vec<Vec<f32>> {
205        let in_map: HashMap<&str, &[f32]> = inputs.iter().copied().collect();
206        let mut f32v: HashMap<NodeId, Vec<f32>> = HashMap::new();
207        let mut bytev: HashMap<NodeId, Vec<u8>> = HashMap::new();
208
209        for node in self.graph.nodes() {
210            let numel = node.shape.num_elements().unwrap_or(0);
211            match &node.op {
212                Op::Input { name } => {
213                    let v = in_map
214                        .get(name.as_str())
215                        .map(|s| s.to_vec())
216                        .unwrap_or_else(|| vec![0.0; numel]);
217                    f32v.insert(node.id, v);
218                }
219                Op::Param { name } => match self.params.get(name) {
220                    Some(ParamVal::F32(v)) => {
221                        f32v.insert(node.id, v.clone());
222                    }
223                    Some(ParamVal::Bytes(b)) => {
224                        bytev.insert(node.id, b.clone());
225                    }
226                    None => {
227                        f32v.insert(node.id, vec![0.0; numel]);
228                    }
229                },
230                Op::Constant { data } => {
231                    if matches!(node.shape.dtype(), DType::U8 | DType::I8) {
232                        bytev.insert(node.id, data.clone());
233                    } else {
234                        f32v.insert(node.id, widen_const_to_f32(data, node.shape.dtype()));
235                    }
236                }
237                _ => {
238                    let in_specs: Vec<(Shape, HostBuf)> = node
239                        .inputs
240                        .iter()
241                        .map(|&id| {
242                            let sh = self.graph.node(id).shape.clone();
243                            let buf = if let Some(b) = bytev.get(&id) {
244                                HostBuf::Bytes(b.clone())
245                            } else {
246                                HostBuf::F32(f32v.get(&id).cloned().unwrap_or_default())
247                            };
248                            (sh, buf)
249                        })
250                        .collect();
251                    let out = host::eval(&node.op, &node.shape, &in_specs);
252                    f32v.insert(node.id, out);
253                }
254            }
255        }
256
257        self.read_outputs(read_indices, |id, n| {
258            f32v.get(&id)
259                .map(|v| v[..n.min(v.len())].to_vec())
260                .unwrap_or_else(|| vec![0.0; n])
261        })
262    }
263
264    // ── native path: USM arena + per-op SPIR-V dispatch (HW-pending) ───────
265
266    fn run_l0(
267        &mut self,
268        inputs: &[(&str, &[f32])],
269        read_indices: Option<&[usize]>,
270    ) -> Vec<Vec<f32>> {
271        let dev = oneapi_device().expect("rlx-oneapi: no device");
272        let kerns = kernels().expect("rlx-oneapi: no kernels");
273
274        let plan = plan_f32_uniform(&self.graph, 64);
275        let arena = match crate::arena::Arena::from_plan(&plan) {
276            Ok(a) => a,
277            // Allocation failed on the device — fall back to the CPU path so we
278            // still return correct results rather than panic.
279            Err(_) => return self.run_host(inputs, read_indices),
280        };
281
282        // Upload constants, params, inputs into the USM arena.
283        for node in self.graph.nodes() {
284            match &node.op {
285                Op::Constant { data } if arena.has(node.id) && !data.is_empty() => {
286                    if matches!(node.shape.dtype(), DType::U8 | DType::I8) {
287                        arena.write_bytes(node.id, data);
288                    } else {
289                        arena.write_f32(node.id, &widen_const_to_f32(data, node.shape.dtype()));
290                    }
291                }
292                Op::Param { name } => match self.params.get(name) {
293                    Some(ParamVal::F32(v)) => arena.write_f32(node.id, v),
294                    Some(ParamVal::Bytes(b)) => arena.write_bytes(node.id, b),
295                    None => {}
296                },
297                _ => {}
298            }
299        }
300        let in_map: HashMap<&str, &[f32]> = inputs.iter().copied().collect();
301        for node in self.graph.nodes() {
302            if let Op::Input { name } = &node.op {
303                if let Some(data) = in_map.get(name.as_str()) {
304                    arena.write_f32(node.id, data);
305                }
306            }
307        }
308
309        // Execute node-by-node: native kernel where available, else CPU
310        // host-fallback against the (host-coherent) USM arena.
311        let list = dev.create_command_list().expect("rlx-oneapi: command list");
312        for node in self.graph.nodes() {
313            if matches!(
314                node.op,
315                Op::Input { .. }
316                    | Op::Param { .. }
317                    | Op::Constant { .. }
318                    | Op::Reshape { .. }
319                    | Op::Cast { .. }
320                    | Op::StopGradient
321            ) {
322                continue;
323            }
324            match native_kernel(&node.op) {
325                Some(name) => self.dispatch(dev, kerns, list, name, node, &arena),
326                None => {
327                    // Read inputs out of the arena, eval on CPU, write back.
328                    let in_specs: Vec<(Shape, HostBuf)> = node
329                        .inputs
330                        .iter()
331                        .map(|&id| {
332                            let sh = self.graph.node(id).shape.clone();
333                            let nn = sh.num_elements().unwrap_or(0);
334                            let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
335                                HostBuf::Bytes(arena.read_bytes(id, nn))
336                            } else {
337                                HostBuf::F32(arena.read_f32(id, nn))
338                            };
339                            (sh, buf)
340                        })
341                        .collect();
342                    let out = host::eval(&node.op, &node.shape, &in_specs);
343                    arena.write_f32(node.id, &out);
344                }
345            }
346        }
347        dev.execute_sync(list).expect("rlx-oneapi: execute");
348        unsafe {
349            let _ = (dev.lib.command_list_destroy)(list);
350        }
351
352        self.read_outputs(read_indices, |id, n| arena.read_f32(id, n))
353    }
354
355    /// Set kernel arguments (arg 0 = arena base pointer, then scalars) and
356    /// append a launch onto `list`. Arg layouts match `kernels/<name>.cl`.
357    fn dispatch(
358        &self,
359        dev: &crate::device::OneApiDevice,
360        kerns: &crate::kernels::Kernels,
361        list: crate::level_zero::CommandListHandle,
362        name: &str,
363        node: &rlx_ir::Node,
364        arena: &crate::arena::Arena,
365    ) {
366        let Some(kernel) = kerns.get(name) else {
367            return;
368        };
369        let off = |id: NodeId| arena.elem_offset(id);
370        let out = node.id;
371        let mut args: Vec<KArg> = vec![KArg::Ptr(arena.base_ptr())];
372        let (global, local): (usize, u32) = match &node.op {
373            Op::Binary(op) => {
374                let a = node.inputs[0];
375                let b = node.inputs[1];
376                let n = numel(&dims(&self.graph, out));
377                let an = numel(&dims(&self.graph, a));
378                let bn = numel(&dims(&self.graph, b));
379                args.extend([
380                    KArg::U32(n as u32),
381                    KArg::U32(off(a)),
382                    KArg::U32(off(b)),
383                    KArg::U32(off(out)),
384                    KArg::U32(if an == n { 0 } else { an as u32 }),
385                    KArg::U32(if bn == n { 0 } else { bn as u32 }),
386                    KArg::U32(binop_id(*op)),
387                ]);
388                (n, 256)
389            }
390            Op::Activation(act) => {
391                let x = node.inputs[0];
392                let n = numel(&dims(&self.graph, out));
393                args.extend([
394                    KArg::U32(n as u32),
395                    KArg::U32(off(x)),
396                    KArg::U32(off(out)),
397                    KArg::U32(act_id(*act)),
398                ]);
399                (n, 256)
400            }
401            Op::MatMul => {
402                let a = node.inputs[0];
403                let b = node.inputs[1];
404                let ad = dims(&self.graph, a);
405                let bd = dims(&self.graph, b);
406                let od = dims(&self.graph, out);
407                let (m, k) = (ad[ad.len() - 2], ad[ad.len() - 1]);
408                let n = bd[bd.len() - 1];
409                let batch = if od.len() > 2 {
410                    numel(&od[..od.len() - 2])
411                } else {
412                    1
413                };
414                let a_batch = if ad.len() > 2 {
415                    numel(&ad[..ad.len() - 2])
416                } else {
417                    1
418                };
419                let b_batch = if bd.len() > 2 {
420                    numel(&bd[..bd.len() - 2])
421                } else {
422                    1
423                };
424                let a_bs = if a_batch <= 1 { 0 } else { m * k };
425                let b_bs = if b_batch <= 1 { 0 } else { k * n };
426                args.extend([
427                    KArg::U32(m as u32),
428                    KArg::U32(k as u32),
429                    KArg::U32(n as u32),
430                    KArg::U32(off(a)),
431                    KArg::U32(off(b)),
432                    KArg::U32(off(out)),
433                    KArg::U32(batch as u32),
434                    KArg::U32(a_bs as u32),
435                    KArg::U32(b_bs as u32),
436                    KArg::U32((m * n) as u32),
437                ]);
438                (batch.max(1) * m * n, 64)
439            }
440            Op::Softmax { axis } => {
441                let x = node.inputs[0];
442                let xd = dims(&self.graph, x);
443                let ax = norm_axis(*axis, xd.len());
444                let axis_len = xd[ax];
445                let outer = numel(&xd[..ax]);
446                let inner = numel(&xd[ax + 1..]);
447                args.extend([
448                    KArg::U32(outer as u32),
449                    KArg::U32(axis_len as u32),
450                    KArg::U32(inner as u32),
451                    KArg::U32(off(x)),
452                    KArg::U32(off(out)),
453                ]);
454                (outer * inner, 256)
455            }
456            Op::RmsNorm { axis, eps } => {
457                let x = node.inputs[0];
458                let gamma = node.inputs[1];
459                let beta = node.inputs[2];
460                let xd = dims(&self.graph, x);
461                let ax = norm_axis(*axis, xd.len());
462                let n = xd[ax];
463                let rows = numel(&xd) / n.max(1);
464                args.extend([
465                    KArg::U32(rows as u32),
466                    KArg::U32(n as u32),
467                    KArg::U32(off(x)),
468                    KArg::U32(off(gamma)),
469                    KArg::U32(off(beta)),
470                    KArg::U32(off(out)),
471                    KArg::F32(*eps),
472                ]);
473                (rows, 64)
474            }
475            _ => return,
476        };
477
478        unsafe {
479            let _ = (dev.lib.kernel_set_group_size)(kernel, local, 1, 1);
480            for (i, a) in args.iter().enumerate() {
481                let (size, ptr) = a.as_arg();
482                let _ = (dev.lib.kernel_set_argument_value)(kernel, i as u32, size, ptr);
483            }
484            let groups = crate::level_zero::GroupCount {
485                group_count_x: ceil_div(global, local).max(1),
486                group_count_y: 1,
487                group_count_z: 1,
488            };
489            let _ = (dev.lib.command_list_append_launch_kernel)(
490                list,
491                kernel,
492                &groups,
493                std::ptr::null_mut(),
494                0,
495                std::ptr::null_mut(),
496            );
497            // Each kernel reads/writes the shared arena; barrier between launches.
498            let _ = (dev.lib.command_list_append_barrier)(
499                list,
500                std::ptr::null_mut(),
501                0,
502                std::ptr::null_mut(),
503            );
504        }
505    }
506
507    fn read_outputs(
508        &self,
509        read_indices: Option<&[usize]>,
510        mut read: impl FnMut(NodeId, usize) -> Vec<f32>,
511    ) -> Vec<Vec<f32>> {
512        let want: Vec<usize> = match read_indices {
513            Some(ix) => ix.to_vec(),
514            None => (0..self.output_ids.len()).collect(),
515        };
516        want.into_iter()
517            .filter_map(|i| {
518                let id = *self.output_ids.get(i)?;
519                let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
520                Some(read(id, n))
521            })
522            .collect()
523    }
524
525    /// Deep copy for the runtime's executable cache: fresh state with the same
526    /// legalized graph + uploaded params.
527    pub fn clone_for_cache(&self) -> Self {
528        Self {
529            graph: self.graph.clone(),
530            params: self.params.clone(),
531            output_ids: self.output_ids.clone(),
532            output_dtypes: self.output_dtypes.clone(),
533            rng: self.rng,
534            active_extent: self.active_extent,
535        }
536    }
537}
538
539// ── kernel-argument helper ─────────────────────────────────────────────────
540
541enum KArg {
542    Ptr(*mut c_void),
543    U32(u32),
544    F32(f32),
545}
546
547impl KArg {
548    /// `(argSize, pArgValue)` for `zeKernelSetArgumentValue`. The returned
549    /// pointer borrows `self`, so it must be consumed before `self` drops —
550    /// callers use it immediately inside the set-arg loop.
551    fn as_arg(&self) -> (usize, *const c_void) {
552        match self {
553            KArg::Ptr(p) => (
554                std::mem::size_of::<*mut c_void>(),
555                p as *const *mut c_void as *const c_void,
556            ),
557            KArg::U32(v) => (4, v as *const u32 as *const c_void),
558            KArg::F32(v) => (4, v as *const f32 as *const c_void),
559        }
560    }
561}
562
563// ── memory plan (f32-uniform bump allocator; same as rlx-vulkan) ───────────
564
565fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
566    let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
567    let mut schedule = Vec::with_capacity(graph.nodes().len());
568    let mut cursor = 0usize;
569    for node in graph.nodes() {
570        if matches!(
571            node.op,
572            Op::Reshape { .. } | Op::Cast { .. } | Op::StopGradient
573        ) {
574            if let Some(in_id) = node.inputs.first() {
575                if let Some(slot) = assignments.get(in_id) {
576                    let aliased = slot.clone();
577                    assignments.insert(node.id, aliased);
578                    schedule.push(node.id);
579                    continue;
580                }
581            }
582        }
583        let elems = node.shape.num_elements().unwrap_or(0);
584        let bytes = (elems * 4).max(4);
585        let aligned = bytes.div_ceil(align) * align;
586        assignments.insert(
587            node.id,
588            BufferSlot {
589                offset: cursor,
590                size: aligned,
591            },
592        );
593        schedule.push(node.id);
594        cursor += aligned;
595    }
596    MemoryPlan {
597        arena_size: cursor.max(align),
598        assignments,
599        schedule,
600    }
601}
602
603// ── small shape helpers (shared with the dispatch builder) ─────────────────
604
605fn dims(graph: &Graph, id: NodeId) -> Vec<usize> {
606    graph
607        .node(id)
608        .shape
609        .dims()
610        .iter()
611        .map(|d| match d {
612            Dim::Static(s) => *s,
613            _ => 0,
614        })
615        .collect()
616}
617
618fn numel(d: &[usize]) -> usize {
619    d.iter()
620        .product::<usize>()
621        .max(if d.is_empty() { 1 } else { 0 })
622}
623
624fn norm_axis(axis: i32, rank: usize) -> usize {
625    if axis < 0 {
626        (rank as i32 + axis).max(0) as usize
627    } else {
628        (axis as usize).min(rank.saturating_sub(1))
629    }
630}
631
632fn ceil_div(n: usize, d: u32) -> u32 {
633    (n as u64).div_ceil(d as u64) as u32
634}
635
636fn act_id(a: Activation) -> u32 {
637    match a {
638        Activation::Gelu => 0,
639        Activation::GeluApprox => 1,
640        Activation::Silu => 2,
641        Activation::Relu => 3,
642        Activation::Sigmoid => 4,
643        Activation::Tanh => 5,
644        Activation::Exp => 6,
645        Activation::Log => 7,
646        Activation::Sqrt => 8,
647        Activation::Rsqrt => 9,
648        Activation::Neg => 10,
649        Activation::Abs => 11,
650        Activation::Sin => 12,
651        Activation::Cos => 13,
652        Activation::Tan => 14,
653        Activation::Atan => 15,
654        Activation::Round => 16,
655    }
656}
657
658fn binop_id(op: rlx_ir::op::BinaryOp) -> u32 {
659    use rlx_ir::op::BinaryOp::*;
660    match op {
661        Add => 0,
662        Sub => 1,
663        Mul => 2,
664        Div => 3,
665        Max => 4,
666        Min => 5,
667        Pow => 6,
668    }
669}
670
671/// Widen a constant byte blob (any IR dtype) to f32 for the f32-uniform arena.
672fn widen_const_to_f32(data: &[u8], dt: DType) -> Vec<f32> {
673    match dt {
674        DType::F32 => data
675            .chunks_exact(4)
676            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
677            .collect(),
678        DType::F16 => data
679            .chunks_exact(2)
680            .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
681            .collect(),
682        DType::BF16 => data
683            .chunks_exact(2)
684            .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
685            .collect(),
686        DType::F64 => data
687            .chunks_exact(8)
688            .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
689            .collect(),
690        DType::I64 => data
691            .chunks_exact(8)
692            .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
693            .collect(),
694        DType::I32 | DType::U32 => data
695            .chunks_exact(4)
696            .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
697            .collect(),
698        DType::I16 => data
699            .chunks_exact(2)
700            .map(|c| i16::from_le_bytes([c[0], c[1]]) as f32)
701            .collect(),
702        DType::I8 => data.iter().map(|&b| b as i8 as f32).collect(),
703        DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
704        DType::C64 => data
705            .chunks_exact(4)
706            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
707            .collect(),
708    }
709}