Skip to main content

ai3_lib/
operations.rs

1use crate::tensor::{Tensor, TensorData, TensorShape};
2use pot_o_core::{TribeError, TribeResult};
3
4/// Trait for all tensor operations (aligned with .AI3 TensorOp)
5pub trait TensorOp: Send + Sync {
6    fn name(&self) -> &str;
7    fn execute(&self, input: &Tensor) -> TribeResult<Tensor>;
8}
9
10pub fn parse_operation(op_type: &str) -> TribeResult<Box<dyn TensorOp>> {
11    match op_type {
12        "matrix_multiply" => Ok(Box::new(MatrixMultiply)),
13        "convolution" => Ok(Box::new(Convolution::default())),
14        "relu" => Ok(Box::new(ActivationFunction::ReLU)),
15        "sigmoid" => Ok(Box::new(ActivationFunction::Sigmoid)),
16        "tanh" => Ok(Box::new(ActivationFunction::Tanh)),
17        "dot_product" => Ok(Box::new(VectorOp::DotProduct)),
18        "normalize" => Ok(Box::new(VectorOp::Normalize)),
19        _ => Err(TribeError::TensorError(format!(
20            "Unknown operation: {op_type}"
21        ))),
22    }
23}
24
25/// Matrix multiplication (self-multiply for square-ish inputs)
26pub struct MatrixMultiply;
27
28impl TensorOp for MatrixMultiply {
29    fn name(&self) -> &str {
30        "matrix_multiply"
31    }
32
33    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
34        let data = input.data.as_f32();
35        let n = (data.len() as f64).sqrt() as usize;
36        if n == 0 {
37            return Ok(Tensor::zeros(TensorShape::new(vec![0])));
38        }
39        let size = n * n;
40        let a: Vec<f32> = data.iter().copied().take(size).collect();
41        let mut result = vec![0.0f32; size];
42        for i in 0..n {
43            for j in 0..n {
44                let mut sum = 0.0f32;
45                for k in 0..n {
46                    let ai = a.get(i * n + k).copied().unwrap_or(0.0);
47                    let bj = a.get(k * n + j).copied().unwrap_or(0.0);
48                    sum += ai * bj;
49                }
50                result[i * n + j] = sum;
51            }
52        }
53        Tensor::new(TensorShape::new(vec![n, n]), TensorData::F32(result))
54    }
55}
56
57/// 1D convolution with a small fixed kernel
58pub struct Convolution {
59    pub kernel: Vec<f32>,
60}
61
62impl Default for Convolution {
63    fn default() -> Self {
64        Self {
65            kernel: vec![0.25, 0.5, 0.25],
66        }
67    }
68}
69
70impl TensorOp for Convolution {
71    fn name(&self) -> &str {
72        "convolution"
73    }
74
75    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
76        let data = input.data.as_f32();
77        let klen = self.kernel.len();
78        if data.len() < klen {
79            return Ok(input.clone());
80        }
81        let out_len = data.len() - klen + 1;
82        let mut result = Vec::with_capacity(out_len);
83        for i in 0..out_len {
84            let mut sum = 0.0f32;
85            for (j, &kv) in self.kernel.iter().enumerate() {
86                sum += data[i + j] * kv;
87            }
88            result.push(sum);
89        }
90        Tensor::new(TensorShape::new(vec![out_len]), TensorData::F32(result))
91    }
92}
93
94#[derive(Debug, Clone)]
95pub enum ActivationFunction {
96    ReLU,
97    Sigmoid,
98    Tanh,
99}
100
101impl TensorOp for ActivationFunction {
102    fn name(&self) -> &str {
103        match self {
104            Self::ReLU => "relu",
105            Self::Sigmoid => "sigmoid",
106            Self::Tanh => "tanh",
107        }
108    }
109
110    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
111        let data = input.data.as_f32();
112        let result: Vec<f32> = match self {
113            Self::ReLU => data.iter().map(|&x| x.max(0.0)).collect(),
114            Self::Sigmoid => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
115            Self::Tanh => data.iter().map(|&x| x.tanh()).collect(),
116        };
117        Tensor::new(input.shape.clone(), TensorData::F32(result))
118    }
119}
120
121#[derive(Debug, Clone)]
122pub enum VectorOp {
123    DotProduct,
124    Normalize,
125}
126
127impl TensorOp for VectorOp {
128    fn name(&self) -> &str {
129        match self {
130            Self::DotProduct => "dot_product",
131            Self::Normalize => "normalize",
132        }
133    }
134
135    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
136        let data = input.data.as_f32();
137        match self {
138            Self::DotProduct => {
139                let half = data.len() / 2;
140                let dot: f32 = data[..half]
141                    .iter()
142                    .zip(data[half..half * 2].iter())
143                    .map(|(a, b)| a * b)
144                    .sum();
145                Tensor::new(TensorShape::new(vec![1]), TensorData::F32(vec![dot]))
146            }
147            Self::Normalize => {
148                let magnitude: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
149                let result = if magnitude > f32::EPSILON {
150                    data.iter().map(|x| x / magnitude).collect()
151                } else {
152                    data.clone()
153                };
154                Tensor::new(input.shape.clone(), TensorData::F32(result))
155            }
156        }
157    }
158}