use std::error::Error;
use std::fmt::{Display, Formatter, Result as FmtResult};
use serde_derive::Deserialize;
use serde_json::{Error as SerdeJsonError, Value};
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub enum ModelLibraries {
AdapterTransformers,
AllenNLP,
Asteroid,
CoreML,
Diffusers,
ESPnet,
Fairseq,
Fastai,
FastText,
Flair,
Flax,
Graphcore,
Habana,
Jax,
Joblib,
Keras,
MLAgents,
NeMo,
OpenCLIP,
OpenVINO,
Onnx,
PaddleNLP,
PaddlePaddle,
PyannoteAudio,
Pythae,
PyTorch,
Rust,
Safetensors,
SampleFactory,
ScikitLearn,
SentenceTransformers,
Spacy,
SpanMarker,
Speechbrain,
StableBaselines3,
Stanza,
TensorBoard,
TensorFlow,
TensorFlowTTS,
TFLite,
Timm,
Transformers,
}
#[derive(Debug)]
pub enum ModelError {
Json(SerdeJsonError),
MissingField(String),
ModelNotImplemented(String),
}
impl Display for ModelError {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
match self {
ModelError::Json(e) => write!(f, "JSON error: {}", e),
ModelError::MissingField(field) => write!(f, "Missing field: {}", field),
ModelError::ModelNotImplemented(model) => write!(
f,
"Model not implemented: {}.\
\nPlease open an issue on the GitHub repository: \
https://github.com/chainyo/aiha/issues",
model
),
}
}
}
impl Error for ModelError {}
impl From<SerdeJsonError> for ModelError {
fn from(error: SerdeJsonError) -> Self {
ModelError::Json(error)
}
}
pub trait ModelConfigTrait {
fn hidden_size(&self) -> i32 {
Default::default()
}
fn intermediate_size(&self) -> i32 {
Default::default()
}
fn max_position_embeddings(&self) -> i32 {
Default::default()
}
fn num_attention_heads(&self) -> i32 {
Default::default()
}
fn num_hidden_layers(&self) -> i32 {
Default::default()
}
fn model_type(&self) -> &str {
""
}
fn available_libraries(&self) -> &[ModelLibraries] {
&[]
}
fn from_json(value: Value) -> Result<Self, ModelError>
where
Self: Sized;
}
#[cfg(test)]
mod tests {
use std::vec;
use super::*;
struct MockModelConfig;
impl ModelConfigTrait for MockModelConfig {
fn hidden_size(&self) -> i32 {
1024
}
fn intermediate_size(&self) -> i32 {
4096
}
fn max_position_embeddings(&self) -> i32 {
512
}
fn num_attention_heads(&self) -> i32 {
16
}
fn num_hidden_layers(&self) -> i32 {
12
}
fn model_type(&self) -> &str {
"mock"
}
fn available_libraries(&self) -> &[ModelLibraries] {
&[ModelLibraries::PyTorch]
}
fn from_json(_value: Value) -> Result<Self, ModelError>
where
Self: Sized,
{
Ok(MockModelConfig)
}
}
#[test]
fn test_hub_model_config() {
let config = MockModelConfig;
assert_eq!(config.hidden_size(), 1024);
assert_eq!(config.intermediate_size(), 4096);
assert_eq!(config.max_position_embeddings(), 512);
assert_eq!(config.num_attention_heads(), 16);
assert_eq!(config.num_hidden_layers(), 12);
assert_eq!(config.model_type(), "mock");
assert_eq!(config.available_libraries(), vec![ModelLibraries::PyTorch]);
}
#[test]
fn test_model_libraries_equality() {
let lib1 = ModelLibraries::PyTorch;
let lib2 = ModelLibraries::PyTorch;
let lib3 = ModelLibraries::TensorFlow;
assert_eq!(lib1, lib2);
assert_ne!(lib1, lib3);
}
#[test]
fn test_model_libraries_display() {
let lib1 = ModelLibraries::PyTorch;
let lib2 = ModelLibraries::TensorFlow;
assert_eq!(format!("{:?}", lib1), "PyTorch");
assert_eq!(format!("{:?}", lib2), "TensorFlow");
}
#[test]
fn test_model_libraries_exhaustiveness() {
let libraries = vec![
ModelLibraries::AdapterTransformers,
ModelLibraries::AllenNLP,
ModelLibraries::Asteroid,
ModelLibraries::CoreML,
ModelLibraries::Diffusers,
ModelLibraries::ESPnet,
ModelLibraries::Fairseq,
ModelLibraries::Fastai,
ModelLibraries::FastText,
ModelLibraries::Flair,
ModelLibraries::Flax,
ModelLibraries::Graphcore,
ModelLibraries::Habana,
ModelLibraries::Jax,
ModelLibraries::Joblib,
ModelLibraries::Keras,
ModelLibraries::MLAgents,
ModelLibraries::NeMo,
ModelLibraries::OpenCLIP,
ModelLibraries::OpenVINO,
ModelLibraries::Onnx,
ModelLibraries::PaddleNLP,
ModelLibraries::PaddlePaddle,
ModelLibraries::PyannoteAudio,
ModelLibraries::Pythae,
ModelLibraries::PyTorch,
ModelLibraries::Rust,
ModelLibraries::Safetensors,
ModelLibraries::SampleFactory,
ModelLibraries::ScikitLearn,
ModelLibraries::SentenceTransformers,
ModelLibraries::Spacy,
ModelLibraries::SpanMarker,
ModelLibraries::Speechbrain,
ModelLibraries::StableBaselines3,
ModelLibraries::Stanza,
ModelLibraries::TensorBoard,
ModelLibraries::TensorFlow,
ModelLibraries::TensorFlowTTS,
ModelLibraries::TFLite,
ModelLibraries::Timm,
ModelLibraries::Transformers,
];
assert_eq!(libraries.len(), 42);
}
}