entrenar/hf_pipeline/fetcher/
types.rs1use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum WeightFormat {
11 SafeTensors,
13 GGUF { quant_type: String },
15 PyTorchBin,
17 ONNX,
19}
20
21impl WeightFormat {
22 #[must_use]
24 pub fn from_filename(filename: &str) -> Option<Self> {
25 if filename.ends_with(".safetensors") {
26 Some(Self::SafeTensors)
27 } else if filename.ends_with(".gguf") {
28 Some(Self::GGUF { quant_type: "unknown".into() })
29 } else if filename.ends_with(".bin") {
30 Some(Self::PyTorchBin)
31 } else if filename.ends_with(".onnx") {
32 Some(Self::ONNX)
33 } else {
34 None
35 }
36 }
37
38 #[must_use]
40 pub fn is_safe(&self) -> bool {
41 matches!(self, Self::SafeTensors | Self::GGUF { .. } | Self::ONNX)
42 }
43}
44
45#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
49pub enum Architecture {
50 BERT { num_layers: usize, hidden_size: usize, num_attention_heads: usize },
52 GPT2 { num_layers: usize, hidden_size: usize, num_attention_heads: usize },
54 Llama {
56 num_layers: usize,
57 hidden_size: usize,
58 num_attention_heads: usize,
59 intermediate_size: usize,
60 },
61 T5 { encoder_layers: usize, decoder_layers: usize, hidden_size: usize },
63 Custom { config: serde_json::Value },
65}
66
67impl Architecture {
68 #[must_use]
70 pub fn param_count(&self) -> u64 {
71 match self {
72 Self::BERT { num_layers, hidden_size, num_attention_heads: _ } => {
73 let per_layer = 4 * (*hidden_size as u64).pow(2) + 4 * (*hidden_size as u64).pow(2);
75 per_layer * (*num_layers as u64)
76 }
77 Self::GPT2 { num_layers, hidden_size, .. } => {
78 let per_layer = 4 * (*hidden_size as u64).pow(2) + 4 * (*hidden_size as u64).pow(2);
79 per_layer * (*num_layers as u64)
80 }
81 Self::Llama { num_layers, hidden_size, intermediate_size, .. } => {
82 let attn = 4 * (*hidden_size as u64).pow(2);
83 let ffn = 2 * (*hidden_size as u64) * (*intermediate_size as u64);
84 (attn + ffn) * (*num_layers as u64)
85 }
86 Self::T5 { encoder_layers, decoder_layers, hidden_size } => {
87 let per_layer = 8 * (*hidden_size as u64).pow(2);
88 per_layer * ((*encoder_layers + *decoder_layers) as u64)
89 }
90 Self::Custom { .. } => 0, }
92 }
93}
94
95#[derive(Debug)]
97pub struct ModelArtifact {
98 pub path: PathBuf,
100 pub format: WeightFormat,
102 pub architecture: Option<Architecture>,
104 pub sha256: Option<String>,
106}