Skip to main content

ai3_lib/
operations.rs

1//! Tensor operations: matrix multiply, convolution, activations, vector ops (aligned with .AI3).
2
3use crate::tensor::{Tensor, TensorData, TensorShape};
4use pot_o_core::{TribeError, TribeResult};
5
6/// Trait for all tensor operations (aligned with .AI3 TensorOp).
7pub trait TensorOp: Send + Sync {
8    fn name(&self) -> &str;
9    fn execute(&self, input: &Tensor) -> TribeResult<Tensor>;
10}
11
12/// Parses an operation type string into a boxed [`TensorOp`] implementation.
13pub fn parse_operation(op_type: &str) -> TribeResult<Box<dyn TensorOp>> {
14    match op_type {
15        "matrix_multiply" => Ok(Box::new(MatrixMultiply)),
16        "convolution" => Ok(Box::new(Convolution::default())),
17        "relu" => Ok(Box::new(ActivationFunction::ReLU)),
18        "sigmoid" => Ok(Box::new(ActivationFunction::Sigmoid)),
19        "tanh" => Ok(Box::new(ActivationFunction::Tanh)),
20        "dot_product" => Ok(Box::new(VectorOp::DotProduct)),
21        "normalize" => Ok(Box::new(VectorOp::Normalize)),
22        _ => Err(TribeError::TensorError(format!(
23            "Unknown operation: {op_type}"
24        ))),
25    }
26}
27
28/// Matrix multiplication (self-multiply for square-ish inputs)
29pub struct MatrixMultiply;
30
31impl TensorOp for MatrixMultiply {
32    fn name(&self) -> &str {
33        "matrix_multiply"
34    }
35
36    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
37        let data = input.data.as_f32();
38        let n = (data.len() as f64).sqrt() as usize;
39        if n == 0 {
40            return Ok(Tensor::zeros(TensorShape::new(vec![0])));
41        }
42        let size = n * n;
43        let a: Vec<f32> = data.iter().copied().take(size).collect();
44        let mut result = vec![0.0f32; size];
45        for i in 0..n {
46            for j in 0..n {
47                let mut sum = 0.0f32;
48                for k in 0..n {
49                    let ai = a.get(i * n + k).copied().unwrap_or(0.0);
50                    let bj = a.get(k * n + j).copied().unwrap_or(0.0);
51                    sum += ai * bj;
52                }
53                result[i * n + j] = sum;
54            }
55        }
56        Tensor::new(TensorShape::new(vec![n, n]), TensorData::F32(result))
57    }
58}
59
60/// 1D convolution with a small fixed kernel
61pub struct Convolution {
62    pub kernel: Vec<f32>,
63}
64
65impl Default for Convolution {
66    fn default() -> Self {
67        Self {
68            kernel: vec![0.25, 0.5, 0.25],
69        }
70    }
71}
72
73impl TensorOp for Convolution {
74    fn name(&self) -> &str {
75        "convolution"
76    }
77
78    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
79        let data = input.data.as_f32();
80        let klen = self.kernel.len();
81        if data.len() < klen {
82            return Ok(input.clone());
83        }
84        let out_len = data.len() - klen + 1;
85        let mut result = Vec::with_capacity(out_len);
86        for i in 0..out_len {
87            let mut sum = 0.0f32;
88            for (j, &kv) in self.kernel.iter().enumerate() {
89                sum += data[i + j] * kv;
90            }
91            result.push(sum);
92        }
93        Tensor::new(TensorShape::new(vec![out_len]), TensorData::F32(result))
94    }
95}
96
97#[derive(Debug, Clone)]
98pub enum ActivationFunction {
99    ReLU,
100    Sigmoid,
101    Tanh,
102}
103
104impl TensorOp for ActivationFunction {
105    fn name(&self) -> &str {
106        match self {
107            Self::ReLU => "relu",
108            Self::Sigmoid => "sigmoid",
109            Self::Tanh => "tanh",
110        }
111    }
112
113    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
114        let data = input.data.as_f32();
115        let result: Vec<f32> = match self {
116            Self::ReLU => data.iter().map(|&x| x.max(0.0)).collect(),
117            Self::Sigmoid => data.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect(),
118            Self::Tanh => data.iter().map(|&x| x.tanh()).collect(),
119        };
120        Tensor::new(input.shape.clone(), TensorData::F32(result))
121    }
122}
123
124#[derive(Debug, Clone)]
125pub enum VectorOp {
126    DotProduct,
127    Normalize,
128}
129
130impl TensorOp for VectorOp {
131    fn name(&self) -> &str {
132        match self {
133            Self::DotProduct => "dot_product",
134            Self::Normalize => "normalize",
135        }
136    }
137
138    fn execute(&self, input: &Tensor) -> TribeResult<Tensor> {
139        let data = input.data.as_f32();
140        match self {
141            Self::DotProduct => {
142                let half = data.len() / 2;
143                let dot: f32 = data[..half]
144                    .iter()
145                    .zip(data[half..half * 2].iter())
146                    .map(|(a, b)| a * b)
147                    .sum();
148                Tensor::new(TensorShape::new(vec![1]), TensorData::F32(vec![dot]))
149            }
150            Self::Normalize => {
151                let magnitude: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
152                let result = if magnitude > f32::EPSILON {
153                    data.iter().map(|x| x / magnitude).collect()
154                } else {
155                    data.clone()
156                };
157                Tensor::new(input.shape.clone(), TensorData::F32(result))
158            }
159        }
160    }
161}