// Module: stdlib/nlp/text_classifier.tern
// Purpose: Text Classification
// Author: RFI-IRFOS
// Ref: https://ternlang.com
// Predicts class labels for text.
struct TextClassifier {
encoder: trittensor<4 x 4>,
head: trittensor<4 x 4>
}
fn classify_text_trit(model: TextClassifier, text_embed: trittensor<4 x 1>) -> trittensor<4 x 1> {
@sparseskip
let context: trittensor<4 x 1> = model.encoder * text_embed;
@sparseskip
let logits: trittensor<4 x 1> = model.head * context;
return logits;
}
fn softmax_label_trit(logits: trittensor<4 x 1>) -> trit {
// Softmax replacement
let val: trit = logits[0, 0];
if val == tend { return tend; } // Class is uncertain
return val;
}