Skip to main content

flodl/
tensor.rs

1//! Tensor — immutable, chainable wrapper around a libtorch tensor.
2//!
3//! Every tensor owns its C++ handle and frees it on drop. This is the
4//! entire VRAM management story — no GC, no scopes, no finalizers.
5//!
6//! Operations are chainable and return `Result<Tensor>`:
7//!
8//! ```ignore
9//! let z = a.add(&b)?.relu()?.sum()?;
10//! ```
11
12use std::ffi::{c_void, CStr};
13use std::fmt;
14use std::ptr;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17use flodl_sys::{self as ffi, FlodlTensor};
18
19/// Global counter of live C++ Tensor handles. Incremented on creation,
20/// decremented on Drop. If this grows over time during training, there
21/// is a Tensor handle leak. If it stays stable but RSS grows, the leak
22/// is inside libtorch internals (not a handle leak).
23static LIVE_TENSOR_COUNT: AtomicU64 = AtomicU64::new(0);
24
25/// Element data type of a tensor. Maps to PyTorch's `torch.dtype`.
26///
27/// Float32 is the default. Use Float16/BFloat16 for mixed precision,
28/// Int64 for indices and labels, Float64 when extra precision is needed.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30#[repr(i32)]
31pub enum DType {
32    Float16 = ffi::FLODL_FLOAT16,
33    BFloat16 = ffi::FLODL_BFLOAT16,
34    Float32 = ffi::FLODL_FLOAT32,
35    Float64 = ffi::FLODL_FLOAT64,
36    Int32 = ffi::FLODL_INT32,
37    Int64 = ffi::FLODL_INT64,
38}
39
40impl DType {
41    fn from_raw(v: i32) -> Self {
42        match v {
43            ffi::FLODL_FLOAT16 => DType::Float16,
44            ffi::FLODL_BFLOAT16 => DType::BFloat16,
45            ffi::FLODL_FLOAT32 => DType::Float32,
46            ffi::FLODL_FLOAT64 => DType::Float64,
47            ffi::FLODL_INT32 => DType::Int32,
48            ffi::FLODL_INT64 => DType::Int64,
49            _ => DType::Float32,
50        }
51    }
52
53    /// Size of one element in bytes.
54    pub fn element_size(self) -> usize {
55        match self {
56            DType::Float16 | DType::BFloat16 => 2,
57            DType::Float32 | DType::Int32 => 4,
58            DType::Float64 | DType::Int64 => 8,
59        }
60    }
61}
62
63/// Device represents where a tensor's data lives.
64///
65/// `Device::CPU` is the host. `Device::CUDA(n)` is GPU index `n`.
66/// Most single-GPU code uses `Device::CUDA(0)`.
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub enum Device {
69    CPU,
70    CUDA(u8),
71}
72
73impl Device {
74    /// Convert to (device_type, device_index) for FFI calls.
75    pub(crate) fn to_ffi(self) -> (i32, i32) {
76        match self {
77            Device::CPU => (ffi::FLODL_CPU, 0),
78            Device::CUDA(idx) => (ffi::FLODL_CUDA, idx as i32),
79        }
80    }
81
82    /// Reconstruct from FFI (device_type, device_index).
83    pub(crate) fn from_ffi(device_type: i32, device_index: i32) -> Self {
84        match device_type {
85            ffi::FLODL_CUDA => Device::CUDA(device_index as u8),
86            _ => Device::CPU,
87        }
88    }
89
90    /// Whether this is a CUDA device.
91    pub fn is_cuda(&self) -> bool {
92        matches!(self, Device::CUDA(_))
93    }
94
95    /// Device index (0 for CPU, GPU index for CUDA).
96    pub fn index(&self) -> u8 {
97        match self {
98            Device::CPU => 0,
99            Device::CUDA(idx) => *idx,
100        }
101    }
102}
103
104impl fmt::Display for Device {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Device::CPU => write!(f, "cpu"),
108            Device::CUDA(0) => write!(f, "cuda"),
109            Device::CUDA(idx) => write!(f, "cuda:{}", idx),
110        }
111    }
112}
113
114/// Error type for tensor operations.
115#[derive(Debug, Clone)]
116pub struct TensorError(String);
117
118impl TensorError {
119    pub fn new(msg: &str) -> Self {
120        TensorError(msg.to_string())
121    }
122}
123
124impl fmt::Display for TensorError {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(f, "{}", self.0)
127    }
128}
129
130impl std::error::Error for TensorError {}
131
132pub type Result<T> = std::result::Result<T, TensorError>;
133
134/// Convert a C error string to Result. Frees the C string.
135pub(crate) fn check_err(err: *mut i8) -> Result<()> {
136    if err.is_null() {
137        Ok(())
138    } else {
139        let msg = unsafe { CStr::from_ptr(err) }
140            .to_string_lossy()
141            .into_owned();
142        unsafe { ffi::flodl_free_string(err) };
143        Err(TensorError(msg))
144    }
145}
146
147/// Options for tensor creation.
148#[derive(Debug, Clone, Copy)]
149pub struct TensorOptions {
150    pub dtype: DType,
151    pub device: Device,
152}
153
154impl Default for TensorOptions {
155    fn default() -> Self {
156        Self {
157            dtype: DType::Float32,
158            device: Device::CPU,
159        }
160    }
161}
162
163/// A tensor wrapping a libtorch C++ tensor.
164///
165/// Owns the underlying C++ handle. When dropped, the C++ tensor is
166/// freed immediately — including any GPU memory. This is the entire
167/// VRAM management story.
168///
169/// Operations are chainable and return `Result<Tensor>`:
170///
171/// ```ignore
172/// let y = x.matmul(&w)?.add(&b)?.relu()?;
173/// ```
174pub struct Tensor {
175    handle: FlodlTensor,
176}
177
178// Safety: libtorch tensors are reference-counted internally and
179// thread-safe for read access. Mutations go through the shim which
180// creates new tensors.
181unsafe impl Send for Tensor {}
182unsafe impl Sync for Tensor {}
183
184impl Drop for Tensor {
185    fn drop(&mut self) {
186        if !self.handle.is_null() {
187            LIVE_TENSOR_COUNT.fetch_sub(1, Ordering::Relaxed);
188            unsafe { ffi::flodl_free_tensor(self.handle) };
189        }
190    }
191}
192
193impl Clone for Tensor {
194    /// Shallow clone: creates a new C++ Tensor handle sharing the same
195    /// TensorImpl (and thus the same data storage). Cheap — just bumps
196    /// libtorch's internal refcount.
197    fn clone(&self) -> Self {
198        let mut handle: FlodlTensor = ptr::null_mut();
199        let err = unsafe { ffi::flodl_shallow_clone(self.handle, &mut handle) };
200        if !err.is_null() {
201            let msg = unsafe { CStr::from_ptr(err) }
202                .to_string_lossy()
203                .into_owned();
204            unsafe { ffi::flodl_free_string(err) };
205            panic!("tensor clone failed: {}", msg);
206        }
207        Self::from_raw(handle)
208    }
209}
210
211impl Tensor {
212    /// Wrap a raw handle. The Tensor takes ownership.
213    fn from_raw(handle: FlodlTensor) -> Self {
214        debug_assert!(!handle.is_null());
215        LIVE_TENSOR_COUNT.fetch_add(1, Ordering::Relaxed);
216        Self { handle }
217    }
218
219    /// Wrap a raw handle (crate-visible). The Tensor takes ownership.
220    ///
221    /// # Safety
222    /// Caller must ensure the handle is valid and not owned elsewhere.
223    pub(crate) unsafe fn from_raw_handle(handle: FlodlTensor) -> Self {
224        Self::from_raw(handle)
225    }
226
227    /// Access the raw handle (for passing to FFI in sibling modules).
228    pub(crate) fn raw(&self) -> FlodlTensor {
229        self.handle
230    }
231
232    // --- Creation ---
233
234    /// Create a tensor filled with zeros.
235    ///
236    /// ```ignore
237    /// let t = Tensor::zeros(&[2, 3], TensorOptions::default())?;
238    /// assert_eq!(t.shape(), vec![2, 3]);
239    /// ```
240    pub fn zeros(shape: &[i64], opts: TensorOptions) -> Result<Self> {
241        let mut shape = shape.to_vec();
242        let mut handle: FlodlTensor = ptr::null_mut();
243        let (dt, di) = opts.device.to_ffi();
244        let err = unsafe {
245            ffi::flodl_zeros(
246                shape.as_mut_ptr(),
247                shape.len() as i32,
248                opts.dtype as i32,
249                dt, di,
250                &mut handle,
251            )
252        };
253        check_err(err)?;
254        Ok(Self::from_raw(handle))
255    }
256
257    /// Create a tensor filled with ones. Like `torch.ones()`.
258    ///
259    /// ```ignore
260    /// let t = Tensor::ones(&[2, 3], TensorOptions::default())?;
261    /// ```
262    pub fn ones(shape: &[i64], opts: TensorOptions) -> Result<Self> {
263        let mut shape = shape.to_vec();
264        let mut handle: FlodlTensor = ptr::null_mut();
265        let (dt, di) = opts.device.to_ffi();
266        let err = unsafe {
267            ffi::flodl_ones(
268                shape.as_mut_ptr(),
269                shape.len() as i32,
270                opts.dtype as i32,
271                dt, di,
272                &mut handle,
273            )
274        };
275        check_err(err)?;
276        Ok(Self::from_raw(handle))
277    }
278
279    /// Create a tensor from f32 data.
280    ///
281    /// ```ignore
282    /// let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], Device::CPU)?;
283    /// assert_eq!(t.shape(), vec![2, 2]);
284    /// ```
285    pub fn from_f32(data: &[f32], shape: &[i64], device: Device) -> Result<Self> {
286        let mut shape = shape.to_vec();
287        let mut handle: FlodlTensor = ptr::null_mut();
288        let (dt, di) = device.to_ffi();
289        let err = unsafe {
290            ffi::flodl_from_blob(
291                data.as_ptr() as *mut c_void,
292                shape.as_mut_ptr(),
293                shape.len() as i32,
294                DType::Float32 as i32,
295                dt, di,
296                &mut handle,
297            )
298        };
299        check_err(err)?;
300        Ok(Self::from_raw(handle))
301    }
302
303    /// Create a Float64 tensor from f64 data. Use when full double precision
304    /// is needed (e.g. loss accumulation, high-precision metrics).
305    pub fn from_f64(data: &[f64], shape: &[i64], device: Device) -> Result<Self> {
306        let mut shape = shape.to_vec();
307        let mut handle: FlodlTensor = ptr::null_mut();
308        let (dt, di) = device.to_ffi();
309        let err = unsafe {
310            ffi::flodl_from_blob(
311                data.as_ptr() as *mut c_void,
312                shape.as_mut_ptr(),
313                shape.len() as i32,
314                DType::Float64 as i32,
315                dt, di,
316                &mut handle,
317            )
318        };
319        check_err(err)?;
320        Ok(Self::from_raw(handle))
321    }
322
323    /// Create an Int64 tensor from i64 data. Commonly used for class labels,
324    /// token indices, and any integer indexing (e.g. `cross_entropy_loss` targets).
325    pub fn from_i64(data: &[i64], shape: &[i64], device: Device) -> Result<Self> {
326        let mut shape = shape.to_vec();
327        let mut handle: FlodlTensor = ptr::null_mut();
328        let (dt, di) = device.to_ffi();
329        let err = unsafe {
330            ffi::flodl_from_blob(
331                data.as_ptr() as *mut c_void,
332                shape.as_mut_ptr(),
333                shape.len() as i32,
334                DType::Int64 as i32,
335                dt, di,
336                &mut handle,
337            )
338        };
339        check_err(err)?;
340        Ok(Self::from_raw(handle))
341    }
342
343    // --- Metadata ---
344
345    /// Number of dimensions (rank). Like `tensor.ndim` in PyTorch.
346    pub fn ndim(&self) -> usize {
347        unsafe { ffi::flodl_ndim(self.handle) as usize }
348    }
349
350    /// Shape of each dimension as a Vec. Like `tensor.shape` in PyTorch.
351    pub fn shape(&self) -> Vec<i64> {
352        let n = self.ndim();
353        (0..n)
354            .map(|i| unsafe { ffi::flodl_shape(self.handle, i as i32) })
355            .collect()
356    }
357
358    /// Total number of elements (product of all dimensions). Like `tensor.numel()`.
359    pub fn numel(&self) -> i64 {
360        unsafe { ffi::flodl_numel(self.handle) }
361    }
362
363    /// Element data type of this tensor. Like `tensor.dtype` in PyTorch.
364    pub fn dtype(&self) -> DType {
365        DType::from_raw(unsafe { ffi::flodl_dtype(self.handle) })
366    }
367
368    /// Device where this tensor's data resides (CPU or CUDA). Like `tensor.device`.
369    pub fn device(&self) -> Device {
370        let dt = unsafe { ffi::flodl_device_type(self.handle) };
371        let di = unsafe { ffi::flodl_device_index(self.handle) };
372        Device::from_ffi(dt, di)
373    }
374
375    // --- Data access ---
376
377    /// Copy tensor data to a `Vec<f32>`. Transparently moves to CPU first
378    /// if the tensor lives on CUDA. Non-f32 dtypes are cast via libtorch.
379    pub fn to_f32_vec(&self) -> Result<Vec<f32>> {
380        let n = self.numel() as usize;
381        let mut buf = vec![0f32; n];
382        let bytes = (n * 4) as i64;
383        let err = unsafe {
384            ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
385        };
386        check_err(err)?;
387        Ok(buf)
388    }
389
390    /// Copy tensor data to a `Vec<f64>`. Moves to CPU if needed.
391    /// Float64 tensors are copied at full precision. All other dtypes
392    /// go through f32 (lossless for f16/bf16, and the best f32 can offer).
393    pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
394        if self.dtype() == DType::Float64 {
395            let n = self.numel() as usize;
396            let mut buf = vec![0.0f64; n];
397            let bytes = (n * 8) as i64;
398            let err = unsafe {
399                ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
400            };
401            check_err(err)?;
402            Ok(buf)
403        } else {
404            let f32s = self.to_f32_vec()?;
405            Ok(f32s.into_iter().map(|v| v as f64).collect())
406        }
407    }
408
409    /// Copy tensor data to a `Vec<i64>`. Moves to CPU if needed.
410    /// Intended for Int64 tensors (indices, labels).
411    pub fn to_i64_vec(&self) -> Result<Vec<i64>> {
412        let n = self.numel() as usize;
413        let mut buf = vec![0i64; n];
414        let bytes = (n * 8) as i64;
415        let err = unsafe {
416            ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
417        };
418        check_err(err)?;
419        Ok(buf)
420    }
421
422    /// Extract a scalar value as f64. Like PyTorch's `.item()`.
423    ///
424    /// The tensor must contain exactly one element (any shape is fine,
425    /// e.g. `[1]`, `[1, 1]`, or `[]`). Returns an error otherwise.
426    /// Preserves full precision for Float64 tensors.
427    ///
428    /// ```ignore
429    /// let loss_val = loss_tensor.item()?;
430    /// println!("loss: {:.4}", loss_val);
431    /// ```
432    pub fn item(&self) -> Result<f64> {
433        if self.numel() != 1 {
434            return Err(TensorError::new(&format!(
435                "item() requires exactly 1 element, got {} (shape {:?})",
436                self.numel(), self.shape()
437            )));
438        }
439        if self.dtype() == DType::Float64 {
440            let mut buf = [0.0f64; 1];
441            let err = unsafe {
442                ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 8)
443            };
444            check_err(err)?;
445            Ok(buf[0])
446        } else {
447            let mut buf = [0.0f32; 1];
448            let err = unsafe {
449                ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 4)
450            };
451            check_err(err)?;
452            Ok(buf[0] as f64)
453        }
454    }
455
456    // --- Arithmetic (chainable) ---
457
458    /// Element-wise addition. Shapes must be broadcastable.
459    ///
460    /// ```ignore
461    /// let c = a.add(&b)?; // [2, 3] + [2, 3] → [2, 3]
462    /// ```
463    pub fn add(&self, other: &Tensor) -> Result<Tensor> {
464        let mut handle: FlodlTensor = ptr::null_mut();
465        let err = unsafe { ffi::flodl_add(self.handle, other.handle, &mut handle) };
466        check_err(err)?;
467        Ok(Tensor::from_raw(handle))
468    }
469
470    /// Element-wise subtraction. Shapes must be broadcastable.
471    pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
472        let mut handle: FlodlTensor = ptr::null_mut();
473        let err = unsafe { ffi::flodl_sub(self.handle, other.handle, &mut handle) };
474        check_err(err)?;
475        Ok(Tensor::from_raw(handle))
476    }
477
478    /// Element-wise (Hadamard) multiplication. Shapes must be broadcastable.
479    /// For matrix multiplication, use [`matmul`](Self::matmul).
480    pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
481        let mut handle: FlodlTensor = ptr::null_mut();
482        let err = unsafe { ffi::flodl_mul(self.handle, other.handle, &mut handle) };
483        check_err(err)?;
484        Ok(Tensor::from_raw(handle))
485    }
486
487    /// Matrix multiplication.
488    ///
489    /// ```ignore
490    /// // [batch, M, K] @ [batch, K, N] → [batch, M, N]
491    /// let c = a.matmul(&b)?;
492    /// ```
493    pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
494        let mut handle: FlodlTensor = ptr::null_mut();
495        let err = unsafe { ffi::flodl_matmul(self.handle, other.handle, &mut handle) };
496        check_err(err)?;
497        Ok(Tensor::from_raw(handle))
498    }
499
500    /// Multiply every element by a scalar. Like `tensor * 0.5` in PyTorch.
501    pub fn mul_scalar(&self, scalar: f64) -> Result<Tensor> {
502        let mut handle: FlodlTensor = ptr::null_mut();
503        let err = unsafe { ffi::flodl_mul_scalar(self.handle, scalar, &mut handle) };
504        check_err(err)?;
505        Ok(Tensor::from_raw(handle))
506    }
507
508    // --- Activations ---
509
510    /// ReLU activation: max(0, x).
511    pub fn relu(&self) -> Result<Tensor> {
512        let mut handle: FlodlTensor = ptr::null_mut();
513        let err = unsafe { ffi::flodl_relu(self.handle, &mut handle) };
514        check_err(err)?;
515        Ok(Tensor::from_raw(handle))
516    }
517
518    /// Sigmoid activation: 1 / (1 + exp(-x)).
519    pub fn sigmoid(&self) -> Result<Tensor> {
520        let mut handle: FlodlTensor = ptr::null_mut();
521        let err = unsafe { ffi::flodl_sigmoid(self.handle, &mut handle) };
522        check_err(err)?;
523        Ok(Tensor::from_raw(handle))
524    }
525
526    // --- Reductions ---
527
528    /// Sum of all elements (scalar result).
529    pub fn sum(&self) -> Result<Tensor> {
530        let mut handle: FlodlTensor = ptr::null_mut();
531        let err = unsafe { ffi::flodl_sum(self.handle, &mut handle) };
532        check_err(err)?;
533        Ok(Tensor::from_raw(handle))
534    }
535
536    /// Mean of all elements (scalar result).
537    pub fn mean(&self) -> Result<Tensor> {
538        let mut handle: FlodlTensor = ptr::null_mut();
539        let err = unsafe { ffi::flodl_mean(self.handle, &mut handle) };
540        check_err(err)?;
541        Ok(Tensor::from_raw(handle))
542    }
543
544    /// Flatten dimensions `[start_dim..=end_dim]` into one.
545    pub fn flatten(&self, start_dim: i32, end_dim: i32) -> Result<Tensor> {
546        let mut handle: FlodlTensor = ptr::null_mut();
547        let err = unsafe { ffi::flodl_flatten(self.handle, start_dim, end_dim, &mut handle) };
548        check_err(err)?;
549        Ok(Tensor::from_raw(handle))
550    }
551
552    // --- Additional arithmetic ---
553
554    /// Element-wise division.
555    pub fn div(&self, other: &Tensor) -> Result<Tensor> {
556        let mut handle: FlodlTensor = ptr::null_mut();
557        let err = unsafe { ffi::flodl_div(self.handle, other.handle, &mut handle) };
558        check_err(err)?;
559        Ok(Tensor::from_raw(handle))
560    }
561
562    /// Negate every element.
563    pub fn neg(&self) -> Result<Tensor> {
564        let mut handle: FlodlTensor = ptr::null_mut();
565        let err = unsafe { ffi::flodl_neg(self.handle, &mut handle) };
566        check_err(err)?;
567        Ok(Tensor::from_raw(handle))
568    }
569
570    /// Add a scalar to every element.
571    pub fn add_scalar(&self, scalar: f64) -> Result<Tensor> {
572        let mut handle: FlodlTensor = ptr::null_mut();
573        let err = unsafe { ffi::flodl_add_scalar(self.handle, scalar, &mut handle) };
574        check_err(err)?;
575        Ok(Tensor::from_raw(handle))
576    }
577
578    /// Divide every element by a scalar.
579    pub fn div_scalar(&self, scalar: f64) -> Result<Tensor> {
580        let mut handle: FlodlTensor = ptr::null_mut();
581        let err = unsafe { ffi::flodl_div_scalar(self.handle, scalar, &mut handle) };
582        check_err(err)?;
583        Ok(Tensor::from_raw(handle))
584    }
585
586    // --- Activations ---
587
588    /// Tanh activation: element-wise hyperbolic tangent.
589    pub fn tanh(&self) -> Result<Tensor> {
590        let mut handle: FlodlTensor = ptr::null_mut();
591        let err = unsafe { ffi::flodl_tanh_op(self.handle, &mut handle) };
592        check_err(err)?;
593        Ok(Tensor::from_raw(handle))
594    }
595
596    // --- Element-wise math ---
597
598    /// Element-wise exponential.
599    pub fn exp(&self) -> Result<Tensor> {
600        let mut handle: FlodlTensor = ptr::null_mut();
601        let err = unsafe { ffi::flodl_exp(self.handle, &mut handle) };
602        check_err(err)?;
603        Ok(Tensor::from_raw(handle))
604    }
605
606    /// Element-wise natural logarithm.
607    pub fn log(&self) -> Result<Tensor> {
608        let mut handle: FlodlTensor = ptr::null_mut();
609        let err = unsafe { ffi::flodl_log(self.handle, &mut handle) };
610        check_err(err)?;
611        Ok(Tensor::from_raw(handle))
612    }
613
614    /// Element-wise square root.
615    pub fn sqrt(&self) -> Result<Tensor> {
616        let mut handle: FlodlTensor = ptr::null_mut();
617        let err = unsafe { ffi::flodl_sqrt(self.handle, &mut handle) };
618        check_err(err)?;
619        Ok(Tensor::from_raw(handle))
620    }
621
622    /// Element-wise absolute value.
623    pub fn abs(&self) -> Result<Tensor> {
624        let mut handle: FlodlTensor = ptr::null_mut();
625        let err = unsafe { ffi::flodl_abs(self.handle, &mut handle) };
626        check_err(err)?;
627        Ok(Tensor::from_raw(handle))
628    }
629
630    /// Upper triangle of a matrix (or batch of matrices).
631    /// Elements below the `diagonal`-th diagonal are zeroed.
632    /// `diagonal=0` keeps the main diagonal; `diagonal=1` excludes it.
633    pub fn triu(&self, diagonal: i64) -> Result<Tensor> {
634        let mut handle: FlodlTensor = ptr::null_mut();
635        let err = unsafe { ffi::flodl_triu(self.handle, diagonal, &mut handle) };
636        check_err(err)?;
637        Ok(Tensor::from_raw(handle))
638    }
639
640    /// Raise every element to a scalar exponent.
641    pub fn pow_scalar(&self, exponent: f64) -> Result<Tensor> {
642        let mut handle: FlodlTensor = ptr::null_mut();
643        let err = unsafe { ffi::flodl_pow_scalar(self.handle, exponent, &mut handle) };
644        check_err(err)?;
645        Ok(Tensor::from_raw(handle))
646    }
647
648    // --- Reductions ---
649
650    /// Sum along a dimension.
651    pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
652        let mut handle: FlodlTensor = ptr::null_mut();
653        let err = unsafe {
654            ffi::flodl_sum_dim(self.handle, dim, keepdim as i32, &mut handle)
655        };
656        check_err(err)?;
657        Ok(Tensor::from_raw(handle))
658    }
659
660    /// Clamp all elements to `[min, max]`.
661    pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor> {
662        let mut handle: FlodlTensor = ptr::null_mut();
663        let err = unsafe { ffi::flodl_clamp(self.handle, min, max, &mut handle) };
664        check_err(err)?;
665        Ok(Tensor::from_raw(handle))
666    }
667
668    // --- Comparisons ---
669
670    /// Element-wise greater-than comparison against a scalar.
671    pub fn gt_scalar(&self, scalar: f64) -> Result<Tensor> {
672        let mut handle: FlodlTensor = ptr::null_mut();
673        let err = unsafe { ffi::flodl_gt_scalar(self.handle, scalar, &mut handle) };
674        check_err(err)?;
675        Ok(Tensor::from_raw(handle))
676    }
677
678    // --- Shape operations ---
679
680    /// Reshape to a new shape (must have same total elements).
681    /// Use -1 for one inferred dimension.
682    ///
683    /// ```ignore
684    /// let flat = t.reshape(&[-1])?; // [2, 3] → [6]
685    /// ```
686    pub fn reshape(&self, shape: &[i64]) -> Result<Tensor> {
687        let mut shape = shape.to_vec();
688        let mut handle: FlodlTensor = ptr::null_mut();
689        let err = unsafe {
690            ffi::flodl_reshape(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
691        };
692        check_err(err)?;
693        Ok(Tensor::from_raw(handle))
694    }
695
696    /// Swap two dimensions.
697    ///
698    /// ```ignore
699    /// let t = x.transpose(0, 1)?; // [M, N] → [N, M]
700    /// ```
701    pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Tensor> {
702        let mut handle: FlodlTensor = ptr::null_mut();
703        let err = unsafe { ffi::flodl_transpose(self.handle, dim0, dim1, &mut handle) };
704        check_err(err)?;
705        Ok(Tensor::from_raw(handle))
706    }
707
708    /// Broadcast to a larger shape.
709    pub fn expand(&self, shape: &[i64]) -> Result<Tensor> {
710        let mut shape = shape.to_vec();
711        let mut handle: FlodlTensor = ptr::null_mut();
712        let err = unsafe {
713            ffi::flodl_expand(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
714        };
715        check_err(err)?;
716        Ok(Tensor::from_raw(handle))
717    }
718
719    // --- Slicing and indexing ---
720
721    /// Narrow (slice) along a dimension: returns a view.
722    pub fn narrow(&self, dim: i32, start: i64, length: i64) -> Result<Tensor> {
723        let mut handle: FlodlTensor = ptr::null_mut();
724        let err = unsafe {
725            ffi::flodl_narrow(self.handle, dim, start, length, &mut handle)
726        };
727        check_err(err)?;
728        Ok(Tensor::from_raw(handle))
729    }
730
731    /// Scatter a narrow slice back into a tensor (for narrow backward).
732    pub fn narrow_scatter(&self, src: &Tensor, dim: i32, start: i64) -> Result<Tensor> {
733        let mut handle: FlodlTensor = ptr::null_mut();
734        let err = unsafe {
735            ffi::flodl_narrow_scatter(self.handle, src.handle, dim, start, &mut handle)
736        };
737        check_err(err)?;
738        Ok(Tensor::from_raw(handle))
739    }
740
741    /// Concatenate two tensors along a dimension.
742    pub fn cat(&self, other: &Tensor, dim: i32) -> Result<Tensor> {
743        let mut handle: FlodlTensor = ptr::null_mut();
744        let err = unsafe { ffi::flodl_cat2(self.handle, other.handle, dim, &mut handle) };
745        check_err(err)?;
746        Ok(Tensor::from_raw(handle))
747    }
748
749    /// Concatenate multiple tensors along an existing dimension.
750    ///
751    /// All tensors must have the same shape except in the concatenation dimension.
752    /// Uses a single kernel launch regardless of the number of tensors.
753    pub fn cat_many(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
754        if tensors.is_empty() {
755            return Err(TensorError::new("cat_many: empty tensor list"));
756        }
757        let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
758        let mut result: FlodlTensor = ptr::null_mut();
759        let err = unsafe {
760            ffi::flodl_cat(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
761        };
762        check_err(err)?;
763        Ok(Tensor::from_raw(result))
764    }
765
766    /// Stack tensors along a new dimension.
767    ///
768    /// All tensors must have the same shape. A new dimension is inserted at `dim`.
769    pub fn stack(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
770        if tensors.is_empty() {
771            return Err(TensorError::new("stack: empty tensor list"));
772        }
773        let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
774        let mut result: FlodlTensor = ptr::null_mut();
775        let err = unsafe {
776            ffi::flodl_stack(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
777        };
778        check_err(err)?;
779        Ok(Tensor::from_raw(result))
780    }
781
782    /// Softmax along a dimension.
783    pub fn softmax(&self, dim: i32) -> Result<Tensor> {
784        let mut handle: FlodlTensor = ptr::null_mut();
785        let err = unsafe { ffi::flodl_softmax(self.handle, dim, &mut handle) };
786        check_err(err)?;
787        Ok(Tensor::from_raw(handle))
788    }
789
790    /// Log-softmax along a dimension (numerically stable).
791    pub fn log_softmax(&self, dim: i32) -> Result<Tensor> {
792        let mut handle: FlodlTensor = ptr::null_mut();
793        let err = unsafe { ffi::flodl_log_softmax(self.handle, dim, &mut handle) };
794        check_err(err)?;
795        Ok(Tensor::from_raw(handle))
796    }
797
798    /// GELU activation (native libtorch).
799    pub fn gelu(&self) -> Result<Tensor> {
800        let mut handle: FlodlTensor = ptr::null_mut();
801        let err = unsafe { ffi::flodl_gelu(self.handle, &mut handle) };
802        check_err(err)?;
803        Ok(Tensor::from_raw(handle))
804    }
805
806    /// SiLU activation (native libtorch).
807    pub fn silu(&self) -> Result<Tensor> {
808        let mut handle: FlodlTensor = ptr::null_mut();
809        let err = unsafe { ffi::flodl_silu(self.handle, &mut handle) };
810        check_err(err)?;
811        Ok(Tensor::from_raw(handle))
812    }
813
814    /// Native layer normalization. Returns (output, mean, rstd).
815    pub fn native_layer_norm(
816        &self, weight: &Tensor, bias: &Tensor, normalized_size: i64, eps: f64,
817    ) -> Result<(Tensor, Tensor, Tensor)> {
818        let mut out: FlodlTensor = ptr::null_mut();
819        let mut mean: FlodlTensor = ptr::null_mut();
820        let mut rstd: FlodlTensor = ptr::null_mut();
821        let err = unsafe {
822            ffi::flodl_native_layer_norm(
823                self.handle, weight.handle, bias.handle,
824                normalized_size, eps,
825                &mut out, &mut mean, &mut rstd,
826            )
827        };
828        check_err(err)?;
829        Ok((Tensor::from_raw(out), Tensor::from_raw(mean), Tensor::from_raw(rstd)))
830    }
831
832    /// Permute dimensions.
833    pub fn permute(&self, dims: &[i64]) -> Result<Tensor> {
834        let mut dims = dims.to_vec();
835        let mut handle: FlodlTensor = ptr::null_mut();
836        let err = unsafe {
837            ffi::flodl_permute(self.handle, dims.as_mut_ptr(), dims.len() as i32, &mut handle)
838        };
839        check_err(err)?;
840        Ok(Tensor::from_raw(handle))
841    }
842
843    /// Select a single index along a dimension (reduces that dim).
844    pub fn select(&self, dim: i32, index: i64) -> Result<Tensor> {
845        let mut handle: FlodlTensor = ptr::null_mut();
846        let err = unsafe { ffi::flodl_select(self.handle, dim, index, &mut handle) };
847        check_err(err)?;
848        Ok(Tensor::from_raw(handle))
849    }
850
851    /// Mean along a dimension.
852    pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
853        let mut handle: FlodlTensor = ptr::null_mut();
854        let err = unsafe {
855            ffi::flodl_mean_dim(self.handle, dim, keepdim as i32, &mut handle)
856        };
857        check_err(err)?;
858        Ok(Tensor::from_raw(handle))
859    }
860
861    /// Select rows/elements along a dimension using an index tensor.
862    pub fn index_select(&self, dim: i32, index: &Tensor) -> Result<Tensor> {
863        let mut handle: FlodlTensor = ptr::null_mut();
864        let err = unsafe {
865            ffi::flodl_index_select(self.handle, dim, index.handle, &mut handle)
866        };
867        check_err(err)?;
868        Ok(Tensor::from_raw(handle))
869    }
870
871    /// Scatter-add src into self along dim at positions given by index.
872    pub fn index_add(&self, dim: i32, index: &Tensor, src: &Tensor) -> Result<Tensor> {
873        let mut handle: FlodlTensor = ptr::null_mut();
874        let err = unsafe {
875            ffi::flodl_index_add(self.handle, dim, index.handle, src.handle, &mut handle)
876        };
877        check_err(err)?;
878        Ok(Tensor::from_raw(handle))
879    }
880
881    // --- Like constructors ---
882
883    /// Create a tensor of zeros with the same shape, dtype, and device as `t`.
884    pub fn zeros_like(t: &Tensor) -> Result<Tensor> {
885        let mut handle: FlodlTensor = ptr::null_mut();
886        let err = unsafe { ffi::flodl_zeros_like(t.handle, &mut handle) };
887        check_err(err)?;
888        Ok(Tensor::from_raw(handle))
889    }
890
891    /// Create a tensor of ones with the same shape, dtype, and device as `t`.
892    pub fn ones_like(t: &Tensor) -> Result<Tensor> {
893        let mut handle: FlodlTensor = ptr::null_mut();
894        let err = unsafe { ffi::flodl_ones_like(t.handle, &mut handle) };
895        check_err(err)?;
896        Ok(Tensor::from_raw(handle))
897    }
898
899    // --- Random ---
900
901    /// Create a tensor with uniform random values in [0, 1).
902    pub fn rand(shape: &[i64], opts: TensorOptions) -> Result<Self> {
903        let mut shape = shape.to_vec();
904        let mut handle: FlodlTensor = ptr::null_mut();
905        let (dt, di) = opts.device.to_ffi();
906        let err = unsafe {
907            ffi::flodl_rand(
908                shape.as_mut_ptr(), shape.len() as i32,
909                opts.dtype as i32, dt, di,
910                &mut handle,
911            )
912        };
913        check_err(err)?;
914        Ok(Self::from_raw(handle))
915    }
916
917    /// Create a tensor with standard normal random values (mean=0, std=1).
918    pub fn randn(shape: &[i64], opts: TensorOptions) -> Result<Self> {
919        let mut shape = shape.to_vec();
920        let mut handle: FlodlTensor = ptr::null_mut();
921        let (dt, di) = opts.device.to_ffi();
922        let err = unsafe {
923            ffi::flodl_randn(
924                shape.as_mut_ptr(), shape.len() as i32,
925                opts.dtype as i32, dt, di,
926                &mut handle,
927            )
928        };
929        check_err(err)?;
930        Ok(Self::from_raw(handle))
931    }
932
933    // --- Convolution (many args unavoidable — maps 1:1 to libtorch C API) ---
934
935    /// 2D convolution. bias may be a null-handle tensor for no bias.
936    #[allow(clippy::too_many_arguments)]
937    pub fn conv2d(
938        &self, weight: &Tensor, bias: Option<&Tensor>,
939        stride: [i64; 2], padding: [i64; 2], dilation: [i64; 2], groups: i64,
940    ) -> Result<Tensor> {
941        let mut handle: FlodlTensor = ptr::null_mut();
942        let mut stride = stride;
943        let mut padding = padding;
944        let mut dilation = dilation;
945        let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
946        let err = unsafe {
947            ffi::flodl_conv2d(
948                self.handle, weight.handle, bias_handle,
949                stride.as_mut_ptr(), padding.as_mut_ptr(), dilation.as_mut_ptr(),
950                groups, &mut handle,
951            )
952        };
953        check_err(err)?;
954        Ok(Tensor::from_raw(handle))
955    }
956
957    /// Transposed 2D convolution.
958    #[allow(clippy::too_many_arguments)]
959    pub fn conv_transpose2d(
960        &self, weight: &Tensor, bias: Option<&Tensor>,
961        stride: [i64; 2], padding: [i64; 2], output_padding: [i64; 2],
962        dilation: [i64; 2], groups: i64,
963    ) -> Result<Tensor> {
964        let mut handle: FlodlTensor = ptr::null_mut();
965        let mut stride = stride;
966        let mut padding = padding;
967        let mut output_padding = output_padding;
968        let mut dilation = dilation;
969        let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
970        let err = unsafe {
971            ffi::flodl_conv_transpose2d(
972                self.handle, weight.handle, bias_handle,
973                stride.as_mut_ptr(), padding.as_mut_ptr(),
974                output_padding.as_mut_ptr(), dilation.as_mut_ptr(),
975                groups, &mut handle,
976            )
977        };
978        check_err(err)?;
979        Ok(Tensor::from_raw(handle))
980    }
981
982    // --- Fused ops ---
983
984    /// Fused linear: `y = input @ weight^T + bias` (single ATen kernel).
985    pub fn linear(&self, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
986        let mut handle: FlodlTensor = ptr::null_mut();
987        let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
988        let err = unsafe {
989            ffi::flodl_linear(self.handle, weight.handle, bias_handle, &mut handle)
990        };
991        check_err(err)?;
992        Ok(Tensor::from_raw(handle))
993    }
994
995    /// Fused GRU cell: single ATen `gru_cell` kernel.
996    /// Returns new hidden state h'.
997    #[allow(clippy::too_many_arguments)]
998    pub fn gru_cell(
999        &self, hx: &Tensor,
1000        w_ih: &Tensor, w_hh: &Tensor,
1001        b_ih: &Tensor, b_hh: &Tensor,
1002    ) -> Result<Tensor> {
1003        let mut handle: FlodlTensor = ptr::null_mut();
1004        let err = unsafe {
1005            ffi::flodl_gru_cell(
1006                self.handle, hx.handle,
1007                w_ih.handle, w_hh.handle,
1008                b_ih.handle, b_hh.handle,
1009                &mut handle,
1010            )
1011        };
1012        check_err(err)?;
1013        Ok(Tensor::from_raw(handle))
1014    }
1015
1016    /// Fused LSTM cell: single ATen `lstm_cell` kernel.
1017    /// Returns `(h', c')`.
1018    #[allow(clippy::too_many_arguments)]
1019    pub fn lstm_cell(
1020        &self, hx: &Tensor, cx: &Tensor,
1021        w_ih: &Tensor, w_hh: &Tensor,
1022        b_ih: &Tensor, b_hh: &Tensor,
1023    ) -> Result<(Tensor, Tensor)> {
1024        let mut h_out: FlodlTensor = ptr::null_mut();
1025        let mut c_out: FlodlTensor = ptr::null_mut();
1026        let err = unsafe {
1027            ffi::flodl_lstm_cell(
1028                self.handle, hx.handle, cx.handle,
1029                w_ih.handle, w_hh.handle,
1030                b_ih.handle, b_hh.handle,
1031                &mut h_out, &mut c_out,
1032            )
1033        };
1034        check_err(err)?;
1035        Ok((Tensor::from_raw(h_out), Tensor::from_raw(c_out)))
1036    }
1037
1038    // --- Fused loss functions ---
1039
1040    /// Fused MSE loss: single libtorch kernel.
1041    /// reduction: 0=None, 1=Mean, 2=Sum.
1042    pub fn mse_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1043        let mut handle: FlodlTensor = ptr::null_mut();
1044        let err = unsafe {
1045            ffi::flodl_mse_loss(self.handle, target.handle, reduction, &mut handle)
1046        };
1047        check_err(err)?;
1048        Ok(Tensor::from_raw(handle))
1049    }
1050
1051    /// Fused cross-entropy loss: single libtorch kernel.
1052    /// pred: [N,C] logits. target: [N] Int64 indices or [N,C] Float probs.
1053    /// reduction: 0=None, 1=Mean, 2=Sum.
1054    #[allow(clippy::too_many_arguments)]
1055    pub fn cross_entropy_loss(
1056        &self, target: &Tensor, reduction: i64,
1057        ignore_index: i64, label_smoothing: f64,
1058    ) -> Result<Tensor> {
1059        let mut handle: FlodlTensor = ptr::null_mut();
1060        let err = unsafe {
1061            ffi::flodl_cross_entropy_loss(
1062                self.handle, target.handle,
1063                reduction, ignore_index, label_smoothing,
1064                &mut handle,
1065            )
1066        };
1067        check_err(err)?;
1068        Ok(Tensor::from_raw(handle))
1069    }
1070
1071    /// Fused BCE with logits loss: single libtorch kernel.
1072    /// Numerically stable binary cross-entropy from raw logits.
1073    /// reduction: 0=None, 1=Mean, 2=Sum.
1074    pub fn bce_with_logits_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1075        let mut handle: FlodlTensor = ptr::null_mut();
1076        let err = unsafe {
1077            ffi::flodl_bce_with_logits_loss(
1078                self.handle, target.handle, reduction, &mut handle,
1079            )
1080        };
1081        check_err(err)?;
1082        Ok(Tensor::from_raw(handle))
1083    }
1084
1085    /// Fused L1 loss: single libtorch kernel.
1086    /// reduction: 0=None, 1=Mean, 2=Sum.
1087    pub fn l1_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1088        let mut handle: FlodlTensor = ptr::null_mut();
1089        let err = unsafe {
1090            ffi::flodl_l1_loss(self.handle, target.handle, reduction, &mut handle)
1091        };
1092        check_err(err)?;
1093        Ok(Tensor::from_raw(handle))
1094    }
1095
1096    /// Fused Smooth L1 (Huber) loss: single libtorch kernel.
1097    /// reduction: 0=None, 1=Mean, 2=Sum. beta: transition point.
1098    pub fn smooth_l1_loss(&self, target: &Tensor, reduction: i64, beta: f64) -> Result<Tensor> {
1099        let mut handle: FlodlTensor = ptr::null_mut();
1100        let err = unsafe {
1101            ffi::flodl_smooth_l1_loss(
1102                self.handle, target.handle, reduction, beta, &mut handle,
1103            )
1104        };
1105        check_err(err)?;
1106        Ok(Tensor::from_raw(handle))
1107    }
1108
1109    /// Fused KL divergence loss: single libtorch kernel.
1110    /// input: log-probabilities. target: probabilities.
1111    /// reduction: 0=None, 1=Mean, 2=Sum, 5=BatchMean.
1112    pub fn kl_div_loss(&self, target: &Tensor, reduction: i64, log_target: bool) -> Result<Tensor> {
1113        let mut handle: FlodlTensor = ptr::null_mut();
1114        let err = unsafe {
1115            ffi::flodl_kl_div_loss(
1116                self.handle, target.handle, reduction, log_target as i32, &mut handle,
1117            )
1118        };
1119        check_err(err)?;
1120        Ok(Tensor::from_raw(handle))
1121    }
1122
1123    // --- Fused batch normalization ---
1124
1125    /// Fused batch normalization: single libtorch kernel.
1126    /// When training=true, updates running_mean/running_var in-place.
1127    #[allow(clippy::too_many_arguments)]
1128    pub fn batch_norm(
1129        &self, weight: Option<&Tensor>, bias: Option<&Tensor>,
1130        running_mean: Option<&Tensor>, running_var: Option<&Tensor>,
1131        training: bool, momentum: f64, eps: f64,
1132    ) -> Result<Tensor> {
1133        let mut handle: FlodlTensor = ptr::null_mut();
1134        let w = weight.map_or(ptr::null_mut(), |t| t.handle);
1135        let b = bias.map_or(ptr::null_mut(), |t| t.handle);
1136        let rm = running_mean.map_or(ptr::null_mut(), |t| t.handle);
1137        let rv = running_var.map_or(ptr::null_mut(), |t| t.handle);
1138        let err = unsafe {
1139            ffi::flodl_batch_norm(
1140                self.handle, w, b, rm, rv,
1141                training as i32, momentum, eps, &mut handle,
1142            )
1143        };
1144        check_err(err)?;
1145        Ok(Tensor::from_raw(handle))
1146    }
1147
1148    // --- Fused dropout ---
1149
1150    /// Fused dropout: single libtorch kernel with inverted scaling.
1151    pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor> {
1152        let mut handle: FlodlTensor = ptr::null_mut();
1153        let err = unsafe {
1154            ffi::flodl_dropout(self.handle, p, training as i32, &mut handle)
1155        };
1156        check_err(err)?;
1157        Ok(Tensor::from_raw(handle))
1158    }
1159
1160    /// Fused 2D feature dropout: drops entire channels.
1161    pub fn feature_dropout(&self, p: f64, training: bool) -> Result<Tensor> {
1162        let mut handle: FlodlTensor = ptr::null_mut();
1163        let err = unsafe {
1164            ffi::flodl_feature_dropout(self.handle, p, training as i32, &mut handle)
1165        };
1166        check_err(err)?;
1167        Ok(Tensor::from_raw(handle))
1168    }
1169
1170    // --- Missing wrappers for existing shims ---
1171
1172    /// Create evenly spaced values.
1173    pub fn linspace(start: f64, end: f64, steps: i64, opts: TensorOptions) -> Result<Self> {
1174        let mut handle: FlodlTensor = ptr::null_mut();
1175        let (dt, di) = opts.device.to_ffi();
1176        let err = unsafe {
1177            ffi::flodl_linspace(start, end, steps, opts.dtype as i32, dt, di, &mut handle)
1178        };
1179        check_err(err)?;
1180        Ok(Self::from_raw(handle))
1181    }
1182
1183    /// Create a range of values [start, end) with given step.
1184    pub fn arange(start: f64, end: f64, step: f64, opts: TensorOptions) -> Result<Self> {
1185        let mut handle: FlodlTensor = ptr::null_mut();
1186        let (dt, di) = opts.device.to_ffi();
1187        let err = unsafe {
1188            ffi::flodl_arange(start, end, step, opts.dtype as i32, dt, di, &mut handle)
1189        };
1190        check_err(err)?;
1191        Ok(Self::from_raw(handle))
1192    }
1193
1194    /// Scalar minimum.
1195    pub fn min(&self) -> Result<Tensor> {
1196        let mut handle: FlodlTensor = ptr::null_mut();
1197        let err = unsafe { ffi::flodl_min(self.handle, &mut handle) };
1198        check_err(err)?;
1199        Ok(Tensor::from_raw(handle))
1200    }
1201
1202    /// Scalar maximum.
1203    pub fn max(&self) -> Result<Tensor> {
1204        let mut handle: FlodlTensor = ptr::null_mut();
1205        let err = unsafe { ffi::flodl_max(self.handle, &mut handle) };
1206        check_err(err)?;
1207        Ok(Tensor::from_raw(handle))
1208    }
1209
1210    /// L2 (Frobenius) norm of all elements.
1211    pub fn norm(&self) -> Result<Tensor> {
1212        let mut handle: FlodlTensor = ptr::null_mut();
1213        let err = unsafe { ffi::flodl_norm(self.handle, &mut handle) };
1214        check_err(err)?;
1215        Ok(Tensor::from_raw(handle))
1216    }
1217
1218    /// Minimum along a dimension (values only).
1219    pub fn min_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1220        let mut handle: FlodlTensor = ptr::null_mut();
1221        let err = unsafe { ffi::flodl_min_dim(self.handle, dim, keepdim as i32, &mut handle) };
1222        check_err(err)?;
1223        Ok(Tensor::from_raw(handle))
1224    }
1225
1226    /// Maximum along a dimension (values only).
1227    pub fn max_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1228        let mut handle: FlodlTensor = ptr::null_mut();
1229        let err = unsafe { ffi::flodl_max_dim(self.handle, dim, keepdim as i32, &mut handle) };
1230        check_err(err)?;
1231        Ok(Tensor::from_raw(handle))
1232    }
1233
1234    /// Argmax along a dimension.
1235    pub fn argmax(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1236        let mut handle: FlodlTensor = ptr::null_mut();
1237        let err = unsafe { ffi::flodl_argmax(self.handle, dim, keepdim as i32, &mut handle) };
1238        check_err(err)?;
1239        Ok(Tensor::from_raw(handle))
1240    }
1241
1242    /// Element-wise greater-than-or-equal comparison against a scalar.
1243    pub fn ge_scalar(&self, scalar: f64) -> Result<Tensor> {
1244        let mut handle: FlodlTensor = ptr::null_mut();
1245        let err = unsafe { ffi::flodl_ge_scalar(self.handle, scalar, &mut handle) };
1246        check_err(err)?;
1247        Ok(Tensor::from_raw(handle))
1248    }
1249
1250    /// Element-wise less-than-or-equal comparison against a scalar.
1251    pub fn le_scalar(&self, scalar: f64) -> Result<Tensor> {
1252        let mut handle: FlodlTensor = ptr::null_mut();
1253        let err = unsafe { ffi::flodl_le_scalar(self.handle, scalar, &mut handle) };
1254        check_err(err)?;
1255        Ok(Tensor::from_raw(handle))
1256    }
1257
1258    /// Element-wise less-than comparison against a scalar.
1259    pub fn lt_scalar(&self, scalar: f64) -> Result<Tensor> {
1260        let mut handle: FlodlTensor = ptr::null_mut();
1261        let err = unsafe { ffi::flodl_lt_scalar(self.handle, scalar, &mut handle) };
1262        check_err(err)?;
1263        Ok(Tensor::from_raw(handle))
1264    }
1265
1266    /// Scatter a selected index back into a tensor.
1267    pub fn select_scatter(&self, src: &Tensor, dim: i32, index: i64) -> Result<Tensor> {
1268        let mut handle: FlodlTensor = ptr::null_mut();
1269        let err = unsafe {
1270            ffi::flodl_select_scatter(self.handle, src.handle, dim, index, &mut handle)
1271        };
1272        check_err(err)?;
1273        Ok(Tensor::from_raw(handle))
1274    }
1275
1276    /// Conditional select: where(condition, self, other).
1277    pub fn where_cond(condition: &Tensor, x: &Tensor, y: &Tensor) -> Result<Tensor> {
1278        let mut handle: FlodlTensor = ptr::null_mut();
1279        let err = unsafe {
1280            ffi::flodl_where(condition.handle, x.handle, y.handle, &mut handle)
1281        };
1282        check_err(err)?;
1283        Ok(Tensor::from_raw(handle))
1284    }
1285
1286    /// Squeeze (remove) a dimension of size 1.
1287    pub fn squeeze(&self, dim: i32) -> Result<Tensor> {
1288        let mut handle: FlodlTensor = ptr::null_mut();
1289        let err = unsafe { ffi::flodl_squeeze(self.handle, dim, &mut handle) };
1290        check_err(err)?;
1291        Ok(Tensor::from_raw(handle))
1292    }
1293
1294    /// Unsqueeze (insert) a dimension of size 1.
1295    pub fn unsqueeze(&self, dim: i32) -> Result<Tensor> {
1296        let mut handle: FlodlTensor = ptr::null_mut();
1297        let err = unsafe { ffi::flodl_unsqueeze(self.handle, dim, &mut handle) };
1298        check_err(err)?;
1299        Ok(Tensor::from_raw(handle))
1300    }
1301
1302    /// Adaptive average pooling to target spatial size.
1303    pub fn adaptive_avg_pool2d(&self, output_size: [i64; 2]) -> Result<Tensor> {
1304        let mut handle: FlodlTensor = ptr::null_mut();
1305        let mut os = output_size;
1306        let err = unsafe {
1307            ffi::flodl_adaptive_avg_pool2d(self.handle, os.as_mut_ptr(), &mut handle)
1308        };
1309        check_err(err)?;
1310        Ok(Tensor::from_raw(handle))
1311    }
1312
1313    /// Grid sampling (bilinear/nearest interpolation).
1314    pub fn grid_sample(
1315        &self, grid: &Tensor, mode: i32, padding_mode: i32, align_corners: bool,
1316    ) -> Result<Tensor> {
1317        let mut handle: FlodlTensor = ptr::null_mut();
1318        let err = unsafe {
1319            ffi::flodl_grid_sample(
1320                self.handle, grid.handle, mode, padding_mode,
1321                align_corners as i32, &mut handle,
1322            )
1323        };
1324        check_err(err)?;
1325        Ok(Tensor::from_raw(handle))
1326    }
1327
1328    /// Cast to a different dtype.
1329    pub fn to_dtype(&self, dtype: DType) -> Result<Tensor> {
1330        let mut handle: FlodlTensor = ptr::null_mut();
1331        let err = unsafe { ffi::flodl_to_dtype(self.handle, dtype as i32, &mut handle) };
1332        check_err(err)?;
1333        Ok(Tensor::from_raw(handle))
1334    }
1335
1336    /// Check if all elements are finite (no inf/nan).
1337    pub fn all_finite(&self) -> Result<bool> {
1338        let mut result: i32 = 0;
1339        let err = unsafe { ffi::flodl_all_finite(self.handle, &mut result) };
1340        check_err(err)?;
1341        Ok(result != 0)
1342    }
1343
1344    // --- Comparison (tensor-tensor) ---
1345
1346    /// Element-wise greater-than (returns float mask: 0.0 or 1.0).
1347    pub fn gt(&self, other: &Tensor) -> Result<Tensor> {
1348        let mut handle: FlodlTensor = ptr::null_mut();
1349        let err = unsafe { ffi::flodl_gt_tensor(self.handle, other.handle, &mut handle) };
1350        check_err(err)?;
1351        Ok(Tensor::from_raw(handle))
1352    }
1353
1354    /// Element-wise less-than (returns float mask: 0.0 or 1.0).
1355    pub fn lt(&self, other: &Tensor) -> Result<Tensor> {
1356        let mut handle: FlodlTensor = ptr::null_mut();
1357        let err = unsafe { ffi::flodl_lt_tensor(self.handle, other.handle, &mut handle) };
1358        check_err(err)?;
1359        Ok(Tensor::from_raw(handle))
1360    }
1361
1362    /// Element-wise greater-than-or-equal (returns float mask: 0.0 or 1.0).
1363    pub fn ge(&self, other: &Tensor) -> Result<Tensor> {
1364        let mut handle: FlodlTensor = ptr::null_mut();
1365        let err = unsafe { ffi::flodl_ge_tensor(self.handle, other.handle, &mut handle) };
1366        check_err(err)?;
1367        Ok(Tensor::from_raw(handle))
1368    }
1369
1370    /// Element-wise less-than-or-equal (returns float mask: 0.0 or 1.0).
1371    pub fn le(&self, other: &Tensor) -> Result<Tensor> {
1372        let mut handle: FlodlTensor = ptr::null_mut();
1373        let err = unsafe { ffi::flodl_le_tensor(self.handle, other.handle, &mut handle) };
1374        check_err(err)?;
1375        Ok(Tensor::from_raw(handle))
1376    }
1377
1378    /// Element-wise equality. Returns a mask (0.0 or 1.0) in the input's
1379    /// dtype for float inputs, or Float32 for integer/bool inputs.
1380    pub fn eq_tensor(&self, other: &Tensor) -> Result<Tensor> {
1381        let mut handle: FlodlTensor = ptr::null_mut();
1382        let err = unsafe { ffi::flodl_eq_tensor(self.handle, other.handle, &mut handle) };
1383        check_err(err)?;
1384        Ok(Tensor::from_raw(handle))
1385    }
1386
1387    /// Element-wise not-equal. Returns a mask (0.0 or 1.0) in the input's
1388    /// dtype for float inputs, or Float32 for integer/bool inputs.
1389    pub fn ne_tensor(&self, other: &Tensor) -> Result<Tensor> {
1390        let mut handle: FlodlTensor = ptr::null_mut();
1391        let err = unsafe { ffi::flodl_ne_tensor(self.handle, other.handle, &mut handle) };
1392        check_err(err)?;
1393        Ok(Tensor::from_raw(handle))
1394    }
1395
1396    // --- Additional reductions ---
1397
1398    /// Argmin along a dimension.
1399    pub fn argmin(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1400        let mut handle: FlodlTensor = ptr::null_mut();
1401        let err = unsafe { ffi::flodl_argmin(self.handle, dim, keepdim as i32, &mut handle) };
1402        check_err(err)?;
1403        Ok(Tensor::from_raw(handle))
1404    }
1405
1406    /// Variance of all elements (Bessel-corrected).
1407    pub fn var(&self) -> Result<Tensor> {
1408        let mut handle: FlodlTensor = ptr::null_mut();
1409        let err = unsafe { ffi::flodl_var(self.handle, &mut handle) };
1410        check_err(err)?;
1411        Ok(Tensor::from_raw(handle))
1412    }
1413
1414    /// Standard deviation of all elements (Bessel-corrected).
1415    #[allow(clippy::should_implement_trait)]
1416    pub fn std(&self) -> Result<Tensor> {
1417        let mut handle: FlodlTensor = ptr::null_mut();
1418        let err = unsafe { ffi::flodl_std_op(self.handle, &mut handle) };
1419        check_err(err)?;
1420        Ok(Tensor::from_raw(handle))
1421    }
1422
1423    /// Variance along a dimension (Bessel-corrected).
1424    pub fn var_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1425        let mut handle: FlodlTensor = ptr::null_mut();
1426        let err = unsafe { ffi::flodl_var_dim(self.handle, dim, keepdim as i32, &mut handle) };
1427        check_err(err)?;
1428        Ok(Tensor::from_raw(handle))
1429    }
1430
1431    /// Standard deviation along a dimension (Bessel-corrected).
1432    pub fn std_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1433        let mut handle: FlodlTensor = ptr::null_mut();
1434        let err = unsafe { ffi::flodl_std_dim(self.handle, dim, keepdim as i32, &mut handle) };
1435        check_err(err)?;
1436        Ok(Tensor::from_raw(handle))
1437    }
1438
1439    // --- Element-wise math (trig, rounding, sign) ---
1440
1441    /// Element-wise sine.
1442    pub fn sin(&self) -> Result<Tensor> {
1443        let mut handle: FlodlTensor = ptr::null_mut();
1444        let err = unsafe { ffi::flodl_sin(self.handle, &mut handle) };
1445        check_err(err)?;
1446        Ok(Tensor::from_raw(handle))
1447    }
1448
1449    /// Element-wise cosine.
1450    pub fn cos(&self) -> Result<Tensor> {
1451        let mut handle: FlodlTensor = ptr::null_mut();
1452        let err = unsafe { ffi::flodl_cos(self.handle, &mut handle) };
1453        check_err(err)?;
1454        Ok(Tensor::from_raw(handle))
1455    }
1456
1457    /// Element-wise sign (-1, 0, or +1).
1458    pub fn sign(&self) -> Result<Tensor> {
1459        let mut handle: FlodlTensor = ptr::null_mut();
1460        let err = unsafe { ffi::flodl_sign(self.handle, &mut handle) };
1461        check_err(err)?;
1462        Ok(Tensor::from_raw(handle))
1463    }
1464
1465    /// Element-wise floor.
1466    pub fn floor(&self) -> Result<Tensor> {
1467        let mut handle: FlodlTensor = ptr::null_mut();
1468        let err = unsafe { ffi::flodl_floor(self.handle, &mut handle) };
1469        check_err(err)?;
1470        Ok(Tensor::from_raw(handle))
1471    }
1472
1473    /// Element-wise ceiling.
1474    pub fn ceil(&self) -> Result<Tensor> {
1475        let mut handle: FlodlTensor = ptr::null_mut();
1476        let err = unsafe { ffi::flodl_ceil(self.handle, &mut handle) };
1477        check_err(err)?;
1478        Ok(Tensor::from_raw(handle))
1479    }
1480
1481    /// Element-wise rounding to nearest integer.
1482    pub fn round(&self) -> Result<Tensor> {
1483        let mut handle: FlodlTensor = ptr::null_mut();
1484        let err = unsafe { ffi::flodl_round(self.handle, &mut handle) };
1485        check_err(err)?;
1486        Ok(Tensor::from_raw(handle))
1487    }
1488
1489    /// Element-wise reciprocal (1/x).
1490    pub fn reciprocal(&self) -> Result<Tensor> {
1491        let mut handle: FlodlTensor = ptr::null_mut();
1492        let err = unsafe { ffi::flodl_reciprocal(self.handle, &mut handle) };
1493        check_err(err)?;
1494        Ok(Tensor::from_raw(handle))
1495    }
1496
1497    // --- Advanced indexing ---
1498
1499    /// Gather values along a dimension using an index tensor.
1500    pub fn gather(&self, dim: i32, index: &Tensor) -> Result<Tensor> {
1501        let mut handle: FlodlTensor = ptr::null_mut();
1502        let err = unsafe {
1503            ffi::flodl_gather(self.handle, dim, index.handle, &mut handle)
1504        };
1505        check_err(err)?;
1506        Ok(Tensor::from_raw(handle))
1507    }
1508
1509    /// Scatter-add: accumulate src into self at index positions along dim.
1510    pub fn scatter_add(&self, dim: i32, index: &Tensor, src: &Tensor) -> Result<Tensor> {
1511        let mut handle: FlodlTensor = ptr::null_mut();
1512        let err = unsafe {
1513            ffi::flodl_scatter_add(self.handle, dim, index.handle, src.handle, &mut handle)
1514        };
1515        check_err(err)?;
1516        Ok(Tensor::from_raw(handle))
1517    }
1518
1519    // --- Sorting ---
1520
1521    /// Top-k values and indices along a dimension. Returns (values, indices).
1522    pub fn topk(&self, k: i64, dim: i32, largest: bool, sorted: bool) -> Result<(Tensor, Tensor)> {
1523        let mut values: FlodlTensor = ptr::null_mut();
1524        let mut indices: FlodlTensor = ptr::null_mut();
1525        let err = unsafe {
1526            ffi::flodl_topk(
1527                self.handle, k, dim, largest as i32, sorted as i32,
1528                &mut values, &mut indices,
1529            )
1530        };
1531        check_err(err)?;
1532        Ok((Tensor::from_raw(values), Tensor::from_raw(indices)))
1533    }
1534
1535    /// Sort along a dimension. Returns (sorted_values, indices).
1536    pub fn sort(&self, dim: i32, descending: bool) -> Result<(Tensor, Tensor)> {
1537        let mut values: FlodlTensor = ptr::null_mut();
1538        let mut indices: FlodlTensor = ptr::null_mut();
1539        let err = unsafe {
1540            ffi::flodl_sort(self.handle, dim, descending as i32, &mut values, &mut indices)
1541        };
1542        check_err(err)?;
1543        Ok((Tensor::from_raw(values), Tensor::from_raw(indices)))
1544    }
1545
1546    // --- Tensor creation (additional) ---
1547
1548    /// Create an identity matrix of size n x n.
1549    pub fn eye(n: i64, opts: TensorOptions) -> Result<Self> {
1550        let mut handle: FlodlTensor = ptr::null_mut();
1551        let (dt, di) = opts.device.to_ffi();
1552        let err = unsafe {
1553            ffi::flodl_eye(n, opts.dtype as i32, dt, di, &mut handle)
1554        };
1555        check_err(err)?;
1556        Ok(Self::from_raw(handle))
1557    }
1558
1559    /// Create a tensor filled with a scalar value.
1560    pub fn full(shape: &[i64], value: f64, opts: TensorOptions) -> Result<Self> {
1561        let mut shape = shape.to_vec();
1562        let mut handle: FlodlTensor = ptr::null_mut();
1563        let (dt, di) = opts.device.to_ffi();
1564        let err = unsafe {
1565            ffi::flodl_full(
1566                shape.as_mut_ptr(), shape.len() as i32, value,
1567                opts.dtype as i32, dt, di, &mut handle,
1568            )
1569        };
1570        check_err(err)?;
1571        Ok(Self::from_raw(handle))
1572    }
1573
1574    // --- Shape operations (additional) ---
1575
1576    /// Split tensor into batches of `batch_size` along dimension 0.
1577    /// The last batch may be smaller if the tensor size isn't evenly divisible.
1578    ///
1579    /// ```ignore
1580    /// let data = Tensor::randn(&[100, 4], opts)?;
1581    /// for batch in data.batches(32)? {
1582    ///     let x = Variable::new(batch, false);
1583    ///     // ...
1584    /// }
1585    /// ```
1586    pub fn batches(&self, batch_size: i64) -> Result<Vec<Tensor>> {
1587        let n = self.shape()[0];
1588        let mut result = Vec::new();
1589        let mut start = 0i64;
1590        while start < n {
1591            let len = (batch_size).min(n - start);
1592            result.push(self.narrow(0, start, len)?);
1593            start += len;
1594        }
1595        Ok(result)
1596    }
1597
1598    /// Split tensor into chunks along a dimension.
1599    pub fn chunk(&self, chunks: i32, dim: i32) -> Result<Vec<Tensor>> {
1600        let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
1601        let mut count: i32 = 0;
1602        let err = unsafe {
1603            ffi::flodl_chunk(self.handle, chunks, dim, &mut results_ptr, &mut count)
1604        };
1605        check_err(err)?;
1606        let mut tensors = Vec::with_capacity(count as usize);
1607        for i in 0..count as usize {
1608            let handle = unsafe { *results_ptr.add(i) };
1609            tensors.push(Tensor::from_raw(handle));
1610        }
1611        if !results_ptr.is_null() {
1612            // Free the C-allocated array (tensors are now owned by Rust).
1613            // flodl_free_string is just free() — safe for any malloc'd pointer.
1614            unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
1615        }
1616        Ok(tensors)
1617    }
1618
1619    /// Repeat the tensor along each dimension.
1620    pub fn repeat(&self, repeats: &[i64]) -> Result<Tensor> {
1621        let mut repeats = repeats.to_vec();
1622        let mut handle: FlodlTensor = ptr::null_mut();
1623        let err = unsafe {
1624            ffi::flodl_repeat(self.handle, repeats.as_mut_ptr(), repeats.len() as i32, &mut handle)
1625        };
1626        check_err(err)?;
1627        Ok(Tensor::from_raw(handle))
1628    }
1629
1630    /// Constant-value padding. Padding format matches PyTorch: [left, right, top, bottom, ...].
1631    pub fn pad(&self, padding: &[i64], value: f64) -> Result<Tensor> {
1632        let mut padding = padding.to_vec();
1633        let mut handle: FlodlTensor = ptr::null_mut();
1634        let err = unsafe {
1635            ffi::flodl_pad(
1636                self.handle, padding.as_mut_ptr(), padding.len() as i32,
1637                value, &mut handle,
1638            )
1639        };
1640        check_err(err)?;
1641        Ok(Tensor::from_raw(handle))
1642    }
1643
1644    /// Insert multiple dimensions of size 1.
1645    /// Dims are sorted ascending and applied sequentially.
1646    pub fn unsqueeze_many(&self, dims: &[i32]) -> Result<Tensor> {
1647        let mut sorted = dims.to_vec();
1648        sorted.sort();
1649        let mut t = self.unsqueeze(sorted[0])?;
1650        for &d in &sorted[1..] {
1651            t = t.unsqueeze(d)?;
1652        }
1653        Ok(t)
1654    }
1655
1656    /// Compute meshgrid from a slice of 1-D tensors (always "ij" indexing).
1657    pub fn meshgrid(tensors: &[&Tensor]) -> Result<Vec<Tensor>> {
1658        let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
1659        let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
1660        let mut count: i32 = 0;
1661        let err = unsafe {
1662            ffi::flodl_meshgrid(
1663                handles.as_mut_ptr(), handles.len() as i32,
1664                &mut results_ptr, &mut count,
1665            )
1666        };
1667        check_err(err)?;
1668        let mut out = Vec::with_capacity(count as usize);
1669        for i in 0..count as usize {
1670            let handle = unsafe { *results_ptr.add(i) };
1671            out.push(Tensor::from_raw(handle));
1672        }
1673        if !results_ptr.is_null() {
1674            unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
1675        }
1676        Ok(out)
1677    }
1678
1679    /// Pairwise L2 distance between rows of two batched matrices.
1680    /// Input shapes: `[B, P, D]` and `[B, R, D]` -> output `[B, P, R]`.
1681    pub fn cdist(&self, other: &Tensor) -> Result<Tensor> {
1682        self.cdist_p(other, 2.0)
1683    }
1684
1685    /// Pairwise distance with custom p-norm.
1686    pub fn cdist_p(&self, other: &Tensor, p: f64) -> Result<Tensor> {
1687        let mut handle: FlodlTensor = ptr::null_mut();
1688        let err = unsafe { ffi::flodl_cdist(self.handle, other.handle, p, &mut handle) };
1689        check_err(err)?;
1690        Ok(Tensor::from_raw(handle))
1691    }
1692
1693    // --- Device ---
1694
1695    /// Move this tensor to a different device (CPU or CUDA).
1696    /// Returns a new tensor; the original is unchanged.
1697    ///
1698    /// ```ignore
1699    /// let gpu = t.to_device(Device::CUDA(0))?;
1700    /// ```
1701    pub fn to_device(&self, device: Device) -> Result<Tensor> {
1702        let mut handle: FlodlTensor = ptr::null_mut();
1703        let (dt, di) = device.to_ffi();
1704        let err = unsafe { ffi::flodl_to_device(self.handle, dt, di, &mut handle) };
1705        check_err(err)?;
1706        Ok(Tensor::from_raw(handle))
1707    }
1708
1709    /// Move this tensor to the same device as `other`.
1710    /// No-op (returns a clone) if both are already on the same device.
1711    ///
1712    /// ```ignore
1713    /// let x = x.to_device_of(&weights)?;  // ensure same device
1714    /// ```
1715    pub fn to_device_of(&self, other: &Tensor) -> Result<Tensor> {
1716        let target = other.device();
1717        if self.device() == target {
1718            return Ok(self.clone());
1719        }
1720        self.to_device(target)
1721    }
1722
1723    // --- Autograd ---
1724
1725    /// Set requires_grad on this tensor. Returns a new tensor that shares
1726    /// storage but has the grad flag set. This enables libtorch's native
1727    /// autograd tracking for all subsequent operations.
1728    pub fn set_requires_grad(&self, requires_grad: bool) -> Result<Tensor> {
1729        let mut handle: FlodlTensor = ptr::null_mut();
1730        let err = unsafe {
1731            ffi::flodl_set_requires_grad(self.handle, requires_grad as i32, &mut handle)
1732        };
1733        check_err(err)?;
1734        Ok(Tensor::from_raw(handle))
1735    }
1736
1737    /// Check whether this tensor requires gradient computation.
1738    pub fn requires_grad(&self) -> bool {
1739        unsafe { ffi::flodl_requires_grad(self.handle) != 0 }
1740    }
1741
1742    /// Run backward pass from this scalar tensor. Populates .grad() on
1743    /// all leaf tensors in the computation graph.
1744    pub fn backward(&self) -> Result<()> {
1745        let err = unsafe { ffi::flodl_backward(self.handle) };
1746        check_err(err)
1747    }
1748
1749    /// Get the accumulated gradient for this tensor, if any.
1750    /// Returns None if no gradient has been computed.
1751    pub fn grad(&self) -> Option<Tensor> {
1752        let mut handle: FlodlTensor = ptr::null_mut();
1753        let err = unsafe { ffi::flodl_grad(self.handle, &mut handle) };
1754        if !err.is_null() {
1755            unsafe { ffi::flodl_free_string(err) };
1756            return None;
1757        }
1758        if handle.is_null() {
1759            None
1760        } else {
1761            Some(Tensor::from_raw(handle))
1762        }
1763    }
1764
1765    /// Replace the gradient tensor (for gradient clipping / unscaling).
1766    pub fn set_grad(&self, grad: &Tensor) -> Result<()> {
1767        let err = unsafe { ffi::flodl_set_grad(self.handle, grad.handle) };
1768        check_err(err)
1769    }
1770
1771    /// Zero out the accumulated gradient.
1772    pub fn zero_grad(&self) -> Result<()> {
1773        let err = unsafe { ffi::flodl_zero_grad(self.handle) };
1774        check_err(err)
1775    }
1776
1777    /// Null out the gradient pointer instead of zeroing the data.
1778    /// No CUDA kernel — just resets the grad tensor to undefined.
1779    /// This is what PyTorch does by default since 1.7.
1780    pub fn zero_grad_set_to_none(&self) {
1781        unsafe { ffi::flodl_zero_grad_set_to_none(self.handle) }
1782    }
1783
1784    /// Fused clip_grad_norm: compute global L2 norm across all param grads
1785    /// and scale in-place if it exceeds max_norm. Single C++ call.
1786    /// Returns the original total norm before clipping.
1787    pub fn clip_grad_norm_fused(params: &[Tensor], max_norm: f64) -> Result<f64> {
1788        if params.is_empty() {
1789            return Ok(0.0);
1790        }
1791        let mut handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
1792        let mut total_norm: f64 = 0.0;
1793        let err = unsafe {
1794            ffi::flodl_clip_grad_norm(
1795                handles.as_mut_ptr(),
1796                handles.len() as i32,
1797                max_norm,
1798                &mut total_norm,
1799            )
1800        };
1801        check_err(err)?;
1802        Ok(total_norm)
1803    }
1804
1805    /// Whether this tensor is a leaf in the autograd graph.
1806    /// A tensor is a leaf if it was created by the user (not by an op)
1807    /// or if it doesn't require grad.
1808    pub fn is_leaf(&self) -> bool {
1809        unsafe { ffi::flodl_is_leaf(self.handle) != 0 }
1810    }
1811
1812    /// Count unique autograd nodes reachable from this tensor's grad_fn.
1813    /// Returns 0 for leaf tensors or tensors without gradient tracking.
1814    /// This is the number of backward operations libtorch will execute.
1815    pub fn autograd_node_count(&self) -> i64 {
1816        unsafe { ffi::flodl_autograd_node_count(self.handle) }
1817    }
1818
1819    /// Detach from the computation graph. Returns a new tensor that shares
1820    /// storage but has no autograd history.
1821    pub fn detach(&self) -> Result<Tensor> {
1822        let mut handle: FlodlTensor = ptr::null_mut();
1823        let err = unsafe { ffi::flodl_detach(self.handle, &mut handle) };
1824        check_err(err)?;
1825        Ok(Tensor::from_raw(handle))
1826    }
1827
1828    /// In-place detach: sever the grad_fn chain on this tensor without
1829    /// allocating a new handle. After this call the tensor's autograd_meta
1830    /// no longer references any C++ Node objects, allowing the autograd
1831    /// graph to be freed immediately rather than when the tensor is dropped.
1832    pub fn detach_(&self) -> Result<()> {
1833        let err = unsafe { ffi::flodl_detach_(self.handle) };
1834        check_err(err)
1835    }
1836
1837    // --- In-place operations ---
1838
1839    /// In-place add: self += other
1840    pub fn add_(&self, other: &Tensor) -> Result<()> {
1841        let err = unsafe { ffi::flodl_add_(self.handle, other.handle) };
1842        check_err(err)
1843    }
1844
1845    /// In-place subtract: self -= other
1846    pub fn sub_(&self, other: &Tensor) -> Result<()> {
1847        let err = unsafe { ffi::flodl_sub_(self.handle, other.handle) };
1848        check_err(err)
1849    }
1850
1851    /// In-place scalar multiply: self *= scalar
1852    pub fn mul_scalar_(&self, scalar: f64) -> Result<()> {
1853        let err = unsafe { ffi::flodl_mul_scalar_(self.handle, scalar) };
1854        check_err(err)
1855    }
1856
1857    /// In-place scalar add: self += scalar
1858    pub fn add_scalar_(&self, scalar: f64) -> Result<()> {
1859        let err = unsafe { ffi::flodl_add_scalar_(self.handle, scalar) };
1860        check_err(err)
1861    }
1862
1863    /// In-place zero: self = 0
1864    pub fn zero_(&self) -> Result<()> {
1865        let err = unsafe { ffi::flodl_zero_(self.handle) };
1866        check_err(err)
1867    }
1868
1869    /// Fused Adam/AdamW step: updates param, m, and v tensors in-place.
1870    #[allow(clippy::too_many_arguments)]
1871    ///
1872    /// Performs the full Adam update in a single FFI call (~5 kernel launches
1873    /// instead of ~16), eliminating temporary tensor allocations.
1874    ///
1875    /// - `self` — parameter tensor (updated in-place)
1876    /// - `grad` — gradient (read-only)
1877    /// - `m`, `v` — moment buffers (updated in-place)
1878    /// - `weight_decay` — 0.0 for Adam, >0 for AdamW (decoupled)
1879    /// - `step` — timestep for bias correction
1880    pub fn adam_step(
1881        &self, grad: &Tensor, m: &Tensor, v: &Tensor,
1882        lr: f64, beta1: f64, beta2: f64, eps: f64,
1883        weight_decay: f64, step: i64,
1884    ) -> Result<()> {
1885        let err = unsafe {
1886            ffi::flodl_adam_step(
1887                self.handle, grad.handle, m.handle, v.handle,
1888                lr, beta1, beta2, eps, weight_decay, step,
1889            )
1890        };
1891        check_err(err)
1892    }
1893
1894    // --- Batched Adam step ---
1895
1896    /// Perform Adam/AdamW update on all params in one C++ loop.
1897    /// Eliminates per-param FFI overhead. `lrs[i]` supports per-group LR.
1898    #[allow(clippy::too_many_arguments)]
1899    pub fn adam_step_batched(
1900        params: &[Tensor], grads: &[Tensor], ms: &[Tensor], vs: &[Tensor],
1901        lrs: &mut [f64], beta1: f64, beta2: f64, eps: f64,
1902        weight_decay: f64, step: i64,
1903    ) -> Result<()> {
1904        let count = params.len() as i32;
1905        let mut p_handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
1906        let mut g_handles: Vec<FlodlTensor> = grads.iter().map(|t| t.handle).collect();
1907        let mut m_handles: Vec<FlodlTensor> = ms.iter().map(|t| t.handle).collect();
1908        let mut v_handles: Vec<FlodlTensor> = vs.iter().map(|t| t.handle).collect();
1909        let err = unsafe {
1910            ffi::flodl_adam_step_batched(
1911                p_handles.as_mut_ptr(), g_handles.as_mut_ptr(),
1912                m_handles.as_mut_ptr(), v_handles.as_mut_ptr(),
1913                lrs.as_mut_ptr(), count,
1914                beta1, beta2, eps, weight_decay, step,
1915            )
1916        };
1917        check_err(err)
1918    }
1919
1920    // --- Pinned memory ---
1921
1922    /// Copy this CPU tensor into page-locked (pinned) memory.
1923    ///
1924    /// Pinned memory enables async CPU→GPU transfers via `cudaMemcpyAsync`.
1925    /// Only valid for CPU tensors. Returns a new tensor in pinned memory.
1926    pub fn pin_memory(&self) -> Result<Tensor> {
1927        let mut handle: FlodlTensor = ptr::null_mut();
1928        let err = unsafe { ffi::flodl_pin_memory(self.handle, &mut handle) };
1929        check_err(err)?;
1930        Ok(Tensor::from_raw(handle))
1931    }
1932
1933    /// Returns true if this tensor is stored in pinned (page-locked) memory.
1934    pub fn is_pinned(&self) -> bool {
1935        unsafe { ffi::flodl_is_pinned(self.handle) != 0 }
1936    }
1937}
1938
1939impl fmt::Debug for Tensor {
1940    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1941        write!(
1942            f,
1943            "Tensor({:?}, {:?}, {:?})",
1944            self.shape(),
1945            self.dtype(),
1946            self.device()
1947        )
1948    }
1949}
1950
1951/// Returns true if CUDA is available.
1952///
1953/// On Linux, this also ensures CUDA libraries are loaded (they can be
1954/// dropped by the linker's `--as-needed` flag since no Rust code
1955/// directly references symbols in `libtorch_cuda.so`).
1956pub fn cuda_available() -> bool {
1957    // flodl_force_cuda_link references c10::cuda::device_count(),
1958    // creating a real symbol dependency on c10_cuda.so. This prevents
1959    // --as-needed from dropping CUDA libs. The call is cheap (no-op on
1960    // non-CUDA builds since the symbol resolves to a stub returning 0).
1961    unsafe { let _ = ffi::flodl_force_cuda_link(); }
1962    unsafe { ffi::flodl_cuda_is_available() != 0 }
1963}
1964
1965/// Returns the number of CUDA devices.
1966pub fn cuda_device_count() -> i32 {
1967    unsafe { ffi::flodl_cuda_device_count() }
1968}
1969
1970/// Query CUDA memory usage for a specific device.
1971/// Returns `(used_bytes, total_bytes)` or an error if CUDA is not available.
1972pub fn cuda_memory_info_idx(device_index: i32) -> Result<(u64, u64)> {
1973    let mut used: u64 = 0;
1974    let mut total: u64 = 0;
1975    check_err(unsafe { ffi::flodl_cuda_mem_info(device_index, &mut used, &mut total) })?;
1976    Ok((used, total))
1977}
1978
1979/// Query CUDA memory usage for device 0.
1980/// Returns `(used_bytes, total_bytes)` or an error if CUDA is not available.
1981pub fn cuda_memory_info() -> Result<(u64, u64)> {
1982    cuda_memory_info_idx(0)
1983}
1984
1985/// Query bytes currently handed out by the CUDA caching allocator on a specific device.
1986///
1987/// This is the Rust equivalent of `torch.cuda.memory_allocated()`. It can exceed
1988/// physical VRAM when unified memory spills to host RAM.
1989pub fn cuda_allocated_bytes_idx(device_index: i32) -> Result<u64> {
1990    let mut allocated: u64 = 0;
1991    check_err(unsafe { ffi::flodl_cuda_alloc_bytes(device_index, &mut allocated) })?;
1992    Ok(allocated)
1993}
1994
1995/// Query bytes currently handed out by the CUDA caching allocator on device 0.
1996pub fn cuda_allocated_bytes() -> Result<u64> {
1997    cuda_allocated_bytes_idx(0)
1998}
1999
2000/// Query GPU utilization percentage (0-100) via NVML.
2001/// Returns `None` if NVML is not available or the query fails.
2002pub fn cuda_utilization() -> Option<u32> {
2003    cuda_utilization_idx(0)
2004}
2005
2006/// Query GPU utilization percentage for a specific device (0-100) via NVML.
2007pub fn cuda_utilization_idx(device_index: i32) -> Option<u32> {
2008    let val = unsafe { ffi::flodl_cuda_utilization(device_index) };
2009    if val >= 0 { Some(val as u32) } else { None }
2010}
2011
2012/// Set the current CUDA device.
2013pub fn set_current_cuda_device(device_index: u8) {
2014    unsafe { ffi::flodl_set_current_device(device_index as i32) };
2015}
2016
2017/// Get the current CUDA device index.
2018pub fn current_cuda_device() -> u8 {
2019    unsafe { ffi::flodl_get_current_device() as u8 }
2020}
2021
2022/// Synchronize a CUDA device (wait for all pending work to complete).
2023pub fn cuda_synchronize(device_index: u8) {
2024    unsafe { ffi::flodl_cuda_synchronize(device_index as i32) };
2025}
2026
2027/// Returns the GPU device name for the given index (e.g. "NVIDIA GeForce GTX 1060 6GB").
2028pub fn cuda_device_name_idx(device: i32) -> Option<String> {
2029    let mut buf = [0i8; 256];
2030    let err = unsafe { ffi::flodl_cuda_device_name(device, buf.as_mut_ptr(), 256) };
2031    if err.is_null() {
2032        let name = unsafe { CStr::from_ptr(buf.as_ptr()) }
2033            .to_string_lossy()
2034            .into_owned();
2035        Some(name)
2036    } else {
2037        unsafe { ffi::flodl_free_string(err) };
2038        None
2039    }
2040}
2041
2042/// Returns the GPU device name for device 0 (e.g. "NVIDIA GeForce GTX 1060 6GB").
2043pub fn cuda_device_name() -> Option<String> {
2044    cuda_device_name_idx(0)
2045}
2046
2047/// Information about a CUDA device.
2048#[derive(Debug, Clone)]
2049pub struct DeviceInfo {
2050    /// Device index (0-based).
2051    pub index: u8,
2052    /// Device name (e.g. "NVIDIA GeForce GTX 1060 6GB").
2053    pub name: String,
2054    /// Total device memory in bytes.
2055    pub total_memory: u64,
2056}
2057
2058/// Enumerate all available CUDA devices.
2059pub fn cuda_devices() -> Vec<DeviceInfo> {
2060    let n = cuda_device_count();
2061    (0..n).filter_map(|i| {
2062        let name = cuda_device_name_idx(i)?;
2063        let total_memory = cuda_memory_info_idx(i).map(|(_, t)| t).unwrap_or(0);
2064        Some(DeviceInfo { index: i as u8, name, total_memory })
2065    }).collect()
2066}
2067
2068/// One-line hardware summary for dashboard headers.
2069///
2070/// Returns something like:
2071/// `"CPU: AMD Ryzen 9 5900X (64GB) | GPU: NVIDIA GeForce GTX 1060 (6GB)"`
2072pub fn hardware_summary() -> String {
2073    let cpu = cpu_model_name().unwrap_or_else(|| "Unknown CPU".into());
2074    let threads = cpu_thread_count();
2075    let ram = total_ram_gb();
2076    let mut s = format!("{} ({} threads, {}GB)", cpu, threads, ram);
2077
2078    if cuda_available() {
2079        let n = cuda_device_count();
2080        for i in 0..n {
2081            if let Some(gpu) = cuda_device_name_idx(i) {
2082                let vram_str = cuda_memory_info_idx(i)
2083                    .map(|(_, total)| format!(" ({}GB)", total / (1024 * 1024 * 1024)))
2084                    .unwrap_or_default();
2085                let _ = std::fmt::Write::write_fmt(&mut s, format_args!(
2086                    " | {}{}", gpu, vram_str
2087                ));
2088            }
2089        }
2090    }
2091    s
2092}
2093
2094/// Count logical CPU threads from /proc/cpuinfo (Linux).
2095fn cpu_thread_count() -> usize {
2096    std::fs::read_to_string("/proc/cpuinfo")
2097        .ok()
2098        .map(|s| s.lines().filter(|l| l.starts_with("processor")).count())
2099        .unwrap_or(1)
2100}
2101
2102/// Read CPU model name from /proc/cpuinfo (Linux).
2103fn cpu_model_name() -> Option<String> {
2104    let info = std::fs::read_to_string("/proc/cpuinfo").ok()?;
2105    for line in info.lines() {
2106        if line.starts_with("model name") && let Some(val) = line.split(':').nth(1) {
2107            return Some(val.trim().to_string());
2108        }
2109    }
2110    None
2111}
2112
2113/// Total physical RAM in GB (Linux).
2114fn total_ram_gb() -> u64 {
2115    std::fs::read_to_string("/proc/meminfo")
2116        .ok()
2117        .and_then(|s| {
2118            for line in s.lines() {
2119                if line.starts_with("MemTotal:") {
2120                    let kb: u64 = line.split_whitespace().nth(1)?.parse().ok()?;
2121                    return Some(kb / (1024 * 1024));
2122                }
2123            }
2124            None
2125        })
2126        .unwrap_or(0)
2127}
2128
2129/// Enable or disable cuDNN benchmark mode.
2130///
2131/// When enabled, cuDNN will benchmark multiple convolution algorithms
2132/// on the first call and cache the fastest. Benefits fixed-size workloads
2133/// (FBRL, fixed image dims) with 5-10% speedup. Can hurt dynamic-shape
2134/// workloads due to warmup cost. Off by default — users opt in.
2135pub fn set_cudnn_benchmark(enable: bool) {
2136    unsafe { ffi::flodl_set_cudnn_benchmark(enable as i32) }
2137}
2138
2139/// Ask glibc to return free memory to the OS (Linux only).
2140///
2141/// Returns `true` if memory was actually released. Useful for
2142/// distinguishing allocator fragmentation from real leaks:
2143/// if RSS drops after calling this, the growth was fragmentation.
2144pub fn malloc_trim() -> bool {
2145    unsafe { ffi::flodl_malloc_trim() != 0 }
2146}
2147
2148/// Number of live C++ Tensor handles (created but not yet dropped).
2149/// If this grows over time during training, there is a handle leak.
2150/// If it stays stable but RSS grows, the leak is inside libtorch.
2151pub fn live_tensor_count() -> u64 {
2152    LIVE_TENSOR_COUNT.load(Ordering::Relaxed)
2153}
2154
2155/// Read current process RSS in kilobytes (Linux only).
2156/// Returns 0 on non-Linux or if /proc/self/statm is unreadable.
2157pub fn rss_kb() -> usize {
2158    std::fs::read_to_string("/proc/self/statm")
2159        .ok()
2160        .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
2161        .map(|pages| pages * 4)
2162        .unwrap_or(0)
2163}
2164
2165/// Returns the device to use in tests: CUDA when compiled with `--features cuda`
2166/// and a GPU is available, CPU otherwise.
2167#[cfg(test)]
2168pub fn test_device() -> Device {
2169    use std::sync::Once;
2170    static PRINT: Once = Once::new();
2171    let dev = if cfg!(feature = "cuda") && cuda_available() { Device::CUDA(0) } else { Device::CPU };
2172    PRINT.call_once(|| eprintln!("\n*** flodl test device: {} ***\n", dev));
2173    dev
2174}
2175
2176/// Returns `TensorOptions` for tests (Float32 on `test_device()`).
2177#[cfg(test)]
2178pub fn test_opts() -> TensorOptions {
2179    TensorOptions { dtype: DType::Float32, device: test_device() }
2180}
2181
2182#[cfg(test)]
2183mod tests {
2184    use super::*;
2185
2186    #[test]
2187    fn test_zeros() {
2188        let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2189        assert_eq!(t.shape(), vec![2, 3]);
2190        assert_eq!(t.dtype(), DType::Float32);
2191        assert_eq!(t.device(), test_device());
2192        assert_eq!(t.numel(), 6);
2193
2194        let data = t.to_f32_vec().unwrap();
2195        assert_eq!(data, vec![0.0; 6]);
2196    }
2197
2198    #[test]
2199    fn test_from_f32() {
2200        let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2201        assert_eq!(t.shape(), vec![3]);
2202        let data = t.to_f32_vec().unwrap();
2203        assert_eq!(data, vec![1.0, 2.0, 3.0]);
2204    }
2205
2206    #[test]
2207    fn test_add() {
2208        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2209        let b = Tensor::from_f32(&[4.0, 5.0, 6.0], &[3], test_device()).unwrap();
2210        let c = a.add(&b).unwrap();
2211        assert_eq!(c.to_f32_vec().unwrap(), vec![5.0, 7.0, 9.0]);
2212    }
2213
2214    #[test]
2215    fn test_matmul() {
2216        let a = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2217        let b = Tensor::from_f32(&[5.0, 6.0, 7.0, 8.0], &[2, 2], test_device()).unwrap();
2218        let c = a.matmul(&b).unwrap();
2219        assert_eq!(c.to_f32_vec().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
2220    }
2221
2222    #[test]
2223    fn test_chaining() {
2224        let a = Tensor::from_f32(&[1.0, -2.0, 3.0], &[3], test_device()).unwrap();
2225        let b = Tensor::from_f32(&[1.0, 1.0, 1.0], &[3], test_device()).unwrap();
2226        let result = a.add(&b).unwrap().relu().unwrap().sum().unwrap();
2227        // [1+1, -2+1, 3+1] = [2, -1, 4] -> relu -> [2, 0, 4] -> sum -> 6
2228        let val = result.item().unwrap();
2229        assert!((val - 6.0).abs() < 1e-5);
2230    }
2231
2232    #[test]
2233    fn test_drop_frees_memory() {
2234        // Create and immediately drop — verifies Drop doesn't crash.
2235        let _ = Tensor::zeros(&[1000, 1000], test_opts()).unwrap();
2236        // If Drop is broken, this would leak or crash.
2237    }
2238
2239    #[test]
2240    fn test_debug_format() {
2241        let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2242        let s = format!("{:?}", t);
2243        assert!(s.contains("[2, 3]"));
2244        assert!(s.contains("Float32"));
2245    }
2246
2247    #[test]
2248    fn test_div_scalar() {
2249        let t = Tensor::from_f32(&[6.0, 9.0], &[2], test_device()).unwrap();
2250        let r = t.div_scalar(3.0).unwrap();
2251        let data = r.to_f32_vec().unwrap();
2252        assert!((data[0] - 2.0).abs() < 1e-5);
2253        assert!((data[1] - 3.0).abs() < 1e-5);
2254    }
2255
2256    #[test]
2257    fn test_mean() {
2258        let t = Tensor::from_f32(&[2.0, 4.0, 6.0], &[3], test_device()).unwrap();
2259        let m = t.mean().unwrap();
2260        assert!((m.item().unwrap() - 4.0).abs() < 1e-5);
2261    }
2262
2263    #[test]
2264    fn test_flatten() {
2265        let t = Tensor::ones(&[2, 3, 4], test_opts()).unwrap();
2266        let f = t.flatten(1, 2).unwrap();
2267        assert_eq!(f.shape(), vec![2, 12]);
2268    }
2269
2270    #[test]
2271    fn test_stack() {
2272        let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2273        let b = Tensor::from_f32(&[3.0, 4.0], &[2], test_device()).unwrap();
2274        let c = Tensor::from_f32(&[5.0, 6.0], &[2], test_device()).unwrap();
2275
2276        // Stack along dim 0: [3, 2]
2277        let s = Tensor::stack(&[&a, &b, &c], 0).unwrap();
2278        assert_eq!(s.shape(), vec![3, 2]);
2279        let data = s.to_f32_vec().unwrap();
2280        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2281
2282        // Stack along dim 1: [2, 3]
2283        let s1 = Tensor::stack(&[&a, &b, &c], 1).unwrap();
2284        assert_eq!(s1.shape(), vec![2, 3]);
2285        let data1 = s1.to_f32_vec().unwrap();
2286        assert_eq!(data1, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
2287    }
2288
2289    #[test]
2290    fn test_ones_from_f64_from_i64() {
2291        let o = Tensor::ones(&[2, 3], test_opts()).unwrap();
2292        assert_eq!(o.to_f32_vec().unwrap(), vec![1.0; 6]);
2293
2294        let f = Tensor::from_f64(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2295        assert_eq!(f.dtype(), DType::Float64);
2296        assert_eq!(f.to_f64_vec().unwrap(), vec![1.0, 2.0, 3.0]);
2297
2298        let i = Tensor::from_i64(&[10, 20, 30], &[3], test_device()).unwrap();
2299        assert_eq!(i.dtype(), DType::Int64);
2300        assert_eq!(i.to_i64_vec().unwrap(), vec![10, 20, 30]);
2301    }
2302
2303    #[test]
2304    fn test_sub_mul_div() {
2305        let a = Tensor::from_f32(&[6.0, 8.0], &[2], test_device()).unwrap();
2306        let b = Tensor::from_f32(&[2.0, 3.0], &[2], test_device()).unwrap();
2307        assert_eq!(a.sub(&b).unwrap().to_f32_vec().unwrap(), vec![4.0, 5.0]);
2308        assert_eq!(a.mul(&b).unwrap().to_f32_vec().unwrap(), vec![12.0, 24.0]);
2309        let d = a.div(&b).unwrap().to_f32_vec().unwrap();
2310        assert!((d[0] - 3.0).abs() < 1e-5);
2311        assert!((d[1] - 8.0 / 3.0).abs() < 1e-5);
2312    }
2313
2314    #[test]
2315    fn test_scalar_ops() {
2316        let t = Tensor::from_f32(&[2.0, 4.0], &[2], test_device()).unwrap();
2317        assert_eq!(t.add_scalar(1.0).unwrap().to_f32_vec().unwrap(), vec![3.0, 5.0]);
2318        assert_eq!(t.mul_scalar(3.0).unwrap().to_f32_vec().unwrap(), vec![6.0, 12.0]);
2319        assert_eq!(t.neg().unwrap().to_f32_vec().unwrap(), vec![-2.0, -4.0]);
2320    }
2321
2322    #[test]
2323    fn test_exp_log_sqrt_abs_pow() {
2324        let t = Tensor::from_f32(&[1.0, 4.0], &[2], test_device()).unwrap();
2325        let e = t.exp().unwrap().to_f32_vec().unwrap();
2326        assert!((e[0] - 1.0_f32.exp()).abs() < 1e-5);
2327
2328        let l = t.log().unwrap().to_f32_vec().unwrap();
2329        assert!((l[1] - 4.0_f32.ln()).abs() < 1e-5);
2330
2331        let s = t.sqrt().unwrap().to_f32_vec().unwrap();
2332        assert!((s[1] - 2.0).abs() < 1e-5);
2333
2334        let a = Tensor::from_f32(&[-3.0, 5.0], &[2], test_device()).unwrap();
2335        assert_eq!(a.abs().unwrap().to_f32_vec().unwrap(), vec![3.0, 5.0]);
2336
2337        let p = t.pow_scalar(2.0).unwrap().to_f32_vec().unwrap();
2338        assert!((p[0] - 1.0).abs() < 1e-5);
2339        assert!((p[1] - 16.0).abs() < 1e-5);
2340    }
2341
2342    #[test]
2343    fn test_clamp() {
2344        let t = Tensor::from_f32(&[-1.0, 0.5, 2.0], &[3], test_device()).unwrap();
2345        let c = t.clamp(0.0, 1.0).unwrap().to_f32_vec().unwrap();
2346        assert_eq!(c, vec![0.0, 0.5, 1.0]);
2347    }
2348
2349    #[test]
2350    fn test_sum_dim_mean_dim() {
2351        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2352        let s = t.sum_dim(1, false).unwrap().to_f32_vec().unwrap();
2353        assert_eq!(s, vec![3.0, 7.0]);
2354
2355        let m = t.mean_dim(0, false).unwrap().to_f32_vec().unwrap();
2356        assert!((m[0] - 2.0).abs() < 1e-5);
2357        assert!((m[1] - 3.0).abs() < 1e-5);
2358    }
2359
2360    #[test]
2361    fn test_norm() {
2362        let t = Tensor::from_f32(&[3.0, 4.0], &[2], test_device()).unwrap();
2363        let n = t.norm().unwrap().item().unwrap();
2364        assert!((n - 5.0).abs() < 1e-5);
2365    }
2366
2367    #[test]
2368    fn test_reshape_transpose_narrow_select() {
2369        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2370        let r = t.reshape(&[3, 2]).unwrap();
2371        assert_eq!(r.shape(), vec![3, 2]);
2372        assert_eq!(r.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2373
2374        let tr = t.transpose(0, 1).unwrap();
2375        assert_eq!(tr.shape(), vec![3, 2]);
2376        assert_eq!(tr.to_f32_vec().unwrap(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
2377
2378        let n = t.narrow(1, 0, 2).unwrap();
2379        assert_eq!(n.shape(), vec![2, 2]);
2380        assert_eq!(n.to_f32_vec().unwrap(), vec![1.0, 2.0, 4.0, 5.0]);
2381
2382        let s = t.select(0, 1).unwrap();
2383        assert_eq!(s.shape(), vec![3]);
2384        assert_eq!(s.to_f32_vec().unwrap(), vec![4.0, 5.0, 6.0]);
2385    }
2386
2387    #[test]
2388    fn test_permute_expand() {
2389        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2390        let p = t.permute(&[1, 0]).unwrap();
2391        assert_eq!(p.shape(), vec![3, 2]);
2392
2393        let s = Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], test_device()).unwrap();
2394        let e = s.expand(&[4, 3]).unwrap();
2395        assert_eq!(e.shape(), vec![4, 3]);
2396        let data = e.to_f32_vec().unwrap();
2397        assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
2398    }
2399
2400    #[test]
2401    fn test_cat_many() {
2402        let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2403        let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
2404        let c = Tensor::from_f32(&[6.0], &[1], test_device()).unwrap();
2405
2406        // Concatenate 3 tensors along dim 0
2407        let result = Tensor::cat_many(&[&a, &b, &c], 0).unwrap();
2408        assert_eq!(result.shape(), vec![6]);
2409        assert_eq!(result.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2410
2411        // 2D: concat along dim 1
2412        let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2413        let y = Tensor::from_f32(&[5.0, 6.0], &[2, 1], test_device()).unwrap();
2414        let z = Tensor::from_f32(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[2, 3], test_device()).unwrap();
2415        let result2 = Tensor::cat_many(&[&x, &y, &z], 1).unwrap();
2416        assert_eq!(result2.shape(), vec![2, 6]);
2417        assert_eq!(
2418            result2.to_f32_vec().unwrap(),
2419            vec![1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0]
2420        );
2421
2422        // Single tensor — should just return a copy
2423        let single = Tensor::cat_many(&[&a], 0).unwrap();
2424        assert_eq!(single.to_f32_vec().unwrap(), vec![1.0, 2.0]);
2425
2426        // Empty list — should error
2427        let empty: Vec<&Tensor> = vec![];
2428        assert!(Tensor::cat_many(&empty, 0).is_err());
2429    }
2430
2431    #[test]
2432    fn test_cat_index_select_index_add() {
2433        let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2434        let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
2435        let c = a.cat(&b, 0).unwrap();
2436        assert_eq!(c.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
2437
2438        let t = Tensor::from_f32(&[10.0, 20.0, 30.0, 40.0, 50.0], &[5], test_device()).unwrap();
2439        let idx = Tensor::from_i64(&[0, 2, 4], &[3], test_device()).unwrap();
2440        let sel = t.index_select(0, &idx).unwrap();
2441        assert_eq!(sel.to_f32_vec().unwrap(), vec![10.0, 30.0, 50.0]);
2442
2443        let base = Tensor::zeros(&[5], test_opts()).unwrap();
2444        let src = Tensor::from_f32(&[1.0, 1.0, 1.0], &[3], test_device()).unwrap();
2445        let r = base.index_add(0, &idx, &src).unwrap();
2446        let data = r.to_f32_vec().unwrap();
2447        assert!((data[0] - 1.0).abs() < 1e-5);
2448        assert!((data[2] - 1.0).abs() < 1e-5);
2449        assert!((data[4] - 1.0).abs() < 1e-5);
2450    }
2451
2452    #[test]
2453    fn test_narrow_scatter_select_scatter() {
2454        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], test_device()).unwrap();
2455        let src = Tensor::from_f32(&[10.0, 20.0], &[2], test_device()).unwrap();
2456        let ns = t.narrow_scatter(&src, 0, 1).unwrap();
2457        assert_eq!(ns.to_f32_vec().unwrap(), vec![1.0, 10.0, 20.0, 4.0]);
2458
2459        let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2460        let row = Tensor::from_f32(&[10.0, 20.0, 30.0], &[3], test_device()).unwrap();
2461        let ss = t2.select_scatter(&row, 0, 0).unwrap();
2462        assert_eq!(ss.to_f32_vec().unwrap(), vec![10.0, 20.0, 30.0, 4.0, 5.0, 6.0]);
2463    }
2464
2465    #[test]
2466    fn test_activations() {
2467        let t = Tensor::from_f32(&[-1.0, 0.0, 1.0], &[3], test_device()).unwrap();
2468        assert_eq!(t.relu().unwrap().to_f32_vec().unwrap(), vec![0.0, 0.0, 1.0]);
2469
2470        let sig = t.sigmoid().unwrap().to_f32_vec().unwrap();
2471        assert!((sig[2] - 0.7310586).abs() < 1e-5);
2472
2473        let th = t.tanh().unwrap().to_f32_vec().unwrap();
2474        assert!((th[2] - 1.0_f32.tanh()).abs() < 1e-5);
2475
2476        // gelu/silu just check they don't crash and return right shape
2477        assert_eq!(t.gelu().unwrap().shape(), vec![3]);
2478        assert_eq!(t.silu().unwrap().shape(), vec![3]);
2479    }
2480
2481    #[test]
2482    fn test_softmax_log_softmax() {
2483        let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2484        let sm = t.softmax(0).unwrap().to_f32_vec().unwrap();
2485        let total: f32 = sm.iter().sum();
2486        assert!((total - 1.0).abs() < 1e-5);
2487        assert!(sm[2] > sm[1] && sm[1] > sm[0]);
2488
2489        let lsm = t.log_softmax(0).unwrap().to_f32_vec().unwrap();
2490        assert!(lsm[0] < 0.0 && lsm[1] < 0.0 && lsm[2] < 0.0);
2491    }
2492
2493    #[test]
2494    fn test_eq_ne_tensor() {
2495        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2496        let b = Tensor::from_f32(&[1.0, 5.0, 3.0], &[3], test_device()).unwrap();
2497
2498        let eq = a.eq_tensor(&b).unwrap().to_f32_vec().unwrap();
2499        assert_eq!(eq, vec![1.0, 0.0, 1.0]);
2500
2501        let ne = a.ne_tensor(&b).unwrap().to_f32_vec().unwrap();
2502        assert_eq!(ne, vec![0.0, 1.0, 0.0]);
2503    }
2504
2505    #[test]
2506    fn test_gt_lt_ge_le_tensor() {
2507        let a = Tensor::from_f32(&[1.0, 3.0, 2.0], &[3], test_device()).unwrap();
2508        let b = Tensor::from_f32(&[2.0, 2.0, 2.0], &[3], test_device()).unwrap();
2509
2510        assert_eq!(a.gt(&b).unwrap().to_f32_vec().unwrap(), vec![0.0, 1.0, 0.0]);
2511        assert_eq!(a.lt(&b).unwrap().to_f32_vec().unwrap(), vec![1.0, 0.0, 0.0]);
2512        assert_eq!(a.ge(&b).unwrap().to_f32_vec().unwrap(), vec![0.0, 1.0, 1.0]);
2513        assert_eq!(a.le(&b).unwrap().to_f32_vec().unwrap(), vec![1.0, 0.0, 1.0]);
2514    }
2515
2516    #[test]
2517    fn test_sign_floor_ceil_round() {
2518        let t = Tensor::from_f32(&[-2.7, 0.0, 1.3], &[3], test_device()).unwrap();
2519        assert_eq!(t.sign().unwrap().to_f32_vec().unwrap(), vec![-1.0, 0.0, 1.0]);
2520        assert_eq!(t.floor().unwrap().to_f32_vec().unwrap(), vec![-3.0, 0.0, 1.0]);
2521        assert_eq!(t.ceil().unwrap().to_f32_vec().unwrap(), vec![-2.0, 0.0, 2.0]);
2522
2523        let r = Tensor::from_f32(&[-0.6, 0.4, 1.5], &[3], test_device()).unwrap();
2524        let rv = r.round().unwrap().to_f32_vec().unwrap();
2525        assert!((rv[0] - (-1.0)).abs() < 1e-5);
2526        assert!((rv[1] - 0.0).abs() < 1e-5);
2527        assert!((rv[2] - 2.0).abs() < 1e-5);
2528    }
2529
2530    #[test]
2531    fn test_argmin() {
2532        let t = Tensor::from_f32(&[3.0, 1.0, 2.0], &[3], test_device()).unwrap();
2533        let idx = t.argmin(0, false).unwrap().to_i64_vec().unwrap();
2534        assert_eq!(idx, vec![1]);
2535    }
2536
2537    #[test]
2538    fn test_var_std() {
2539        let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2540        // Bessel: var = ((1-2)²+(2-2)²+(3-2)²)/2 = 1.0
2541        assert!((t.var().unwrap().item().unwrap() - 1.0).abs() < 1e-5);
2542        assert!((t.std().unwrap().item().unwrap() - 1.0).abs() < 1e-5);
2543
2544        // dim variant: [[1,2],[3,4]] var along dim=1 = [0.5, 0.5]
2545        let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2546        let vd = t2.var_dim(1, false).unwrap().to_f32_vec().unwrap();
2547        assert!((vd[0] - 0.5).abs() < 1e-5);
2548        assert!((vd[1] - 0.5).abs() < 1e-5);
2549    }
2550
2551    #[test]
2552    fn test_sin_cos_reciprocal() {
2553        let t = Tensor::from_f32(&[0.0, 1.0], &[2], test_device()).unwrap();
2554        let s = t.sin().unwrap().to_f32_vec().unwrap();
2555        assert!((s[0] - 0.0).abs() < 1e-5);
2556        assert!((s[1] - 1.0_f32.sin()).abs() < 1e-5);
2557
2558        let c = t.cos().unwrap().to_f32_vec().unwrap();
2559        assert!((c[0] - 1.0).abs() < 1e-5);
2560        assert!((c[1] - 1.0_f32.cos()).abs() < 1e-5);
2561
2562        let r = Tensor::from_f32(&[2.0, 5.0], &[2], test_device()).unwrap();
2563        let rec = r.reciprocal().unwrap().to_f32_vec().unwrap();
2564        assert!((rec[0] - 0.5).abs() < 1e-5);
2565        assert!((rec[1] - 0.2).abs() < 1e-5);
2566    }
2567
2568    #[test]
2569    fn test_eye_full() {
2570        let eye = Tensor::eye(3, test_opts()).unwrap();
2571        assert_eq!(eye.shape(), vec![3, 3]);
2572        let data = eye.to_f32_vec().unwrap();
2573        assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
2574
2575        let f = Tensor::full(&[2, 3], 7.0, test_opts()).unwrap();
2576        assert_eq!(f.shape(), vec![2, 3]);
2577        assert_eq!(f.to_f32_vec().unwrap(), vec![7.0; 6]);
2578    }
2579
2580    #[test]
2581    fn test_gather_scatter_add() {
2582        // gather: pick elements by index
2583        let t = Tensor::from_f32(&[10.0, 20.0, 30.0, 40.0], &[2, 2], test_device()).unwrap();
2584        let idx = Tensor::from_i64(&[1, 0, 0, 1], &[2, 2], test_device()).unwrap();
2585        let g = t.gather(1, &idx).unwrap().to_f32_vec().unwrap();
2586        assert_eq!(g, vec![20.0, 10.0, 30.0, 40.0]);
2587
2588        // scatter_add: accumulate into base at positions
2589        let base = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2590        let src = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2591        let idx2 = Tensor::from_i64(&[0, 2, 1, 0], &[2, 2], test_device()).unwrap();
2592        let sa = base.scatter_add(1, &idx2, &src).unwrap();
2593        let data = sa.to_f32_vec().unwrap();
2594        // Row 0: pos 0 += 1.0, pos 2 += 2.0 → [1, 0, 2]
2595        // Row 1: pos 1 += 3.0, pos 0 += 4.0 → [4, 3, 0]
2596        assert!((data[0] - 1.0).abs() < 1e-5);
2597        assert!((data[2] - 2.0).abs() < 1e-5);
2598        assert!((data[3] - 4.0).abs() < 1e-5);
2599        assert!((data[4] - 3.0).abs() < 1e-5);
2600    }
2601
2602    #[test]
2603    fn test_topk_sort() {
2604        let t = Tensor::from_f32(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], test_device()).unwrap();
2605        let (vals, idxs) = t.topk(3, 0, true, true).unwrap();
2606        assert_eq!(vals.to_f32_vec().unwrap(), vec![5.0, 4.0, 3.0]);
2607        let idx_data = idxs.to_i64_vec().unwrap();
2608        assert_eq!(idx_data, vec![4, 2, 0]);
2609
2610        let (svals, sidxs) = t.sort(0, false).unwrap();
2611        assert_eq!(svals.to_f32_vec().unwrap(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
2612        let si = sidxs.to_i64_vec().unwrap();
2613        assert_eq!(si[4], 4); // 5.0 was at index 4
2614    }
2615
2616    #[test]
2617    fn test_chunk_repeat_pad() {
2618        let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], test_device()).unwrap();
2619        let chunks = t.chunk(3, 0).unwrap();
2620        assert_eq!(chunks.len(), 3);
2621        assert_eq!(chunks[0].to_f32_vec().unwrap(), vec![1.0, 2.0]);
2622        assert_eq!(chunks[1].to_f32_vec().unwrap(), vec![3.0, 4.0]);
2623        assert_eq!(chunks[2].to_f32_vec().unwrap(), vec![5.0, 6.0]);
2624
2625        let s = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2626        let rep = s.repeat(&[3]).unwrap();
2627        assert_eq!(rep.to_f32_vec().unwrap(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
2628
2629        let pad = s.pad(&[1, 2], 0.0).unwrap();
2630        assert_eq!(pad.shape(), vec![5]);
2631        assert_eq!(pad.to_f32_vec().unwrap(), vec![0.0, 1.0, 2.0, 0.0, 0.0]);
2632    }
2633
2634    #[test]
2635    fn test_zeros_like_ones_like() {
2636        let t = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2637        let zl = Tensor::zeros_like(&t).unwrap();
2638        assert_eq!(zl.to_f32_vec().unwrap(), vec![0.0, 0.0]);
2639        assert_eq!(zl.dtype(), DType::Float32);
2640
2641        let ol = Tensor::ones_like(&t).unwrap();
2642        assert_eq!(ol.to_f32_vec().unwrap(), vec![1.0, 1.0]);
2643    }
2644
2645    #[test]
2646    fn test_unsqueeze_many() {
2647        let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2648        let u = t.unsqueeze_many(&[1, 2]).unwrap();
2649        assert_eq!(u.shape(), vec![3, 1, 1]);
2650        // Should match sequential unsqueeze
2651        let u2 = t.unsqueeze(1).unwrap().unsqueeze(2).unwrap();
2652        assert_eq!(u.shape(), u2.shape());
2653        assert_eq!(u.to_f32_vec().unwrap(), u2.to_f32_vec().unwrap());
2654    }
2655
2656    #[test]
2657    fn test_meshgrid() {
2658        let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2659        let b = Tensor::from_f32(&[4.0, 5.0], &[2], test_device()).unwrap();
2660        let grids = Tensor::meshgrid(&[&a, &b]).unwrap();
2661        assert_eq!(grids.len(), 2);
2662        assert_eq!(grids[0].shape(), vec![3, 2]);
2663        assert_eq!(grids[1].shape(), vec![3, 2]);
2664        // Grid 0: rows repeat a values
2665        assert_eq!(grids[0].to_f32_vec().unwrap(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
2666        // Grid 1: cols repeat b values
2667        assert_eq!(grids[1].to_f32_vec().unwrap(), vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
2668    }
2669
2670    #[test]
2671    fn test_cdist() {
2672        // Two 2D points: [0,0] and [3,4] -> distance = 5
2673        let x = Tensor::from_f32(&[0.0, 0.0], &[1, 1, 2], test_device()).unwrap();
2674        let y = Tensor::from_f32(&[3.0, 4.0], &[1, 1, 2], test_device()).unwrap();
2675        let d = x.cdist(&y).unwrap();
2676        assert_eq!(d.shape(), vec![1, 1, 1]);
2677        assert!((d.item().unwrap() - 5.0).abs() < 1e-4);
2678    }
2679
2680    #[test]
2681    fn test_cdist_p1() {
2682        // L1: |3| + |4| = 7
2683        let x = Tensor::from_f32(&[0.0, 0.0], &[1, 1, 2], test_device()).unwrap();
2684        let y = Tensor::from_f32(&[3.0, 4.0], &[1, 1, 2], test_device()).unwrap();
2685        let d = x.cdist_p(&y, 1.0).unwrap();
2686        assert!((d.item().unwrap() - 7.0).abs() < 1e-4);
2687    }
2688
2689    #[test]
2690    fn test_from_i64_device() {
2691        let t = Tensor::from_i64(&[1, 2, 3], &[3], test_device()).unwrap();
2692        assert_eq!(t.device(), test_device());
2693        assert_eq!(t.dtype(), DType::Int64);
2694        assert_eq!(t.to_i64_vec().unwrap(), vec![1, 2, 3]);
2695    }
2696
2697    #[test]
2698    fn test_pin_memory() {
2699        let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], Device::CPU).unwrap();
2700        assert!(!t.is_pinned(), "regular CPU tensor should not be pinned");
2701
2702        if cuda_available() {
2703            let pinned = t.pin_memory().unwrap();
2704            assert!(pinned.is_pinned(), "pin_memory() result should be pinned");
2705            assert_eq!(pinned.device(), Device::CPU, "pinned tensor should stay on CPU");
2706            assert_eq!(pinned.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0],
2707                "data should be preserved after pinning");
2708        } else {
2709            // pin_memory requires CUDA — verify it returns an error on CPU-only
2710            assert!(t.pin_memory().is_err(),
2711                "pin_memory should fail without CUDA");
2712        }
2713    }
2714
2715    #[test]
2716    fn test_adam_step_basic() {
2717        // Basic smoke test for the fused adam_step at tensor level
2718        let param = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2719        let grad = Tensor::from_f32(&[0.5, 0.5], &[2], test_device()).unwrap();
2720        let m = Tensor::zeros(&[2], test_opts()).unwrap();
2721        let v = Tensor::zeros(&[2], test_opts()).unwrap();
2722
2723        param.adam_step(&grad, &m, &v, 0.001, 0.9, 0.999, 1e-8, 0.0, 1).unwrap();
2724
2725        let p = param.to_f32_vec().unwrap();
2726        assert!(p[0] < 1.0, "param[0] should decrease");
2727        assert!(p[1] < 2.0, "param[1] should decrease");
2728        // m and v should be non-zero after the step
2729        let m_data = m.to_f32_vec().unwrap();
2730        let v_data = v.to_f32_vec().unwrap();
2731        assert!(m_data[0] > 0.0, "m should be updated");
2732        assert!(v_data[0] > 0.0, "v should be updated");
2733    }
2734
2735    // --- Device model tests ---
2736
2737    #[test]
2738    fn test_device_enum_basics() {
2739        assert_eq!(Device::CPU, Device::CPU);
2740        assert_eq!(Device::CUDA(0), Device::CUDA(0));
2741        assert_ne!(Device::CUDA(0), Device::CUDA(1));
2742        assert_ne!(Device::CPU, Device::CUDA(0));
2743
2744        assert!(!Device::CPU.is_cuda());
2745        assert!(Device::CUDA(0).is_cuda());
2746        assert!(Device::CUDA(1).is_cuda());
2747
2748        assert_eq!(Device::CPU.index(), 0);
2749        assert_eq!(Device::CUDA(0).index(), 0);
2750        assert_eq!(Device::CUDA(1).index(), 1);
2751    }
2752
2753    #[test]
2754    fn test_device_display() {
2755        assert_eq!(format!("{}", Device::CPU), "cpu");
2756        assert_eq!(format!("{}", Device::CUDA(0)), "cuda");
2757        assert_eq!(format!("{}", Device::CUDA(1)), "cuda:1");
2758    }
2759
2760    #[test]
2761    fn test_device_ffi_roundtrip() {
2762        let devices = [Device::CPU, Device::CUDA(0), Device::CUDA(1), Device::CUDA(7)];
2763        for dev in &devices {
2764            let (dt, di) = dev.to_ffi();
2765            let back = Device::from_ffi(dt, di);
2766            assert_eq!(*dev, back, "FFI roundtrip failed for {:?}", dev);
2767        }
2768    }
2769
2770    #[test]
2771    fn test_device_hash() {
2772        use std::collections::HashSet;
2773        let mut set = HashSet::new();
2774        set.insert(Device::CPU);
2775        set.insert(Device::CUDA(0));
2776        set.insert(Device::CUDA(1));
2777        assert_eq!(set.len(), 3);
2778        assert!(set.contains(&Device::CPU));
2779        assert!(set.contains(&Device::CUDA(0)));
2780        assert!(set.contains(&Device::CUDA(1)));
2781    }
2782
2783    // --- Send + Sync compile-time checks ---
2784
2785    #[test]
2786    fn test_tensor_is_send_sync() {
2787        fn assert_send_sync<T: Send + Sync>() {}
2788        assert_send_sync::<Tensor>();
2789    }
2790}