1use pot_o_core::{TribeError, TribeResult};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct TensorShape {
10 pub dims: Vec<usize>,
12}
13
14impl TensorShape {
15 pub fn new(dims: Vec<usize>) -> Self {
17 Self { dims }
18 }
19
20 pub fn total_elements(&self) -> usize {
22 self.dims.iter().product()
23 }
24
25 pub fn is_matrix(&self) -> bool {
27 self.dims.len() == 2
28 }
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub enum TensorData {
34 F32(Vec<f32>),
36 U8(Vec<u8>),
38}
39
40impl TensorData {
41 pub fn len(&self) -> usize {
42 match self {
43 Self::F32(v) => v.len(),
44 Self::U8(v) => v.len(),
45 }
46 }
47
48 pub fn is_empty(&self) -> bool {
49 self.len() == 0
50 }
51
52 pub fn as_f32(&self) -> Vec<f32> {
53 match self {
54 Self::F32(v) => v.clone(),
55 Self::U8(v) => v.iter().map(|&b| b as f32 / 255.0).collect(),
56 }
57 }
58
59 pub fn to_bytes(&self) -> Vec<u8> {
60 match self {
61 Self::F32(v) => v.iter().flat_map(|f| f.to_le_bytes()).collect(),
62 Self::U8(v) => v.clone(),
63 }
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Tensor {
70 pub shape: TensorShape,
72 pub data: TensorData,
74}
75
76impl Tensor {
77 pub fn new(shape: TensorShape, data: TensorData) -> TribeResult<Self> {
79 let expected = shape.total_elements();
80 let actual = data.len();
81 if actual != expected {
82 return Err(TribeError::TensorError(format!(
83 "Shape expects {expected} elements but data has {actual}"
84 )));
85 }
86 Ok(Self { shape, data })
87 }
88
89 pub fn zeros(shape: TensorShape) -> Self {
91 let n = shape.total_elements();
92 Self {
93 shape,
94 data: TensorData::F32(vec![0.0; n]),
95 }
96 }
97
98 pub fn from_slot_hash(hash_bytes: &[u8]) -> Self {
100 let floats: Vec<f32> = hash_bytes.iter().map(|&b| b as f32 / 255.0).collect();
101 let n = floats.len();
102 Self {
103 shape: TensorShape::new(vec![n]),
104 data: TensorData::F32(floats),
105 }
106 }
107
108 pub fn calculate_hash(&self) -> String {
109 let mut hasher = Sha256::new();
110 hasher.update(self.data.to_bytes());
111 for d in &self.shape.dims {
112 hasher.update(d.to_le_bytes());
113 }
114 hex::encode(hasher.finalize())
115 }
116
117 pub fn clamp_dimensions(&self, max_dim: usize) -> Self {
119 let floats = self.data.as_f32();
120 let clamped_len = floats.len().min(max_dim * max_dim);
121 let clamped_data: Vec<f32> = floats.into_iter().take(clamped_len).collect();
122 let new_shape = self.shape.dims.iter().map(|&d| d.min(max_dim)).collect();
123 Self {
124 shape: TensorShape::new(new_shape),
125 data: TensorData::F32(clamped_data),
126 }
127 }
128
129 pub fn byte_size(&self) -> usize {
131 self.data.to_bytes().len()
132 }
133}