#![allow(clippy::upper_case_acronyms)]
use crate::download::ModelUrl;
#[derive(Debug, Clone)]
pub enum MachineComprehension {
BiDAF,
BERTSquad,
RoBERTa(RoBERTa),
GPT2(GPT2)
}
#[derive(Debug, Clone)]
pub enum RoBERTa {
RoBERTaBase,
RoBERTaSequenceClassification
}
#[derive(Debug, Clone)]
pub enum GPT2 {
GPT2,
GPT2LmHead
}
impl ModelUrl for MachineComprehension {
fn fetch_url(&self) -> &'static str {
match self {
MachineComprehension::BiDAF => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx",
MachineComprehension::BERTSquad => "https://github.com/onnx/models/raw/main/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx",
MachineComprehension::RoBERTa(variant) => variant.fetch_url(),
MachineComprehension::GPT2(variant) => variant.fetch_url()
}
}
}
impl ModelUrl for RoBERTa {
fn fetch_url(&self) -> &'static str {
match self {
RoBERTa::RoBERTaBase => "https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-base-11.onnx",
RoBERTa::RoBERTaSequenceClassification => {
"https://github.com/onnx/models/raw/main/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.onnx"
}
}
}
}
impl ModelUrl for GPT2 {
fn fetch_url(&self) -> &'static str {
match self {
GPT2::GPT2 => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-10.onnx",
GPT2::GPT2LmHead => "https://github.com/onnx/models/raw/main/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx"
}
}
}