// Module: stdlib/nn/feedforward.tern
// Purpose: Ternary Transformer Feedforward Layer
// Author: RFI-IRFOS
// Ref: https://ternlang.com
// The FFN block of a transformer. Expands dimension then contracts.
struct FFNLayer {
expand_weights: trittensor<8 x 4>, // Up-projection
contract_weights: trittensor<4 x 8> // Down-projection
}
fn expand_trit(layer: FFNLayer, input: trittensor<4 x 1>) -> trittensor<8 x 1> {
@sparseskip
let expanded: trittensor<8 x 1> = layer.expand_weights * input;
return expanded;
}
fn contract_trit(layer: FFNLayer, expanded: trittensor<8 x 1>) -> trittensor<4 x 1> {
@sparseskip
let contracted: trittensor<4 x 1> = layer.contract_weights * expanded;
return contracted;
}
fn ff_forward(layer: FFNLayer, input: trittensor<4 x 1>) -> trittensor<4 x 1> {
let exp: trittensor<8 x 1> = expand_trit(layer, input);
let out: trittensor<4 x 1> = contract_trit(layer, exp);
return out;
}