ipfrs_core/
tensor.rs

1//! Tensor-aware block types for neural network data.
2//!
3//! This module provides specialized types for storing and managing tensor data
4//! in a content-addressed manner. Tensors are the fundamental data structure
5//! in machine learning frameworks like PyTorch and TensorFlow.
6//!
7//! # Example
8//!
9//! ```rust
10//! use ipfrs_core::tensor::{TensorBlock, TensorDtype, TensorShape};
11//! use bytes::Bytes;
12//!
13//! // Create a 2x3 f32 tensor
14//! let shape = TensorShape::new(vec![2, 3]);
15//! let data = Bytes::from(vec![
16//!     0f32.to_le_bytes(), 1f32.to_le_bytes(),
17//!     2f32.to_le_bytes(), 3f32.to_le_bytes(),
18//!     4f32.to_le_bytes(), 5f32.to_le_bytes(),
19//! ].concat());
20//!
21//! let tensor = TensorBlock::new(data, shape, TensorDtype::F32).unwrap();
22//! assert_eq!(tensor.element_count(), 6);
23//! ```
24
25use crate::block::Block;
26use crate::error::{Error, Result};
27use bytes::Bytes;
28use serde::{Deserialize, Serialize};
29
30/// Supported tensor data types
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32pub enum TensorDtype {
33    /// 32-bit floating point (IEEE 754)
34    F32,
35    /// 16-bit floating point (IEEE 754-2008)
36    F16,
37    /// 64-bit floating point (IEEE 754)
38    F64,
39    /// 8-bit signed integer
40    I8,
41    /// 32-bit signed integer
42    I32,
43    /// 64-bit signed integer
44    I64,
45    /// 8-bit unsigned integer
46    U8,
47    /// 32-bit unsigned integer
48    U32,
49    /// Boolean (1 byte)
50    Bool,
51}
52
53impl TensorDtype {
54    /// Get the size in bytes of this data type
55    #[inline]
56    pub fn size_bytes(&self) -> usize {
57        match self {
58            TensorDtype::F32 => 4,
59            TensorDtype::F16 => 2,
60            TensorDtype::F64 => 8,
61            TensorDtype::I8 => 1,
62            TensorDtype::I32 => 4,
63            TensorDtype::I64 => 8,
64            TensorDtype::U8 => 1,
65            TensorDtype::U32 => 4,
66            TensorDtype::Bool => 1,
67        }
68    }
69
70    /// Get a human-readable name for this data type
71    #[inline]
72    pub fn name(&self) -> &'static str {
73        match self {
74            TensorDtype::F32 => "float32",
75            TensorDtype::F16 => "float16",
76            TensorDtype::F64 => "float64",
77            TensorDtype::I8 => "int8",
78            TensorDtype::I32 => "int32",
79            TensorDtype::I64 => "int64",
80            TensorDtype::U8 => "uint8",
81            TensorDtype::U32 => "uint32",
82            TensorDtype::Bool => "bool",
83        }
84    }
85}
86
87/// Tensor shape (dimensions)
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
89pub struct TensorShape {
90    dims: Vec<usize>,
91}
92
93impl TensorShape {
94    /// Create a new tensor shape
95    pub fn new(dims: Vec<usize>) -> Self {
96        Self { dims }
97    }
98
99    /// Create a scalar (0-dimensional tensor)
100    pub fn scalar() -> Self {
101        Self { dims: vec![] }
102    }
103
104    /// Get the dimensions
105    #[inline]
106    pub fn dims(&self) -> &[usize] {
107        &self.dims
108    }
109
110    /// Get the rank (number of dimensions)
111    #[inline]
112    pub fn rank(&self) -> usize {
113        self.dims.len()
114    }
115
116    /// Calculate the total number of elements
117    #[inline]
118    pub fn element_count(&self) -> usize {
119        if self.dims.is_empty() {
120            1
121        } else {
122            self.dims.iter().product()
123        }
124    }
125
126    /// Check if this is a scalar
127    #[inline]
128    pub fn is_scalar(&self) -> bool {
129        self.dims.is_empty()
130    }
131
132    /// Check if this is a vector (1D)
133    #[inline]
134    pub fn is_vector(&self) -> bool {
135        self.dims.len() == 1
136    }
137
138    /// Check if this is a matrix (2D)
139    #[inline]
140    pub fn is_matrix(&self) -> bool {
141        self.dims.len() == 2
142    }
143}
144
145/// Tensor metadata
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct TensorMetadata {
148    /// Tensor shape
149    pub shape: TensorShape,
150    /// Data type
151    pub dtype: TensorDtype,
152    /// Optional tensor name
153    pub name: Option<String>,
154    /// Optional additional metadata (e.g., gradient info, requires_grad)
155    pub metadata: std::collections::BTreeMap<String, String>,
156}
157
158impl TensorMetadata {
159    /// Create new tensor metadata
160    pub fn new(shape: TensorShape, dtype: TensorDtype) -> Self {
161        Self {
162            shape,
163            dtype,
164            name: None,
165            metadata: std::collections::BTreeMap::new(),
166        }
167    }
168
169    /// Set the tensor name
170    pub fn with_name(mut self, name: String) -> Self {
171        self.name = Some(name);
172        self
173    }
174
175    /// Add custom metadata
176    pub fn with_metadata(mut self, key: String, value: String) -> Self {
177        self.metadata.insert(key, value);
178        self
179    }
180
181    /// Calculate expected data size in bytes
182    pub fn expected_size(&self) -> usize {
183        self.shape.element_count() * self.dtype.size_bytes()
184    }
185}
186
187/// A content-addressed tensor block
188///
189/// Combines a regular [`Block`] with tensor-specific metadata like shape and dtype.
190/// This allows storing neural network weights, activations, and gradients in a
191/// content-addressed manner.
192#[derive(Debug, Clone)]
193pub struct TensorBlock {
194    /// The underlying data block
195    block: Block,
196    /// Tensor metadata
197    metadata: TensorMetadata,
198}
199
200impl TensorBlock {
201    /// Create a new tensor block
202    ///
203    /// # Arguments
204    ///
205    /// * `data` - Raw tensor data (should be in native endian format)
206    /// * `shape` - Tensor shape
207    /// * `dtype` - Data type
208    ///
209    /// # Errors
210    ///
211    /// Returns an error if:
212    /// - Data size doesn't match shape * dtype size
213    /// - Block creation fails
214    ///
215    /// # Example
216    ///
217    /// ```rust
218    /// use ipfrs_core::tensor::{TensorBlock, TensorDtype, TensorShape};
219    /// use bytes::Bytes;
220    ///
221    /// let shape = TensorShape::new(vec![2, 2]);
222    /// let data = Bytes::from(vec![1.0f32, 2.0, 3.0, 4.0]
223    ///     .iter()
224    ///     .flat_map(|f| f.to_le_bytes())
225    ///     .collect::<Vec<u8>>());
226    ///
227    /// let tensor = TensorBlock::new(data, shape, TensorDtype::F32).unwrap();
228    /// assert_eq!(tensor.element_count(), 4);
229    /// ```
230    pub fn new(data: Bytes, shape: TensorShape, dtype: TensorDtype) -> Result<Self> {
231        let metadata = TensorMetadata::new(shape, dtype);
232
233        // Validate data size
234        let expected_size = metadata.expected_size();
235        if data.len() != expected_size {
236            return Err(Error::InvalidData(format!(
237                "Tensor data size mismatch: expected {} bytes, got {}",
238                expected_size,
239                data.len()
240            )));
241        }
242
243        // Create underlying block
244        let block = Block::new(data)?;
245
246        Ok(Self { block, metadata })
247    }
248
249    /// Create a tensor block with metadata
250    pub fn with_metadata(data: Bytes, metadata: TensorMetadata) -> Result<Self> {
251        let expected_size = metadata.expected_size();
252        if data.len() != expected_size {
253            return Err(Error::InvalidData(format!(
254                "Tensor data size mismatch: expected {} bytes, got {}",
255                expected_size,
256                data.len()
257            )));
258        }
259
260        let block = Block::new(data)?;
261        Ok(Self { block, metadata })
262    }
263
264    /// Get the underlying block
265    pub fn block(&self) -> &Block {
266        &self.block
267    }
268
269    /// Get tensor metadata
270    pub fn metadata(&self) -> &TensorMetadata {
271        &self.metadata
272    }
273
274    /// Get the tensor shape
275    pub fn shape(&self) -> &TensorShape {
276        &self.metadata.shape
277    }
278
279    /// Get the data type
280    pub fn dtype(&self) -> TensorDtype {
281        self.metadata.dtype
282    }
283
284    /// Get the number of elements
285    pub fn element_count(&self) -> usize {
286        self.metadata.shape.element_count()
287    }
288
289    /// Get the CID of this tensor
290    pub fn cid(&self) -> &crate::cid::Cid {
291        self.block.cid()
292    }
293
294    /// Get the raw tensor data
295    pub fn data(&self) -> &Bytes {
296        self.block.data()
297    }
298
299    /// Consume and return the underlying block and metadata
300    pub fn into_parts(self) -> (Block, TensorMetadata) {
301        (self.block, self.metadata)
302    }
303
304    /// Verify the tensor block integrity
305    pub fn verify(&self) -> Result<bool> {
306        self.block.verify()
307    }
308
309    /// Reshape the tensor to a new shape (must have same element count)
310    pub fn reshape(&self, new_shape: TensorShape) -> Result<Self> {
311        if new_shape.element_count() != self.element_count() {
312            return Err(Error::InvalidInput(format!(
313                "Cannot reshape tensor with {} elements to shape with {} elements",
314                self.element_count(),
315                new_shape.element_count()
316            )));
317        }
318
319        let new_metadata = TensorMetadata {
320            shape: new_shape,
321            dtype: self.metadata.dtype,
322            name: self.metadata.name.clone(),
323            metadata: self.metadata.metadata.clone(),
324        };
325
326        Ok(Self {
327            block: self.block.clone(),
328            metadata: new_metadata,
329        })
330    }
331
332    /// Get the size in bytes
333    pub fn size_bytes(&self) -> usize {
334        self.data().len()
335    }
336
337    /// Check if this is a scalar tensor (0-dimensional)
338    pub fn is_scalar(&self) -> bool {
339        self.shape().is_scalar()
340    }
341
342    /// Check if this is a vector (1-dimensional)
343    pub fn is_vector(&self) -> bool {
344        self.shape().is_vector()
345    }
346
347    /// Check if this is a matrix (2-dimensional)
348    pub fn is_matrix(&self) -> bool {
349        self.shape().is_matrix()
350    }
351}
352
353/// Utility functions for creating tensors from typed data
354impl TensorBlock {
355    /// Create a tensor from a slice of f32 values
356    pub fn from_f32_slice(data: &[f32], shape: TensorShape) -> Result<Self> {
357        if data.len() != shape.element_count() {
358            return Err(Error::InvalidInput(format!(
359                "Data length {} doesn't match shape element count {}",
360                data.len(),
361                shape.element_count()
362            )));
363        }
364
365        let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
366        Self::new(Bytes::from(bytes), shape, TensorDtype::F32)
367    }
368
369    /// Create a tensor from a slice of f64 values
370    pub fn from_f64_slice(data: &[f64], shape: TensorShape) -> Result<Self> {
371        if data.len() != shape.element_count() {
372            return Err(Error::InvalidInput(format!(
373                "Data length {} doesn't match shape element count {}",
374                data.len(),
375                shape.element_count()
376            )));
377        }
378
379        let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
380        Self::new(Bytes::from(bytes), shape, TensorDtype::F64)
381    }
382
383    /// Create a tensor from a slice of i32 values
384    pub fn from_i32_slice(data: &[i32], shape: TensorShape) -> Result<Self> {
385        if data.len() != shape.element_count() {
386            return Err(Error::InvalidInput(format!(
387                "Data length {} doesn't match shape element count {}",
388                data.len(),
389                shape.element_count()
390            )));
391        }
392
393        let bytes: Vec<u8> = data.iter().flat_map(|&i| i.to_le_bytes()).collect();
394        Self::new(Bytes::from(bytes), shape, TensorDtype::I32)
395    }
396
397    /// Create a tensor from a slice of i64 values
398    pub fn from_i64_slice(data: &[i64], shape: TensorShape) -> Result<Self> {
399        if data.len() != shape.element_count() {
400            return Err(Error::InvalidInput(format!(
401                "Data length {} doesn't match shape element count {}",
402                data.len(),
403                shape.element_count()
404            )));
405        }
406
407        let bytes: Vec<u8> = data.iter().flat_map(|&i| i.to_le_bytes()).collect();
408        Self::new(Bytes::from(bytes), shape, TensorDtype::I64)
409    }
410
411    /// Create a tensor from a slice of u8 values
412    pub fn from_u8_slice(data: &[u8], shape: TensorShape) -> Result<Self> {
413        if data.len() != shape.element_count() {
414            return Err(Error::InvalidInput(format!(
415                "Data length {} doesn't match shape element count {}",
416                data.len(),
417                shape.element_count()
418            )));
419        }
420
421        Self::new(Bytes::copy_from_slice(data), shape, TensorDtype::U8)
422    }
423
424    /// Convert tensor data to a Vec of f32 values (if dtype is F32)
425    pub fn to_f32_vec(&self) -> Result<Vec<f32>> {
426        if self.dtype() != TensorDtype::F32 {
427            return Err(Error::InvalidInput(format!(
428                "Cannot convert {} tensor to f32",
429                self.dtype().name()
430            )));
431        }
432
433        let data = self.data();
434        let mut result = Vec::with_capacity(self.element_count());
435
436        for chunk in data.chunks_exact(4) {
437            let bytes: [u8; 4] = chunk.try_into().unwrap();
438            result.push(f32::from_le_bytes(bytes));
439        }
440
441        Ok(result)
442    }
443
444    /// Convert tensor data to a Vec of f64 values (if dtype is F64)
445    pub fn to_f64_vec(&self) -> Result<Vec<f64>> {
446        if self.dtype() != TensorDtype::F64 {
447            return Err(Error::InvalidInput(format!(
448                "Cannot convert {} tensor to f64",
449                self.dtype().name()
450            )));
451        }
452
453        let data = self.data();
454        let mut result = Vec::with_capacity(self.element_count());
455
456        for chunk in data.chunks_exact(8) {
457            let bytes: [u8; 8] = chunk.try_into().unwrap();
458            result.push(f64::from_le_bytes(bytes));
459        }
460
461        Ok(result)
462    }
463
464    /// Convert tensor data to a Vec of i32 values (if dtype is I32)
465    pub fn to_i32_vec(&self) -> Result<Vec<i32>> {
466        if self.dtype() != TensorDtype::I32 {
467            return Err(Error::InvalidInput(format!(
468                "Cannot convert {} tensor to i32",
469                self.dtype().name()
470            )));
471        }
472
473        let data = self.data();
474        let mut result = Vec::with_capacity(self.element_count());
475
476        for chunk in data.chunks_exact(4) {
477            let bytes: [u8; 4] = chunk.try_into().unwrap();
478            result.push(i32::from_le_bytes(bytes));
479        }
480
481        Ok(result)
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_tensor_dtype_sizes() {
491        assert_eq!(TensorDtype::F32.size_bytes(), 4);
492        assert_eq!(TensorDtype::F16.size_bytes(), 2);
493        assert_eq!(TensorDtype::I8.size_bytes(), 1);
494        assert_eq!(TensorDtype::I32.size_bytes(), 4);
495    }
496
497    #[test]
498    fn test_tensor_shape() {
499        let shape = TensorShape::new(vec![2, 3, 4]);
500        assert_eq!(shape.rank(), 3);
501        assert_eq!(shape.element_count(), 24);
502        assert!(!shape.is_scalar());
503        assert!(!shape.is_vector());
504        assert!(!shape.is_matrix());
505
506        let scalar = TensorShape::scalar();
507        assert!(scalar.is_scalar());
508        assert_eq!(scalar.element_count(), 1);
509    }
510
511    #[test]
512    fn test_tensor_block_creation() {
513        let shape = TensorShape::new(vec![2, 2]);
514        let data: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
515            .iter()
516            .flat_map(|f| f.to_le_bytes())
517            .collect();
518
519        let tensor = TensorBlock::new(Bytes::from(data), shape, TensorDtype::F32).unwrap();
520
521        assert_eq!(tensor.element_count(), 4);
522        assert_eq!(tensor.dtype(), TensorDtype::F32);
523        assert_eq!(tensor.shape().dims(), &[2, 2]);
524    }
525
526    #[test]
527    fn test_tensor_size_validation() {
528        let shape = TensorShape::new(vec![2, 2]);
529        // Too small data (only 3 floats instead of 4)
530        let data: Vec<u8> = [1.0f32, 2.0, 3.0]
531            .iter()
532            .flat_map(|f| f.to_le_bytes())
533            .collect();
534
535        let result = TensorBlock::new(Bytes::from(data), shape, TensorDtype::F32);
536        assert!(result.is_err());
537    }
538
539    #[test]
540    fn test_tensor_metadata() {
541        let shape = TensorShape::new(vec![10, 20]);
542        let metadata = TensorMetadata::new(shape, TensorDtype::F32)
543            .with_name("layer1.weight".to_string())
544            .with_metadata("requires_grad".to_string(), "true".to_string());
545
546        assert_eq!(metadata.name, Some("layer1.weight".to_string()));
547        assert_eq!(metadata.expected_size(), 10 * 20 * 4); // 800 bytes
548    }
549
550    #[test]
551    fn test_tensor_from_f32_slice() {
552        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
553        let shape = TensorShape::new(vec![2, 3]);
554
555        let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
556        assert_eq!(tensor.element_count(), 6);
557        assert_eq!(tensor.dtype(), TensorDtype::F32);
558
559        // Roundtrip test
560        let recovered = tensor.to_f32_vec().unwrap();
561        assert_eq!(recovered, data);
562    }
563
564    #[test]
565    fn test_tensor_from_i32_slice() {
566        let data = vec![10i32, 20, 30, 40];
567        let shape = TensorShape::new(vec![2, 2]);
568
569        let tensor = TensorBlock::from_i32_slice(&data, shape).unwrap();
570        assert_eq!(tensor.element_count(), 4);
571
572        let recovered = tensor.to_i32_vec().unwrap();
573        assert_eq!(recovered, data);
574    }
575
576    #[test]
577    fn test_tensor_reshape() {
578        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
579        let shape = TensorShape::new(vec![2, 3]);
580        let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
581
582        // Reshape 2x3 to 3x2
583        let reshaped = tensor.reshape(TensorShape::new(vec![3, 2])).unwrap();
584        assert_eq!(reshaped.shape().dims(), &[3, 2]);
585        assert_eq!(reshaped.element_count(), 6);
586
587        // Verify data is preserved
588        let recovered = reshaped.to_f32_vec().unwrap();
589        assert_eq!(recovered, data);
590    }
591
592    #[test]
593    fn test_tensor_reshape_invalid() {
594        let data = vec![1.0f32, 2.0, 3.0, 4.0];
595        let shape = TensorShape::new(vec![2, 2]);
596        let tensor = TensorBlock::from_f32_slice(&data, shape).unwrap();
597
598        // Try to reshape to incompatible shape
599        let result = tensor.reshape(TensorShape::new(vec![3, 2])); // 6 elements != 4
600        assert!(result.is_err());
601    }
602
603    #[test]
604    fn test_tensor_type_checks() {
605        let data = vec![1.0f32, 2.0];
606        let tensor = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![2])).unwrap();
607        assert!(tensor.is_vector());
608        assert!(!tensor.is_matrix());
609        assert!(!tensor.is_scalar());
610
611        let matrix = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![1, 2])).unwrap();
612        assert!(matrix.is_matrix());
613    }
614
615    #[test]
616    fn test_tensor_to_vec_wrong_dtype() {
617        let data = vec![1i32, 2, 3];
618        let tensor = TensorBlock::from_i32_slice(&data, TensorShape::new(vec![3])).unwrap();
619
620        // Try to convert i32 tensor to f32
621        let result = tensor.to_f32_vec();
622        assert!(result.is_err());
623    }
624}