1use std::ffi::{c_void, CStr};
13use std::fmt;
14use std::ptr;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17use flodl_sys::{self as ffi, FlodlTensor};
18
19static LIVE_TENSOR_COUNT: AtomicU64 = AtomicU64::new(0);
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30#[repr(i32)]
31pub enum DType {
32 Float16 = ffi::FLODL_FLOAT16,
33 BFloat16 = ffi::FLODL_BFLOAT16,
34 Float32 = ffi::FLODL_FLOAT32,
35 Float64 = ffi::FLODL_FLOAT64,
36 Int32 = ffi::FLODL_INT32,
37 Int64 = ffi::FLODL_INT64,
38}
39
40impl DType {
41 fn from_raw(v: i32) -> Self {
42 match v {
43 ffi::FLODL_FLOAT16 => DType::Float16,
44 ffi::FLODL_BFLOAT16 => DType::BFloat16,
45 ffi::FLODL_FLOAT32 => DType::Float32,
46 ffi::FLODL_FLOAT64 => DType::Float64,
47 ffi::FLODL_INT32 => DType::Int32,
48 ffi::FLODL_INT64 => DType::Int64,
49 _ => DType::Float32,
50 }
51 }
52
53 pub fn element_size(self) -> usize {
55 match self {
56 DType::Float16 | DType::BFloat16 => 2,
57 DType::Float32 | DType::Int32 => 4,
58 DType::Float64 | DType::Int64 => 8,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub enum Device {
69 CPU,
70 CUDA(u8),
71}
72
73impl Device {
74 pub(crate) fn to_ffi(self) -> (i32, i32) {
76 match self {
77 Device::CPU => (ffi::FLODL_CPU, 0),
78 Device::CUDA(idx) => (ffi::FLODL_CUDA, idx as i32),
79 }
80 }
81
82 pub(crate) fn from_ffi(device_type: i32, device_index: i32) -> Self {
84 match device_type {
85 ffi::FLODL_CUDA => Device::CUDA(device_index as u8),
86 _ => Device::CPU,
87 }
88 }
89
90 pub fn is_cuda(&self) -> bool {
92 matches!(self, Device::CUDA(_))
93 }
94
95 pub fn index(&self) -> u8 {
97 match self {
98 Device::CPU => 0,
99 Device::CUDA(idx) => *idx,
100 }
101 }
102}
103
104impl fmt::Display for Device {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Device::CPU => write!(f, "cpu"),
108 Device::CUDA(0) => write!(f, "cuda"),
109 Device::CUDA(idx) => write!(f, "cuda:{}", idx),
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct TensorError(String);
117
118impl TensorError {
119 pub fn new(msg: &str) -> Self {
120 TensorError(msg.to_string())
121 }
122}
123
124impl fmt::Display for TensorError {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "{}", self.0)
127 }
128}
129
130impl std::error::Error for TensorError {}
131
132pub type Result<T> = std::result::Result<T, TensorError>;
133
134pub(crate) fn check_err(err: *mut i8) -> Result<()> {
136 if err.is_null() {
137 Ok(())
138 } else {
139 let msg = unsafe { CStr::from_ptr(err) }
140 .to_string_lossy()
141 .into_owned();
142 unsafe { ffi::flodl_free_string(err) };
143 Err(TensorError(msg))
144 }
145}
146
147#[derive(Debug, Clone, Copy)]
149pub struct TensorOptions {
150 pub dtype: DType,
151 pub device: Device,
152}
153
154impl Default for TensorOptions {
155 fn default() -> Self {
156 Self {
157 dtype: DType::Float32,
158 device: Device::CPU,
159 }
160 }
161}
162
163pub struct Tensor {
175 handle: FlodlTensor,
176}
177
178unsafe impl Send for Tensor {}
182unsafe impl Sync for Tensor {}
183
184impl Drop for Tensor {
185 fn drop(&mut self) {
186 if !self.handle.is_null() {
187 LIVE_TENSOR_COUNT.fetch_sub(1, Ordering::Relaxed);
188 unsafe { ffi::flodl_free_tensor(self.handle) };
189 }
190 }
191}
192
193impl Clone for Tensor {
194 fn clone(&self) -> Self {
198 let mut handle: FlodlTensor = ptr::null_mut();
199 let err = unsafe { ffi::flodl_shallow_clone(self.handle, &mut handle) };
200 if !err.is_null() {
201 let msg = unsafe { CStr::from_ptr(err) }
202 .to_string_lossy()
203 .into_owned();
204 unsafe { ffi::flodl_free_string(err) };
205 panic!("tensor clone failed: {}", msg);
206 }
207 Self::from_raw(handle)
208 }
209}
210
211impl Tensor {
212 fn from_raw(handle: FlodlTensor) -> Self {
214 debug_assert!(!handle.is_null());
215 LIVE_TENSOR_COUNT.fetch_add(1, Ordering::Relaxed);
216 Self { handle }
217 }
218
219 pub(crate) unsafe fn from_raw_handle(handle: FlodlTensor) -> Self {
224 Self::from_raw(handle)
225 }
226
227 pub(crate) fn raw(&self) -> FlodlTensor {
229 self.handle
230 }
231
232 pub fn zeros(shape: &[i64], opts: TensorOptions) -> Result<Self> {
241 let mut shape = shape.to_vec();
242 let mut handle: FlodlTensor = ptr::null_mut();
243 let (dt, di) = opts.device.to_ffi();
244 let err = unsafe {
245 ffi::flodl_zeros(
246 shape.as_mut_ptr(),
247 shape.len() as i32,
248 opts.dtype as i32,
249 dt, di,
250 &mut handle,
251 )
252 };
253 check_err(err)?;
254 Ok(Self::from_raw(handle))
255 }
256
257 pub fn ones(shape: &[i64], opts: TensorOptions) -> Result<Self> {
263 let mut shape = shape.to_vec();
264 let mut handle: FlodlTensor = ptr::null_mut();
265 let (dt, di) = opts.device.to_ffi();
266 let err = unsafe {
267 ffi::flodl_ones(
268 shape.as_mut_ptr(),
269 shape.len() as i32,
270 opts.dtype as i32,
271 dt, di,
272 &mut handle,
273 )
274 };
275 check_err(err)?;
276 Ok(Self::from_raw(handle))
277 }
278
279 pub fn from_f32(data: &[f32], shape: &[i64], device: Device) -> Result<Self> {
286 let mut shape = shape.to_vec();
287 let mut handle: FlodlTensor = ptr::null_mut();
288 let (dt, di) = device.to_ffi();
289 let err = unsafe {
290 ffi::flodl_from_blob(
291 data.as_ptr() as *mut c_void,
292 shape.as_mut_ptr(),
293 shape.len() as i32,
294 DType::Float32 as i32,
295 dt, di,
296 &mut handle,
297 )
298 };
299 check_err(err)?;
300 Ok(Self::from_raw(handle))
301 }
302
303 pub fn from_f64(data: &[f64], shape: &[i64], device: Device) -> Result<Self> {
306 let mut shape = shape.to_vec();
307 let mut handle: FlodlTensor = ptr::null_mut();
308 let (dt, di) = device.to_ffi();
309 let err = unsafe {
310 ffi::flodl_from_blob(
311 data.as_ptr() as *mut c_void,
312 shape.as_mut_ptr(),
313 shape.len() as i32,
314 DType::Float64 as i32,
315 dt, di,
316 &mut handle,
317 )
318 };
319 check_err(err)?;
320 Ok(Self::from_raw(handle))
321 }
322
323 pub fn from_i64(data: &[i64], shape: &[i64], device: Device) -> Result<Self> {
326 let mut shape = shape.to_vec();
327 let mut handle: FlodlTensor = ptr::null_mut();
328 let (dt, di) = device.to_ffi();
329 let err = unsafe {
330 ffi::flodl_from_blob(
331 data.as_ptr() as *mut c_void,
332 shape.as_mut_ptr(),
333 shape.len() as i32,
334 DType::Int64 as i32,
335 dt, di,
336 &mut handle,
337 )
338 };
339 check_err(err)?;
340 Ok(Self::from_raw(handle))
341 }
342
343 pub fn ndim(&self) -> usize {
347 unsafe { ffi::flodl_ndim(self.handle) as usize }
348 }
349
350 pub fn shape(&self) -> Vec<i64> {
352 let n = self.ndim();
353 (0..n)
354 .map(|i| unsafe { ffi::flodl_shape(self.handle, i as i32) })
355 .collect()
356 }
357
358 pub fn numel(&self) -> i64 {
360 unsafe { ffi::flodl_numel(self.handle) }
361 }
362
363 pub fn dtype(&self) -> DType {
365 DType::from_raw(unsafe { ffi::flodl_dtype(self.handle) })
366 }
367
368 pub fn device(&self) -> Device {
370 let dt = unsafe { ffi::flodl_device_type(self.handle) };
371 let di = unsafe { ffi::flodl_device_index(self.handle) };
372 Device::from_ffi(dt, di)
373 }
374
375 pub fn to_f32_vec(&self) -> Result<Vec<f32>> {
380 let n = self.numel() as usize;
381 let mut buf = vec![0f32; n];
382 let bytes = (n * 4) as i64;
383 let err = unsafe {
384 ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
385 };
386 check_err(err)?;
387 Ok(buf)
388 }
389
390 pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
394 if self.dtype() == DType::Float64 {
395 let n = self.numel() as usize;
396 let mut buf = vec![0.0f64; n];
397 let bytes = (n * 8) as i64;
398 let err = unsafe {
399 ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
400 };
401 check_err(err)?;
402 Ok(buf)
403 } else {
404 let f32s = self.to_f32_vec()?;
405 Ok(f32s.into_iter().map(|v| v as f64).collect())
406 }
407 }
408
409 pub fn to_i64_vec(&self) -> Result<Vec<i64>> {
412 let n = self.numel() as usize;
413 let mut buf = vec![0i64; n];
414 let bytes = (n * 8) as i64;
415 let err = unsafe {
416 ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, bytes)
417 };
418 check_err(err)?;
419 Ok(buf)
420 }
421
422 pub fn item(&self) -> Result<f64> {
433 if self.numel() != 1 {
434 return Err(TensorError::new(&format!(
435 "item() requires exactly 1 element, got {} (shape {:?})",
436 self.numel(), self.shape()
437 )));
438 }
439 if self.dtype() == DType::Float64 {
440 let mut buf = [0.0f64; 1];
441 let err = unsafe {
442 ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 8)
443 };
444 check_err(err)?;
445 Ok(buf[0])
446 } else {
447 let mut buf = [0.0f32; 1];
448 let err = unsafe {
449 ffi::flodl_copy_data(self.handle, buf.as_mut_ptr() as *mut c_void, 4)
450 };
451 check_err(err)?;
452 Ok(buf[0] as f64)
453 }
454 }
455
456 pub fn add(&self, other: &Tensor) -> Result<Tensor> {
464 let mut handle: FlodlTensor = ptr::null_mut();
465 let err = unsafe { ffi::flodl_add(self.handle, other.handle, &mut handle) };
466 check_err(err)?;
467 Ok(Tensor::from_raw(handle))
468 }
469
470 pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
472 let mut handle: FlodlTensor = ptr::null_mut();
473 let err = unsafe { ffi::flodl_sub(self.handle, other.handle, &mut handle) };
474 check_err(err)?;
475 Ok(Tensor::from_raw(handle))
476 }
477
478 pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
481 let mut handle: FlodlTensor = ptr::null_mut();
482 let err = unsafe { ffi::flodl_mul(self.handle, other.handle, &mut handle) };
483 check_err(err)?;
484 Ok(Tensor::from_raw(handle))
485 }
486
487 pub fn matmul(&self, other: &Tensor) -> Result<Tensor> {
494 let mut handle: FlodlTensor = ptr::null_mut();
495 let err = unsafe { ffi::flodl_matmul(self.handle, other.handle, &mut handle) };
496 check_err(err)?;
497 Ok(Tensor::from_raw(handle))
498 }
499
500 pub fn mul_scalar(&self, scalar: f64) -> Result<Tensor> {
502 let mut handle: FlodlTensor = ptr::null_mut();
503 let err = unsafe { ffi::flodl_mul_scalar(self.handle, scalar, &mut handle) };
504 check_err(err)?;
505 Ok(Tensor::from_raw(handle))
506 }
507
508 pub fn relu(&self) -> Result<Tensor> {
512 let mut handle: FlodlTensor = ptr::null_mut();
513 let err = unsafe { ffi::flodl_relu(self.handle, &mut handle) };
514 check_err(err)?;
515 Ok(Tensor::from_raw(handle))
516 }
517
518 pub fn sigmoid(&self) -> Result<Tensor> {
520 let mut handle: FlodlTensor = ptr::null_mut();
521 let err = unsafe { ffi::flodl_sigmoid(self.handle, &mut handle) };
522 check_err(err)?;
523 Ok(Tensor::from_raw(handle))
524 }
525
526 pub fn sum(&self) -> Result<Tensor> {
530 let mut handle: FlodlTensor = ptr::null_mut();
531 let err = unsafe { ffi::flodl_sum(self.handle, &mut handle) };
532 check_err(err)?;
533 Ok(Tensor::from_raw(handle))
534 }
535
536 pub fn mean(&self) -> Result<Tensor> {
538 let mut handle: FlodlTensor = ptr::null_mut();
539 let err = unsafe { ffi::flodl_mean(self.handle, &mut handle) };
540 check_err(err)?;
541 Ok(Tensor::from_raw(handle))
542 }
543
544 pub fn flatten(&self, start_dim: i32, end_dim: i32) -> Result<Tensor> {
546 let mut handle: FlodlTensor = ptr::null_mut();
547 let err = unsafe { ffi::flodl_flatten(self.handle, start_dim, end_dim, &mut handle) };
548 check_err(err)?;
549 Ok(Tensor::from_raw(handle))
550 }
551
552 pub fn div(&self, other: &Tensor) -> Result<Tensor> {
556 let mut handle: FlodlTensor = ptr::null_mut();
557 let err = unsafe { ffi::flodl_div(self.handle, other.handle, &mut handle) };
558 check_err(err)?;
559 Ok(Tensor::from_raw(handle))
560 }
561
562 pub fn neg(&self) -> Result<Tensor> {
564 let mut handle: FlodlTensor = ptr::null_mut();
565 let err = unsafe { ffi::flodl_neg(self.handle, &mut handle) };
566 check_err(err)?;
567 Ok(Tensor::from_raw(handle))
568 }
569
570 pub fn add_scalar(&self, scalar: f64) -> Result<Tensor> {
572 let mut handle: FlodlTensor = ptr::null_mut();
573 let err = unsafe { ffi::flodl_add_scalar(self.handle, scalar, &mut handle) };
574 check_err(err)?;
575 Ok(Tensor::from_raw(handle))
576 }
577
578 pub fn div_scalar(&self, scalar: f64) -> Result<Tensor> {
580 let mut handle: FlodlTensor = ptr::null_mut();
581 let err = unsafe { ffi::flodl_div_scalar(self.handle, scalar, &mut handle) };
582 check_err(err)?;
583 Ok(Tensor::from_raw(handle))
584 }
585
586 pub fn tanh(&self) -> Result<Tensor> {
590 let mut handle: FlodlTensor = ptr::null_mut();
591 let err = unsafe { ffi::flodl_tanh_op(self.handle, &mut handle) };
592 check_err(err)?;
593 Ok(Tensor::from_raw(handle))
594 }
595
596 pub fn exp(&self) -> Result<Tensor> {
600 let mut handle: FlodlTensor = ptr::null_mut();
601 let err = unsafe { ffi::flodl_exp(self.handle, &mut handle) };
602 check_err(err)?;
603 Ok(Tensor::from_raw(handle))
604 }
605
606 pub fn log(&self) -> Result<Tensor> {
608 let mut handle: FlodlTensor = ptr::null_mut();
609 let err = unsafe { ffi::flodl_log(self.handle, &mut handle) };
610 check_err(err)?;
611 Ok(Tensor::from_raw(handle))
612 }
613
614 pub fn sqrt(&self) -> Result<Tensor> {
616 let mut handle: FlodlTensor = ptr::null_mut();
617 let err = unsafe { ffi::flodl_sqrt(self.handle, &mut handle) };
618 check_err(err)?;
619 Ok(Tensor::from_raw(handle))
620 }
621
622 pub fn abs(&self) -> Result<Tensor> {
624 let mut handle: FlodlTensor = ptr::null_mut();
625 let err = unsafe { ffi::flodl_abs(self.handle, &mut handle) };
626 check_err(err)?;
627 Ok(Tensor::from_raw(handle))
628 }
629
630 pub fn triu(&self, diagonal: i64) -> Result<Tensor> {
634 let mut handle: FlodlTensor = ptr::null_mut();
635 let err = unsafe { ffi::flodl_triu(self.handle, diagonal, &mut handle) };
636 check_err(err)?;
637 Ok(Tensor::from_raw(handle))
638 }
639
640 pub fn pow_scalar(&self, exponent: f64) -> Result<Tensor> {
642 let mut handle: FlodlTensor = ptr::null_mut();
643 let err = unsafe { ffi::flodl_pow_scalar(self.handle, exponent, &mut handle) };
644 check_err(err)?;
645 Ok(Tensor::from_raw(handle))
646 }
647
648 pub fn sum_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
652 let mut handle: FlodlTensor = ptr::null_mut();
653 let err = unsafe {
654 ffi::flodl_sum_dim(self.handle, dim, keepdim as i32, &mut handle)
655 };
656 check_err(err)?;
657 Ok(Tensor::from_raw(handle))
658 }
659
660 pub fn clamp(&self, min: f64, max: f64) -> Result<Tensor> {
662 let mut handle: FlodlTensor = ptr::null_mut();
663 let err = unsafe { ffi::flodl_clamp(self.handle, min, max, &mut handle) };
664 check_err(err)?;
665 Ok(Tensor::from_raw(handle))
666 }
667
668 pub fn gt_scalar(&self, scalar: f64) -> Result<Tensor> {
672 let mut handle: FlodlTensor = ptr::null_mut();
673 let err = unsafe { ffi::flodl_gt_scalar(self.handle, scalar, &mut handle) };
674 check_err(err)?;
675 Ok(Tensor::from_raw(handle))
676 }
677
678 pub fn reshape(&self, shape: &[i64]) -> Result<Tensor> {
687 let mut shape = shape.to_vec();
688 let mut handle: FlodlTensor = ptr::null_mut();
689 let err = unsafe {
690 ffi::flodl_reshape(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
691 };
692 check_err(err)?;
693 Ok(Tensor::from_raw(handle))
694 }
695
696 pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Tensor> {
702 let mut handle: FlodlTensor = ptr::null_mut();
703 let err = unsafe { ffi::flodl_transpose(self.handle, dim0, dim1, &mut handle) };
704 check_err(err)?;
705 Ok(Tensor::from_raw(handle))
706 }
707
708 pub fn expand(&self, shape: &[i64]) -> Result<Tensor> {
710 let mut shape = shape.to_vec();
711 let mut handle: FlodlTensor = ptr::null_mut();
712 let err = unsafe {
713 ffi::flodl_expand(self.handle, shape.as_mut_ptr(), shape.len() as i32, &mut handle)
714 };
715 check_err(err)?;
716 Ok(Tensor::from_raw(handle))
717 }
718
719 pub fn narrow(&self, dim: i32, start: i64, length: i64) -> Result<Tensor> {
723 let mut handle: FlodlTensor = ptr::null_mut();
724 let err = unsafe {
725 ffi::flodl_narrow(self.handle, dim, start, length, &mut handle)
726 };
727 check_err(err)?;
728 Ok(Tensor::from_raw(handle))
729 }
730
731 pub fn narrow_scatter(&self, src: &Tensor, dim: i32, start: i64) -> Result<Tensor> {
733 let mut handle: FlodlTensor = ptr::null_mut();
734 let err = unsafe {
735 ffi::flodl_narrow_scatter(self.handle, src.handle, dim, start, &mut handle)
736 };
737 check_err(err)?;
738 Ok(Tensor::from_raw(handle))
739 }
740
741 pub fn cat(&self, other: &Tensor, dim: i32) -> Result<Tensor> {
743 let mut handle: FlodlTensor = ptr::null_mut();
744 let err = unsafe { ffi::flodl_cat2(self.handle, other.handle, dim, &mut handle) };
745 check_err(err)?;
746 Ok(Tensor::from_raw(handle))
747 }
748
749 pub fn cat_many(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
754 if tensors.is_empty() {
755 return Err(TensorError::new("cat_many: empty tensor list"));
756 }
757 let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
758 let mut result: FlodlTensor = ptr::null_mut();
759 let err = unsafe {
760 ffi::flodl_cat(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
761 };
762 check_err(err)?;
763 Ok(Tensor::from_raw(result))
764 }
765
766 pub fn stack(tensors: &[&Tensor], dim: i32) -> Result<Tensor> {
770 if tensors.is_empty() {
771 return Err(TensorError::new("stack: empty tensor list"));
772 }
773 let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
774 let mut result: FlodlTensor = ptr::null_mut();
775 let err = unsafe {
776 ffi::flodl_stack(handles.as_mut_ptr(), handles.len() as i32, dim, &mut result)
777 };
778 check_err(err)?;
779 Ok(Tensor::from_raw(result))
780 }
781
782 pub fn softmax(&self, dim: i32) -> Result<Tensor> {
784 let mut handle: FlodlTensor = ptr::null_mut();
785 let err = unsafe { ffi::flodl_softmax(self.handle, dim, &mut handle) };
786 check_err(err)?;
787 Ok(Tensor::from_raw(handle))
788 }
789
790 pub fn log_softmax(&self, dim: i32) -> Result<Tensor> {
792 let mut handle: FlodlTensor = ptr::null_mut();
793 let err = unsafe { ffi::flodl_log_softmax(self.handle, dim, &mut handle) };
794 check_err(err)?;
795 Ok(Tensor::from_raw(handle))
796 }
797
798 pub fn gelu(&self) -> Result<Tensor> {
800 let mut handle: FlodlTensor = ptr::null_mut();
801 let err = unsafe { ffi::flodl_gelu(self.handle, &mut handle) };
802 check_err(err)?;
803 Ok(Tensor::from_raw(handle))
804 }
805
806 pub fn silu(&self) -> Result<Tensor> {
808 let mut handle: FlodlTensor = ptr::null_mut();
809 let err = unsafe { ffi::flodl_silu(self.handle, &mut handle) };
810 check_err(err)?;
811 Ok(Tensor::from_raw(handle))
812 }
813
814 pub fn native_layer_norm(
816 &self, weight: &Tensor, bias: &Tensor, normalized_size: i64, eps: f64,
817 ) -> Result<(Tensor, Tensor, Tensor)> {
818 let mut out: FlodlTensor = ptr::null_mut();
819 let mut mean: FlodlTensor = ptr::null_mut();
820 let mut rstd: FlodlTensor = ptr::null_mut();
821 let err = unsafe {
822 ffi::flodl_native_layer_norm(
823 self.handle, weight.handle, bias.handle,
824 normalized_size, eps,
825 &mut out, &mut mean, &mut rstd,
826 )
827 };
828 check_err(err)?;
829 Ok((Tensor::from_raw(out), Tensor::from_raw(mean), Tensor::from_raw(rstd)))
830 }
831
832 pub fn permute(&self, dims: &[i64]) -> Result<Tensor> {
834 let mut dims = dims.to_vec();
835 let mut handle: FlodlTensor = ptr::null_mut();
836 let err = unsafe {
837 ffi::flodl_permute(self.handle, dims.as_mut_ptr(), dims.len() as i32, &mut handle)
838 };
839 check_err(err)?;
840 Ok(Tensor::from_raw(handle))
841 }
842
843 pub fn select(&self, dim: i32, index: i64) -> Result<Tensor> {
845 let mut handle: FlodlTensor = ptr::null_mut();
846 let err = unsafe { ffi::flodl_select(self.handle, dim, index, &mut handle) };
847 check_err(err)?;
848 Ok(Tensor::from_raw(handle))
849 }
850
851 pub fn mean_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
853 let mut handle: FlodlTensor = ptr::null_mut();
854 let err = unsafe {
855 ffi::flodl_mean_dim(self.handle, dim, keepdim as i32, &mut handle)
856 };
857 check_err(err)?;
858 Ok(Tensor::from_raw(handle))
859 }
860
861 pub fn index_select(&self, dim: i32, index: &Tensor) -> Result<Tensor> {
863 let mut handle: FlodlTensor = ptr::null_mut();
864 let err = unsafe {
865 ffi::flodl_index_select(self.handle, dim, index.handle, &mut handle)
866 };
867 check_err(err)?;
868 Ok(Tensor::from_raw(handle))
869 }
870
871 pub fn index_add(&self, dim: i32, index: &Tensor, src: &Tensor) -> Result<Tensor> {
873 let mut handle: FlodlTensor = ptr::null_mut();
874 let err = unsafe {
875 ffi::flodl_index_add(self.handle, dim, index.handle, src.handle, &mut handle)
876 };
877 check_err(err)?;
878 Ok(Tensor::from_raw(handle))
879 }
880
881 pub fn zeros_like(t: &Tensor) -> Result<Tensor> {
885 let mut handle: FlodlTensor = ptr::null_mut();
886 let err = unsafe { ffi::flodl_zeros_like(t.handle, &mut handle) };
887 check_err(err)?;
888 Ok(Tensor::from_raw(handle))
889 }
890
891 pub fn ones_like(t: &Tensor) -> Result<Tensor> {
893 let mut handle: FlodlTensor = ptr::null_mut();
894 let err = unsafe { ffi::flodl_ones_like(t.handle, &mut handle) };
895 check_err(err)?;
896 Ok(Tensor::from_raw(handle))
897 }
898
899 pub fn rand(shape: &[i64], opts: TensorOptions) -> Result<Self> {
903 let mut shape = shape.to_vec();
904 let mut handle: FlodlTensor = ptr::null_mut();
905 let (dt, di) = opts.device.to_ffi();
906 let err = unsafe {
907 ffi::flodl_rand(
908 shape.as_mut_ptr(), shape.len() as i32,
909 opts.dtype as i32, dt, di,
910 &mut handle,
911 )
912 };
913 check_err(err)?;
914 Ok(Self::from_raw(handle))
915 }
916
917 pub fn randn(shape: &[i64], opts: TensorOptions) -> Result<Self> {
919 let mut shape = shape.to_vec();
920 let mut handle: FlodlTensor = ptr::null_mut();
921 let (dt, di) = opts.device.to_ffi();
922 let err = unsafe {
923 ffi::flodl_randn(
924 shape.as_mut_ptr(), shape.len() as i32,
925 opts.dtype as i32, dt, di,
926 &mut handle,
927 )
928 };
929 check_err(err)?;
930 Ok(Self::from_raw(handle))
931 }
932
933 #[allow(clippy::too_many_arguments)]
937 pub fn conv2d(
938 &self, weight: &Tensor, bias: Option<&Tensor>,
939 stride: [i64; 2], padding: [i64; 2], dilation: [i64; 2], groups: i64,
940 ) -> Result<Tensor> {
941 let mut handle: FlodlTensor = ptr::null_mut();
942 let mut stride = stride;
943 let mut padding = padding;
944 let mut dilation = dilation;
945 let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
946 let err = unsafe {
947 ffi::flodl_conv2d(
948 self.handle, weight.handle, bias_handle,
949 stride.as_mut_ptr(), padding.as_mut_ptr(), dilation.as_mut_ptr(),
950 groups, &mut handle,
951 )
952 };
953 check_err(err)?;
954 Ok(Tensor::from_raw(handle))
955 }
956
957 #[allow(clippy::too_many_arguments)]
959 pub fn conv_transpose2d(
960 &self, weight: &Tensor, bias: Option<&Tensor>,
961 stride: [i64; 2], padding: [i64; 2], output_padding: [i64; 2],
962 dilation: [i64; 2], groups: i64,
963 ) -> Result<Tensor> {
964 let mut handle: FlodlTensor = ptr::null_mut();
965 let mut stride = stride;
966 let mut padding = padding;
967 let mut output_padding = output_padding;
968 let mut dilation = dilation;
969 let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
970 let err = unsafe {
971 ffi::flodl_conv_transpose2d(
972 self.handle, weight.handle, bias_handle,
973 stride.as_mut_ptr(), padding.as_mut_ptr(),
974 output_padding.as_mut_ptr(), dilation.as_mut_ptr(),
975 groups, &mut handle,
976 )
977 };
978 check_err(err)?;
979 Ok(Tensor::from_raw(handle))
980 }
981
982 pub fn linear(&self, weight: &Tensor, bias: Option<&Tensor>) -> Result<Tensor> {
986 let mut handle: FlodlTensor = ptr::null_mut();
987 let bias_handle = bias.map_or(ptr::null_mut(), |b| b.handle);
988 let err = unsafe {
989 ffi::flodl_linear(self.handle, weight.handle, bias_handle, &mut handle)
990 };
991 check_err(err)?;
992 Ok(Tensor::from_raw(handle))
993 }
994
995 #[allow(clippy::too_many_arguments)]
998 pub fn gru_cell(
999 &self, hx: &Tensor,
1000 w_ih: &Tensor, w_hh: &Tensor,
1001 b_ih: &Tensor, b_hh: &Tensor,
1002 ) -> Result<Tensor> {
1003 let mut handle: FlodlTensor = ptr::null_mut();
1004 let err = unsafe {
1005 ffi::flodl_gru_cell(
1006 self.handle, hx.handle,
1007 w_ih.handle, w_hh.handle,
1008 b_ih.handle, b_hh.handle,
1009 &mut handle,
1010 )
1011 };
1012 check_err(err)?;
1013 Ok(Tensor::from_raw(handle))
1014 }
1015
1016 #[allow(clippy::too_many_arguments)]
1019 pub fn lstm_cell(
1020 &self, hx: &Tensor, cx: &Tensor,
1021 w_ih: &Tensor, w_hh: &Tensor,
1022 b_ih: &Tensor, b_hh: &Tensor,
1023 ) -> Result<(Tensor, Tensor)> {
1024 let mut h_out: FlodlTensor = ptr::null_mut();
1025 let mut c_out: FlodlTensor = ptr::null_mut();
1026 let err = unsafe {
1027 ffi::flodl_lstm_cell(
1028 self.handle, hx.handle, cx.handle,
1029 w_ih.handle, w_hh.handle,
1030 b_ih.handle, b_hh.handle,
1031 &mut h_out, &mut c_out,
1032 )
1033 };
1034 check_err(err)?;
1035 Ok((Tensor::from_raw(h_out), Tensor::from_raw(c_out)))
1036 }
1037
1038 pub fn mse_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1043 let mut handle: FlodlTensor = ptr::null_mut();
1044 let err = unsafe {
1045 ffi::flodl_mse_loss(self.handle, target.handle, reduction, &mut handle)
1046 };
1047 check_err(err)?;
1048 Ok(Tensor::from_raw(handle))
1049 }
1050
1051 #[allow(clippy::too_many_arguments)]
1055 pub fn cross_entropy_loss(
1056 &self, target: &Tensor, reduction: i64,
1057 ignore_index: i64, label_smoothing: f64,
1058 ) -> Result<Tensor> {
1059 let mut handle: FlodlTensor = ptr::null_mut();
1060 let err = unsafe {
1061 ffi::flodl_cross_entropy_loss(
1062 self.handle, target.handle,
1063 reduction, ignore_index, label_smoothing,
1064 &mut handle,
1065 )
1066 };
1067 check_err(err)?;
1068 Ok(Tensor::from_raw(handle))
1069 }
1070
1071 pub fn bce_with_logits_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1075 let mut handle: FlodlTensor = ptr::null_mut();
1076 let err = unsafe {
1077 ffi::flodl_bce_with_logits_loss(
1078 self.handle, target.handle, reduction, &mut handle,
1079 )
1080 };
1081 check_err(err)?;
1082 Ok(Tensor::from_raw(handle))
1083 }
1084
1085 pub fn l1_loss(&self, target: &Tensor, reduction: i64) -> Result<Tensor> {
1088 let mut handle: FlodlTensor = ptr::null_mut();
1089 let err = unsafe {
1090 ffi::flodl_l1_loss(self.handle, target.handle, reduction, &mut handle)
1091 };
1092 check_err(err)?;
1093 Ok(Tensor::from_raw(handle))
1094 }
1095
1096 pub fn smooth_l1_loss(&self, target: &Tensor, reduction: i64, beta: f64) -> Result<Tensor> {
1099 let mut handle: FlodlTensor = ptr::null_mut();
1100 let err = unsafe {
1101 ffi::flodl_smooth_l1_loss(
1102 self.handle, target.handle, reduction, beta, &mut handle,
1103 )
1104 };
1105 check_err(err)?;
1106 Ok(Tensor::from_raw(handle))
1107 }
1108
1109 pub fn kl_div_loss(&self, target: &Tensor, reduction: i64, log_target: bool) -> Result<Tensor> {
1113 let mut handle: FlodlTensor = ptr::null_mut();
1114 let err = unsafe {
1115 ffi::flodl_kl_div_loss(
1116 self.handle, target.handle, reduction, log_target as i32, &mut handle,
1117 )
1118 };
1119 check_err(err)?;
1120 Ok(Tensor::from_raw(handle))
1121 }
1122
1123 #[allow(clippy::too_many_arguments)]
1128 pub fn batch_norm(
1129 &self, weight: Option<&Tensor>, bias: Option<&Tensor>,
1130 running_mean: Option<&Tensor>, running_var: Option<&Tensor>,
1131 training: bool, momentum: f64, eps: f64,
1132 ) -> Result<Tensor> {
1133 let mut handle: FlodlTensor = ptr::null_mut();
1134 let w = weight.map_or(ptr::null_mut(), |t| t.handle);
1135 let b = bias.map_or(ptr::null_mut(), |t| t.handle);
1136 let rm = running_mean.map_or(ptr::null_mut(), |t| t.handle);
1137 let rv = running_var.map_or(ptr::null_mut(), |t| t.handle);
1138 let err = unsafe {
1139 ffi::flodl_batch_norm(
1140 self.handle, w, b, rm, rv,
1141 training as i32, momentum, eps, &mut handle,
1142 )
1143 };
1144 check_err(err)?;
1145 Ok(Tensor::from_raw(handle))
1146 }
1147
1148 pub fn dropout(&self, p: f64, training: bool) -> Result<Tensor> {
1152 let mut handle: FlodlTensor = ptr::null_mut();
1153 let err = unsafe {
1154 ffi::flodl_dropout(self.handle, p, training as i32, &mut handle)
1155 };
1156 check_err(err)?;
1157 Ok(Tensor::from_raw(handle))
1158 }
1159
1160 pub fn feature_dropout(&self, p: f64, training: bool) -> Result<Tensor> {
1162 let mut handle: FlodlTensor = ptr::null_mut();
1163 let err = unsafe {
1164 ffi::flodl_feature_dropout(self.handle, p, training as i32, &mut handle)
1165 };
1166 check_err(err)?;
1167 Ok(Tensor::from_raw(handle))
1168 }
1169
1170 pub fn linspace(start: f64, end: f64, steps: i64, opts: TensorOptions) -> Result<Self> {
1174 let mut handle: FlodlTensor = ptr::null_mut();
1175 let (dt, di) = opts.device.to_ffi();
1176 let err = unsafe {
1177 ffi::flodl_linspace(start, end, steps, opts.dtype as i32, dt, di, &mut handle)
1178 };
1179 check_err(err)?;
1180 Ok(Self::from_raw(handle))
1181 }
1182
1183 pub fn arange(start: f64, end: f64, step: f64, opts: TensorOptions) -> Result<Self> {
1185 let mut handle: FlodlTensor = ptr::null_mut();
1186 let (dt, di) = opts.device.to_ffi();
1187 let err = unsafe {
1188 ffi::flodl_arange(start, end, step, opts.dtype as i32, dt, di, &mut handle)
1189 };
1190 check_err(err)?;
1191 Ok(Self::from_raw(handle))
1192 }
1193
1194 pub fn min(&self) -> Result<Tensor> {
1196 let mut handle: FlodlTensor = ptr::null_mut();
1197 let err = unsafe { ffi::flodl_min(self.handle, &mut handle) };
1198 check_err(err)?;
1199 Ok(Tensor::from_raw(handle))
1200 }
1201
1202 pub fn max(&self) -> Result<Tensor> {
1204 let mut handle: FlodlTensor = ptr::null_mut();
1205 let err = unsafe { ffi::flodl_max(self.handle, &mut handle) };
1206 check_err(err)?;
1207 Ok(Tensor::from_raw(handle))
1208 }
1209
1210 pub fn norm(&self) -> Result<Tensor> {
1212 let mut handle: FlodlTensor = ptr::null_mut();
1213 let err = unsafe { ffi::flodl_norm(self.handle, &mut handle) };
1214 check_err(err)?;
1215 Ok(Tensor::from_raw(handle))
1216 }
1217
1218 pub fn min_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1220 let mut handle: FlodlTensor = ptr::null_mut();
1221 let err = unsafe { ffi::flodl_min_dim(self.handle, dim, keepdim as i32, &mut handle) };
1222 check_err(err)?;
1223 Ok(Tensor::from_raw(handle))
1224 }
1225
1226 pub fn max_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1228 let mut handle: FlodlTensor = ptr::null_mut();
1229 let err = unsafe { ffi::flodl_max_dim(self.handle, dim, keepdim as i32, &mut handle) };
1230 check_err(err)?;
1231 Ok(Tensor::from_raw(handle))
1232 }
1233
1234 pub fn argmax(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1236 let mut handle: FlodlTensor = ptr::null_mut();
1237 let err = unsafe { ffi::flodl_argmax(self.handle, dim, keepdim as i32, &mut handle) };
1238 check_err(err)?;
1239 Ok(Tensor::from_raw(handle))
1240 }
1241
1242 pub fn ge_scalar(&self, scalar: f64) -> Result<Tensor> {
1244 let mut handle: FlodlTensor = ptr::null_mut();
1245 let err = unsafe { ffi::flodl_ge_scalar(self.handle, scalar, &mut handle) };
1246 check_err(err)?;
1247 Ok(Tensor::from_raw(handle))
1248 }
1249
1250 pub fn le_scalar(&self, scalar: f64) -> Result<Tensor> {
1252 let mut handle: FlodlTensor = ptr::null_mut();
1253 let err = unsafe { ffi::flodl_le_scalar(self.handle, scalar, &mut handle) };
1254 check_err(err)?;
1255 Ok(Tensor::from_raw(handle))
1256 }
1257
1258 pub fn lt_scalar(&self, scalar: f64) -> Result<Tensor> {
1260 let mut handle: FlodlTensor = ptr::null_mut();
1261 let err = unsafe { ffi::flodl_lt_scalar(self.handle, scalar, &mut handle) };
1262 check_err(err)?;
1263 Ok(Tensor::from_raw(handle))
1264 }
1265
1266 pub fn select_scatter(&self, src: &Tensor, dim: i32, index: i64) -> Result<Tensor> {
1268 let mut handle: FlodlTensor = ptr::null_mut();
1269 let err = unsafe {
1270 ffi::flodl_select_scatter(self.handle, src.handle, dim, index, &mut handle)
1271 };
1272 check_err(err)?;
1273 Ok(Tensor::from_raw(handle))
1274 }
1275
1276 pub fn where_cond(condition: &Tensor, x: &Tensor, y: &Tensor) -> Result<Tensor> {
1278 let mut handle: FlodlTensor = ptr::null_mut();
1279 let err = unsafe {
1280 ffi::flodl_where(condition.handle, x.handle, y.handle, &mut handle)
1281 };
1282 check_err(err)?;
1283 Ok(Tensor::from_raw(handle))
1284 }
1285
1286 pub fn squeeze(&self, dim: i32) -> Result<Tensor> {
1288 let mut handle: FlodlTensor = ptr::null_mut();
1289 let err = unsafe { ffi::flodl_squeeze(self.handle, dim, &mut handle) };
1290 check_err(err)?;
1291 Ok(Tensor::from_raw(handle))
1292 }
1293
1294 pub fn unsqueeze(&self, dim: i32) -> Result<Tensor> {
1296 let mut handle: FlodlTensor = ptr::null_mut();
1297 let err = unsafe { ffi::flodl_unsqueeze(self.handle, dim, &mut handle) };
1298 check_err(err)?;
1299 Ok(Tensor::from_raw(handle))
1300 }
1301
1302 pub fn adaptive_avg_pool2d(&self, output_size: [i64; 2]) -> Result<Tensor> {
1304 let mut handle: FlodlTensor = ptr::null_mut();
1305 let mut os = output_size;
1306 let err = unsafe {
1307 ffi::flodl_adaptive_avg_pool2d(self.handle, os.as_mut_ptr(), &mut handle)
1308 };
1309 check_err(err)?;
1310 Ok(Tensor::from_raw(handle))
1311 }
1312
1313 pub fn grid_sample(
1315 &self, grid: &Tensor, mode: i32, padding_mode: i32, align_corners: bool,
1316 ) -> Result<Tensor> {
1317 let mut handle: FlodlTensor = ptr::null_mut();
1318 let err = unsafe {
1319 ffi::flodl_grid_sample(
1320 self.handle, grid.handle, mode, padding_mode,
1321 align_corners as i32, &mut handle,
1322 )
1323 };
1324 check_err(err)?;
1325 Ok(Tensor::from_raw(handle))
1326 }
1327
1328 pub fn to_dtype(&self, dtype: DType) -> Result<Tensor> {
1330 let mut handle: FlodlTensor = ptr::null_mut();
1331 let err = unsafe { ffi::flodl_to_dtype(self.handle, dtype as i32, &mut handle) };
1332 check_err(err)?;
1333 Ok(Tensor::from_raw(handle))
1334 }
1335
1336 pub fn all_finite(&self) -> Result<bool> {
1338 let mut result: i32 = 0;
1339 let err = unsafe { ffi::flodl_all_finite(self.handle, &mut result) };
1340 check_err(err)?;
1341 Ok(result != 0)
1342 }
1343
1344 pub fn gt(&self, other: &Tensor) -> Result<Tensor> {
1348 let mut handle: FlodlTensor = ptr::null_mut();
1349 let err = unsafe { ffi::flodl_gt_tensor(self.handle, other.handle, &mut handle) };
1350 check_err(err)?;
1351 Ok(Tensor::from_raw(handle))
1352 }
1353
1354 pub fn lt(&self, other: &Tensor) -> Result<Tensor> {
1356 let mut handle: FlodlTensor = ptr::null_mut();
1357 let err = unsafe { ffi::flodl_lt_tensor(self.handle, other.handle, &mut handle) };
1358 check_err(err)?;
1359 Ok(Tensor::from_raw(handle))
1360 }
1361
1362 pub fn ge(&self, other: &Tensor) -> Result<Tensor> {
1364 let mut handle: FlodlTensor = ptr::null_mut();
1365 let err = unsafe { ffi::flodl_ge_tensor(self.handle, other.handle, &mut handle) };
1366 check_err(err)?;
1367 Ok(Tensor::from_raw(handle))
1368 }
1369
1370 pub fn le(&self, other: &Tensor) -> Result<Tensor> {
1372 let mut handle: FlodlTensor = ptr::null_mut();
1373 let err = unsafe { ffi::flodl_le_tensor(self.handle, other.handle, &mut handle) };
1374 check_err(err)?;
1375 Ok(Tensor::from_raw(handle))
1376 }
1377
1378 pub fn eq_tensor(&self, other: &Tensor) -> Result<Tensor> {
1381 let mut handle: FlodlTensor = ptr::null_mut();
1382 let err = unsafe { ffi::flodl_eq_tensor(self.handle, other.handle, &mut handle) };
1383 check_err(err)?;
1384 Ok(Tensor::from_raw(handle))
1385 }
1386
1387 pub fn ne_tensor(&self, other: &Tensor) -> Result<Tensor> {
1390 let mut handle: FlodlTensor = ptr::null_mut();
1391 let err = unsafe { ffi::flodl_ne_tensor(self.handle, other.handle, &mut handle) };
1392 check_err(err)?;
1393 Ok(Tensor::from_raw(handle))
1394 }
1395
1396 pub fn argmin(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1400 let mut handle: FlodlTensor = ptr::null_mut();
1401 let err = unsafe { ffi::flodl_argmin(self.handle, dim, keepdim as i32, &mut handle) };
1402 check_err(err)?;
1403 Ok(Tensor::from_raw(handle))
1404 }
1405
1406 pub fn var(&self) -> Result<Tensor> {
1408 let mut handle: FlodlTensor = ptr::null_mut();
1409 let err = unsafe { ffi::flodl_var(self.handle, &mut handle) };
1410 check_err(err)?;
1411 Ok(Tensor::from_raw(handle))
1412 }
1413
1414 #[allow(clippy::should_implement_trait)]
1416 pub fn std(&self) -> Result<Tensor> {
1417 let mut handle: FlodlTensor = ptr::null_mut();
1418 let err = unsafe { ffi::flodl_std_op(self.handle, &mut handle) };
1419 check_err(err)?;
1420 Ok(Tensor::from_raw(handle))
1421 }
1422
1423 pub fn var_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1425 let mut handle: FlodlTensor = ptr::null_mut();
1426 let err = unsafe { ffi::flodl_var_dim(self.handle, dim, keepdim as i32, &mut handle) };
1427 check_err(err)?;
1428 Ok(Tensor::from_raw(handle))
1429 }
1430
1431 pub fn std_dim(&self, dim: i32, keepdim: bool) -> Result<Tensor> {
1433 let mut handle: FlodlTensor = ptr::null_mut();
1434 let err = unsafe { ffi::flodl_std_dim(self.handle, dim, keepdim as i32, &mut handle) };
1435 check_err(err)?;
1436 Ok(Tensor::from_raw(handle))
1437 }
1438
1439 pub fn sin(&self) -> Result<Tensor> {
1443 let mut handle: FlodlTensor = ptr::null_mut();
1444 let err = unsafe { ffi::flodl_sin(self.handle, &mut handle) };
1445 check_err(err)?;
1446 Ok(Tensor::from_raw(handle))
1447 }
1448
1449 pub fn cos(&self) -> Result<Tensor> {
1451 let mut handle: FlodlTensor = ptr::null_mut();
1452 let err = unsafe { ffi::flodl_cos(self.handle, &mut handle) };
1453 check_err(err)?;
1454 Ok(Tensor::from_raw(handle))
1455 }
1456
1457 pub fn sign(&self) -> Result<Tensor> {
1459 let mut handle: FlodlTensor = ptr::null_mut();
1460 let err = unsafe { ffi::flodl_sign(self.handle, &mut handle) };
1461 check_err(err)?;
1462 Ok(Tensor::from_raw(handle))
1463 }
1464
1465 pub fn floor(&self) -> Result<Tensor> {
1467 let mut handle: FlodlTensor = ptr::null_mut();
1468 let err = unsafe { ffi::flodl_floor(self.handle, &mut handle) };
1469 check_err(err)?;
1470 Ok(Tensor::from_raw(handle))
1471 }
1472
1473 pub fn ceil(&self) -> Result<Tensor> {
1475 let mut handle: FlodlTensor = ptr::null_mut();
1476 let err = unsafe { ffi::flodl_ceil(self.handle, &mut handle) };
1477 check_err(err)?;
1478 Ok(Tensor::from_raw(handle))
1479 }
1480
1481 pub fn round(&self) -> Result<Tensor> {
1483 let mut handle: FlodlTensor = ptr::null_mut();
1484 let err = unsafe { ffi::flodl_round(self.handle, &mut handle) };
1485 check_err(err)?;
1486 Ok(Tensor::from_raw(handle))
1487 }
1488
1489 pub fn reciprocal(&self) -> Result<Tensor> {
1491 let mut handle: FlodlTensor = ptr::null_mut();
1492 let err = unsafe { ffi::flodl_reciprocal(self.handle, &mut handle) };
1493 check_err(err)?;
1494 Ok(Tensor::from_raw(handle))
1495 }
1496
1497 pub fn gather(&self, dim: i32, index: &Tensor) -> Result<Tensor> {
1501 let mut handle: FlodlTensor = ptr::null_mut();
1502 let err = unsafe {
1503 ffi::flodl_gather(self.handle, dim, index.handle, &mut handle)
1504 };
1505 check_err(err)?;
1506 Ok(Tensor::from_raw(handle))
1507 }
1508
1509 pub fn scatter_add(&self, dim: i32, index: &Tensor, src: &Tensor) -> Result<Tensor> {
1511 let mut handle: FlodlTensor = ptr::null_mut();
1512 let err = unsafe {
1513 ffi::flodl_scatter_add(self.handle, dim, index.handle, src.handle, &mut handle)
1514 };
1515 check_err(err)?;
1516 Ok(Tensor::from_raw(handle))
1517 }
1518
1519 pub fn topk(&self, k: i64, dim: i32, largest: bool, sorted: bool) -> Result<(Tensor, Tensor)> {
1523 let mut values: FlodlTensor = ptr::null_mut();
1524 let mut indices: FlodlTensor = ptr::null_mut();
1525 let err = unsafe {
1526 ffi::flodl_topk(
1527 self.handle, k, dim, largest as i32, sorted as i32,
1528 &mut values, &mut indices,
1529 )
1530 };
1531 check_err(err)?;
1532 Ok((Tensor::from_raw(values), Tensor::from_raw(indices)))
1533 }
1534
1535 pub fn sort(&self, dim: i32, descending: bool) -> Result<(Tensor, Tensor)> {
1537 let mut values: FlodlTensor = ptr::null_mut();
1538 let mut indices: FlodlTensor = ptr::null_mut();
1539 let err = unsafe {
1540 ffi::flodl_sort(self.handle, dim, descending as i32, &mut values, &mut indices)
1541 };
1542 check_err(err)?;
1543 Ok((Tensor::from_raw(values), Tensor::from_raw(indices)))
1544 }
1545
1546 pub fn eye(n: i64, opts: TensorOptions) -> Result<Self> {
1550 let mut handle: FlodlTensor = ptr::null_mut();
1551 let (dt, di) = opts.device.to_ffi();
1552 let err = unsafe {
1553 ffi::flodl_eye(n, opts.dtype as i32, dt, di, &mut handle)
1554 };
1555 check_err(err)?;
1556 Ok(Self::from_raw(handle))
1557 }
1558
1559 pub fn full(shape: &[i64], value: f64, opts: TensorOptions) -> Result<Self> {
1561 let mut shape = shape.to_vec();
1562 let mut handle: FlodlTensor = ptr::null_mut();
1563 let (dt, di) = opts.device.to_ffi();
1564 let err = unsafe {
1565 ffi::flodl_full(
1566 shape.as_mut_ptr(), shape.len() as i32, value,
1567 opts.dtype as i32, dt, di, &mut handle,
1568 )
1569 };
1570 check_err(err)?;
1571 Ok(Self::from_raw(handle))
1572 }
1573
1574 pub fn batches(&self, batch_size: i64) -> Result<Vec<Tensor>> {
1587 let n = self.shape()[0];
1588 let mut result = Vec::new();
1589 let mut start = 0i64;
1590 while start < n {
1591 let len = (batch_size).min(n - start);
1592 result.push(self.narrow(0, start, len)?);
1593 start += len;
1594 }
1595 Ok(result)
1596 }
1597
1598 pub fn chunk(&self, chunks: i32, dim: i32) -> Result<Vec<Tensor>> {
1600 let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
1601 let mut count: i32 = 0;
1602 let err = unsafe {
1603 ffi::flodl_chunk(self.handle, chunks, dim, &mut results_ptr, &mut count)
1604 };
1605 check_err(err)?;
1606 let mut tensors = Vec::with_capacity(count as usize);
1607 for i in 0..count as usize {
1608 let handle = unsafe { *results_ptr.add(i) };
1609 tensors.push(Tensor::from_raw(handle));
1610 }
1611 if !results_ptr.is_null() {
1612 unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
1615 }
1616 Ok(tensors)
1617 }
1618
1619 pub fn repeat(&self, repeats: &[i64]) -> Result<Tensor> {
1621 let mut repeats = repeats.to_vec();
1622 let mut handle: FlodlTensor = ptr::null_mut();
1623 let err = unsafe {
1624 ffi::flodl_repeat(self.handle, repeats.as_mut_ptr(), repeats.len() as i32, &mut handle)
1625 };
1626 check_err(err)?;
1627 Ok(Tensor::from_raw(handle))
1628 }
1629
1630 pub fn pad(&self, padding: &[i64], value: f64) -> Result<Tensor> {
1632 let mut padding = padding.to_vec();
1633 let mut handle: FlodlTensor = ptr::null_mut();
1634 let err = unsafe {
1635 ffi::flodl_pad(
1636 self.handle, padding.as_mut_ptr(), padding.len() as i32,
1637 value, &mut handle,
1638 )
1639 };
1640 check_err(err)?;
1641 Ok(Tensor::from_raw(handle))
1642 }
1643
1644 pub fn unsqueeze_many(&self, dims: &[i32]) -> Result<Tensor> {
1647 let mut sorted = dims.to_vec();
1648 sorted.sort();
1649 let mut t = self.unsqueeze(sorted[0])?;
1650 for &d in &sorted[1..] {
1651 t = t.unsqueeze(d)?;
1652 }
1653 Ok(t)
1654 }
1655
1656 pub fn meshgrid(tensors: &[&Tensor]) -> Result<Vec<Tensor>> {
1658 let mut handles: Vec<FlodlTensor> = tensors.iter().map(|t| t.handle).collect();
1659 let mut results_ptr: *mut FlodlTensor = ptr::null_mut();
1660 let mut count: i32 = 0;
1661 let err = unsafe {
1662 ffi::flodl_meshgrid(
1663 handles.as_mut_ptr(), handles.len() as i32,
1664 &mut results_ptr, &mut count,
1665 )
1666 };
1667 check_err(err)?;
1668 let mut out = Vec::with_capacity(count as usize);
1669 for i in 0..count as usize {
1670 let handle = unsafe { *results_ptr.add(i) };
1671 out.push(Tensor::from_raw(handle));
1672 }
1673 if !results_ptr.is_null() {
1674 unsafe { ffi::flodl_free_string(results_ptr as *mut i8) };
1675 }
1676 Ok(out)
1677 }
1678
1679 pub fn cdist(&self, other: &Tensor) -> Result<Tensor> {
1682 self.cdist_p(other, 2.0)
1683 }
1684
1685 pub fn cdist_p(&self, other: &Tensor, p: f64) -> Result<Tensor> {
1687 let mut handle: FlodlTensor = ptr::null_mut();
1688 let err = unsafe { ffi::flodl_cdist(self.handle, other.handle, p, &mut handle) };
1689 check_err(err)?;
1690 Ok(Tensor::from_raw(handle))
1691 }
1692
1693 pub fn to_device(&self, device: Device) -> Result<Tensor> {
1702 let mut handle: FlodlTensor = ptr::null_mut();
1703 let (dt, di) = device.to_ffi();
1704 let err = unsafe { ffi::flodl_to_device(self.handle, dt, di, &mut handle) };
1705 check_err(err)?;
1706 Ok(Tensor::from_raw(handle))
1707 }
1708
1709 pub fn to_device_of(&self, other: &Tensor) -> Result<Tensor> {
1716 let target = other.device();
1717 if self.device() == target {
1718 return Ok(self.clone());
1719 }
1720 self.to_device(target)
1721 }
1722
1723 pub fn set_requires_grad(&self, requires_grad: bool) -> Result<Tensor> {
1729 let mut handle: FlodlTensor = ptr::null_mut();
1730 let err = unsafe {
1731 ffi::flodl_set_requires_grad(self.handle, requires_grad as i32, &mut handle)
1732 };
1733 check_err(err)?;
1734 Ok(Tensor::from_raw(handle))
1735 }
1736
1737 pub fn requires_grad(&self) -> bool {
1739 unsafe { ffi::flodl_requires_grad(self.handle) != 0 }
1740 }
1741
1742 pub fn backward(&self) -> Result<()> {
1745 let err = unsafe { ffi::flodl_backward(self.handle) };
1746 check_err(err)
1747 }
1748
1749 pub fn grad(&self) -> Option<Tensor> {
1752 let mut handle: FlodlTensor = ptr::null_mut();
1753 let err = unsafe { ffi::flodl_grad(self.handle, &mut handle) };
1754 if !err.is_null() {
1755 unsafe { ffi::flodl_free_string(err) };
1756 return None;
1757 }
1758 if handle.is_null() {
1759 None
1760 } else {
1761 Some(Tensor::from_raw(handle))
1762 }
1763 }
1764
1765 pub fn set_grad(&self, grad: &Tensor) -> Result<()> {
1767 let err = unsafe { ffi::flodl_set_grad(self.handle, grad.handle) };
1768 check_err(err)
1769 }
1770
1771 pub fn zero_grad(&self) -> Result<()> {
1773 let err = unsafe { ffi::flodl_zero_grad(self.handle) };
1774 check_err(err)
1775 }
1776
1777 pub fn zero_grad_set_to_none(&self) {
1781 unsafe { ffi::flodl_zero_grad_set_to_none(self.handle) }
1782 }
1783
1784 pub fn clip_grad_norm_fused(params: &[Tensor], max_norm: f64) -> Result<f64> {
1788 if params.is_empty() {
1789 return Ok(0.0);
1790 }
1791 let mut handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
1792 let mut total_norm: f64 = 0.0;
1793 let err = unsafe {
1794 ffi::flodl_clip_grad_norm(
1795 handles.as_mut_ptr(),
1796 handles.len() as i32,
1797 max_norm,
1798 &mut total_norm,
1799 )
1800 };
1801 check_err(err)?;
1802 Ok(total_norm)
1803 }
1804
1805 pub fn is_leaf(&self) -> bool {
1809 unsafe { ffi::flodl_is_leaf(self.handle) != 0 }
1810 }
1811
1812 pub fn autograd_node_count(&self) -> i64 {
1816 unsafe { ffi::flodl_autograd_node_count(self.handle) }
1817 }
1818
1819 pub fn detach(&self) -> Result<Tensor> {
1822 let mut handle: FlodlTensor = ptr::null_mut();
1823 let err = unsafe { ffi::flodl_detach(self.handle, &mut handle) };
1824 check_err(err)?;
1825 Ok(Tensor::from_raw(handle))
1826 }
1827
1828 pub fn detach_(&self) -> Result<()> {
1833 let err = unsafe { ffi::flodl_detach_(self.handle) };
1834 check_err(err)
1835 }
1836
1837 pub fn add_(&self, other: &Tensor) -> Result<()> {
1841 let err = unsafe { ffi::flodl_add_(self.handle, other.handle) };
1842 check_err(err)
1843 }
1844
1845 pub fn sub_(&self, other: &Tensor) -> Result<()> {
1847 let err = unsafe { ffi::flodl_sub_(self.handle, other.handle) };
1848 check_err(err)
1849 }
1850
1851 pub fn mul_scalar_(&self, scalar: f64) -> Result<()> {
1853 let err = unsafe { ffi::flodl_mul_scalar_(self.handle, scalar) };
1854 check_err(err)
1855 }
1856
1857 pub fn add_scalar_(&self, scalar: f64) -> Result<()> {
1859 let err = unsafe { ffi::flodl_add_scalar_(self.handle, scalar) };
1860 check_err(err)
1861 }
1862
1863 pub fn zero_(&self) -> Result<()> {
1865 let err = unsafe { ffi::flodl_zero_(self.handle) };
1866 check_err(err)
1867 }
1868
1869 #[allow(clippy::too_many_arguments)]
1871 pub fn adam_step(
1881 &self, grad: &Tensor, m: &Tensor, v: &Tensor,
1882 lr: f64, beta1: f64, beta2: f64, eps: f64,
1883 weight_decay: f64, step: i64,
1884 ) -> Result<()> {
1885 let err = unsafe {
1886 ffi::flodl_adam_step(
1887 self.handle, grad.handle, m.handle, v.handle,
1888 lr, beta1, beta2, eps, weight_decay, step,
1889 )
1890 };
1891 check_err(err)
1892 }
1893
1894 #[allow(clippy::too_many_arguments)]
1899 pub fn adam_step_batched(
1900 params: &[Tensor], grads: &[Tensor], ms: &[Tensor], vs: &[Tensor],
1901 lrs: &mut [f64], beta1: f64, beta2: f64, eps: f64,
1902 weight_decay: f64, step: i64,
1903 ) -> Result<()> {
1904 let count = params.len() as i32;
1905 let mut p_handles: Vec<FlodlTensor> = params.iter().map(|t| t.handle).collect();
1906 let mut g_handles: Vec<FlodlTensor> = grads.iter().map(|t| t.handle).collect();
1907 let mut m_handles: Vec<FlodlTensor> = ms.iter().map(|t| t.handle).collect();
1908 let mut v_handles: Vec<FlodlTensor> = vs.iter().map(|t| t.handle).collect();
1909 let err = unsafe {
1910 ffi::flodl_adam_step_batched(
1911 p_handles.as_mut_ptr(), g_handles.as_mut_ptr(),
1912 m_handles.as_mut_ptr(), v_handles.as_mut_ptr(),
1913 lrs.as_mut_ptr(), count,
1914 beta1, beta2, eps, weight_decay, step,
1915 )
1916 };
1917 check_err(err)
1918 }
1919
1920 pub fn pin_memory(&self) -> Result<Tensor> {
1927 let mut handle: FlodlTensor = ptr::null_mut();
1928 let err = unsafe { ffi::flodl_pin_memory(self.handle, &mut handle) };
1929 check_err(err)?;
1930 Ok(Tensor::from_raw(handle))
1931 }
1932
1933 pub fn is_pinned(&self) -> bool {
1935 unsafe { ffi::flodl_is_pinned(self.handle) != 0 }
1936 }
1937}
1938
1939impl fmt::Debug for Tensor {
1940 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1941 write!(
1942 f,
1943 "Tensor({:?}, {:?}, {:?})",
1944 self.shape(),
1945 self.dtype(),
1946 self.device()
1947 )
1948 }
1949}
1950
1951pub fn cuda_available() -> bool {
1957 unsafe { let _ = ffi::flodl_force_cuda_link(); }
1962 unsafe { ffi::flodl_cuda_is_available() != 0 }
1963}
1964
1965pub fn cuda_device_count() -> i32 {
1967 unsafe { ffi::flodl_cuda_device_count() }
1968}
1969
1970pub fn cuda_memory_info_idx(device_index: i32) -> Result<(u64, u64)> {
1973 let mut used: u64 = 0;
1974 let mut total: u64 = 0;
1975 check_err(unsafe { ffi::flodl_cuda_mem_info(device_index, &mut used, &mut total) })?;
1976 Ok((used, total))
1977}
1978
1979pub fn cuda_memory_info() -> Result<(u64, u64)> {
1982 cuda_memory_info_idx(0)
1983}
1984
1985pub fn cuda_allocated_bytes_idx(device_index: i32) -> Result<u64> {
1990 let mut allocated: u64 = 0;
1991 check_err(unsafe { ffi::flodl_cuda_alloc_bytes(device_index, &mut allocated) })?;
1992 Ok(allocated)
1993}
1994
1995pub fn cuda_allocated_bytes() -> Result<u64> {
1997 cuda_allocated_bytes_idx(0)
1998}
1999
2000pub fn cuda_utilization() -> Option<u32> {
2003 cuda_utilization_idx(0)
2004}
2005
2006pub fn cuda_utilization_idx(device_index: i32) -> Option<u32> {
2008 let val = unsafe { ffi::flodl_cuda_utilization(device_index) };
2009 if val >= 0 { Some(val as u32) } else { None }
2010}
2011
2012pub fn set_current_cuda_device(device_index: u8) {
2014 unsafe { ffi::flodl_set_current_device(device_index as i32) };
2015}
2016
2017pub fn current_cuda_device() -> u8 {
2019 unsafe { ffi::flodl_get_current_device() as u8 }
2020}
2021
2022pub fn cuda_synchronize(device_index: u8) {
2024 unsafe { ffi::flodl_cuda_synchronize(device_index as i32) };
2025}
2026
2027pub fn cuda_device_name_idx(device: i32) -> Option<String> {
2029 let mut buf = [0i8; 256];
2030 let err = unsafe { ffi::flodl_cuda_device_name(device, buf.as_mut_ptr(), 256) };
2031 if err.is_null() {
2032 let name = unsafe { CStr::from_ptr(buf.as_ptr()) }
2033 .to_string_lossy()
2034 .into_owned();
2035 Some(name)
2036 } else {
2037 unsafe { ffi::flodl_free_string(err) };
2038 None
2039 }
2040}
2041
2042pub fn cuda_device_name() -> Option<String> {
2044 cuda_device_name_idx(0)
2045}
2046
2047#[derive(Debug, Clone)]
2049pub struct DeviceInfo {
2050 pub index: u8,
2052 pub name: String,
2054 pub total_memory: u64,
2056}
2057
2058pub fn cuda_devices() -> Vec<DeviceInfo> {
2060 let n = cuda_device_count();
2061 (0..n).filter_map(|i| {
2062 let name = cuda_device_name_idx(i)?;
2063 let total_memory = cuda_memory_info_idx(i).map(|(_, t)| t).unwrap_or(0);
2064 Some(DeviceInfo { index: i as u8, name, total_memory })
2065 }).collect()
2066}
2067
2068pub fn hardware_summary() -> String {
2073 let cpu = cpu_model_name().unwrap_or_else(|| "Unknown CPU".into());
2074 let threads = cpu_thread_count();
2075 let ram = total_ram_gb();
2076 let mut s = format!("{} ({} threads, {}GB)", cpu, threads, ram);
2077
2078 if cuda_available() {
2079 let n = cuda_device_count();
2080 for i in 0..n {
2081 if let Some(gpu) = cuda_device_name_idx(i) {
2082 let vram_str = cuda_memory_info_idx(i)
2083 .map(|(_, total)| format!(" ({}GB)", total / (1024 * 1024 * 1024)))
2084 .unwrap_or_default();
2085 let _ = std::fmt::Write::write_fmt(&mut s, format_args!(
2086 " | {}{}", gpu, vram_str
2087 ));
2088 }
2089 }
2090 }
2091 s
2092}
2093
2094fn cpu_thread_count() -> usize {
2096 std::fs::read_to_string("/proc/cpuinfo")
2097 .ok()
2098 .map(|s| s.lines().filter(|l| l.starts_with("processor")).count())
2099 .unwrap_or(1)
2100}
2101
2102fn cpu_model_name() -> Option<String> {
2104 let info = std::fs::read_to_string("/proc/cpuinfo").ok()?;
2105 for line in info.lines() {
2106 if line.starts_with("model name") && let Some(val) = line.split(':').nth(1) {
2107 return Some(val.trim().to_string());
2108 }
2109 }
2110 None
2111}
2112
2113fn total_ram_gb() -> u64 {
2115 std::fs::read_to_string("/proc/meminfo")
2116 .ok()
2117 .and_then(|s| {
2118 for line in s.lines() {
2119 if line.starts_with("MemTotal:") {
2120 let kb: u64 = line.split_whitespace().nth(1)?.parse().ok()?;
2121 return Some(kb / (1024 * 1024));
2122 }
2123 }
2124 None
2125 })
2126 .unwrap_or(0)
2127}
2128
2129pub fn set_cudnn_benchmark(enable: bool) {
2136 unsafe { ffi::flodl_set_cudnn_benchmark(enable as i32) }
2137}
2138
2139pub fn malloc_trim() -> bool {
2145 unsafe { ffi::flodl_malloc_trim() != 0 }
2146}
2147
2148pub fn live_tensor_count() -> u64 {
2152 LIVE_TENSOR_COUNT.load(Ordering::Relaxed)
2153}
2154
2155pub fn rss_kb() -> usize {
2158 std::fs::read_to_string("/proc/self/statm")
2159 .ok()
2160 .and_then(|s| s.split_whitespace().nth(1)?.parse::<usize>().ok())
2161 .map(|pages| pages * 4)
2162 .unwrap_or(0)
2163}
2164
2165#[cfg(test)]
2168pub fn test_device() -> Device {
2169 use std::sync::Once;
2170 static PRINT: Once = Once::new();
2171 let dev = if cfg!(feature = "cuda") && cuda_available() { Device::CUDA(0) } else { Device::CPU };
2172 PRINT.call_once(|| eprintln!("\n*** flodl test device: {} ***\n", dev));
2173 dev
2174}
2175
2176#[cfg(test)]
2178pub fn test_opts() -> TensorOptions {
2179 TensorOptions { dtype: DType::Float32, device: test_device() }
2180}
2181
2182#[cfg(test)]
2183mod tests {
2184 use super::*;
2185
2186 #[test]
2187 fn test_zeros() {
2188 let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2189 assert_eq!(t.shape(), vec![2, 3]);
2190 assert_eq!(t.dtype(), DType::Float32);
2191 assert_eq!(t.device(), test_device());
2192 assert_eq!(t.numel(), 6);
2193
2194 let data = t.to_f32_vec().unwrap();
2195 assert_eq!(data, vec![0.0; 6]);
2196 }
2197
2198 #[test]
2199 fn test_from_f32() {
2200 let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2201 assert_eq!(t.shape(), vec![3]);
2202 let data = t.to_f32_vec().unwrap();
2203 assert_eq!(data, vec![1.0, 2.0, 3.0]);
2204 }
2205
2206 #[test]
2207 fn test_add() {
2208 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2209 let b = Tensor::from_f32(&[4.0, 5.0, 6.0], &[3], test_device()).unwrap();
2210 let c = a.add(&b).unwrap();
2211 assert_eq!(c.to_f32_vec().unwrap(), vec![5.0, 7.0, 9.0]);
2212 }
2213
2214 #[test]
2215 fn test_matmul() {
2216 let a = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2217 let b = Tensor::from_f32(&[5.0, 6.0, 7.0, 8.0], &[2, 2], test_device()).unwrap();
2218 let c = a.matmul(&b).unwrap();
2219 assert_eq!(c.to_f32_vec().unwrap(), vec![19.0, 22.0, 43.0, 50.0]);
2220 }
2221
2222 #[test]
2223 fn test_chaining() {
2224 let a = Tensor::from_f32(&[1.0, -2.0, 3.0], &[3], test_device()).unwrap();
2225 let b = Tensor::from_f32(&[1.0, 1.0, 1.0], &[3], test_device()).unwrap();
2226 let result = a.add(&b).unwrap().relu().unwrap().sum().unwrap();
2227 let val = result.item().unwrap();
2229 assert!((val - 6.0).abs() < 1e-5);
2230 }
2231
2232 #[test]
2233 fn test_drop_frees_memory() {
2234 let _ = Tensor::zeros(&[1000, 1000], test_opts()).unwrap();
2236 }
2238
2239 #[test]
2240 fn test_debug_format() {
2241 let t = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2242 let s = format!("{:?}", t);
2243 assert!(s.contains("[2, 3]"));
2244 assert!(s.contains("Float32"));
2245 }
2246
2247 #[test]
2248 fn test_div_scalar() {
2249 let t = Tensor::from_f32(&[6.0, 9.0], &[2], test_device()).unwrap();
2250 let r = t.div_scalar(3.0).unwrap();
2251 let data = r.to_f32_vec().unwrap();
2252 assert!((data[0] - 2.0).abs() < 1e-5);
2253 assert!((data[1] - 3.0).abs() < 1e-5);
2254 }
2255
2256 #[test]
2257 fn test_mean() {
2258 let t = Tensor::from_f32(&[2.0, 4.0, 6.0], &[3], test_device()).unwrap();
2259 let m = t.mean().unwrap();
2260 assert!((m.item().unwrap() - 4.0).abs() < 1e-5);
2261 }
2262
2263 #[test]
2264 fn test_flatten() {
2265 let t = Tensor::ones(&[2, 3, 4], test_opts()).unwrap();
2266 let f = t.flatten(1, 2).unwrap();
2267 assert_eq!(f.shape(), vec![2, 12]);
2268 }
2269
2270 #[test]
2271 fn test_stack() {
2272 let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2273 let b = Tensor::from_f32(&[3.0, 4.0], &[2], test_device()).unwrap();
2274 let c = Tensor::from_f32(&[5.0, 6.0], &[2], test_device()).unwrap();
2275
2276 let s = Tensor::stack(&[&a, &b, &c], 0).unwrap();
2278 assert_eq!(s.shape(), vec![3, 2]);
2279 let data = s.to_f32_vec().unwrap();
2280 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2281
2282 let s1 = Tensor::stack(&[&a, &b, &c], 1).unwrap();
2284 assert_eq!(s1.shape(), vec![2, 3]);
2285 let data1 = s1.to_f32_vec().unwrap();
2286 assert_eq!(data1, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
2287 }
2288
2289 #[test]
2290 fn test_ones_from_f64_from_i64() {
2291 let o = Tensor::ones(&[2, 3], test_opts()).unwrap();
2292 assert_eq!(o.to_f32_vec().unwrap(), vec![1.0; 6]);
2293
2294 let f = Tensor::from_f64(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2295 assert_eq!(f.dtype(), DType::Float64);
2296 assert_eq!(f.to_f64_vec().unwrap(), vec![1.0, 2.0, 3.0]);
2297
2298 let i = Tensor::from_i64(&[10, 20, 30], &[3], test_device()).unwrap();
2299 assert_eq!(i.dtype(), DType::Int64);
2300 assert_eq!(i.to_i64_vec().unwrap(), vec![10, 20, 30]);
2301 }
2302
2303 #[test]
2304 fn test_sub_mul_div() {
2305 let a = Tensor::from_f32(&[6.0, 8.0], &[2], test_device()).unwrap();
2306 let b = Tensor::from_f32(&[2.0, 3.0], &[2], test_device()).unwrap();
2307 assert_eq!(a.sub(&b).unwrap().to_f32_vec().unwrap(), vec![4.0, 5.0]);
2308 assert_eq!(a.mul(&b).unwrap().to_f32_vec().unwrap(), vec![12.0, 24.0]);
2309 let d = a.div(&b).unwrap().to_f32_vec().unwrap();
2310 assert!((d[0] - 3.0).abs() < 1e-5);
2311 assert!((d[1] - 8.0 / 3.0).abs() < 1e-5);
2312 }
2313
2314 #[test]
2315 fn test_scalar_ops() {
2316 let t = Tensor::from_f32(&[2.0, 4.0], &[2], test_device()).unwrap();
2317 assert_eq!(t.add_scalar(1.0).unwrap().to_f32_vec().unwrap(), vec![3.0, 5.0]);
2318 assert_eq!(t.mul_scalar(3.0).unwrap().to_f32_vec().unwrap(), vec![6.0, 12.0]);
2319 assert_eq!(t.neg().unwrap().to_f32_vec().unwrap(), vec![-2.0, -4.0]);
2320 }
2321
2322 #[test]
2323 fn test_exp_log_sqrt_abs_pow() {
2324 let t = Tensor::from_f32(&[1.0, 4.0], &[2], test_device()).unwrap();
2325 let e = t.exp().unwrap().to_f32_vec().unwrap();
2326 assert!((e[0] - 1.0_f32.exp()).abs() < 1e-5);
2327
2328 let l = t.log().unwrap().to_f32_vec().unwrap();
2329 assert!((l[1] - 4.0_f32.ln()).abs() < 1e-5);
2330
2331 let s = t.sqrt().unwrap().to_f32_vec().unwrap();
2332 assert!((s[1] - 2.0).abs() < 1e-5);
2333
2334 let a = Tensor::from_f32(&[-3.0, 5.0], &[2], test_device()).unwrap();
2335 assert_eq!(a.abs().unwrap().to_f32_vec().unwrap(), vec![3.0, 5.0]);
2336
2337 let p = t.pow_scalar(2.0).unwrap().to_f32_vec().unwrap();
2338 assert!((p[0] - 1.0).abs() < 1e-5);
2339 assert!((p[1] - 16.0).abs() < 1e-5);
2340 }
2341
2342 #[test]
2343 fn test_clamp() {
2344 let t = Tensor::from_f32(&[-1.0, 0.5, 2.0], &[3], test_device()).unwrap();
2345 let c = t.clamp(0.0, 1.0).unwrap().to_f32_vec().unwrap();
2346 assert_eq!(c, vec![0.0, 0.5, 1.0]);
2347 }
2348
2349 #[test]
2350 fn test_sum_dim_mean_dim() {
2351 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2352 let s = t.sum_dim(1, false).unwrap().to_f32_vec().unwrap();
2353 assert_eq!(s, vec![3.0, 7.0]);
2354
2355 let m = t.mean_dim(0, false).unwrap().to_f32_vec().unwrap();
2356 assert!((m[0] - 2.0).abs() < 1e-5);
2357 assert!((m[1] - 3.0).abs() < 1e-5);
2358 }
2359
2360 #[test]
2361 fn test_norm() {
2362 let t = Tensor::from_f32(&[3.0, 4.0], &[2], test_device()).unwrap();
2363 let n = t.norm().unwrap().item().unwrap();
2364 assert!((n - 5.0).abs() < 1e-5);
2365 }
2366
2367 #[test]
2368 fn test_reshape_transpose_narrow_select() {
2369 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2370 let r = t.reshape(&[3, 2]).unwrap();
2371 assert_eq!(r.shape(), vec![3, 2]);
2372 assert_eq!(r.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2373
2374 let tr = t.transpose(0, 1).unwrap();
2375 assert_eq!(tr.shape(), vec![3, 2]);
2376 assert_eq!(tr.to_f32_vec().unwrap(), vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
2377
2378 let n = t.narrow(1, 0, 2).unwrap();
2379 assert_eq!(n.shape(), vec![2, 2]);
2380 assert_eq!(n.to_f32_vec().unwrap(), vec![1.0, 2.0, 4.0, 5.0]);
2381
2382 let s = t.select(0, 1).unwrap();
2383 assert_eq!(s.shape(), vec![3]);
2384 assert_eq!(s.to_f32_vec().unwrap(), vec![4.0, 5.0, 6.0]);
2385 }
2386
2387 #[test]
2388 fn test_permute_expand() {
2389 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2390 let p = t.permute(&[1, 0]).unwrap();
2391 assert_eq!(p.shape(), vec![3, 2]);
2392
2393 let s = Tensor::from_f32(&[1.0, 2.0, 3.0], &[1, 3], test_device()).unwrap();
2394 let e = s.expand(&[4, 3]).unwrap();
2395 assert_eq!(e.shape(), vec![4, 3]);
2396 let data = e.to_f32_vec().unwrap();
2397 assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
2398 }
2399
2400 #[test]
2401 fn test_cat_many() {
2402 let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2403 let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
2404 let c = Tensor::from_f32(&[6.0], &[1], test_device()).unwrap();
2405
2406 let result = Tensor::cat_many(&[&a, &b, &c], 0).unwrap();
2408 assert_eq!(result.shape(), vec![6]);
2409 assert_eq!(result.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2410
2411 let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2413 let y = Tensor::from_f32(&[5.0, 6.0], &[2, 1], test_device()).unwrap();
2414 let z = Tensor::from_f32(&[7.0, 8.0, 9.0, 10.0, 11.0, 12.0], &[2, 3], test_device()).unwrap();
2415 let result2 = Tensor::cat_many(&[&x, &y, &z], 1).unwrap();
2416 assert_eq!(result2.shape(), vec![2, 6]);
2417 assert_eq!(
2418 result2.to_f32_vec().unwrap(),
2419 vec![1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0]
2420 );
2421
2422 let single = Tensor::cat_many(&[&a], 0).unwrap();
2424 assert_eq!(single.to_f32_vec().unwrap(), vec![1.0, 2.0]);
2425
2426 let empty: Vec<&Tensor> = vec![];
2428 assert!(Tensor::cat_many(&empty, 0).is_err());
2429 }
2430
2431 #[test]
2432 fn test_cat_index_select_index_add() {
2433 let a = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2434 let b = Tensor::from_f32(&[3.0, 4.0, 5.0], &[3], test_device()).unwrap();
2435 let c = a.cat(&b, 0).unwrap();
2436 assert_eq!(c.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
2437
2438 let t = Tensor::from_f32(&[10.0, 20.0, 30.0, 40.0, 50.0], &[5], test_device()).unwrap();
2439 let idx = Tensor::from_i64(&[0, 2, 4], &[3], test_device()).unwrap();
2440 let sel = t.index_select(0, &idx).unwrap();
2441 assert_eq!(sel.to_f32_vec().unwrap(), vec![10.0, 30.0, 50.0]);
2442
2443 let base = Tensor::zeros(&[5], test_opts()).unwrap();
2444 let src = Tensor::from_f32(&[1.0, 1.0, 1.0], &[3], test_device()).unwrap();
2445 let r = base.index_add(0, &idx, &src).unwrap();
2446 let data = r.to_f32_vec().unwrap();
2447 assert!((data[0] - 1.0).abs() < 1e-5);
2448 assert!((data[2] - 1.0).abs() < 1e-5);
2449 assert!((data[4] - 1.0).abs() < 1e-5);
2450 }
2451
2452 #[test]
2453 fn test_narrow_scatter_select_scatter() {
2454 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[4], test_device()).unwrap();
2455 let src = Tensor::from_f32(&[10.0, 20.0], &[2], test_device()).unwrap();
2456 let ns = t.narrow_scatter(&src, 0, 1).unwrap();
2457 assert_eq!(ns.to_f32_vec().unwrap(), vec![1.0, 10.0, 20.0, 4.0]);
2458
2459 let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], test_device()).unwrap();
2460 let row = Tensor::from_f32(&[10.0, 20.0, 30.0], &[3], test_device()).unwrap();
2461 let ss = t2.select_scatter(&row, 0, 0).unwrap();
2462 assert_eq!(ss.to_f32_vec().unwrap(), vec![10.0, 20.0, 30.0, 4.0, 5.0, 6.0]);
2463 }
2464
2465 #[test]
2466 fn test_activations() {
2467 let t = Tensor::from_f32(&[-1.0, 0.0, 1.0], &[3], test_device()).unwrap();
2468 assert_eq!(t.relu().unwrap().to_f32_vec().unwrap(), vec![0.0, 0.0, 1.0]);
2469
2470 let sig = t.sigmoid().unwrap().to_f32_vec().unwrap();
2471 assert!((sig[2] - 0.7310586).abs() < 1e-5);
2472
2473 let th = t.tanh().unwrap().to_f32_vec().unwrap();
2474 assert!((th[2] - 1.0_f32.tanh()).abs() < 1e-5);
2475
2476 assert_eq!(t.gelu().unwrap().shape(), vec![3]);
2478 assert_eq!(t.silu().unwrap().shape(), vec![3]);
2479 }
2480
2481 #[test]
2482 fn test_softmax_log_softmax() {
2483 let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2484 let sm = t.softmax(0).unwrap().to_f32_vec().unwrap();
2485 let total: f32 = sm.iter().sum();
2486 assert!((total - 1.0).abs() < 1e-5);
2487 assert!(sm[2] > sm[1] && sm[1] > sm[0]);
2488
2489 let lsm = t.log_softmax(0).unwrap().to_f32_vec().unwrap();
2490 assert!(lsm[0] < 0.0 && lsm[1] < 0.0 && lsm[2] < 0.0);
2491 }
2492
2493 #[test]
2494 fn test_eq_ne_tensor() {
2495 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2496 let b = Tensor::from_f32(&[1.0, 5.0, 3.0], &[3], test_device()).unwrap();
2497
2498 let eq = a.eq_tensor(&b).unwrap().to_f32_vec().unwrap();
2499 assert_eq!(eq, vec![1.0, 0.0, 1.0]);
2500
2501 let ne = a.ne_tensor(&b).unwrap().to_f32_vec().unwrap();
2502 assert_eq!(ne, vec![0.0, 1.0, 0.0]);
2503 }
2504
2505 #[test]
2506 fn test_gt_lt_ge_le_tensor() {
2507 let a = Tensor::from_f32(&[1.0, 3.0, 2.0], &[3], test_device()).unwrap();
2508 let b = Tensor::from_f32(&[2.0, 2.0, 2.0], &[3], test_device()).unwrap();
2509
2510 assert_eq!(a.gt(&b).unwrap().to_f32_vec().unwrap(), vec![0.0, 1.0, 0.0]);
2511 assert_eq!(a.lt(&b).unwrap().to_f32_vec().unwrap(), vec![1.0, 0.0, 0.0]);
2512 assert_eq!(a.ge(&b).unwrap().to_f32_vec().unwrap(), vec![0.0, 1.0, 1.0]);
2513 assert_eq!(a.le(&b).unwrap().to_f32_vec().unwrap(), vec![1.0, 0.0, 1.0]);
2514 }
2515
2516 #[test]
2517 fn test_sign_floor_ceil_round() {
2518 let t = Tensor::from_f32(&[-2.7, 0.0, 1.3], &[3], test_device()).unwrap();
2519 assert_eq!(t.sign().unwrap().to_f32_vec().unwrap(), vec![-1.0, 0.0, 1.0]);
2520 assert_eq!(t.floor().unwrap().to_f32_vec().unwrap(), vec![-3.0, 0.0, 1.0]);
2521 assert_eq!(t.ceil().unwrap().to_f32_vec().unwrap(), vec![-2.0, 0.0, 2.0]);
2522
2523 let r = Tensor::from_f32(&[-0.6, 0.4, 1.5], &[3], test_device()).unwrap();
2524 let rv = r.round().unwrap().to_f32_vec().unwrap();
2525 assert!((rv[0] - (-1.0)).abs() < 1e-5);
2526 assert!((rv[1] - 0.0).abs() < 1e-5);
2527 assert!((rv[2] - 2.0).abs() < 1e-5);
2528 }
2529
2530 #[test]
2531 fn test_argmin() {
2532 let t = Tensor::from_f32(&[3.0, 1.0, 2.0], &[3], test_device()).unwrap();
2533 let idx = t.argmin(0, false).unwrap().to_i64_vec().unwrap();
2534 assert_eq!(idx, vec![1]);
2535 }
2536
2537 #[test]
2538 fn test_var_std() {
2539 let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2540 assert!((t.var().unwrap().item().unwrap() - 1.0).abs() < 1e-5);
2542 assert!((t.std().unwrap().item().unwrap() - 1.0).abs() < 1e-5);
2543
2544 let t2 = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2546 let vd = t2.var_dim(1, false).unwrap().to_f32_vec().unwrap();
2547 assert!((vd[0] - 0.5).abs() < 1e-5);
2548 assert!((vd[1] - 0.5).abs() < 1e-5);
2549 }
2550
2551 #[test]
2552 fn test_sin_cos_reciprocal() {
2553 let t = Tensor::from_f32(&[0.0, 1.0], &[2], test_device()).unwrap();
2554 let s = t.sin().unwrap().to_f32_vec().unwrap();
2555 assert!((s[0] - 0.0).abs() < 1e-5);
2556 assert!((s[1] - 1.0_f32.sin()).abs() < 1e-5);
2557
2558 let c = t.cos().unwrap().to_f32_vec().unwrap();
2559 assert!((c[0] - 1.0).abs() < 1e-5);
2560 assert!((c[1] - 1.0_f32.cos()).abs() < 1e-5);
2561
2562 let r = Tensor::from_f32(&[2.0, 5.0], &[2], test_device()).unwrap();
2563 let rec = r.reciprocal().unwrap().to_f32_vec().unwrap();
2564 assert!((rec[0] - 0.5).abs() < 1e-5);
2565 assert!((rec[1] - 0.2).abs() < 1e-5);
2566 }
2567
2568 #[test]
2569 fn test_eye_full() {
2570 let eye = Tensor::eye(3, test_opts()).unwrap();
2571 assert_eq!(eye.shape(), vec![3, 3]);
2572 let data = eye.to_f32_vec().unwrap();
2573 assert_eq!(data, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
2574
2575 let f = Tensor::full(&[2, 3], 7.0, test_opts()).unwrap();
2576 assert_eq!(f.shape(), vec![2, 3]);
2577 assert_eq!(f.to_f32_vec().unwrap(), vec![7.0; 6]);
2578 }
2579
2580 #[test]
2581 fn test_gather_scatter_add() {
2582 let t = Tensor::from_f32(&[10.0, 20.0, 30.0, 40.0], &[2, 2], test_device()).unwrap();
2584 let idx = Tensor::from_i64(&[1, 0, 0, 1], &[2, 2], test_device()).unwrap();
2585 let g = t.gather(1, &idx).unwrap().to_f32_vec().unwrap();
2586 assert_eq!(g, vec![20.0, 10.0, 30.0, 40.0]);
2587
2588 let base = Tensor::zeros(&[2, 3], test_opts()).unwrap();
2590 let src = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &[2, 2], test_device()).unwrap();
2591 let idx2 = Tensor::from_i64(&[0, 2, 1, 0], &[2, 2], test_device()).unwrap();
2592 let sa = base.scatter_add(1, &idx2, &src).unwrap();
2593 let data = sa.to_f32_vec().unwrap();
2594 assert!((data[0] - 1.0).abs() < 1e-5);
2597 assert!((data[2] - 2.0).abs() < 1e-5);
2598 assert!((data[3] - 4.0).abs() < 1e-5);
2599 assert!((data[4] - 3.0).abs() < 1e-5);
2600 }
2601
2602 #[test]
2603 fn test_topk_sort() {
2604 let t = Tensor::from_f32(&[3.0, 1.0, 4.0, 1.0, 5.0], &[5], test_device()).unwrap();
2605 let (vals, idxs) = t.topk(3, 0, true, true).unwrap();
2606 assert_eq!(vals.to_f32_vec().unwrap(), vec![5.0, 4.0, 3.0]);
2607 let idx_data = idxs.to_i64_vec().unwrap();
2608 assert_eq!(idx_data, vec![4, 2, 0]);
2609
2610 let (svals, sidxs) = t.sort(0, false).unwrap();
2611 assert_eq!(svals.to_f32_vec().unwrap(), vec![1.0, 1.0, 3.0, 4.0, 5.0]);
2612 let si = sidxs.to_i64_vec().unwrap();
2613 assert_eq!(si[4], 4); }
2615
2616 #[test]
2617 fn test_chunk_repeat_pad() {
2618 let t = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], test_device()).unwrap();
2619 let chunks = t.chunk(3, 0).unwrap();
2620 assert_eq!(chunks.len(), 3);
2621 assert_eq!(chunks[0].to_f32_vec().unwrap(), vec![1.0, 2.0]);
2622 assert_eq!(chunks[1].to_f32_vec().unwrap(), vec![3.0, 4.0]);
2623 assert_eq!(chunks[2].to_f32_vec().unwrap(), vec![5.0, 6.0]);
2624
2625 let s = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2626 let rep = s.repeat(&[3]).unwrap();
2627 assert_eq!(rep.to_f32_vec().unwrap(), vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
2628
2629 let pad = s.pad(&[1, 2], 0.0).unwrap();
2630 assert_eq!(pad.shape(), vec![5]);
2631 assert_eq!(pad.to_f32_vec().unwrap(), vec![0.0, 1.0, 2.0, 0.0, 0.0]);
2632 }
2633
2634 #[test]
2635 fn test_zeros_like_ones_like() {
2636 let t = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2637 let zl = Tensor::zeros_like(&t).unwrap();
2638 assert_eq!(zl.to_f32_vec().unwrap(), vec![0.0, 0.0]);
2639 assert_eq!(zl.dtype(), DType::Float32);
2640
2641 let ol = Tensor::ones_like(&t).unwrap();
2642 assert_eq!(ol.to_f32_vec().unwrap(), vec![1.0, 1.0]);
2643 }
2644
2645 #[test]
2646 fn test_unsqueeze_many() {
2647 let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2648 let u = t.unsqueeze_many(&[1, 2]).unwrap();
2649 assert_eq!(u.shape(), vec![3, 1, 1]);
2650 let u2 = t.unsqueeze(1).unwrap().unsqueeze(2).unwrap();
2652 assert_eq!(u.shape(), u2.shape());
2653 assert_eq!(u.to_f32_vec().unwrap(), u2.to_f32_vec().unwrap());
2654 }
2655
2656 #[test]
2657 fn test_meshgrid() {
2658 let a = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap();
2659 let b = Tensor::from_f32(&[4.0, 5.0], &[2], test_device()).unwrap();
2660 let grids = Tensor::meshgrid(&[&a, &b]).unwrap();
2661 assert_eq!(grids.len(), 2);
2662 assert_eq!(grids[0].shape(), vec![3, 2]);
2663 assert_eq!(grids[1].shape(), vec![3, 2]);
2664 assert_eq!(grids[0].to_f32_vec().unwrap(), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
2666 assert_eq!(grids[1].to_f32_vec().unwrap(), vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
2668 }
2669
2670 #[test]
2671 fn test_cdist() {
2672 let x = Tensor::from_f32(&[0.0, 0.0], &[1, 1, 2], test_device()).unwrap();
2674 let y = Tensor::from_f32(&[3.0, 4.0], &[1, 1, 2], test_device()).unwrap();
2675 let d = x.cdist(&y).unwrap();
2676 assert_eq!(d.shape(), vec![1, 1, 1]);
2677 assert!((d.item().unwrap() - 5.0).abs() < 1e-4);
2678 }
2679
2680 #[test]
2681 fn test_cdist_p1() {
2682 let x = Tensor::from_f32(&[0.0, 0.0], &[1, 1, 2], test_device()).unwrap();
2684 let y = Tensor::from_f32(&[3.0, 4.0], &[1, 1, 2], test_device()).unwrap();
2685 let d = x.cdist_p(&y, 1.0).unwrap();
2686 assert!((d.item().unwrap() - 7.0).abs() < 1e-4);
2687 }
2688
2689 #[test]
2690 fn test_from_i64_device() {
2691 let t = Tensor::from_i64(&[1, 2, 3], &[3], test_device()).unwrap();
2692 assert_eq!(t.device(), test_device());
2693 assert_eq!(t.dtype(), DType::Int64);
2694 assert_eq!(t.to_i64_vec().unwrap(), vec![1, 2, 3]);
2695 }
2696
2697 #[test]
2698 fn test_pin_memory() {
2699 let t = Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], Device::CPU).unwrap();
2700 assert!(!t.is_pinned(), "regular CPU tensor should not be pinned");
2701
2702 if cuda_available() {
2703 let pinned = t.pin_memory().unwrap();
2704 assert!(pinned.is_pinned(), "pin_memory() result should be pinned");
2705 assert_eq!(pinned.device(), Device::CPU, "pinned tensor should stay on CPU");
2706 assert_eq!(pinned.to_f32_vec().unwrap(), vec![1.0, 2.0, 3.0],
2707 "data should be preserved after pinning");
2708 } else {
2709 assert!(t.pin_memory().is_err(),
2711 "pin_memory should fail without CUDA");
2712 }
2713 }
2714
2715 #[test]
2716 fn test_adam_step_basic() {
2717 let param = Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap();
2719 let grad = Tensor::from_f32(&[0.5, 0.5], &[2], test_device()).unwrap();
2720 let m = Tensor::zeros(&[2], test_opts()).unwrap();
2721 let v = Tensor::zeros(&[2], test_opts()).unwrap();
2722
2723 param.adam_step(&grad, &m, &v, 0.001, 0.9, 0.999, 1e-8, 0.0, 1).unwrap();
2724
2725 let p = param.to_f32_vec().unwrap();
2726 assert!(p[0] < 1.0, "param[0] should decrease");
2727 assert!(p[1] < 2.0, "param[1] should decrease");
2728 let m_data = m.to_f32_vec().unwrap();
2730 let v_data = v.to_f32_vec().unwrap();
2731 assert!(m_data[0] > 0.0, "m should be updated");
2732 assert!(v_data[0] > 0.0, "v should be updated");
2733 }
2734
2735 #[test]
2738 fn test_device_enum_basics() {
2739 assert_eq!(Device::CPU, Device::CPU);
2740 assert_eq!(Device::CUDA(0), Device::CUDA(0));
2741 assert_ne!(Device::CUDA(0), Device::CUDA(1));
2742 assert_ne!(Device::CPU, Device::CUDA(0));
2743
2744 assert!(!Device::CPU.is_cuda());
2745 assert!(Device::CUDA(0).is_cuda());
2746 assert!(Device::CUDA(1).is_cuda());
2747
2748 assert_eq!(Device::CPU.index(), 0);
2749 assert_eq!(Device::CUDA(0).index(), 0);
2750 assert_eq!(Device::CUDA(1).index(), 1);
2751 }
2752
2753 #[test]
2754 fn test_device_display() {
2755 assert_eq!(format!("{}", Device::CPU), "cpu");
2756 assert_eq!(format!("{}", Device::CUDA(0)), "cuda");
2757 assert_eq!(format!("{}", Device::CUDA(1)), "cuda:1");
2758 }
2759
2760 #[test]
2761 fn test_device_ffi_roundtrip() {
2762 let devices = [Device::CPU, Device::CUDA(0), Device::CUDA(1), Device::CUDA(7)];
2763 for dev in &devices {
2764 let (dt, di) = dev.to_ffi();
2765 let back = Device::from_ffi(dt, di);
2766 assert_eq!(*dev, back, "FFI roundtrip failed for {:?}", dev);
2767 }
2768 }
2769
2770 #[test]
2771 fn test_device_hash() {
2772 use std::collections::HashSet;
2773 let mut set = HashSet::new();
2774 set.insert(Device::CPU);
2775 set.insert(Device::CUDA(0));
2776 set.insert(Device::CUDA(1));
2777 assert_eq!(set.len(), 3);
2778 assert!(set.contains(&Device::CPU));
2779 assert!(set.contains(&Device::CUDA(0)));
2780 assert!(set.contains(&Device::CUDA(1)));
2781 }
2782
2783 #[test]
2786 fn test_tensor_is_send_sync() {
2787 fn assert_send_sync<T: Send + Sync>() {}
2788 assert_send_sync::<Tensor>();
2789 }
2790}