dynamo_llm/gguf/
gguf_metadata.rs1use akin::akin;
29use anyhow::Result;
30use anyhow::ensure;
31use candle_core::quantized::gguf_file;
32use std::collections::HashMap;
33use tracing::warn;
34
35use crate::gguf::Content;
36
37pub trait ModelConfigLike {
38 fn max_seq_len(&self) -> usize;
39 fn num_layers(&self) -> usize;
40 fn hidden_size(&self) -> usize;
41 fn num_kv_heads(&self) -> usize;
42 fn num_attn_heads(&self) -> usize;
43 fn k_head_dim(&self) -> usize;
44 fn v_head_dim(&self) -> usize;
45}
46
47#[allow(dead_code)]
48#[derive(Debug)]
49pub struct ContentConfig {
50 max_seq_len: usize,
51 hidden_size: usize,
52 num_attn_heads: usize,
53 num_kv_heads: usize,
54 num_layers: usize,
55 key_length: Option<usize>,
56 value_length: Option<usize>,
57}
58
59#[allow(clippy::cast_possible_truncation)]
60impl From<&Content> for ContentConfig {
61 fn from(value: &Content) -> Self {
62 let metadata = value.get_metadata();
63 let arch = metadata["general.architecture"].to_string().unwrap();
64 Self {
65 max_seq_len: metadata[&format!("{arch}.context_length")]
66 .to_u64()
67 .unwrap() as usize,
68 hidden_size: metadata[&format!("{arch}.embedding_length")]
69 .to_u64()
70 .unwrap() as usize,
71 num_attn_heads: metadata[&format!("{arch}.attention.head_count")]
72 .to_u64()
73 .unwrap() as usize,
74 num_kv_heads: metadata[&format!("{arch}.attention.head_count_kv")]
75 .to_u64()
76 .unwrap() as usize,
77 num_layers: metadata[&format!("{arch}.block_count")].to_u64().unwrap() as usize,
78 key_length: metadata
79 .get(&format!("{arch}.attention.key_length"))
80 .map(|x| x.to_u64().unwrap() as usize),
81 value_length: metadata
82 .get(&format!("{arch}.attention.value_length"))
83 .map(|x| x.to_u64().unwrap() as usize),
84 }
85 }
86}
87
88impl ModelConfigLike for ContentConfig {
89 fn max_seq_len(&self) -> usize {
90 self.max_seq_len
91 }
92 fn hidden_size(&self) -> usize {
93 self.hidden_size
94 }
95 fn num_attn_heads(&self) -> usize {
96 self.num_attn_heads
97 }
98 fn num_kv_heads(&self) -> usize {
99 self.num_kv_heads
100 }
101 fn num_layers(&self) -> usize {
102 self.num_layers
103 }
104 fn k_head_dim(&self) -> usize {
105 self.key_length
106 .unwrap_or(self.hidden_size / self.num_attn_heads)
107 }
108 fn v_head_dim(&self) -> usize {
109 self.value_length
110 .unwrap_or(self.hidden_size / self.num_attn_heads)
111 }
112}
113
114pub struct ContentMetadata<'a> {
115 pub path_prefix: &'a str,
116 pub metadata: &'a HashMap<String, gguf_file::Value>,
117}
118
119impl ContentMetadata<'_> {
120 pub fn get_value<T: TryFromValue>(&self, field_name: &str) -> Result<T, anyhow::Error> {
122 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
123 let value = self.metadata.get(&prop_key).cloned();
124
125 value
128 .try_value_into()
129 .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
130 }
131
132 pub fn has_required_keys(&self, fields: &[&str]) -> Result<()> {
134 let mut all_props_are_present = true;
135
136 for field_name in fields {
137 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
138
139 if !self.metadata.contains_key(&prop_key) {
140 all_props_are_present = false;
141 warn!("Expected GGUF metadata to have key: `{prop_key}`");
142 }
143 }
144
145 ensure!(all_props_are_present, "Tokenizer is missing required props");
146 Ok(())
147 }
148}
149
150pub trait TryFromValue {
153 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error>
154 where
155 Self: Sized;
156}
157
158akin! {
162 let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64];
163 let &to_type = [
164 value.to_string().cloned(),
165 value.to_bool(),
166 value.to_f32(),
167 value.to_f64(),
168 value.to_i8(),
169 value.to_i16(),
170 value.to_i32(),
171 value.to_i64(),
172 value.to_u8(),
173 value.to_u16(),
174 value.to_u32(),
175 value.to_u64(),
176 ];
177
178 impl TryFromValue for *types {
179 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error> {
180 *to_type.or_else(|_| candle_core::bail!("value is not a `*types`"))
181 }
182 }
183}
184
185impl<T: TryFromValue> TryFromValue for Vec<T> {
187 fn try_from_value(value_vec: gguf_file::Value) -> Result<Self, candle_core::Error> {
188 value_vec
189 .to_vec()
190 .or_else(|_| candle_core::bail!("value is not a `Vec`"))?
191 .clone()
192 .into_iter()
193 .map(|item| T::try_from_value(item))
194 .collect()
195 }
196}
197
198pub trait TryValueInto<T>: Sized {
199 fn try_value_into(self) -> Result<T, candle_core::Error>;
200}
201
202impl<T: TryFromValue> TryValueInto<T> for gguf_file::Value {
203 fn try_value_into(self) -> Result<T, candle_core::Error> {
204 T::try_from_value(self)
205 }
206}
207
208impl<T: TryFromValue> TryValueInto<T> for Option<gguf_file::Value> {
209 fn try_value_into(self) -> Result<T, candle_core::Error> {
210 match self {
211 Some(value) => value.try_value_into(),
212 None => candle_core::bail!("Expected `Option<gguf_file::Value>` to contain a value"),
213 }
214 }
215}