use serde_json::Value;
use flodl::{Result, TensorError};
pub(crate) fn required_i64(v: &Value, key: &str) -> Result<i64> {
v.get(key).and_then(|x| x.as_i64()).ok_or_else(|| {
TensorError::new(&format!(
"config.json missing required integer field: {key}",
))
})
}
pub(crate) fn required_string<'a>(v: &'a Value, key: &str) -> Result<&'a str> {
v.get(key).and_then(|x| x.as_str()).ok_or_else(|| {
TensorError::new(&format!(
"config.json missing required string field: {key}",
))
})
}
pub(crate) fn optional_i64(v: &Value, key: &str, default: i64) -> i64 {
v.get(key).and_then(|x| x.as_i64()).unwrap_or(default)
}
pub(crate) fn optional_i64_or_none(v: &Value, key: &str) -> Option<i64> {
v.get(key).and_then(|x| x.as_i64())
}
pub(crate) fn optional_f64(v: &Value, key: &str, default: f64) -> f64 {
v.get(key).and_then(|x| x.as_f64()).unwrap_or(default)
}
pub(crate) fn optional_bool(v: &Value, key: &str, default: bool) -> bool {
v.get(key).and_then(|x| x.as_bool()).unwrap_or(default)
}
pub(crate) fn parse_id2label(v: &Value) -> Result<Option<Vec<String>>> {
let obj = match v.get("id2label").and_then(|x| x.as_object()) {
Some(obj) => obj,
None => return Ok(None),
};
let mut pairs: Vec<(i64, String)> = Vec::with_capacity(obj.len());
for (k, val) in obj {
let id: i64 = k.parse().map_err(|_| {
TensorError::new(&format!(
"config.json: id2label key {k:?} is not an integer",
))
})?;
let label = val.as_str().ok_or_else(|| {
TensorError::new(&format!(
"config.json: id2label[{k}] is not a string",
))
})?;
pairs.push((id, label.to_string()));
}
pairs.sort_by_key(|(id, _)| *id);
for (idx, (id, _)) in pairs.iter().enumerate() {
if *id != idx as i64 {
return Err(TensorError::new(&format!(
"config.json: id2label must have contiguous ids 0..N, \
but index {idx} has id {id}",
)));
}
}
Ok(Some(pairs.into_iter().map(|(_, s)| s).collect()))
}
pub(crate) fn parse_num_labels(v: &Value, id2label: Option<&[String]>) -> Option<i64> {
v.get("num_labels")
.and_then(|x| x.as_i64())
.or_else(|| id2label.map(|v| v.len() as i64))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn required_i64_reads_or_errors() {
let v: Value = serde_json::from_str(r#"{"a": 42}"#).unwrap();
assert_eq!(required_i64(&v, "a").unwrap(), 42);
assert!(required_i64(&v, "missing").is_err());
}
#[test]
fn required_string_reads_or_errors() {
let v: Value = serde_json::from_str(r#"{"model_type": "bert", "num": 7}"#).unwrap();
assert_eq!(required_string(&v, "model_type").unwrap(), "bert");
assert!(required_string(&v, "missing").is_err());
assert!(
required_string(&v, "num").is_err(),
"non-string values must error"
);
}
#[test]
fn optional_i64_falls_back() {
let v: Value = serde_json::from_str(r#"{"a": 42}"#).unwrap();
assert_eq!(optional_i64(&v, "a", 7), 42);
assert_eq!(optional_i64(&v, "missing", 7), 7);
}
#[test]
fn optional_i64_or_none_treats_absence_as_none() {
let v: Value = serde_json::from_str(r#"{"a": 0, "b": null}"#).unwrap();
assert_eq!(optional_i64_or_none(&v, "a"), Some(0));
assert_eq!(optional_i64_or_none(&v, "b"), None);
assert_eq!(optional_i64_or_none(&v, "missing"), None);
}
#[test]
fn optional_f64_and_bool_defaults() {
let v: Value = serde_json::from_str(r#"{"x": 0.5, "b": true}"#).unwrap();
assert!((optional_f64(&v, "x", 1.0) - 0.5).abs() < 1e-12);
assert!((optional_f64(&v, "missing", 1.0) - 1.0).abs() < 1e-12);
assert!(optional_bool(&v, "b", false));
assert!(!optional_bool(&v, "missing", false));
}
#[test]
fn parse_id2label_orders_and_rejects_gaps() {
let v: Value = serde_json::from_str(
r#"{"id2label": {"2": "c", "0": "a", "1": "b"}}"#,
).unwrap();
let out = parse_id2label(&v).unwrap().unwrap();
assert_eq!(out, vec!["a", "b", "c"]);
let gap: Value = serde_json::from_str(
r#"{"id2label": {"0": "a", "2": "c"}}"#,
).unwrap();
let err = parse_id2label(&gap).unwrap_err();
assert!(format!("{err}").contains("contiguous"), "got: {err}");
}
#[test]
fn parse_id2label_absent_is_none() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
assert!(parse_id2label(&v).unwrap().is_none());
}
#[test]
fn parse_num_labels_explicit_wins() {
let v: Value = serde_json::from_str(r#"{"num_labels": 5}"#).unwrap();
let labels = vec!["a".to_string(), "b".to_string()];
assert_eq!(parse_num_labels(&v, Some(&labels)), Some(5));
}
#[test]
fn parse_num_labels_falls_back_to_id2label_len() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
let labels = vec!["a".to_string(), "b".to_string(), "c".to_string()];
assert_eq!(parse_num_labels(&v, Some(&labels)), Some(3));
}
#[test]
fn parse_num_labels_none_when_both_missing() {
let v: Value = serde_json::from_str(r#"{}"#).unwrap();
assert_eq!(parse_num_labels(&v, None), None);
}
}