use flodl::nn::Parameter;
use flodl::{Result, TensorError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HfPath {
path: String,
}
impl HfPath {
pub fn new(root: impl Into<String>) -> Self {
let root = root.into();
validate_segment(&root).expect("invalid HfPath root");
HfPath { path: root }
}
pub fn try_new(root: impl Into<String>) -> Result<Self> {
let root = root.into();
validate_segment(&root)?;
Ok(HfPath { path: root })
}
pub fn sub<S: ToString>(&self, segment: S) -> Self {
let seg = segment.to_string();
validate_segment(&seg).expect("invalid HfPath segment");
HfPath { path: format!("{}.{}", self.path, seg) }
}
pub fn try_sub<S: ToString>(&self, segment: S) -> Result<Self> {
let seg = segment.to_string();
validate_segment(&seg)?;
Ok(HfPath { path: format!("{}.{}", self.path, seg) })
}
pub fn leaf(&self, name: &str) -> String {
validate_segment(name).expect("invalid HfPath leaf");
format!("{}.{}", self.path, name)
}
pub fn try_leaf(&self, name: &str) -> Result<String> {
validate_segment(name)?;
Ok(format!("{}.{}", self.path, name))
}
pub fn as_str(&self) -> &str { &self.path }
}
impl std::fmt::Display for HfPath {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.path.fmt(f)
}
}
pub fn prefix_params(prefix: &str, params: Vec<Parameter>) -> Vec<Parameter> {
params
.into_iter()
.map(|p| Parameter {
variable: p.variable,
name: format!("{prefix}.{}", p.name),
})
.collect()
}
pub fn hf_key_from_flodl_key(flodl_key: &str) -> String {
match flodl_key.rsplit_once('/') {
Some((prefix, leaf)) => format!("{prefix}.{leaf}"),
None => flodl_key.to_string(),
}
}
fn validate_segment(seg: &str) -> Result<()> {
if seg.is_empty() {
return Err(TensorError::new("HfPath segment must not be empty"));
}
if seg.contains('/') {
return Err(TensorError::new(&format!(
"HfPath segment {seg:?} must not contain '/'"
)));
}
if seg.contains('.') {
return Err(TensorError::new(&format!(
"HfPath segment {seg:?} must not contain '.' (use .sub() to add segments)"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn root_and_sub_compose_dotted() {
let p = HfPath::new("bert").sub("embeddings");
assert_eq!(p.as_str(), "bert.embeddings");
}
#[test]
fn leaf_appends_final_segment() {
let p = HfPath::new("bert").sub("encoder").sub("layer").sub(0);
assert_eq!(p.leaf("attention"), "bert.encoder.layer.0.attention");
}
#[test]
fn integer_segments_via_to_string() {
let p = HfPath::new("bert").sub("layer");
for i in 0..3 {
let s = p.sub(i);
assert_eq!(s.as_str(), format!("bert.layer.{i}"));
}
}
#[test]
fn sub_is_immutable_returns_new() {
let root = HfPath::new("bert");
let a = root.sub("a");
let b = root.sub("b");
assert_eq!(root.as_str(), "bert");
assert_eq!(a.as_str(), "bert.a");
assert_eq!(b.as_str(), "bert.b");
}
#[test]
fn full_bert_self_attention_path() {
let root = HfPath::new("bert");
let attn_self = root
.sub("encoder")
.sub("layer")
.sub(0)
.sub("attention")
.sub("self");
assert_eq!(
attn_self.leaf("query"),
"bert.encoder.layer.0.attention.self.query",
);
}
#[test]
fn try_new_rejects_empty_dot_slash() {
assert!(HfPath::try_new("").is_err());
assert!(HfPath::try_new("a.b").is_err());
assert!(HfPath::try_new("a/b").is_err());
}
#[test]
fn try_sub_rejects_invalid_segments() {
let root = HfPath::new("bert");
assert!(root.try_sub("").is_err());
assert!(root.try_sub("foo.bar").is_err());
assert!(root.try_sub("foo/bar").is_err());
}
#[test]
#[should_panic(expected = "invalid HfPath root")]
fn new_panics_on_empty_root() {
let _ = HfPath::new("");
}
#[test]
fn try_leaf_rejects_dots_in_name() {
let root = HfPath::new("bert");
assert!(root.try_leaf("foo.bar").is_err());
}
#[test]
fn hf_key_conversion_swaps_last_slash() {
assert_eq!(
hf_key_from_flodl_key("bert.embeddings.word_embeddings/weight"),
"bert.embeddings.word_embeddings.weight",
);
assert_eq!(
hf_key_from_flodl_key("bert.pooler.dense/bias"),
"bert.pooler.dense.bias",
);
}
#[test]
fn hf_key_conversion_only_last_slash() {
assert_eq!(
hf_key_from_flodl_key("a/b/c"),
"a/b.c",
);
}
#[test]
fn hf_key_conversion_no_slash_passthrough() {
assert_eq!(hf_key_from_flodl_key("plain_name"), "plain_name");
}
#[test]
fn display_matches_as_str() {
let p = HfPath::new("bert").sub("encoder");
assert_eq!(format!("{p}"), "bert.encoder");
}
}