Skip to main content

ai3_lib/
tensor.rs

1//! Tensor types: shape, data (F32/U8), and operations (hash, clamp for ESP).
2
3use pot_o_core::{TribeError, TribeResult};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6
7/// Shape of a tensor (dimension sizes).
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TensorShape {
10    /// Dimension sizes (e.g. [8, 8] for 8x8 matrix).
11    pub dims: Vec<usize>,
12}
13
14impl TensorShape {
15    /// Creates a shape from dimension sizes.
16    pub fn new(dims: Vec<usize>) -> Self {
17        Self { dims }
18    }
19
20    /// Total number of elements (product of dims).
21    pub fn total_elements(&self) -> usize {
22        self.dims.iter().product()
23    }
24
25    /// True if 2D (matrix).
26    pub fn is_matrix(&self) -> bool {
27        self.dims.len() == 2
28    }
29}
30
31/// Tensor element data (float or u8).
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum TensorData {
34    /// 32-bit float elements.
35    F32(Vec<f32>),
36    /// 8-bit unsigned elements.
37    U8(Vec<u8>),
38}
39
40impl TensorData {
41    pub fn len(&self) -> usize {
42        match self {
43            Self::F32(v) => v.len(),
44            Self::U8(v) => v.len(),
45        }
46    }
47
48    pub fn is_empty(&self) -> bool {
49        self.len() == 0
50    }
51
52    pub fn as_f32(&self) -> Vec<f32> {
53        match self {
54            Self::F32(v) => v.clone(),
55            Self::U8(v) => v.iter().map(|&b| b as f32 / 255.0).collect(),
56        }
57    }
58
59    pub fn to_bytes(&self) -> Vec<u8> {
60        match self {
61            Self::F32(v) => v.iter().flat_map(|f| f.to_le_bytes()).collect(),
62            Self::U8(v) => v.clone(),
63        }
64    }
65}
66
67/// A typed tensor with shape and data (used in challenges and mining results).
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Tensor {
70    /// Shape (dimension sizes).
71    pub shape: TensorShape,
72    /// Element data.
73    pub data: TensorData,
74}
75
76impl Tensor {
77    /// Builds a tensor; errors if data length does not match shape.
78    pub fn new(shape: TensorShape, data: TensorData) -> TribeResult<Self> {
79        let expected = shape.total_elements();
80        let actual = data.len();
81        if actual != expected {
82            return Err(TribeError::TensorError(format!(
83                "Shape expects {expected} elements but data has {actual}"
84            )));
85        }
86        Ok(Self { shape, data })
87    }
88
89    /// Creates a tensor of zeros with the given shape.
90    pub fn zeros(shape: TensorShape) -> Self {
91        let n = shape.total_elements();
92        Self {
93            shape,
94            data: TensorData::F32(vec![0.0; n]),
95        }
96    }
97
98    /// Builds a 1D tensor from raw hash bytes (normalized to 0..1).
99    pub fn from_slot_hash(hash_bytes: &[u8]) -> Self {
100        let floats: Vec<f32> = hash_bytes.iter().map(|&b| b as f32 / 255.0).collect();
101        let n = floats.len();
102        Self {
103            shape: TensorShape::new(vec![n]),
104            data: TensorData::F32(floats),
105        }
106    }
107
108    pub fn calculate_hash(&self) -> String {
109        let mut hasher = Sha256::new();
110        hasher.update(self.data.to_bytes());
111        for d in &self.shape.dims {
112            hasher.update(d.to_le_bytes());
113        }
114        hex::encode(hasher.finalize())
115    }
116
117    /// Clamp tensor dimensions to a maximum size (for ESP compatibility).
118    pub fn clamp_dimensions(&self, max_dim: usize) -> Self {
119        let floats = self.data.as_f32();
120        let clamped_len = floats.len().min(max_dim * max_dim);
121        let clamped_data: Vec<f32> = floats.into_iter().take(clamped_len).collect();
122        let new_shape = self.shape.dims.iter().map(|&d| d.min(max_dim)).collect();
123        Self {
124            shape: TensorShape::new(new_shape),
125            data: TensorData::F32(clamped_data),
126        }
127    }
128
129    /// Serialized byte length of the tensor data.
130    pub fn byte_size(&self) -> usize {
131        self.data.to_bytes().len()
132    }
133}