1use crate::ops::shape::ShapeOps;
7use crate::types::{DataType, Tensor as RonnTensor, TensorLayout};
8use anyhow::{Result, anyhow};
9use candle_core::{DType, Device, Module, Shape, Tensor as CandleTensor};
10
11#[derive(Debug, Clone)]
13pub struct Tensor {
14 candle_tensor: CandleTensor,
16 dtype: DataType,
18 layout: TensorLayout,
20}
21
22impl Tensor {
23 pub fn from_data(
41 data: Vec<f32>,
42 shape: Vec<usize>,
43 dtype: DataType,
44 layout: TensorLayout,
45 ) -> Result<Self> {
46 let device = Device::Cpu;
47 let candle_shape = Shape::from_dims(&shape);
48
49 let candle_tensor = match dtype {
50 DataType::F32 => CandleTensor::from_vec(data, candle_shape, &device)?,
51 DataType::F16 => {
52 let f16_data: Vec<half::f16> = data.into_iter().map(half::f16::from_f32).collect();
53 CandleTensor::from_vec(f16_data, candle_shape, &device)?
54 }
55 DataType::BF16 => {
56 let bf16_data: Vec<half::bf16> =
57 data.into_iter().map(half::bf16::from_f32).collect();
58 CandleTensor::from_vec(bf16_data, candle_shape, &device)?
59 }
60 DataType::F64 => {
61 let f64_data: Vec<f64> = data.into_iter().map(|x| x as f64).collect();
62 CandleTensor::from_vec(f64_data, candle_shape, &device)?
63 }
64 DataType::U8 => {
65 let u8_data: Vec<u8> = data.into_iter().map(|x| x as u8).collect();
66 CandleTensor::from_vec(u8_data, candle_shape, &device)?
67 }
68 DataType::U32 => {
69 let u32_data: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
70 CandleTensor::from_vec(u32_data, candle_shape, &device)?
71 }
72 DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => {
74 CandleTensor::from_vec(data, candle_shape, &device)?
75 }
76 };
77
78 Ok(Self {
79 candle_tensor,
80 dtype,
81 layout,
82 })
83 }
84
85 pub fn zeros(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
92 let device = Device::Cpu;
93 let candle_dtype = dtype_to_candle(&dtype)?;
94 let candle_shape = Shape::from_dims(&shape);
95
96 let candle_tensor = CandleTensor::zeros(candle_shape, candle_dtype, &device)?;
97
98 Ok(Self {
99 candle_tensor,
100 dtype,
101 layout,
102 })
103 }
104
105 pub fn ones(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
112 let device = Device::Cpu;
113 let candle_dtype = dtype_to_candle(&dtype)?;
114 let candle_shape = Shape::from_dims(&shape);
115
116 let candle_tensor = CandleTensor::ones(candle_shape, candle_dtype, &device)?;
117
118 Ok(Self {
119 candle_tensor,
120 dtype,
121 layout,
122 })
123 }
124
125 pub fn rand(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
127 let device = Device::Cpu;
128 let _candle_dtype = dtype_to_candle(&dtype)?;
129 let candle_shape = Shape::from_dims(&shape);
130
131 let candle_tensor = CandleTensor::rand(0.0, 1.0, candle_shape, &device)?;
132
133 Ok(Self {
134 candle_tensor,
135 dtype,
136 layout,
137 })
138 }
139
140 pub fn shape(&self) -> Vec<usize> {
142 self.candle_tensor.dims().to_vec()
143 }
144
145 pub fn dtype(&self) -> DataType {
147 self.dtype
148 }
149
150 pub fn layout(&self) -> TensorLayout {
152 self.layout
153 }
154
155 pub fn ndim(&self) -> usize {
157 self.candle_tensor.dims().len()
158 }
159
160 pub fn numel(&self) -> usize {
162 self.candle_tensor.elem_count()
163 }
164
165 pub fn device(&self) -> &Device {
167 self.candle_tensor.device()
168 }
169
170 pub fn to_cpu(&self) -> Result<Self> {
172 let cpu_tensor = self.candle_tensor.to_device(&Device::Cpu)?;
173 Ok(Self {
174 candle_tensor: cpu_tensor,
175 dtype: self.dtype,
176 layout: self.layout,
177 })
178 }
179
180 #[cfg(feature = "gpu")]
182 pub fn to_gpu(&self, device_id: usize) -> Result<Self> {
183 let gpu_device = Device::new_cuda(device_id)?;
184 let gpu_tensor = self.candle_tensor.to_device(&gpu_device)?;
185 Ok(Self {
186 candle_tensor: gpu_tensor,
187 dtype: self.dtype,
188 layout: self.layout,
189 })
190 }
191
192 pub fn to_vec(&self) -> Result<Vec<f32>> {
194 let flattened = if self.candle_tensor.dims().len() > 1 {
196 self.candle_tensor.flatten_all()?
197 } else {
198 self.candle_tensor.clone()
199 };
200
201 match self.dtype {
202 DataType::F32 | DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => {
203 let data: Vec<f32> = flattened.to_vec1()?;
204 Ok(data)
205 }
206 DataType::F16 => {
207 let data: Vec<half::f16> = flattened.to_vec1()?;
208 Ok(data.into_iter().map(|x| x.to_f32()).collect())
209 }
210 DataType::BF16 => {
211 let data: Vec<half::bf16> = flattened.to_vec1()?;
212 Ok(data.into_iter().map(|x| x.to_f32()).collect())
213 }
214 DataType::F64 => {
215 let data: Vec<f64> = flattened.to_vec1()?;
216 Ok(data.into_iter().map(|x| x as f32).collect())
217 }
218 DataType::U8 => {
219 let data: Vec<u8> = flattened.to_vec1()?;
220 Ok(data.into_iter().map(|x| x as f32).collect())
221 }
222 DataType::U32 => {
223 let data: Vec<u32> = flattened.to_vec1()?;
224 Ok(data.into_iter().map(|x| x as f32).collect())
225 }
226 }
227 }
228
229 pub fn candle_tensor(&self) -> &CandleTensor {
231 &self.candle_tensor
232 }
233
234 pub fn from_candle(candle_tensor: CandleTensor, dtype: DataType, layout: TensorLayout) -> Self {
236 Self {
237 candle_tensor,
238 dtype,
239 layout,
240 }
241 }
242
243 pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
245 let shape1 = self.shape();
246 let shape2 = other.shape();
247
248 let max_len = shape1.len().max(shape2.len());
250 let mut padded1 = vec![1; max_len - shape1.len()];
251 let mut padded2 = vec![1; max_len - shape2.len()];
252 padded1.extend(shape1);
253 padded2.extend(shape2);
254
255 for (d1, d2) in padded1.iter().zip(padded2.iter()) {
257 if *d1 != *d2 && *d1 != 1 && *d2 != 1 {
258 return false;
259 }
260 }
261 true
262 }
263
264 pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Vec<usize>> {
266 let max_len = shape1.len().max(shape2.len());
267 let mut padded1 = vec![1; max_len - shape1.len()];
268 let mut padded2 = vec![1; max_len - shape2.len()];
269 padded1.extend(shape1);
270 padded2.extend(shape2);
271
272 let mut result = Vec::with_capacity(max_len);
273 for (d1, d2) in padded1.iter().zip(padded2.iter()) {
274 match (d1, d2) {
275 (1, d) | (d, 1) => result.push(*d),
276 (d1, d2) if d1 == d2 => result.push(*d1),
277 (d1, d2) => {
278 return Err(anyhow!(
279 "Cannot broadcast shapes: dimension {} vs {}",
280 d1,
281 d2
282 ));
283 }
284 }
285 }
286 Ok(result)
287 }
288
289 pub fn conv2d(
291 &self,
292 weight: &Tensor,
293 bias: Option<&Tensor>,
294 strides: &[usize],
295 pads: &[usize],
296 dilations: &[usize],
297 groups: usize,
298 ) -> Result<Tensor> {
299 let _ = (weight, bias, strides, pads, dilations, groups);
301 Err(anyhow!("conv2d not yet fully implemented"))
302 }
303
304 pub fn max_pool2d(
306 &self,
307 kernel_shape: &[usize],
308 strides: &[usize],
309 pads: &[usize],
310 ) -> Result<Tensor> {
311 let _ = (kernel_shape, strides, pads);
312 Err(anyhow!("max_pool2d not yet fully implemented"))
313 }
314
315 pub fn avg_pool2d(
317 &self,
318 kernel_shape: &[usize],
319 strides: &[usize],
320 pads: &[usize],
321 ) -> Result<Tensor> {
322 let _ = (kernel_shape, strides, pads);
323 Err(anyhow!("avg_pool2d not yet fully implemented"))
324 }
325
326 pub fn batch_norm(
328 &self,
329 scale: &Tensor,
330 bias: &Tensor,
331 mean: &Tensor,
332 var: &Tensor,
333 epsilon: f32,
334 ) -> Result<Tensor> {
335 let _ = (scale, bias, mean, var, epsilon);
336 Err(anyhow!("batch_norm not yet fully implemented"))
337 }
338
339 pub fn rank(&self) -> usize {
341 self.ndim()
342 }
343
344 pub fn to_vec1<T: candle_core::WithDType>(&self) -> Result<Vec<T>> {
346 let flattened = if self.candle_tensor.dims().len() > 1 {
347 self.candle_tensor.flatten_all()?
348 } else {
349 self.candle_tensor.clone()
350 };
351 Ok(flattened.to_vec1()?)
352 }
353
354 pub fn stack(tensors: &[&Tensor], dim: usize) -> Result<Self> {
372 if tensors.is_empty() {
373 return Err(anyhow!("Cannot stack empty tensor list"));
374 }
375
376 let candle_tensors: Vec<_> = tensors.iter().map(|t| &t.candle_tensor).collect();
377 let stacked = CandleTensor::stack(&candle_tensors, dim)?;
378
379 Ok(Self {
380 candle_tensor: stacked,
381 dtype: tensors[0].dtype,
382 layout: tensors[0].layout,
383 })
384 }
385
386 pub fn split(&self, num_chunks: usize, dim: usize) -> Result<Vec<Tensor>> {
408 if num_chunks == 0 {
409 return Err(anyhow!("Cannot split into 0 chunks"));
410 }
411
412 let shape = self.shape();
413 if dim >= shape.len() {
414 return Err(anyhow!(
415 "Dimension {} out of bounds for shape {:?}",
416 dim,
417 shape
418 ));
419 }
420
421 let dim_size = shape[dim];
422 if dim_size % num_chunks != 0 {
423 return Err(anyhow!(
424 "Dimension size {} not evenly divisible by {} chunks",
425 dim_size,
426 num_chunks
427 ));
428 }
429
430 let chunk_size = dim_size / num_chunks;
431 let mut chunks = Vec::with_capacity(num_chunks);
432
433 for i in 0..num_chunks {
434 let start = i * chunk_size;
435 let _end = start + chunk_size;
436 let chunk = self.candle_tensor.narrow(dim, start, chunk_size)?;
437 chunks.push(Self {
438 candle_tensor: chunk,
439 dtype: self.dtype,
440 layout: self.layout,
441 });
442 }
443
444 Ok(chunks)
445 }
446
447 pub fn gather(&self, indices: &Tensor, dim: usize) -> Result<Tensor> {
449 let _ = (indices, dim);
450 Err(anyhow!("gather not yet fully implemented"))
451 }
452
453 pub fn transpose(&self, perm: &[usize]) -> Result<Tensor> {
455 let result = self.candle_tensor.permute(perm)?;
456 Ok(Tensor::from_candle(result, self.dtype, self.layout))
457 }
458
459 pub fn layer_norm(
484 &self,
485 scale: Option<&Tensor>,
486 bias: Option<&Tensor>,
487 epsilon: f32,
488 axis: i32,
489 ) -> Result<Self> {
490 use candle_nn::LayerNorm;
491
492 let shape = self.shape();
493 let _normalized_shape = if axis == -1 {
494 vec![shape[shape.len() - 1]]
495 } else {
496 let axis_usize = if axis < 0 {
497 (shape.len() as i32 + axis) as usize
498 } else {
499 axis as usize
500 };
501 vec![shape[axis_usize]]
502 };
503
504 let normalized = if let (Some(s), Some(b)) = (scale, bias) {
507 let ln = LayerNorm::new(
508 s.candle_tensor.clone(),
509 b.candle_tensor.clone(),
510 epsilon as f64,
511 );
512 ln.forward(&self.candle_tensor)?
513 } else {
514 let mean = self.candle_tensor.mean_keepdim(axis as usize)?;
516 let variance = self
517 .candle_tensor
518 .broadcast_sub(&mean)?
519 .sqr()?
520 .mean_keepdim(axis as usize)?;
521 let std = (variance + epsilon as f64)?.sqrt()?;
522 self.candle_tensor
523 .broadcast_sub(&mean)?
524 .broadcast_div(&std)?
525 };
526
527 Ok(Self::from_candle(normalized, self.dtype, self.layout))
528 }
529
530 pub fn attention(
558 &self,
559 key: &Tensor,
560 value: &Tensor,
561 num_heads: usize,
562 mask: Option<&Tensor>,
563 ) -> Result<Self> {
564 let query = &self.candle_tensor;
565 let key = &key.candle_tensor;
566 let value = &value.candle_tensor;
567
568 let query_shape = query.dims();
570 if query_shape.len() != 3 {
571 return Err(anyhow!(
572 "Query must be 3D (batch, seq_len, d_model), got {:?}",
573 query_shape
574 ));
575 }
576
577 let batch_size = query_shape[0];
578 let seq_len = query_shape[1];
579 let d_model = query_shape[2];
580
581 if d_model % num_heads != 0 {
582 return Err(anyhow!(
583 "d_model ({}) must be divisible by num_heads ({})",
584 d_model,
585 num_heads
586 ));
587 }
588
589 let d_k = d_model / num_heads;
590
591 let q = query
594 .reshape(&[batch_size, seq_len, num_heads, d_k])?
595 .transpose(1, 2)?;
596 let k = key
597 .reshape(&[batch_size, seq_len, num_heads, d_k])?
598 .transpose(1, 2)?;
599 let v = value
600 .reshape(&[batch_size, seq_len, num_heads, d_k])?
601 .transpose(1, 2)?;
602
603 let k_t = k.transpose(2, 3)?;
605 let scores = (q.matmul(&k_t)? / (d_k as f64).sqrt())?;
606
607 let scores = if let Some(m) = mask {
609 scores.broadcast_add(&m.candle_tensor)?
610 } else {
611 scores
612 };
613
614 let attention_weights = candle_nn::ops::softmax_last_dim(&scores)?;
616
617 let output = attention_weights.matmul(&v)?;
619
620 let output = output
622 .transpose(1, 2)?
623 .reshape(&[batch_size, seq_len, d_model])?;
624
625 Ok(Self::from_candle(output, self.dtype, self.layout))
626 }
627
628 pub fn clip(&self, min: f32, max: f32) -> Result<Self> {
630 let result = self.candle_tensor.clamp(min, max)?;
631 Ok(Self::from_candle(result, self.dtype, self.layout))
632 }
633
634 pub fn pow_tensor(&self, exponent: &Tensor) -> Result<Self> {
637 let result = self.candle_tensor.pow(&exponent.candle_tensor)?;
638 Ok(Self::from_candle(result, self.dtype, self.layout))
639 }
640
641 pub fn sqrt(&self) -> Result<Self> {
643 let result = self.candle_tensor.sqrt()?;
644 Ok(Self::from_candle(result, self.dtype, self.layout))
645 }
646
647 pub fn exp(&self) -> Result<Self> {
649 let result = self.candle_tensor.exp()?;
650 Ok(Self::from_candle(result, self.dtype, self.layout))
651 }
652
653 pub fn log(&self) -> Result<Self> {
655 let result = self.candle_tensor.log()?;
656 Ok(Self::from_candle(result, self.dtype, self.layout))
657 }
658
659 pub fn neg(&self) -> Result<Self> {
661 let result = self.candle_tensor.neg()?;
662 Ok(Self::from_candle(result, self.dtype, self.layout))
663 }
664
665 pub fn abs(&self) -> Result<Self> {
667 let result = self.candle_tensor.abs()?;
668 Ok(Self::from_candle(result, self.dtype, self.layout))
669 }
670
671 pub fn leaky_relu(&self, alpha: f32) -> Result<Self> {
673 let scaled = self.candle_tensor.affine(alpha as f64, 0.0)?;
674 let result = self.candle_tensor.maximum(&scaled)?;
675 Ok(Self::from_candle(result, self.dtype, self.layout))
676 }
677
678 pub fn elu(&self, alpha: f32) -> Result<Self> {
680 let zero = self.candle_tensor.zeros_like()?;
682 let mask = self.candle_tensor.gt(&zero)?;
683
684 let positive_part = &self.candle_tensor;
685 let exp_part = self.candle_tensor.exp()?.affine(1.0, -1.0)?;
686 let negative_part = exp_part.affine(alpha as f64, 0.0)?;
687
688 let result = mask.where_cond(positive_part, &negative_part)?;
689 Ok(Self::from_candle(result, self.dtype, self.layout))
690 }
691
692 pub fn swish(&self) -> Result<Self> {
694 let sigmoid = candle_nn::ops::sigmoid(&self.candle_tensor)?;
695 let result = (&self.candle_tensor * &sigmoid)?;
696 Ok(Self::from_candle(result, self.dtype, self.layout))
697 }
698
699 pub fn squeeze(&self, axes: Option<Vec<usize>>) -> Result<Self> {
701 let shape = self.shape();
702 let new_shape: Vec<usize> = if let Some(axes) = axes {
703 shape
705 .iter()
706 .enumerate()
707 .filter(|(i, dim)| !axes.contains(i) || **dim != 1)
708 .map(|(_, dim)| *dim)
709 .collect()
710 } else {
711 shape.iter().copied().filter(|dim| *dim != 1).collect()
713 };
714
715 if new_shape.is_empty() {
716 return self.reshape(&[1]);
718 }
719
720 self.reshape(&new_shape)
721 }
722
723 pub fn unsqueeze(&self, axes: &[usize]) -> Result<Self> {
725 let mut new_shape = self.shape();
726 let mut axes_sorted = axes.to_vec();
727 axes_sorted.sort_unstable();
728
729 for &axis in &axes_sorted {
730 if axis > new_shape.len() {
732 return Err(anyhow!(
733 "Unsqueeze axis {} is out of bounds for shape with {} dimensions",
734 axis,
735 new_shape.len()
736 ));
737 }
738 new_shape.insert(axis, 1);
739 }
740
741 self.reshape(&new_shape)
742 }
743
744 pub fn reduce_mean(&self, axes: &[usize], keepdims: bool) -> Result<Self> {
746 let mut result = self.candle_tensor.clone();
747
748 let mut sorted_axes = axes.to_vec();
750 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
751
752 for &axis in &sorted_axes {
753 result = result.mean_keepdim(axis)?;
754 if !keepdims {
755 result = result.squeeze(axis)?;
756 }
757 }
758
759 Ok(Self::from_candle(result, self.dtype, self.layout))
760 }
761
762 pub fn reduce_sum(&self, axes: &[usize], keepdims: bool) -> Result<Self> {
764 let mut result = self.candle_tensor.clone();
765
766 let mut sorted_axes = axes.to_vec();
768 sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
769
770 for &axis in &sorted_axes {
771 result = result.sum_keepdim(axis)?;
772 if !keepdims {
773 result = result.squeeze(axis)?;
774 }
775 }
776
777 Ok(Self::from_candle(result, self.dtype, self.layout))
778 }
779
780 pub fn cast(&self, to: DataType) -> Result<Self> {
782 let target_dtype = dtype_to_candle(&to)?;
783 let result = self.candle_tensor.to_dtype(target_dtype)?;
784 Ok(Self::from_candle(result, to, self.layout))
785 }
786
787 pub fn to_scalar_f32(&self) -> Result<f32> {
789 let value = self.candle_tensor.to_scalar::<f32>()?;
790 Ok(value)
791 }
792}
793
794fn dtype_to_candle(dtype: &DataType) -> Result<DType> {
796 match dtype {
797 DataType::F32 => Ok(DType::F32),
798 DataType::F16 => Ok(DType::F16),
799 DataType::BF16 => Ok(DType::BF16),
800 DataType::F64 => Ok(DType::F64),
801 DataType::U8 => Ok(DType::U8),
802 DataType::U32 => Ok(DType::U32),
803 DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => Ok(DType::F32),
805 }
806}
807
808#[allow(dead_code)]
810fn dtype_from_candle(dtype: DType) -> DataType {
811 match dtype {
812 DType::F32 => DataType::F32,
813 DType::F16 => DataType::F16,
814 DType::U8 => DataType::U8,
815 DType::U32 => DataType::U32,
816 DType::F64 => DataType::F64,
817 _ => DataType::F32, }
819}
820
821impl From<RonnTensor> for Tensor {
823 fn from(legacy: RonnTensor) -> Self {
824 Self::from_data(legacy.data, legacy.shape, legacy.dtype, legacy.layout)
825 .expect("Failed to convert legacy tensor")
826 }
827}
828
829impl From<Tensor> for RonnTensor {
831 fn from(tensor: Tensor) -> Self {
832 let data = tensor.to_vec().expect("Failed to extract tensor data");
833 Self {
834 data,
835 shape: tensor.shape(),
836 dtype: tensor.dtype,
837 layout: tensor.layout,
838 }
839 }
840}
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_tensor_creation() -> Result<()> {
848 let data = vec![1.0, 2.0, 3.0, 4.0];
849 let tensor = Tensor::from_data(
850 data.clone(),
851 vec![2, 2],
852 DataType::F32,
853 TensorLayout::RowMajor,
854 )?;
855
856 assert_eq!(tensor.shape(), vec![2, 2]);
857 assert_eq!(tensor.dtype(), DataType::F32);
858 assert_eq!(tensor.numel(), 4);
859
860 let extracted = tensor.to_vec()?;
861 assert_eq!(extracted, data);
862
863 Ok(())
864 }
865
866 #[test]
867 fn test_zeros_and_ones() -> Result<()> {
868 let zeros = Tensor::zeros(vec![3, 3], DataType::F32, TensorLayout::RowMajor)?;
869 let zeros_data = zeros.to_vec()?;
870 assert!(zeros_data.iter().all(|&x| x == 0.0));
871
872 let ones = Tensor::ones(vec![2, 3], DataType::F32, TensorLayout::RowMajor)?;
873 let ones_data = ones.to_vec()?;
874 assert!(ones_data.iter().all(|&x| x == 1.0));
875
876 Ok(())
877 }
878
879 #[test]
880 fn test_broadcasting() {
881 assert_eq!(
883 Tensor::broadcast_shape(&[3, 1], &[1, 4]).unwrap(),
884 vec![3, 4]
885 );
886 assert_eq!(
887 Tensor::broadcast_shape(&[2, 3, 1], &[1, 4]).unwrap(),
888 vec![2, 3, 4]
889 );
890
891 assert!(Tensor::broadcast_shape(&[3, 2], &[2, 3]).is_err());
893 }
894
895 #[test]
896 fn test_broadcastable_check() -> Result<()> {
897 let tensor1 = Tensor::zeros(vec![3, 1], DataType::F32, TensorLayout::RowMajor)?;
898 let tensor2 = Tensor::zeros(vec![1, 4], DataType::F32, TensorLayout::RowMajor)?;
899 let tensor3 = Tensor::zeros(vec![2, 3], DataType::F32, TensorLayout::RowMajor)?;
900
901 assert!(tensor1.is_broadcastable_with(&tensor2));
902 assert!(!tensor1.is_broadcastable_with(&tensor3));
903
904 Ok(())
905 }
906
907 #[test]
908 fn test_data_type_conversions() -> Result<()> {
909 let data = vec![1.5, 2.5, 3.5, 4.5];
911 let tensor_f16 = Tensor::from_data(
912 data.clone(),
913 vec![2, 2],
914 DataType::F16,
915 TensorLayout::RowMajor,
916 )?;
917 let extracted_f16 = tensor_f16.to_vec()?;
918
919 for (original, extracted) in data.iter().zip(extracted_f16.iter()) {
921 assert!((original - extracted).abs() < 0.01);
922 }
923
924 let int_data = vec![1.0, -2.0, 3.0, -4.0];
926 let tensor_i8 =
927 Tensor::from_data(int_data, vec![2, 2], DataType::I8, TensorLayout::RowMajor)?;
928 let extracted_i8 = tensor_i8.to_vec()?;
929 assert_eq!(extracted_i8, vec![1.0, -2.0, 3.0, -4.0]);
930
931 Ok(())
932 }
933
934 #[test]
935 fn test_device_operations() -> Result<()> {
936 let tensor = Tensor::zeros(vec![2, 2], DataType::F32, TensorLayout::RowMajor)?;
937
938 assert!(matches!(tensor.device(), Device::Cpu));
940
941 let cpu_tensor = tensor.to_cpu()?;
943 assert!(matches!(cpu_tensor.device(), Device::Cpu));
944
945 Ok(())
946 }
947}