Skip to main content

ronn_core/
tensor.rs

1//! Tensor implementation with Candle backend integration.
2//!
3//! This module provides the core Tensor type for RONN with seamless integration
4//! to the Candle tensor library for high-performance operations and GPU acceleration.
5
6use crate::ops::shape::ShapeOps;
7use crate::types::{DataType, Tensor as RonnTensor, TensorLayout};
8use anyhow::{Result, anyhow};
9use candle_core::{DType, Device, Module, Shape, Tensor as CandleTensor};
10
11/// Enhanced Tensor implementation with Candle backend.
12#[derive(Debug, Clone)]
13pub struct Tensor {
14    /// The underlying Candle tensor for computation.
15    candle_tensor: CandleTensor,
16    /// Original data type specification.
17    dtype: DataType,
18    /// Memory layout preference.
19    layout: TensorLayout,
20}
21
22impl Tensor {
23    /// Create a new tensor from raw data.
24    ///
25    /// # Arguments
26    /// * `data` - Raw tensor data
27    /// * `shape` - Tensor dimensions
28    /// * `dtype` - Data type specification
29    /// * `layout` - Memory layout preference
30    ///
31    /// # Example
32    /// ```rust
33    /// use ronn_core::tensor::Tensor;
34    /// use ronn_core::types::{DataType, TensorLayout};
35    ///
36    /// let data = vec![1.0, 2.0, 3.0, 4.0];
37    /// let tensor = Tensor::from_data(data, vec![2, 2], DataType::F32, TensorLayout::RowMajor)?;
38    /// # Ok::<(), Box<dyn std::error::Error>>(())
39    /// ```
40    pub fn from_data(
41        data: Vec<f32>,
42        shape: Vec<usize>,
43        dtype: DataType,
44        layout: TensorLayout,
45    ) -> Result<Self> {
46        let device = Device::Cpu;
47        let candle_shape = Shape::from_dims(&shape);
48
49        let candle_tensor = match dtype {
50            DataType::F32 => CandleTensor::from_vec(data, candle_shape, &device)?,
51            DataType::F16 => {
52                let f16_data: Vec<half::f16> = data.into_iter().map(half::f16::from_f32).collect();
53                CandleTensor::from_vec(f16_data, candle_shape, &device)?
54            }
55            DataType::BF16 => {
56                let bf16_data: Vec<half::bf16> =
57                    data.into_iter().map(half::bf16::from_f32).collect();
58                CandleTensor::from_vec(bf16_data, candle_shape, &device)?
59            }
60            DataType::F64 => {
61                let f64_data: Vec<f64> = data.into_iter().map(|x| x as f64).collect();
62                CandleTensor::from_vec(f64_data, candle_shape, &device)?
63            }
64            DataType::U8 => {
65                let u8_data: Vec<u8> = data.into_iter().map(|x| x as u8).collect();
66                CandleTensor::from_vec(u8_data, candle_shape, &device)?
67            }
68            DataType::U32 => {
69                let u32_data: Vec<u32> = data.into_iter().map(|x| x as u32).collect();
70                CandleTensor::from_vec(u32_data, candle_shape, &device)?
71            }
72            // For unsupported types, convert to F32
73            DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => {
74                CandleTensor::from_vec(data, candle_shape, &device)?
75            }
76        };
77
78        Ok(Self {
79            candle_tensor,
80            dtype,
81            layout,
82        })
83    }
84
85    /// Create a tensor filled with zeros.
86    ///
87    /// # Arguments
88    /// * `shape` - Tensor dimensions
89    /// * `dtype` - Data type specification
90    /// * `layout` - Memory layout preference
91    pub fn zeros(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
92        let device = Device::Cpu;
93        let candle_dtype = dtype_to_candle(&dtype)?;
94        let candle_shape = Shape::from_dims(&shape);
95
96        let candle_tensor = CandleTensor::zeros(candle_shape, candle_dtype, &device)?;
97
98        Ok(Self {
99            candle_tensor,
100            dtype,
101            layout,
102        })
103    }
104
105    /// Create a tensor filled with ones.
106    ///
107    /// # Arguments
108    /// * `shape` - Tensor dimensions
109    /// * `dtype` - Data type specification
110    /// * `layout` - Memory layout preference
111    pub fn ones(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
112        let device = Device::Cpu;
113        let candle_dtype = dtype_to_candle(&dtype)?;
114        let candle_shape = Shape::from_dims(&shape);
115
116        let candle_tensor = CandleTensor::ones(candle_shape, candle_dtype, &device)?;
117
118        Ok(Self {
119            candle_tensor,
120            dtype,
121            layout,
122        })
123    }
124
125    /// Create a tensor with random values from a uniform distribution.
126    pub fn rand(shape: Vec<usize>, dtype: DataType, layout: TensorLayout) -> Result<Self> {
127        let device = Device::Cpu;
128        let _candle_dtype = dtype_to_candle(&dtype)?;
129        let candle_shape = Shape::from_dims(&shape);
130
131        let candle_tensor = CandleTensor::rand(0.0, 1.0, candle_shape, &device)?;
132
133        Ok(Self {
134            candle_tensor,
135            dtype,
136            layout,
137        })
138    }
139
140    /// Get the shape of the tensor.
141    pub fn shape(&self) -> Vec<usize> {
142        self.candle_tensor.dims().to_vec()
143    }
144
145    /// Get the data type of the tensor.
146    pub fn dtype(&self) -> DataType {
147        self.dtype
148    }
149
150    /// Get the memory layout of the tensor.
151    pub fn layout(&self) -> TensorLayout {
152        self.layout
153    }
154
155    /// Get the number of dimensions.
156    pub fn ndim(&self) -> usize {
157        self.candle_tensor.dims().len()
158    }
159
160    /// Get the total number of elements.
161    pub fn numel(&self) -> usize {
162        self.candle_tensor.elem_count()
163    }
164
165    /// Get the device where the tensor is stored.
166    pub fn device(&self) -> &Device {
167        self.candle_tensor.device()
168    }
169
170    /// Convert tensor to CPU device.
171    pub fn to_cpu(&self) -> Result<Self> {
172        let cpu_tensor = self.candle_tensor.to_device(&Device::Cpu)?;
173        Ok(Self {
174            candle_tensor: cpu_tensor,
175            dtype: self.dtype,
176            layout: self.layout,
177        })
178    }
179
180    /// Convert tensor to GPU device (if available).
181    #[cfg(feature = "gpu")]
182    pub fn to_gpu(&self, device_id: usize) -> Result<Self> {
183        let gpu_device = Device::new_cuda(device_id)?;
184        let gpu_tensor = self.candle_tensor.to_device(&gpu_device)?;
185        Ok(Self {
186            candle_tensor: gpu_tensor,
187            dtype: self.dtype,
188            layout: self.layout,
189        })
190    }
191
192    /// Extract data as a vector of f32 values.
193    pub fn to_vec(&self) -> Result<Vec<f32>> {
194        // Flatten the tensor first if it's multi-dimensional
195        let flattened = if self.candle_tensor.dims().len() > 1 {
196            self.candle_tensor.flatten_all()?
197        } else {
198            self.candle_tensor.clone()
199        };
200
201        match self.dtype {
202            DataType::F32 | DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => {
203                let data: Vec<f32> = flattened.to_vec1()?;
204                Ok(data)
205            }
206            DataType::F16 => {
207                let data: Vec<half::f16> = flattened.to_vec1()?;
208                Ok(data.into_iter().map(|x| x.to_f32()).collect())
209            }
210            DataType::BF16 => {
211                let data: Vec<half::bf16> = flattened.to_vec1()?;
212                Ok(data.into_iter().map(|x| x.to_f32()).collect())
213            }
214            DataType::F64 => {
215                let data: Vec<f64> = flattened.to_vec1()?;
216                Ok(data.into_iter().map(|x| x as f32).collect())
217            }
218            DataType::U8 => {
219                let data: Vec<u8> = flattened.to_vec1()?;
220                Ok(data.into_iter().map(|x| x as f32).collect())
221            }
222            DataType::U32 => {
223                let data: Vec<u32> = flattened.to_vec1()?;
224                Ok(data.into_iter().map(|x| x as f32).collect())
225            }
226        }
227    }
228
229    /// Get the underlying Candle tensor for advanced operations.
230    pub fn candle_tensor(&self) -> &CandleTensor {
231        &self.candle_tensor
232    }
233
234    /// Create a Tensor from a Candle tensor.
235    pub fn from_candle(candle_tensor: CandleTensor, dtype: DataType, layout: TensorLayout) -> Self {
236        Self {
237            candle_tensor,
238            dtype,
239            layout,
240        }
241    }
242
243    /// Check if tensor shapes are broadcastable.
244    pub fn is_broadcastable_with(&self, other: &Tensor) -> bool {
245        let shape1 = self.shape();
246        let shape2 = other.shape();
247
248        // Pad shorter shape with 1s on the left
249        let max_len = shape1.len().max(shape2.len());
250        let mut padded1 = vec![1; max_len - shape1.len()];
251        let mut padded2 = vec![1; max_len - shape2.len()];
252        padded1.extend(shape1);
253        padded2.extend(shape2);
254
255        // Check compatibility dimension by dimension
256        for (d1, d2) in padded1.iter().zip(padded2.iter()) {
257            if *d1 != *d2 && *d1 != 1 && *d2 != 1 {
258                return false;
259            }
260        }
261        true
262    }
263
264    /// Compute broadcast shape for two tensors.
265    pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Vec<usize>> {
266        let max_len = shape1.len().max(shape2.len());
267        let mut padded1 = vec![1; max_len - shape1.len()];
268        let mut padded2 = vec![1; max_len - shape2.len()];
269        padded1.extend(shape1);
270        padded2.extend(shape2);
271
272        let mut result = Vec::with_capacity(max_len);
273        for (d1, d2) in padded1.iter().zip(padded2.iter()) {
274            match (d1, d2) {
275                (1, d) | (d, 1) => result.push(*d),
276                (d1, d2) if d1 == d2 => result.push(*d1),
277                (d1, d2) => {
278                    return Err(anyhow!(
279                        "Cannot broadcast shapes: dimension {} vs {}",
280                        d1,
281                        d2
282                    ));
283                }
284            }
285        }
286        Ok(result)
287    }
288
289    /// Convolution 2D operation (placeholder - uses Candle's conv2d).
290    pub fn conv2d(
291        &self,
292        weight: &Tensor,
293        bias: Option<&Tensor>,
294        strides: &[usize],
295        pads: &[usize],
296        dilations: &[usize],
297        groups: usize,
298    ) -> Result<Tensor> {
299        // Simplified implementation - full implementation would use candle_nn
300        let _ = (weight, bias, strides, pads, dilations, groups);
301        Err(anyhow!("conv2d not yet fully implemented"))
302    }
303
304    /// Max pooling 2D operation.
305    pub fn max_pool2d(
306        &self,
307        kernel_shape: &[usize],
308        strides: &[usize],
309        pads: &[usize],
310    ) -> Result<Tensor> {
311        let _ = (kernel_shape, strides, pads);
312        Err(anyhow!("max_pool2d not yet fully implemented"))
313    }
314
315    /// Average pooling 2D operation.
316    pub fn avg_pool2d(
317        &self,
318        kernel_shape: &[usize],
319        strides: &[usize],
320        pads: &[usize],
321    ) -> Result<Tensor> {
322        let _ = (kernel_shape, strides, pads);
323        Err(anyhow!("avg_pool2d not yet fully implemented"))
324    }
325
326    /// Batch normalization operation.
327    pub fn batch_norm(
328        &self,
329        scale: &Tensor,
330        bias: &Tensor,
331        mean: &Tensor,
332        var: &Tensor,
333        epsilon: f32,
334    ) -> Result<Tensor> {
335        let _ = (scale, bias, mean, var, epsilon);
336        Err(anyhow!("batch_norm not yet fully implemented"))
337    }
338
339    /// Get rank (number of dimensions).
340    pub fn rank(&self) -> usize {
341        self.ndim()
342    }
343
344    /// Convert to 1D vector (alias for to_vec).
345    pub fn to_vec1<T: candle_core::WithDType>(&self) -> Result<Vec<T>> {
346        let flattened = if self.candle_tensor.dims().len() > 1 {
347            self.candle_tensor.flatten_all()?
348        } else {
349            self.candle_tensor.clone()
350        };
351        Ok(flattened.to_vec1()?)
352    }
353
354    /// Stack tensors along a new dimension.
355    ///
356    /// # Arguments
357    /// * `tensors` - Slice of tensors to stack
358    /// * `dim` - Dimension along which to stack
359    ///
360    /// # Example
361    /// ```rust
362    /// use ronn_core::tensor::Tensor;
363    /// use ronn_core::types::{DataType, TensorLayout};
364    ///
365    /// let t1 = Tensor::from_data(vec![1.0, 2.0], vec![2], DataType::F32, TensorLayout::RowMajor)?;
366    /// let t2 = Tensor::from_data(vec![3.0, 4.0], vec![2], DataType::F32, TensorLayout::RowMajor)?;
367    /// let stacked = Tensor::stack(&[&t1, &t2], 0)?;
368    /// assert_eq!(stacked.shape(), vec![2, 2]);
369    /// # Ok::<(), Box<dyn std::error::Error>>(())
370    /// ```
371    pub fn stack(tensors: &[&Tensor], dim: usize) -> Result<Self> {
372        if tensors.is_empty() {
373            return Err(anyhow!("Cannot stack empty tensor list"));
374        }
375
376        let candle_tensors: Vec<_> = tensors.iter().map(|t| &t.candle_tensor).collect();
377        let stacked = CandleTensor::stack(&candle_tensors, dim)?;
378
379        Ok(Self {
380            candle_tensor: stacked,
381            dtype: tensors[0].dtype,
382            layout: tensors[0].layout,
383        })
384    }
385
386    /// Split tensor into chunks along an axis.
387    ///
388    /// # Arguments
389    /// * `num_chunks` - Number of chunks to split into
390    /// * `dim` - Dimension along which to split
391    ///
392    /// # Example
393    /// ```rust
394    /// use ronn_core::tensor::Tensor;
395    /// use ronn_core::types::{DataType, TensorLayout};
396    ///
397    /// let t = Tensor::from_data(
398    ///     vec![1.0, 2.0, 3.0, 4.0],
399    ///     vec![2, 2],
400    ///     DataType::F32,
401    ///     TensorLayout::RowMajor
402    /// )?;
403    /// let chunks = t.split(2, 0)?;
404    /// assert_eq!(chunks.len(), 2);
405    /// # Ok::<(), Box<dyn std::error::Error>>(())
406    /// ```
407    pub fn split(&self, num_chunks: usize, dim: usize) -> Result<Vec<Tensor>> {
408        if num_chunks == 0 {
409            return Err(anyhow!("Cannot split into 0 chunks"));
410        }
411
412        let shape = self.shape();
413        if dim >= shape.len() {
414            return Err(anyhow!(
415                "Dimension {} out of bounds for shape {:?}",
416                dim,
417                shape
418            ));
419        }
420
421        let dim_size = shape[dim];
422        if dim_size % num_chunks != 0 {
423            return Err(anyhow!(
424                "Dimension size {} not evenly divisible by {} chunks",
425                dim_size,
426                num_chunks
427            ));
428        }
429
430        let chunk_size = dim_size / num_chunks;
431        let mut chunks = Vec::with_capacity(num_chunks);
432
433        for i in 0..num_chunks {
434            let start = i * chunk_size;
435            let _end = start + chunk_size;
436            let chunk = self.candle_tensor.narrow(dim, start, chunk_size)?;
437            chunks.push(Self {
438                candle_tensor: chunk,
439                dtype: self.dtype,
440                layout: self.layout,
441            });
442        }
443
444        Ok(chunks)
445    }
446
447    /// Gather elements along an axis.
448    pub fn gather(&self, indices: &Tensor, dim: usize) -> Result<Tensor> {
449        let _ = (indices, dim);
450        Err(anyhow!("gather not yet fully implemented"))
451    }
452
453    /// Transpose with specific permutation.
454    pub fn transpose(&self, perm: &[usize]) -> Result<Tensor> {
455        let result = self.candle_tensor.permute(perm)?;
456        Ok(Tensor::from_candle(result, self.dtype, self.layout))
457    }
458
459    /// Layer normalization (critical for transformers).
460    ///
461    /// Normalizes the input across the specified axis.
462    ///
463    /// # Arguments
464    /// * `scale` - Optional scale parameter (gamma)
465    /// * `bias` - Optional bias parameter (beta)
466    /// * `epsilon` - Small constant for numerical stability
467    /// * `axis` - Axis to normalize over (default: -1 for last dimension)
468    ///
469    /// # Example
470    /// ```ignore
471    /// use ronn_core::tensor::Tensor;
472    /// use ronn_core::types::{DataType, TensorLayout};
473    ///
474    /// let input = Tensor::from_data(
475    ///     vec![1.0, 2.0, 3.0, 4.0],
476    ///     vec![2, 2],
477    ///     DataType::F32,
478    ///     TensorLayout::RowMajor
479    /// )?;
480    /// let normalized = input.layer_norm(None, None, 1e-5, 1)?;  // Use positive axis
481    /// # Ok::<(), Box<dyn std::error::Error>>(())
482    /// ```
483    pub fn layer_norm(
484        &self,
485        scale: Option<&Tensor>,
486        bias: Option<&Tensor>,
487        epsilon: f32,
488        axis: i32,
489    ) -> Result<Self> {
490        use candle_nn::LayerNorm;
491
492        let shape = self.shape();
493        let _normalized_shape = if axis == -1 {
494            vec![shape[shape.len() - 1]]
495        } else {
496            let axis_usize = if axis < 0 {
497                (shape.len() as i32 + axis) as usize
498            } else {
499                axis as usize
500            };
501            vec![shape[axis_usize]]
502        };
503
504        // Create layer norm config
505        // If scale and bias provided, use them
506        let normalized = if let (Some(s), Some(b)) = (scale, bias) {
507            let ln = LayerNorm::new(
508                s.candle_tensor.clone(),
509                b.candle_tensor.clone(),
510                epsilon as f64,
511            );
512            ln.forward(&self.candle_tensor)?
513        } else {
514            // Simple normalization without learnable parameters
515            let mean = self.candle_tensor.mean_keepdim(axis as usize)?;
516            let variance = self
517                .candle_tensor
518                .broadcast_sub(&mean)?
519                .sqr()?
520                .mean_keepdim(axis as usize)?;
521            let std = (variance + epsilon as f64)?.sqrt()?;
522            self.candle_tensor
523                .broadcast_sub(&mean)?
524                .broadcast_div(&std)?
525        };
526
527        Ok(Self::from_candle(normalized, self.dtype, self.layout))
528    }
529
530    /// Multi-head attention mechanism (critical for transformers).
531    ///
532    /// Computes scaled dot-product attention: softmax(Q·K^T / sqrt(d_k))·V
533    ///
534    /// # Arguments
535    /// * `key` - Key tensor
536    /// * `value` - Value tensor
537    /// * `num_heads` - Number of attention heads
538    /// * `mask` - Optional attention mask
539    ///
540    /// # Example
541    /// ```rust
542    /// use ronn_core::tensor::Tensor;
543    /// use ronn_core::types::{DataType, TensorLayout};
544    ///
545    /// let query = Tensor::from_data(
546    ///     vec![1.0; 64],
547    ///     vec![1, 8, 8],  // (batch, seq_len, d_model)
548    ///     DataType::F32,
549    ///     TensorLayout::RowMajor
550    /// )?;
551    /// let key = query.clone();
552    /// let value = query.clone();
553    ///
554    /// let output = query.attention(&key, &value, 2, None)?;
555    /// # Ok::<(), Box<dyn std::error::Error>>(())
556    /// ```
557    pub fn attention(
558        &self,
559        key: &Tensor,
560        value: &Tensor,
561        num_heads: usize,
562        mask: Option<&Tensor>,
563    ) -> Result<Self> {
564        let query = &self.candle_tensor;
565        let key = &key.candle_tensor;
566        let value = &value.candle_tensor;
567
568        // Get dimensions
569        let query_shape = query.dims();
570        if query_shape.len() != 3 {
571            return Err(anyhow!(
572                "Query must be 3D (batch, seq_len, d_model), got {:?}",
573                query_shape
574            ));
575        }
576
577        let batch_size = query_shape[0];
578        let seq_len = query_shape[1];
579        let d_model = query_shape[2];
580
581        if d_model % num_heads != 0 {
582            return Err(anyhow!(
583                "d_model ({}) must be divisible by num_heads ({})",
584                d_model,
585                num_heads
586            ));
587        }
588
589        let d_k = d_model / num_heads;
590
591        // Reshape Q, K, V for multi-head attention
592        // (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
593        let q = query
594            .reshape(&[batch_size, seq_len, num_heads, d_k])?
595            .transpose(1, 2)?;
596        let k = key
597            .reshape(&[batch_size, seq_len, num_heads, d_k])?
598            .transpose(1, 2)?;
599        let v = value
600            .reshape(&[batch_size, seq_len, num_heads, d_k])?
601            .transpose(1, 2)?;
602
603        // Compute attention scores: Q·K^T / sqrt(d_k)
604        let k_t = k.transpose(2, 3)?;
605        let scores = (q.matmul(&k_t)? / (d_k as f64).sqrt())?;
606
607        // Apply mask if provided
608        let scores = if let Some(m) = mask {
609            scores.broadcast_add(&m.candle_tensor)?
610        } else {
611            scores
612        };
613
614        // Apply softmax
615        let attention_weights = candle_nn::ops::softmax_last_dim(&scores)?;
616
617        // Apply attention to values: attention_weights·V
618        let output = attention_weights.matmul(&v)?;
619
620        // Reshape back: (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
621        let output = output
622            .transpose(1, 2)?
623            .reshape(&[batch_size, seq_len, d_model])?;
624
625        Ok(Self::from_candle(output, self.dtype, self.layout))
626    }
627
628    /// Clip values to a range [min, max]
629    pub fn clip(&self, min: f32, max: f32) -> Result<Self> {
630        let result = self.candle_tensor.clamp(min, max)?;
631        Ok(Self::from_candle(result, self.dtype, self.layout))
632    }
633
634    /// Element-wise power operation with tensor exponent.
635    /// For scalar exponents, use the `ArithmeticOps::pow` trait method instead.
636    pub fn pow_tensor(&self, exponent: &Tensor) -> Result<Self> {
637        let result = self.candle_tensor.pow(&exponent.candle_tensor)?;
638        Ok(Self::from_candle(result, self.dtype, self.layout))
639    }
640
641    /// Element-wise square root
642    pub fn sqrt(&self) -> Result<Self> {
643        let result = self.candle_tensor.sqrt()?;
644        Ok(Self::from_candle(result, self.dtype, self.layout))
645    }
646
647    /// Element-wise exponential (e^x)
648    pub fn exp(&self) -> Result<Self> {
649        let result = self.candle_tensor.exp()?;
650        Ok(Self::from_candle(result, self.dtype, self.layout))
651    }
652
653    /// Element-wise natural logarithm
654    pub fn log(&self) -> Result<Self> {
655        let result = self.candle_tensor.log()?;
656        Ok(Self::from_candle(result, self.dtype, self.layout))
657    }
658
659    /// Element-wise negation
660    pub fn neg(&self) -> Result<Self> {
661        let result = self.candle_tensor.neg()?;
662        Ok(Self::from_candle(result, self.dtype, self.layout))
663    }
664
665    /// Element-wise absolute value
666    pub fn abs(&self) -> Result<Self> {
667        let result = self.candle_tensor.abs()?;
668        Ok(Self::from_candle(result, self.dtype, self.layout))
669    }
670
671    /// LeakyReLU activation: max(alpha * x, x)
672    pub fn leaky_relu(&self, alpha: f32) -> Result<Self> {
673        let scaled = self.candle_tensor.affine(alpha as f64, 0.0)?;
674        let result = self.candle_tensor.maximum(&scaled)?;
675        Ok(Self::from_candle(result, self.dtype, self.layout))
676    }
677
678    /// ELU activation: x if x > 0 else alpha * (exp(x) - 1)
679    pub fn elu(&self, alpha: f32) -> Result<Self> {
680        // ELU(x) = x if x > 0 else alpha * (exp(x) - 1)
681        let zero = self.candle_tensor.zeros_like()?;
682        let mask = self.candle_tensor.gt(&zero)?;
683
684        let positive_part = &self.candle_tensor;
685        let exp_part = self.candle_tensor.exp()?.affine(1.0, -1.0)?;
686        let negative_part = exp_part.affine(alpha as f64, 0.0)?;
687
688        let result = mask.where_cond(positive_part, &negative_part)?;
689        Ok(Self::from_candle(result, self.dtype, self.layout))
690    }
691
692    /// Swish/SiLU activation: x * sigmoid(x)
693    pub fn swish(&self) -> Result<Self> {
694        let sigmoid = candle_nn::ops::sigmoid(&self.candle_tensor)?;
695        let result = (&self.candle_tensor * &sigmoid)?;
696        Ok(Self::from_candle(result, self.dtype, self.layout))
697    }
698
699    /// Remove dimensions of size 1
700    pub fn squeeze(&self, axes: Option<Vec<usize>>) -> Result<Self> {
701        let shape = self.shape();
702        let new_shape: Vec<usize> = if let Some(axes) = axes {
703            // Remove specific axes
704            shape
705                .iter()
706                .enumerate()
707                .filter(|(i, dim)| !axes.contains(i) || **dim != 1)
708                .map(|(_, dim)| *dim)
709                .collect()
710        } else {
711            // Remove all dimensions of size 1
712            shape.iter().copied().filter(|dim| *dim != 1).collect()
713        };
714
715        if new_shape.is_empty() {
716            // If all dimensions were 1, keep at least one
717            return self.reshape(&[1]);
718        }
719
720        self.reshape(&new_shape)
721    }
722
723    /// Add dimensions of size 1
724    pub fn unsqueeze(&self, axes: &[usize]) -> Result<Self> {
725        let mut new_shape = self.shape();
726        let mut axes_sorted = axes.to_vec();
727        axes_sorted.sort_unstable();
728
729        for &axis in &axes_sorted {
730            // Check bounds before inserting
731            if axis > new_shape.len() {
732                return Err(anyhow!(
733                    "Unsqueeze axis {} is out of bounds for shape with {} dimensions",
734                    axis,
735                    new_shape.len()
736                ));
737            }
738            new_shape.insert(axis, 1);
739        }
740
741        self.reshape(&new_shape)
742    }
743
744    /// Reduce mean along axes
745    pub fn reduce_mean(&self, axes: &[usize], keepdims: bool) -> Result<Self> {
746        let mut result = self.candle_tensor.clone();
747
748        // Sort axes in descending order to maintain correct indices
749        let mut sorted_axes = axes.to_vec();
750        sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
751
752        for &axis in &sorted_axes {
753            result = result.mean_keepdim(axis)?;
754            if !keepdims {
755                result = result.squeeze(axis)?;
756            }
757        }
758
759        Ok(Self::from_candle(result, self.dtype, self.layout))
760    }
761
762    /// Reduce sum along axes
763    pub fn reduce_sum(&self, axes: &[usize], keepdims: bool) -> Result<Self> {
764        let mut result = self.candle_tensor.clone();
765
766        // Sort axes in descending order to maintain correct indices
767        let mut sorted_axes = axes.to_vec();
768        sorted_axes.sort_unstable_by(|a, b| b.cmp(a));
769
770        for &axis in &sorted_axes {
771            result = result.sum_keepdim(axis)?;
772            if !keepdims {
773                result = result.squeeze(axis)?;
774            }
775        }
776
777        Ok(Self::from_candle(result, self.dtype, self.layout))
778    }
779
780    /// Cast tensor to a different data type
781    pub fn cast(&self, to: DataType) -> Result<Self> {
782        let target_dtype = dtype_to_candle(&to)?;
783        let result = self.candle_tensor.to_dtype(target_dtype)?;
784        Ok(Self::from_candle(result, to, self.layout))
785    }
786
787    /// Convert tensor to a scalar f32 value
788    pub fn to_scalar_f32(&self) -> Result<f32> {
789        let value = self.candle_tensor.to_scalar::<f32>()?;
790        Ok(value)
791    }
792}
793
794/// Convert RONN DataType to Candle DType.
795fn dtype_to_candle(dtype: &DataType) -> Result<DType> {
796    match dtype {
797        DataType::F32 => Ok(DType::F32),
798        DataType::F16 => Ok(DType::F16),
799        DataType::BF16 => Ok(DType::BF16),
800        DataType::F64 => Ok(DType::F64),
801        DataType::U8 => Ok(DType::U8),
802        DataType::U32 => Ok(DType::U32),
803        // For unsupported types, use F32
804        DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => Ok(DType::F32),
805    }
806}
807
808/// Convert Candle DType to RONN DataType.
809#[allow(dead_code)]
810fn dtype_from_candle(dtype: DType) -> DataType {
811    match dtype {
812        DType::F32 => DataType::F32,
813        DType::F16 => DataType::F16,
814        DType::U8 => DataType::U8,
815        DType::U32 => DataType::U32,
816        DType::F64 => DataType::F64,
817        _ => DataType::F32, // Default fallback
818    }
819}
820
821/// Convert legacy RonnTensor to new Tensor implementation.
822impl From<RonnTensor> for Tensor {
823    fn from(legacy: RonnTensor) -> Self {
824        Self::from_data(legacy.data, legacy.shape, legacy.dtype, legacy.layout)
825            .expect("Failed to convert legacy tensor")
826    }
827}
828
829/// Convert new Tensor to legacy RonnTensor for compatibility.
830impl From<Tensor> for RonnTensor {
831    fn from(tensor: Tensor) -> Self {
832        let data = tensor.to_vec().expect("Failed to extract tensor data");
833        Self {
834            data,
835            shape: tensor.shape(),
836            dtype: tensor.dtype,
837            layout: tensor.layout,
838        }
839    }
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_tensor_creation() -> Result<()> {
848        let data = vec![1.0, 2.0, 3.0, 4.0];
849        let tensor = Tensor::from_data(
850            data.clone(),
851            vec![2, 2],
852            DataType::F32,
853            TensorLayout::RowMajor,
854        )?;
855
856        assert_eq!(tensor.shape(), vec![2, 2]);
857        assert_eq!(tensor.dtype(), DataType::F32);
858        assert_eq!(tensor.numel(), 4);
859
860        let extracted = tensor.to_vec()?;
861        assert_eq!(extracted, data);
862
863        Ok(())
864    }
865
866    #[test]
867    fn test_zeros_and_ones() -> Result<()> {
868        let zeros = Tensor::zeros(vec![3, 3], DataType::F32, TensorLayout::RowMajor)?;
869        let zeros_data = zeros.to_vec()?;
870        assert!(zeros_data.iter().all(|&x| x == 0.0));
871
872        let ones = Tensor::ones(vec![2, 3], DataType::F32, TensorLayout::RowMajor)?;
873        let ones_data = ones.to_vec()?;
874        assert!(ones_data.iter().all(|&x| x == 1.0));
875
876        Ok(())
877    }
878
879    #[test]
880    fn test_broadcasting() {
881        // Compatible shapes
882        assert_eq!(
883            Tensor::broadcast_shape(&[3, 1], &[1, 4]).unwrap(),
884            vec![3, 4]
885        );
886        assert_eq!(
887            Tensor::broadcast_shape(&[2, 3, 1], &[1, 4]).unwrap(),
888            vec![2, 3, 4]
889        );
890
891        // Incompatible shapes
892        assert!(Tensor::broadcast_shape(&[3, 2], &[2, 3]).is_err());
893    }
894
895    #[test]
896    fn test_broadcastable_check() -> Result<()> {
897        let tensor1 = Tensor::zeros(vec![3, 1], DataType::F32, TensorLayout::RowMajor)?;
898        let tensor2 = Tensor::zeros(vec![1, 4], DataType::F32, TensorLayout::RowMajor)?;
899        let tensor3 = Tensor::zeros(vec![2, 3], DataType::F32, TensorLayout::RowMajor)?;
900
901        assert!(tensor1.is_broadcastable_with(&tensor2));
902        assert!(!tensor1.is_broadcastable_with(&tensor3));
903
904        Ok(())
905    }
906
907    #[test]
908    fn test_data_type_conversions() -> Result<()> {
909        // Test F16 conversion
910        let data = vec![1.5, 2.5, 3.5, 4.5];
911        let tensor_f16 = Tensor::from_data(
912            data.clone(),
913            vec![2, 2],
914            DataType::F16,
915            TensorLayout::RowMajor,
916        )?;
917        let extracted_f16 = tensor_f16.to_vec()?;
918
919        // F16 has limited precision, so we check with tolerance
920        for (original, extracted) in data.iter().zip(extracted_f16.iter()) {
921            assert!((original - extracted).abs() < 0.01);
922        }
923
924        // Test I8 conversion
925        let int_data = vec![1.0, -2.0, 3.0, -4.0];
926        let tensor_i8 =
927            Tensor::from_data(int_data, vec![2, 2], DataType::I8, TensorLayout::RowMajor)?;
928        let extracted_i8 = tensor_i8.to_vec()?;
929        assert_eq!(extracted_i8, vec![1.0, -2.0, 3.0, -4.0]);
930
931        Ok(())
932    }
933
934    #[test]
935    fn test_device_operations() -> Result<()> {
936        let tensor = Tensor::zeros(vec![2, 2], DataType::F32, TensorLayout::RowMajor)?;
937
938        // Should be on CPU by default
939        assert!(matches!(tensor.device(), Device::Cpu));
940
941        // CPU conversion should work
942        let cpu_tensor = tensor.to_cpu()?;
943        assert!(matches!(cpu_tensor.device(), Device::Cpu));
944
945        Ok(())
946    }
947}