Skip to main content

ai3_lib/
tensor.rs

1use pot_o_core::{TribeError, TribeResult};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct TensorShape {
7    pub dims: Vec<usize>,
8}
9
10impl TensorShape {
11    pub fn new(dims: Vec<usize>) -> Self {
12        Self { dims }
13    }
14
15    pub fn total_elements(&self) -> usize {
16        self.dims.iter().product()
17    }
18
19    pub fn is_matrix(&self) -> bool {
20        self.dims.len() == 2
21    }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum TensorData {
26    F32(Vec<f32>),
27    U8(Vec<u8>),
28}
29
30impl TensorData {
31    pub fn len(&self) -> usize {
32        match self {
33            Self::F32(v) => v.len(),
34            Self::U8(v) => v.len(),
35        }
36    }
37
38    pub fn is_empty(&self) -> bool {
39        self.len() == 0
40    }
41
42    pub fn as_f32(&self) -> Vec<f32> {
43        match self {
44            Self::F32(v) => v.clone(),
45            Self::U8(v) => v.iter().map(|&b| b as f32 / 255.0).collect(),
46        }
47    }
48
49    pub fn to_bytes(&self) -> Vec<u8> {
50        match self {
51            Self::F32(v) => v.iter().flat_map(|f| f.to_le_bytes()).collect(),
52            Self::U8(v) => v.clone(),
53        }
54    }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Tensor {
59    pub shape: TensorShape,
60    pub data: TensorData,
61}
62
63impl Tensor {
64    pub fn new(shape: TensorShape, data: TensorData) -> TribeResult<Self> {
65        let expected = shape.total_elements();
66        let actual = data.len();
67        if actual != expected {
68            return Err(TribeError::TensorError(format!(
69                "Shape expects {expected} elements but data has {actual}"
70            )));
71        }
72        Ok(Self { shape, data })
73    }
74
75    pub fn zeros(shape: TensorShape) -> Self {
76        let n = shape.total_elements();
77        Self {
78            shape,
79            data: TensorData::F32(vec![0.0; n]),
80        }
81    }
82
83    pub fn from_slot_hash(hash_bytes: &[u8]) -> Self {
84        let floats: Vec<f32> = hash_bytes.iter().map(|&b| b as f32 / 255.0).collect();
85        let n = floats.len();
86        Self {
87            shape: TensorShape::new(vec![n]),
88            data: TensorData::F32(floats),
89        }
90    }
91
92    pub fn calculate_hash(&self) -> String {
93        let mut hasher = Sha256::new();
94        hasher.update(self.data.to_bytes());
95        for d in &self.shape.dims {
96            hasher.update(d.to_le_bytes());
97        }
98        hex::encode(hasher.finalize())
99    }
100
101    /// Clamp tensor dimensions to a maximum size (for ESP compatibility).
102    pub fn clamp_dimensions(&self, max_dim: usize) -> Self {
103        let floats = self.data.as_f32();
104        let clamped_len = floats.len().min(max_dim * max_dim);
105        let clamped_data: Vec<f32> = floats.into_iter().take(clamped_len).collect();
106        let new_shape = self.shape.dims.iter().map(|&d| d.min(max_dim)).collect();
107        Self {
108            shape: TensorShape::new(new_shape),
109            data: TensorData::F32(clamped_data),
110        }
111    }
112
113    pub fn byte_size(&self) -> usize {
114        self.data.to_bytes().len()
115    }
116}