1use crate::tensor::{Tensor, TensorData, TensorShape};
4use pot_o_core::{TribeError, TribeResult};
5
6pub trait TensorOp: Send + Sync {
8 fn name(&self) -> &str;
9 fn execute(&self, input: &Tensor) -> TribeResult<Tensor>;
10}
11
12pub fn parse_operation(op_type: &str) -> TribeResult<Box<dyn TensorOp>> {
14 match op_type {
15 "matrix_multiply" => Ok(Box::new(MatrixMultiply)),
16 "convolution" => Ok(Box::new(Convolution::default())),
17 "relu" => Ok(Box::new(ActivationFunction::ReLU)),
18 "sigmoid" => Ok(Box::new(ActivationFunction::Sigmoid)),
19 "tanh" => Ok(Box::new(ActivationFunction::Tanh)),
20 "dot_product" => Ok(Box::new(VectorOp::DotProduct)),
21 "normalize" => Ok(Box::new(VectorOp::Normalize)),
22 _ => Err(TribeError::TensorError(format!(
23 "Unknown operation: {op_type}"
24 ))),
25 }
26}
27
28pub struct MatrixMultiply;
30
31impl TensorOp for MatrixMultiply {
32 fn name(&self) -> &str {
33 "matrix_multiply"
34 }
35
36 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
37 let data = input.data.as_f32();
38 let n = (data.len() as f64).sqrt() as usize;
39 if n == 0 {
40 return Ok(Tensor::zeros(TensorShape::new(vec![0])));
41 }
42 let size = n * n;
43 let a: Vec<f32> = data.iter().copied().take(size).collect();
44 let mut result = vec![0.0f32; size];
45 for i in 0..n {
46 for j in 0..n {
47 let mut sum = 0.0f32;
48 for k in 0..n {
49 let ai = a.get(i * n + k).copied().unwrap_or(0.0);
50 let bj = a.get(k * n + j).copied().unwrap_or(0.0);
51 sum += ai * bj;
52 }
53 result[i * n + j] = sum;
54 }
55 }
56 Tensor::new(TensorShape::new(vec![n, n]), TensorData::F32(result))
57 }
58}
59
60pub struct Convolution {
62 pub kernel: Vec<f32>,
63}
64
65impl Default for Convolution {
66 fn default() -> Self {
67 Self {
68 kernel: vec![0.25, 0.5, 0.25],
69 }
70 }
71}
72
73impl TensorOp for Convolution {
74 fn name(&self) -> &str {
75 "convolution"
76 }
77
78 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
79 let data = input.data.as_f32();
80 let klen = self.kernel.len();
81 if data.len() < klen {
82 return Ok(input.clone());
83 }
84 let out_len = data.len() - klen + 1;
85 let mut result = Vec::with_capacity(out_len);
86 for i in 0..out_len {
87 let mut sum = 0.0f32;
88 for (j, &kv) in self.kernel.iter().enumerate() {
89 sum += data[i + j] * kv;
90 }
91 result.push(sum);
92 }
93 Tensor::new(TensorShape::new(vec![out_len]), TensorData::F32(result))
94 }
95}
96
97#[derive(Debug, Clone)]
98pub enum ActivationFunction {
99 ReLU,
100 Sigmoid,
101 Tanh,
102}
103
104impl TensorOp for ActivationFunction {
105 fn name(&self) -> &str {
106 match self {
107 Self::ReLU => "relu",
108 Self::Sigmoid => "sigmoid",
109 Self::Tanh => "tanh",
110 }
111 }
112
113 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
114 let data = input.data.as_f32();
115 let result: Vec<f32> = match self {
116 Self::ReLU => data.iter().map(|&x| x.max(0.0)).collect(),
117 Self::Sigmoid => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
118 Self::Tanh => data.iter().map(|&x| x.tanh()).collect(),
119 };
120 Tensor::new(input.shape.clone(), TensorData::F32(result))
121 }
122}
123
124#[derive(Debug, Clone)]
125pub enum VectorOp {
126 DotProduct,
127 Normalize,
128}
129
130impl TensorOp for VectorOp {
131 fn name(&self) -> &str {
132 match self {
133 Self::DotProduct => "dot_product",
134 Self::Normalize => "normalize",
135 }
136 }
137
138 fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
139 let data = input.data.as_f32();
140 match self {
141 Self::DotProduct => {
142 let half = data.len() / 2;
143 let dot: f32 = data[..half]
144 .iter()
145 .zip(data[half..half * 2].iter())
146 .map(|(a, b)| a * b)
147 .sum();
148 Tensor::new(TensorShape::new(vec![1]), TensorData::F32(vec![dot]))
149 }
150 Self::Normalize => {
151 let magnitude: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
152 let result = if magnitude > f32::EPSILON {
153 data.iter().map(|x| x / magnitude).collect()
154 } else {
155 data.clone()
156 };
157 Tensor::new(input.shape.clone(), TensorData::F32(result))
158 }
159 }
160 }
161}