1use pot_o_core::{TribeError, TribeResult};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct TensorShape {
7 pub dims: Vec<usize>,
8}
9
10impl TensorShape {
11 pub fn new(dims: Vec<usize>) -> Self {
12 Self { dims }
13 }
14
15 pub fn total_elements(&self) -> usize {
16 self.dims.iter().product()
17 }
18
19 pub fn is_matrix(&self) -> bool {
20 self.dims.len() == 2
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub enum TensorData {
26 F32(Vec<f32>),
27 U8(Vec<u8>),
28}
29
30impl TensorData {
31 pub fn len(&self) -> usize {
32 match self {
33 Self::F32(v) => v.len(),
34 Self::U8(v) => v.len(),
35 }
36 }
37
38 pub fn is_empty(&self) -> bool {
39 self.len() == 0
40 }
41
42 pub fn as_f32(&self) -> Vec<f32> {
43 match self {
44 Self::F32(v) => v.clone(),
45 Self::U8(v) => v.iter().map(|&b| b as f32 / 255.0).collect(),
46 }
47 }
48
49 pub fn to_bytes(&self) -> Vec<u8> {
50 match self {
51 Self::F32(v) => v.iter().flat_map(|f| f.to_le_bytes()).collect(),
52 Self::U8(v) => v.clone(),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Tensor {
59 pub shape: TensorShape,
60 pub data: TensorData,
61}
62
63impl Tensor {
64 pub fn new(shape: TensorShape, data: TensorData) -> TribeResult<Self> {
65 let expected = shape.total_elements();
66 let actual = data.len();
67 if actual != expected {
68 return Err(TribeError::TensorError(format!(
69 "Shape expects {expected} elements but data has {actual}"
70 )));
71 }
72 Ok(Self { shape, data })
73 }
74
75 pub fn zeros(shape: TensorShape) -> Self {
76 let n = shape.total_elements();
77 Self {
78 shape,
79 data: TensorData::F32(vec![0.0; n]),
80 }
81 }
82
83 pub fn from_slot_hash(hash_bytes: &[u8]) -> Self {
84 let floats: Vec<f32> = hash_bytes.iter().map(|&b| b as f32 / 255.0).collect();
85 let n = floats.len();
86 Self {
87 shape: TensorShape::new(vec![n]),
88 data: TensorData::F32(floats),
89 }
90 }
91
92 pub fn calculate_hash(&self) -> String {
93 let mut hasher = Sha256::new();
94 hasher.update(self.data.to_bytes());
95 for d in &self.shape.dims {
96 hasher.update(d.to_le_bytes());
97 }
98 hex::encode(hasher.finalize())
99 }
100
101 pub fn clamp_dimensions(&self, max_dim: usize) -> Self {
103 let floats = self.data.as_f32();
104 let clamped_len = floats.len().min(max_dim * max_dim);
105 let clamped_data: Vec<f32> = floats.into_iter().take(clamped_len).collect();
106 let new_shape = self.shape.dims.iter().map(|&d| d.min(max_dim)).collect();
107 Self {
108 shape: TensorShape::new(new_shape),
109 data: TensorData::F32(clamped_data),
110 }
111 }
112
113 pub fn byte_size(&self) -> usize {
114 self.data.to_bytes().len()
115 }
116}