// Module: stdlib/models/cnn.tern
// Purpose: Convolutional Neural Network in Ternary
// Author: RFI-IRFOS
// Ref: https://ternlang.com
// CNN leveraging ternary convolutions and pooling.
struct TritCNN {
conv1: trittensor<4 x 4>,
dense1: trittensor<4 x 4>
}
fn feature_extract_trit(model: TritCNN, image: trittensor<4 x 1>) -> trittensor<4 x 1> {
@sparseskip
let features: trittensor<4 x 1> = model.conv1 * image;
return features;
}
fn classify_trit(model: TritCNN, features: trittensor<4 x 1>) -> trit {
@sparseskip
let logits: trittensor<4 x 1> = model.dense1 * features;
let pred: trit = logits[0, 0];
match pred {
affirm => { return affirm; }
tend => { return tend; }
reject => { return reject; }
}
}
fn cnn_forward(model: TritCNN, image: trittensor<4 x 1>) -> trit {
let feat: trittensor<4 x 1> = feature_extract_trit(model, image);
return classify_trit(model, feat);
}