use serde_json::Value;
use flodl::nn::GeluApprox;
use flodl::{Result, TensorError};
pub(crate) fn required_i64(v: &Value, key: &str) -> Result<i64> {
v.get(key).and_then(|x| x.as_i64()).ok_or_else(|| {
TensorError::new(&format!(
"config.json missing required integer field: {key}",
))
})
}
pub(crate) fn required_string<'a>(v: &'a Value, key: &str) -> Result<&'a str> {
v.get(key).and_then(|x| x.as_str()).ok_or_else(|| {
TensorError::new(&format!(
"config.json missing required string field: {key}",
))
})
}
pub(crate) fn optional_i64(v: &Value, key: &str, default: i64) -> i64 {
v.get(key).and_then(|x| x.as_i64()).unwrap_or(default)
}
pub(crate) fn optional_i64_or_none(v: &Value, key: &str) -> Option<i64> {
v.get(key).and_then(|x| x.as_i64())
}
pub(crate) fn optional_f64(v: &Value, key: &str, default: f64) -> f64 {
v.get(key).and_then(|x| x.as_f64()).unwrap_or(default)
}
pub(crate) fn optional_bool(v: &Value, key: &str, default: bool) -> bool {
v.get(key).and_then(|x| x.as_bool()).unwrap_or(default)
}
pub(crate) fn parse_id2label(v: &Value) -> Result<Option<Vec<String>>> {
let obj = match v.get("id2label").and_then(|x| x.as_object()) {
Some(obj) => obj,
None => return Ok(None),
};
let mut pairs: Vec<(i64, String)> = Vec::with_capacity(obj.len());
for (k, val) in obj {
let id: i64 = k.parse().map_err(|_| {
TensorError::new(&format!(
"config.json: id2label key {k:?} is not an integer",
))
})?;
let label = val.as_str().ok_or_else(|| {
TensorError::new(&format!(
"config.json: id2label[{k}] is not a string",
))
})?;
pairs.push((id, label.to_string()));
}
pairs.sort_by_key(|(id, _)| *id);
for (idx, (id, _)) in pairs.iter().enumerate() {
if *id != idx as i64 {
return Err(TensorError::new(&format!(
"config.json: id2label must have contiguous ids 0..N, \
but index {idx} has id {id}",
)));
}
}
Ok(Some(pairs.into_iter().map(|(_, s)| s).collect()))
}
fn map_hidden_act(s: &str) -> Result<GeluApprox> {
match s {
"gelu" => Ok(GeluApprox::Exact),
"gelu_new" => Ok(GeluApprox::Tanh),
"gelu_pytorch_tanh" => Ok(GeluApprox::Tanh),
other => Err(TensorError::new(&format!(
"config.json: unsupported hidden_act = {other:?}. \
flodl-hf currently maps {{\"gelu\", \"gelu_new\", \
\"gelu_pytorch_tanh\"}} to flodl::nn::GeluApprox; \
other activations (e.g. \"relu\", \"silu\") are not \
yet wired through the transformer layer. File against \
flodl-hf with the failing checkpoint id."
))),
}
}
pub(crate) fn optional_hidden_act(
v: &Value, key: &str, default: &str,
) -> Result<GeluApprox> {
let raw = v.get(key).and_then(|x| x.as_str()).unwrap_or(default);
map_hidden_act(raw)
}
pub(crate) fn parse_num_labels(v: &Value, id2label: Option<&[String]>) -> Option<i64> {
v.get("num_labels")
.and_then(|x| x.as_i64())
.or_else(|| id2label.map(|v| v.len() as i64))
}
pub(crate) fn parse_architectures(v: &Value) -> Option<Vec<String>> {
let arr = v.get("architectures").and_then(|x| x.as_array())?;
let names: Vec<String> = arr
.iter()
.filter_map(|x| x.as_str().map(|s| s.to_string()))
.collect();
if names.is_empty() { None } else { Some(names) }
}
pub(crate) fn emit_hidden_act(act: GeluApprox) -> &'static str {
match act {
GeluApprox::Exact => "gelu",
GeluApprox::Tanh => "gelu_new",
}
}
pub(crate) fn emit_architectures(stored: Option<&[String]>, default_class: &str) -> Value {
let arr: Vec<Value> = match stored {
Some(names) if !names.is_empty() => {
names.iter().map(|s| Value::from(s.as_str())).collect()
}
_ => vec![Value::from(default_class)],
};
Value::Array(arr)
}
pub(crate) fn emit_id2label(out: &mut serde_json::Map<String, Value>, labels: Option<&[String]>) {
let Some(labels) = labels else { return };
let mut id2 = serde_json::Map::with_capacity(labels.len());
let mut lab2 = serde_json::Map::with_capacity(labels.len());
for (idx, name) in labels.iter().enumerate() {
id2.insert(idx.to_string(), Value::from(name.as_str()));
lab2.insert(name.clone(), Value::from(idx as i64));
}
out.insert("id2label".into(), Value::Object(id2));
out.insert("label2id".into(), Value::Object(lab2));
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn required_i64_reads_or_errors() {
let v: Value = serde_json::from_str(r#"{"a": 42}"#).unwrap();
assert_eq!(required_i64(&v, "a").unwrap(), 42);
assert!(required_i64(&v, "missing").is_err());
}
#[test]
fn required_string_reads_or_errors() {
let v: Value = serde_json::from_str(r#"{"model_type": "bert", "num": 7}"#).unwrap();
assert_eq!(required_string(&v, "model_type").unwrap(), "bert");
assert!(required_string(&v, "missing").is_err());
assert!(
required_string(&v, "num").is_err(),
"non-string values must error"
);
}
#[test]
fn optional_i64_falls_back() {
let v: Value = serde_json::from_str(r#"{"a": 42}"#).unwrap();
assert_eq!(optional_i64(&v, "a", 7), 42);
assert_eq!(optional_i64(&v, "missing", 7), 7);
}
#[test]
fn optional_i64_or_none_treats_absence_as_none() {
let v: Value = serde_json::from_str(r#"{"a": 0, "b": null}"#).unwrap();
assert_eq!(optional_i64_or_none(&v, "a"), Some(0));
assert_eq!(optional_i64_or_none(&v, "b"), None);
assert_eq!(optional_i64_or_none(&v, "missing"), None);
}
#[test]
fn optional_f64_and_bool_defaults() {
let v: Value = serde_json::from_str(r#"{"x": 0.5, "b": true}"#).unwrap();
assert!((optional_f64(&v, "x", 1.0) - 0.5).abs() < 1e-12);
assert!((optional_f64(&v, "missing", 1.0) - 1.0).abs() < 1e-12);
assert!(optional_bool(&v, "b", false));
assert!(!optional_bool(&v, "missing", false));
}
#[test]
fn parse_id2label_orders_and_rejects_gaps() {
let v: Value = serde_json::from_str(
r#"{"id2label": {"2": "c", "0": "a", "1": "b"}}"#,
).unwrap();
let out = parse_id2label(&v).unwrap().unwrap();
assert_eq!(out, vec!["a", "b", "c"]);
let gap: Value = serde_json::from_str(
r#"{"id2label": {"0": "a", "2": "c"}}"#,
).unwrap();
let err = parse_id2label(&gap).unwrap_err();
assert!(format!("{err}").contains("contiguous"), "got: {err}");
}
#[test]
fn parse_id2label_absent_is_none() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
assert!(parse_id2label(&v).unwrap().is_none());
}
#[test]
fn parse_num_labels_explicit_wins() {
let v: Value = serde_json::from_str(r#"{"num_labels": 5}"#).unwrap();
let labels = vec!["a".to_string(), "b".to_string()];
assert_eq!(parse_num_labels(&v, Some(&labels)), Some(5));
}
#[test]
fn parse_num_labels_falls_back_to_id2label_len() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
let labels = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert_eq!(parse_num_labels(&v, Some(&labels)), Some(3));
}
#[test]
fn parse_num_labels_none_when_both_missing() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
assert_eq!(parse_num_labels(&v, None), None);
}
#[test]
fn emit_hidden_act_maps_both_variants() {
assert_eq!(emit_hidden_act(GeluApprox::Exact), "gelu");
assert_eq!(emit_hidden_act(GeluApprox::Tanh), "gelu_new");
}
#[test]
fn emit_hidden_act_round_trips_through_map() {
for act in [GeluApprox::Exact, GeluApprox::Tanh] {
let parsed = map_hidden_act(emit_hidden_act(act)).unwrap();
assert_eq!(parsed, act);
}
}
#[test]
fn emit_id2label_writes_both_maps() {
let mut m = serde_json::Map::new();
let labels = vec!["A".to_string(), "B".to_string(), "C".to_string()];
emit_id2label(&mut m, Some(&labels));
let id2 = m.get("id2label").unwrap().as_object().unwrap();
assert_eq!(id2.get("0").unwrap().as_str().unwrap(), "A");
assert_eq!(id2.get("1").unwrap().as_str().unwrap(), "B");
assert_eq!(id2.get("2").unwrap().as_str().unwrap(), "C");
let lab2 = m.get("label2id").unwrap().as_object().unwrap();
assert_eq!(lab2.get("A").unwrap().as_i64().unwrap(), 0);
assert_eq!(lab2.get("B").unwrap().as_i64().unwrap(), 1);
assert_eq!(lab2.get("C").unwrap().as_i64().unwrap(), 2);
}
#[test]
fn emit_id2label_skips_when_none() {
let mut m = serde_json::Map::new();
emit_id2label(&mut m, None);
assert!(m.is_empty());
}
#[test]
fn emit_id2label_round_trips_through_parse_id2label() {
let mut m = serde_json::Map::new();
let labels = vec!["LABEL_0".to_string(), "LABEL_1".to_string()];
emit_id2label(&mut m, Some(&labels));
let v = Value::Object(m);
let parsed = parse_id2label(&v).unwrap().unwrap();
assert_eq!(parsed, labels);
}
}