#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EmbedderIdentity {
pub name: String,
pub dim: u32,
pub dtype: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ActiveModelIdentity {
pub name: String,
pub dim: u32,
pub dtype: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FallbackIdentity<'a> {
pub name: &'a str,
pub dim: u32,
pub dtype: &'a str,
}
#[must_use]
pub fn derive(
which: &str,
active: Option<&ActiveModelIdentity>,
fallback: &FallbackIdentity<'_>,
) -> EmbedderIdentity {
derive_inner(which, active, fallback, true)
}
#[must_use]
pub fn derive_quiet(
which: &str,
active: Option<&ActiveModelIdentity>,
fallback: &FallbackIdentity<'_>,
) -> EmbedderIdentity {
derive_inner(which, active, fallback, false)
}
fn derive_inner(
which: &str,
active: Option<&ActiveModelIdentity>,
fallback: &FallbackIdentity<'_>,
warn_on_fallback: bool,
) -> EmbedderIdentity {
active.map_or_else(
|| {
if warn_on_fallback {
tracing::warn!(
which,
fallback_model = fallback.name,
fallback_dim = fallback.dim,
fallback_dtype = fallback.dtype,
"active-model fetch unavailable; embedding with local config as a fallback \
(vectors are NOT guaranteed to match the corpus model)"
);
}
EmbedderIdentity {
name: fallback.name.to_owned(),
dim: fallback.dim,
dtype: fallback.dtype.to_owned(),
}
},
|a| EmbedderIdentity {
name: a.name.clone(),
dim: a.dim,
dtype: a.dtype.clone(),
},
)
}
#[cfg(test)]
mod tests {
use super::*;
fn fallback() -> FallbackIdentity<'static> {
FallbackIdentity {
name: "config-model",
dim: 256,
dtype: "int8",
}
}
#[test]
fn derive_prefers_active_over_divergent_config() {
let active = ActiveModelIdentity {
name: "voyage-context-3".to_owned(),
dim: 1024,
dtype: "float".to_owned(),
};
let id = derive("general", Some(&active), &fallback());
assert_eq!(id.name, "voyage-context-3");
assert_eq!(id.dim, 1024);
assert_eq!(id.dtype, "float");
assert_ne!(id.name, "config-model");
assert_ne!(id.dim, 256);
assert_ne!(id.dtype, "int8");
}
#[test]
fn derive_falls_back_to_config_when_active_absent() {
let id = derive("code", None, &fallback());
assert_eq!(id.name, "config-model");
assert_eq!(id.dim, 256);
assert_eq!(id.dtype, "int8");
}
}