Skip to main content

burn_mpsgraph/ops/
float_tensor.rs

1use burn_backend::ops::FloatTensorOps;
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, Distribution, ExecutionError, Scalar, TensorData};
4use burn_std::{FloatDType, Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11// ── Macros for common op patterns ──────────────────────────────────────────
12
13/// Unary op using a named MPSGraph selector: graph.$sel(tensor, nil).
14macro_rules! unary_op {
15    ($fn_name:ident, $sel:expr) => {
16        fn $fn_name(tensor: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
17            bridge::run_unary(&tensor, |g, ph| unsafe { ffi::graph_unary(g, $sel, ph) })
18        }
19    };
20}
21
22/// Binary op: graph.$sel(lhs, rhs, nil).
23macro_rules! binary_op {
24    ($fn_name:ident, $sel:expr) => {
25        fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
26            bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
27        }
28    };
29}
30
31/// Scalar op: create constant, then binary.
32macro_rules! scalar_op {
33    ($fn_name:ident, $sel:expr) => {
34        fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> FloatTensor<MpsGraph> {
35            bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
36                let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
37                ffi::graph_binary(g, $sel, ph, s)
38            })
39        }
40    };
41}
42
43/// Comparison op returning bool tensor.
44macro_rules! cmp_op {
45    ($fn_name:ident, $sel:expr) => {
46        fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
47            bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
48        }
49    };
50}
51
52macro_rules! cmp_scalar_op {
53    ($fn_name:ident, $sel:expr) => {
54        fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> BoolTensor<MpsGraph> {
55            bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
56                let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
57                ffi::graph_binary(g, $sel, ph, s)
58            })
59        }
60    };
61}
62
63pub(crate) const ADD: &str = "additionWithPrimaryTensor:secondaryTensor:name:";
64pub(crate) const SUB: &str = "subtractionWithPrimaryTensor:secondaryTensor:name:";
65pub(crate) const MUL: &str = "multiplicationWithPrimaryTensor:secondaryTensor:name:";
66pub(crate) const DIV: &str = "divisionWithPrimaryTensor:secondaryTensor:name:";
67pub(crate) const MOD: &str = "moduloWithPrimaryTensor:secondaryTensor:name:";
68pub(crate) const POW: &str = "powerWithPrimaryTensor:secondaryTensor:name:";
69pub(crate) const EQ:  &str = "equalWithPrimaryTensor:secondaryTensor:name:";
70pub(crate) const GT:  &str = "greaterThanWithPrimaryTensor:secondaryTensor:name:";
71pub(crate) const GTE: &str = "greaterThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
72pub(crate) const LT:  &str = "lessThanWithPrimaryTensor:secondaryTensor:name:";
73pub(crate) const LTE: &str = "lessThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
74
75impl FloatTensorOps<MpsGraph> for MpsGraph {
76    fn float_from_data(data: TensorData, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
77        bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
78    }
79
80    fn float_random(shape: Shape, distribution: Distribution, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
81        let n = shape.num_elements();
82        let mut buf = vec![0f32; n];
83        let mut rng = crate::ops::get_seeded_rng();
84        use rand::Rng;
85        match distribution {
86            Distribution::Default => buf.iter_mut().for_each(|v| *v = rng.gen_range(0.0..1.0)),
87            Distribution::Bernoulli(p) => buf.iter_mut().for_each(|v| *v = if rng.gen_range(0.0..1.0) < p { 1.0 } else { 0.0 }),
88            Distribution::Uniform(lo, hi) => buf.iter_mut().for_each(|v| *v = rng.gen_range(lo as f32..hi as f32)),
89            Distribution::Normal(mu, sigma) => buf.iter_mut().for_each(|v| {
90                let u1: f64 = rng.gen_range(1e-7..1.0);
91                let u2: f64 = rng.gen_range(0.0..1.0);
92                *v = (mu + sigma * (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()) as f32;
93            }),
94        }
95        let bytes = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, n * 4) };
96        bridge::tensor_from_bytes(bytes, shape, burn_backend::DType::F32, *device)
97    }
98
99    fn float_into_data(t: FloatTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
100        async move {
101            let bytes = bridge::tensor_to_bytes(&t);
102            Ok(TensorData::from_bytes_vec(bytes, t.shape.clone(), t.dtype))
103        }
104    }
105
106    fn float_device(t: &FloatTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
107
108    fn float_to_device(t: FloatTensor<MpsGraph>, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
109        { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *device } }
110    }
111
112    fn float_into_int(t: FloatTensor<MpsGraph>) -> IntTensor<MpsGraph> {
113        let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe {
114            ffi::graph_cast(g, ph, ffi::MPSDataType::INT32)
115        });
116        r.dtype = burn_backend::DType::I32;
117        r
118    }
119
120    fn float_empty(shape: Shape, device: &Device<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
121        let dt: burn_backend::DType = dtype.into();
122        bridge::tensor_zeros(shape, dt, *device)
123    }
124
125    // ── Arithmetic ──────────────────────────────────────────────────────
126    binary_op!(float_add, ADD);
127    scalar_op!(float_add_scalar, ADD);
128    binary_op!(float_sub, SUB);
129    scalar_op!(float_sub_scalar, SUB);
130    binary_op!(float_mul, MUL);
131    scalar_op!(float_mul_scalar, MUL);
132    binary_op!(float_div, DIV);
133    scalar_op!(float_div_scalar, DIV);
134    binary_op!(float_remainder, MOD);
135    scalar_op!(float_remainder_scalar, MOD);
136    binary_op!(float_powf, POW);
137
138    fn float_powf_scalar_impl(t: FloatTensor<MpsGraph>, v: Scalar) -> FloatTensor<MpsGraph> {
139        bridge::run_unary_ctx(&t, |g, ph| unsafe {
140            let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
141            ffi::graph_binary(g, POW, ph, s)
142        })
143    }
144
145    fn float_matmul(a: FloatTensor<MpsGraph>, b: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
146        bridge::run_binary(&a, &b, |g, pa, pb| unsafe { ffi::graph_matmul(g, pa, pb) })
147    }
148
149    fn float_recip(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
150        unary_op_impl(&t, "reciprocalWithTensor:name:")
151    }
152
153    fn float_cross(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
154        let sl = |shape: &Shape, idx: usize| -> Vec<Slice> {
155            (0..shape.num_dims()).map(|d| {
156                if d == dim { Slice::new(idx as isize, Some((idx+1) as isize), 1) }
157                else        { Slice::new(0, Some(shape[d] as isize), 1) }
158            }).collect()
159        };
160        let (a0,a1,a2) = (Self::float_slice(lhs.clone(), &sl(&lhs.shape,0)),
161                           Self::float_slice(lhs.clone(), &sl(&lhs.shape,1)),
162                           Self::float_slice(lhs.clone(), &sl(&lhs.shape,2)));
163        let (b0,b1,b2) = (Self::float_slice(rhs.clone(), &sl(&rhs.shape,0)),
164                           Self::float_slice(rhs.clone(), &sl(&rhs.shape,1)),
165                           Self::float_slice(rhs.clone(), &sl(&rhs.shape,2)));
166        let c0 = Self::float_sub(Self::float_mul(a1.clone(),b2.clone()), Self::float_mul(a2.clone(),b1.clone()));
167        let c1 = Self::float_sub(Self::float_mul(a2,b0.clone()), Self::float_mul(a0.clone(),b2));
168        let c2 = Self::float_sub(Self::float_mul(a0,b1), Self::float_mul(a1,b0));
169        Self::float_cat(vec![c0,c1,c2], dim)
170    }
171
172    // ── Comparisons ─────────────────────────────────────────────────────
173    cmp_op!(float_equal, EQ);
174    cmp_scalar_op!(float_equal_elem, EQ);
175    cmp_op!(float_greater, GT);
176    cmp_scalar_op!(float_greater_elem, GT);
177    cmp_op!(float_greater_equal, GTE);
178    cmp_scalar_op!(float_greater_equal_elem, GTE);
179    cmp_op!(float_lower, LT);
180    cmp_scalar_op!(float_lower_elem, LT);
181    cmp_op!(float_lower_equal, LTE);
182    cmp_scalar_op!(float_lower_equal_elem, LTE);
183
184    // ── Unary math ──────────────────────────────────────────────────────
185    unary_op!(float_exp,   "exponentWithTensor:name:");
186    unary_op!(float_log,   "logarithmWithTensor:name:");
187    unary_op!(float_sqrt,  "squareRootWithTensor:name:");
188    unary_op!(float_abs,   "absoluteWithTensor:name:");
189    unary_op!(float_cos,   "cosWithTensor:name:");
190    unary_op!(float_sin,   "sinWithTensor:name:");
191    unary_op!(float_tan,   "tanWithTensor:name:");
192    unary_op!(float_cosh,  "coshWithTensor:name:");
193    unary_op!(float_sinh,  "sinhWithTensor:name:");
194    unary_op!(float_tanh,  "tanhWithTensor:name:");
195    unary_op!(float_acos,  "acosWithTensor:name:");
196    unary_op!(float_acosh, "acoshWithTensor:name:");
197    unary_op!(float_asin,  "asinWithTensor:name:");
198    unary_op!(float_asinh, "asinhWithTensor:name:");
199    unary_op!(float_atan,  "atanWithTensor:name:");
200    unary_op!(float_atanh, "atanhWithTensor:name:");
201    unary_op!(float_erf,   "erfWithTensor:name:");
202    unary_op!(float_floor, "floorWithTensor:name:");
203    unary_op!(float_ceil,  "ceilWithTensor:name:");
204    unary_op!(float_round, "rintWithTensor:name:");
205
206    fn float_atan2(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
207        bridge::run_binary(&lhs, &rhs, |g,a,b| unsafe { ffi::graph_binary(g, "atan2WithPrimaryTensor:secondaryTensor:name:", a, b) })
208    }
209
210    fn float_log1p(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
211        Self::float_log(Self::float_add_scalar(t, 1.0f32.into()))
212    }
213
214    fn float_trunc(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
215        bridge::run_unary_ctx(&t, |g, ph| unsafe {
216            let abs = ffi::graph_unary(g, "absoluteWithTensor:name:", ph);
217            let fl  = ffi::graph_unary(g, "floorWithTensor:name:", abs);
218            let sgn = ffi::graph_unary(g, "signWithTensor:name:", ph);
219            ffi::graph_binary(g, MUL, sgn, fl)
220        })
221    }
222
223    // ── Shape ops ───────────────────────────────────────────────────────
224
225    fn float_swap_dims(t: FloatTensor<MpsGraph>, d1: usize, d2: usize) -> FloatTensor<MpsGraph> {
226        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_transpose(g, ph, d1, d2) })
227    }
228
229    fn float_permute(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
230        let perm_ns = unsafe { ffi::ns_usize_array(axes) };
231        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_permute(g, ph, perm_ns) })
232    }
233
234    fn float_flip(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
235        let nd = t.shape.num_dims();
236        let shape = &t.shape;
237        let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize - 1 } else { 0 }).collect();
238        let ends:   Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize) - 1 } else { shape[d] as isize }).collect();
239        let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
240        bridge::run_unary_ctx(&t, |g, ph| unsafe {
241            ffi::graph_slice_masked(g, ph,
242                ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends),
243                ffi::ns_isize_array(&strides), 0, 0, 0)
244        })
245    }
246
247    fn float_reshape(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
248        let ns = bridge::shape_to_ns(&shape);
249        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reshape(g, ph, ns) })
250    }
251
252    fn float_expand(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
253        let ns = bridge::shape_to_ns(&shape);
254        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_broadcast(g, ph, ns) })
255    }
256
257    fn float_slice(t: FloatTensor<MpsGraph>, slices: &[Slice]) -> FloatTensor<MpsGraph> {
258        let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
259        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_slice(g, ph, sa, ea, st) })
260    }
261
262    fn float_slice_assign(t: FloatTensor<MpsGraph>, slices: &[Slice], value: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
263        let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
264        bridge::run_binary_ctx(&t, &value, |g, pd, pu| unsafe {
265            ffi::graph_slice_update(g, pd, pu, sa, ea, st)
266        })
267    }
268
269    fn float_cat(tensors: Vec<FloatTensor<MpsGraph>>, dim: usize) -> FloatTensor<MpsGraph> {
270        if tensors.len() == 1 { return tensors.into_iter().next().unwrap(); }
271        let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
272        bridge::run_multi_ctx(&refs, tensors[0].device, |g, phs| unsafe {
273            let arr = ffi::ns_array(phs);
274            ffi::graph_concat(g, arr, dim as isize)
275        })
276    }
277
278    // ── Gather / Scatter ────────────────────────────────────────────────
279
280    fn float_gather(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
281        bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g, a, b, dim, 0) })
282    }
283
284    fn float_scatter_add(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
285        bridge::run_multi_ctx(&[&t, &idx, &val], t.device, |g, phs| unsafe {
286            ffi::graph_scatter_along(g, dim as isize, phs[0], phs[2], phs[1], ffi::MPSGraphScatterMode::ADD)
287        })
288    }
289
290    fn float_select(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
291        Self::float_gather(dim, t, idx)
292    }
293
294    fn float_select_add(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
295        Self::float_scatter_add(dim, t, idx, val)
296    }
297
298    // ── Mask ────────────────────────────────────────────────────────────
299
300    fn float_mask_where(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
301        bridge::run_ternary(&t, &mask, &val, |g, pt, pm, pv| unsafe { ffi::graph_select(g, pm, pv, pt) })
302    }
303
304    fn float_mask_fill(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: Scalar) -> FloatTensor<MpsGraph> {
305        bridge::run_binary_ctx(&t, &mask, |g, pt, pm| unsafe {
306            let s = ffi::graph_constant_scalar(g, val.elem::<f64>(), burn_to_mps_dtype(t.dtype));
307            ffi::graph_select(g, pm, s, pt)
308        })
309    }
310
311    // ── Reductions ──────────────────────────────────────────────────────
312
313    fn float_sum(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
314        let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
315        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum(g, ph, ffi::ns_isize_array(&axes)) })
316    }
317
318    fn float_sum_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
319        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum_axis(g, ph, dim as isize) })
320    }
321
322    fn float_mean_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
323        let n = t.shape[dim] as f64;
324        let sum = Self::float_sum_dim(t, dim);
325        Self::float_div_scalar(sum, Scalar::from(n as f32))
326    }
327
328    fn float_argmax(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
329        let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmax(g, ph, dim as isize) });
330        r.dtype = DType::I32; r
331    }
332
333    fn float_argmin(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
334        let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmin(g, ph, dim as isize) });
335        r.dtype = DType::I32; r
336    }
337
338    // ── Cumulative ──────────────────────────────────────────────────────
339
340    fn float_cumsum(t: FloatTensor<MpsGraph>, dim: usize)  -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumsum(g,ph,dim as isize) }) }
341    fn float_cumprod(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumprod(g,ph,dim as isize) }) }
342    fn float_cummin(t: FloatTensor<MpsGraph>, dim: usize)  -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummin(g,ph,dim as isize) }) }
343    fn float_cummax(t: FloatTensor<MpsGraph>, dim: usize)  -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummax(g,ph,dim as isize) }) }
344
345    // ── Sort ────────────────────────────────────────────────────────────
346
347    fn float_sort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> FloatTensor<MpsGraph> {
348        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_sort(g, ph, dim as isize, desc) })
349    }
350
351    fn float_argsort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
352        let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argsort(g, ph, dim as isize, desc) });
353        r.dtype = DType::I32; r
354    }
355
356    // ── Cast ────────────────────────────────────────────────────────────
357
358    fn float_cast(t: FloatTensor<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
359        let dt: DType = dtype.into();
360        if t.dtype == dt { return t; }
361        let mps = burn_to_mps_dtype(dt);
362        let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_cast(g, ph, mps) });
363        r.dtype = dt; r
364    }
365
366    // ── Reductions with native MPSGraph ops ─────────────────────────────
367
368    fn float_prod(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
369        let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
370        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod(g, ph, ffi::ns_isize_array(&axes)) })
371    }
372
373    fn float_prod_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
374        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod_axis(g, ph, dim as isize) })
375    }
376
377    fn float_max_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
378        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_max_axis(g, ph, dim as isize) })
379    }
380
381    fn float_min_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
382        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_min_axis(g, ph, dim as isize) })
383    }
384
385    // ── Unfold ──────────────────────────────────────────────────────────
386
387    fn float_unfold(t: FloatTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> FloatTensor<MpsGraph> {
388        let dim_size = t.shape[dim];
389        let n_win = (dim_size.saturating_sub(size)) / step + 1;
390        let mut windows = Vec::with_capacity(n_win);
391        for i in 0..n_win {
392            let start = i * step;
393            let slices: Vec<Slice> = (0..t.shape.num_dims()).map(|d| {
394                if d == dim { Slice::new(start as isize, Some((start+size) as isize), 1) }
395                else        { Slice::new(0, Some(t.shape[d] as isize), 1) }
396            }).collect();
397            let w = Self::float_slice(t.clone(), &slices);
398            let mut dims: Vec<usize> = (0..w.shape.num_dims()).map(|d| w.shape[d]).collect();
399            dims[dim] = 1; dims.push(size);
400            windows.push(Self::float_reshape(w, Shape::from(dims)));
401        }
402        Self::float_cat(windows, dim)
403    }
404}
405
406/// Helper for standalone unary ops.
407fn unary_op_impl(t: &MpsGraphTensor, sel: &'static str) -> MpsGraphTensor {
408    bridge::run_unary_ctx(t, |g, ph| unsafe { ffi::graph_unary(g, sel, ph) })
409}