1use crate::tensor::{Tensor, TensorData, TensorShape};
2use pot_o_core::{TribeError, TribeResult};
3
4pub trait TensorOp: Send + Sync {
6 fn name(&self) -> &str;
7 fn execute(&self, input: &Tensor) -> TribeResult<Tensor>;
8}
9
10pub fn parse_operation(op_type: &str) -> TribeResult<Box<dyn TensorOp>> {
11 match op_type {
12 "matrix_multiply" => Ok(Box::new(MatrixMultiply)),
13 "convolution" => Ok(Box::new(Convolution::default())),
14 "relu" => Ok(Box::new(ActivationFunction::ReLU)),
15 "sigmoid" => Ok(Box::new(ActivationFunction::Sigmoid)),
16 "tanh" => Ok(Box::new(ActivationFunction::Tanh)),
17 "dot_product" => Ok(Box::new(VectorOp::DotProduct)),
18 "normalize" => Ok(Box::new(VectorOp::Normalize)),
19 _ => Err(TribeError::TensorError(format!(
20 "Unknown operation: {op_type}"
21 ))),
22 }
23}
24
25pub struct MatrixMultiply;
27
28impl TensorOp for MatrixMultiply {
29 fn name(&self) -> &str {
30 "matrix_multiply"
31 }
32
33 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
34 let data = input.data.as_f32();
35 let n = (data.len() as f64).sqrt() as usize;
36 if n == 0 {
37 return Ok(Tensor::zeros(TensorShape::new(vec![0])));
38 }
39 let size = n * n;
40 let a: Vec<f32> = data.iter().copied().take(size).collect();
41 let mut result = vec![0.0f32; size];
42 for i in 0..n {
43 for j in 0..n {
44 let mut sum = 0.0f32;
45 for k in 0..n {
46 let ai = a.get(i * n + k).copied().unwrap_or(0.0);
47 let bj = a.get(k * n + j).copied().unwrap_or(0.0);
48 sum += ai * bj;
49 }
50 result[i * n + j] = sum;
51 }
52 }
53 Tensor::new(TensorShape::new(vec![n, n]), TensorData::F32(result))
54 }
55}
56
57pub struct Convolution {
59 pub kernel: Vec<f32>,
60}
61
62impl Default for Convolution {
63 fn default() -> Self {
64 Self {
65 kernel: vec![0.25, 0.5, 0.25],
66 }
67 }
68}
69
70impl TensorOp for Convolution {
71 fn name(&self) -> &str {
72 "convolution"
73 }
74
75 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
76 let data = input.data.as_f32();
77 let klen = self.kernel.len();
78 if data.len() < klen {
79 return Ok(input.clone());
80 }
81 let out_len = data.len() - klen + 1;
82 let mut result = Vec::with_capacity(out_len);
83 for i in 0..out_len {
84 let mut sum = 0.0f32;
85 for (j, &kv) in self.kernel.iter().enumerate() {
86 sum += data[i + j] * kv;
87 }
88 result.push(sum);
89 }
90 Tensor::new(TensorShape::new(vec![out_len]), TensorData::F32(result))
91 }
92}
93
94#[derive(Debug, Clone)]
95pub enum ActivationFunction {
96 ReLU,
97 Sigmoid,
98 Tanh,
99}
100
101impl TensorOp for ActivationFunction {
102 fn name(&self) -> &str {
103 match self {
104 Self::ReLU => "relu",
105 Self::Sigmoid => "sigmoid",
106 Self::Tanh => "tanh",
107 }
108 }
109
110 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
111 let data = input.data.as_f32();
112 let result: Vec<f32> = match self {
113 Self::ReLU => data.iter().map(|&x| x.max(0.0)).collect(),
114 Self::Sigmoid => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
115 Self::Tanh => data.iter().map(|&x| x.tanh()).collect(),
116 };
117 Tensor::new(input.shape.clone(), TensorData::F32(result))
118 }
119}
120
121#[derive(Debug, Clone)]
122pub enum VectorOp {
123 DotProduct,
124 Normalize,
125}
126
127impl TensorOp for VectorOp {
128 fn name(&self) -> &str {
129 match self {
130 Self::DotProduct => "dot_product",
131 Self::Normalize => "normalize",
132 }
133 }
134
135 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
136 let data = input.data.as_f32();
137 match self {
138 Self::DotProduct => {
139 let half = data.len() / 2;
140 let dot: f32 = data[..half]
141 .iter()
142 .zip(data[half..half * 2].iter())
143 .map(|(a, b)| a * b)
144 .sum();
145 Tensor::new(TensorShape::new(vec![1]), TensorData::F32(vec![dot]))
146 }
147 Self::Normalize => {
148 let magnitude: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
149 let result = if magnitude > f32::EPSILON {
150 data.iter().map(|x| x / magnitude).collect()
151 } else {
152 data.clone()
153 };
154 Tensor::new(input.shape.clone(), TensorData::F32(result))
155 }
156 }
157 }
158}