use std::collections::BTreeMap;
use uni_common::core::schema::{IndexDefinition, Schema};
use uni_xervo::api::ModelTask;
use uni_xervo::traits::HeadSet;
pub fn text_embedding_heads(task: ModelTask) -> HeadSet {
match task {
ModelTask::Embed => HeadSet::DENSE,
ModelTask::EmbedSparse => HeadSet::SPARSE,
ModelTask::EmbedMultiVector => HeadSet::MULTI_VECTOR,
ModelTask::EmbedHybrid => HeadSet::ALL,
_ => HeadSet::empty(),
}
}
pub(crate) fn is_multivector_property(schema: &Schema, label: &str, property: &str) -> bool {
schema
.properties
.get(label)
.and_then(|p| p.get(property))
.is_some_and(|m| {
matches!(&m.r#type, uni_common::DataType::List(inner)
if matches!(**inner, uni_common::DataType::Vector { .. }))
})
}
#[derive(Debug, Clone)]
pub struct RequiredHeads {
pub heads: HeadSet,
pub columns: Vec<(String, HeadSet)>,
}
pub fn required_embed_heads(schema: &Schema) -> BTreeMap<String, RequiredHeads> {
let mut out: BTreeMap<String, RequiredHeads> = BTreeMap::new();
for idx in &schema.indexes {
let (alias, column, head) = match idx {
IndexDefinition::Vector(cfg) => {
let Some(emb) = cfg.embedding_config.as_ref() else {
continue;
};
let head = if is_multivector_property(schema, &cfg.label, &cfg.property) {
HeadSet::MULTI_VECTOR
} else {
HeadSet::DENSE
};
(emb.alias.clone(), cfg.property.clone(), head)
}
IndexDefinition::Sparse(cfg) => {
let Some(emb) = cfg.embedding_config.as_ref() else {
continue;
};
(emb.alias.clone(), cfg.property.clone(), HeadSet::SPARSE)
}
_ => continue,
};
let entry = out.entry(alias).or_insert_with(|| RequiredHeads {
heads: HeadSet::empty(),
columns: Vec::new(),
});
entry.heads |= head;
entry.columns.push((column, head));
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn text_tasks_map_to_their_head() {
assert_eq!(text_embedding_heads(ModelTask::Embed), HeadSet::DENSE);
assert_eq!(
text_embedding_heads(ModelTask::EmbedSparse),
HeadSet::SPARSE
);
assert_eq!(
text_embedding_heads(ModelTask::EmbedMultiVector),
HeadSet::MULTI_VECTOR
);
assert_eq!(text_embedding_heads(ModelTask::EmbedHybrid), HeadSet::ALL);
}
#[test]
fn hybrid_covers_every_single_head() {
let hybrid = text_embedding_heads(ModelTask::EmbedHybrid);
for head in [HeadSet::DENSE, HeadSet::SPARSE, HeadSet::MULTI_VECTOR] {
assert!(hybrid.contains(head), "hybrid must cover {head:?}");
}
}
#[test]
fn non_text_and_non_embed_tasks_map_to_no_heads() {
for task in [
ModelTask::EmbedImage,
ModelTask::EmbedAudio,
ModelTask::EmbedMultimodal,
ModelTask::Rerank,
ModelTask::Generate,
ModelTask::Raw,
ModelTask::Nlp,
ModelTask::DocumentExtract,
ModelTask::Transcribe,
ModelTask::Ocr,
] {
assert!(
text_embedding_heads(task).is_empty(),
"task {task:?} must produce no text-embedding heads"
);
}
}
#[test]
fn single_task_alias_rejects_a_foreign_head() {
let dense_only = text_embedding_heads(ModelTask::Embed);
assert!(dense_only.contains(HeadSet::DENSE));
assert!(!dense_only.contains(HeadSet::SPARSE));
assert!(!dense_only.contains(HeadSet::MULTI_VECTOR));
let mixed = HeadSet::DENSE | HeadSet::SPARSE;
assert!(!dense_only.contains(mixed));
assert!(text_embedding_heads(ModelTask::EmbedHybrid).contains(mixed));
}
}