#![allow(clippy::upper_case_acronyms)]
use crate::download::{language::Language, AvailableOnnxModel, ModelUrl};
#[derive(Debug, Clone)]
pub enum MachineComprehension {
BiDAF,
BERTSquad,
RoBERTa(RoBERTa),
GPT2(GPT2),
}
#[derive(Debug, Clone)]
pub enum RoBERTa {
RoBERTaBase,
RoBERTaSequenceClassification,
}
#[derive(Debug, Clone)]
pub enum GPT2 {
GPT2,
GPT2LmHead,
}
impl ModelUrl for MachineComprehension {
fn fetch_url(&self) -> &'static str {
match self {
MachineComprehension::BiDAF => "https://github.com/onnx/models/raw/master/text/machine_comprehension/bidirectional_attention_flow/model/bidaf-9.onnx",
MachineComprehension::BERTSquad => "https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx",
MachineComprehension::RoBERTa(variant) => variant.fetch_url(),
MachineComprehension::GPT2(variant) => variant.fetch_url(),
}
}
}
impl ModelUrl for RoBERTa {
fn fetch_url(&self) -> &'static str {
match self {
RoBERTa::RoBERTaBase => "https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-base-11.onnx",
RoBERTa::RoBERTaSequenceClassification => "https://github.com/onnx/models/raw/master/text/machine_comprehension/roberta/model/roberta-sequence-classification-9.onnx",
}
}
}
impl ModelUrl for GPT2 {
fn fetch_url(&self) -> &'static str {
match self {
GPT2::GPT2 => "https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-10.onnx",
GPT2::GPT2LmHead => "https://github.com/onnx/models/raw/master/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx",
}
}
}
impl From<MachineComprehension> for AvailableOnnxModel {
fn from(model: MachineComprehension) -> Self {
AvailableOnnxModel::Language(Language::MachineComprehension(model))
}
}
impl From<RoBERTa> for AvailableOnnxModel {
fn from(model: RoBERTa) -> Self {
AvailableOnnxModel::Language(Language::MachineComprehension(
MachineComprehension::RoBERTa(model),
))
}
}
impl From<GPT2> for AvailableOnnxModel {
fn from(model: GPT2) -> Self {
AvailableOnnxModel::Language(Language::MachineComprehension(MachineComprehension::GPT2(
model,
)))
}
}