use crate::expanded::{
ExpandedClassification, ExpandedEntity, ExpandedJsonStructure, ExpandedRelation,
ExpandedSchema, ExpandedStructureProperty,
};
use crate::normalized::{DType, ExpandedName};
use serde::Serialize;
use std::collections::BTreeMap;
use std::convert::TryFrom;
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct LiftedStructureProperty {
pub choices: Vec<ExpandedName>,
pub description: Option<String>,
pub value: Option<String>,
pub dtype: Option<DType>,
pub validator: Option<super::normalized::Validator>,
pub threshold: Option<f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct LiftedJsonStructure {
pub name: ExpandedName,
pub props: BTreeMap<ExpandedName, LiftedStructureProperty>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum LiftedRelation {
EmptyAcquired {
name: ExpandedName,
description: Option<String>,
},
EntityAcquired {
name: ExpandedName,
description: Option<String>,
head: ExpandedName,
tail: ExpandedName,
},
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct LiftedClassification {
pub task: ExpandedName,
pub labels: Vec<ExpandedName>,
pub threshold: Option<f64>,
pub multi_label: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize, Default)]
pub struct LiftedSchema {
pub entities: BTreeMap<ExpandedName, ExpandedEntity>,
pub json_structures: Vec<LiftedJsonStructure>,
pub relations: Vec<LiftedRelation>,
pub classifications: Vec<LiftedClassification>,
}
#[derive(Debug, thiserror::Error)]
pub enum SchemaLiftError {
#[error("conflicting entity definitions for {name}")]
ConflictingEntityDefinition { name: String },
#[error("duplicate property key in structure {structure}: {property}")]
DuplicateStructureProperty { structure: String, property: String },
#[error("duplicate label description key in classification {task}: {label}")]
DuplicateLabelDescription { task: String, label: String },
}
#[derive(Debug, Default)]
struct EntityRegistry {
entities: BTreeMap<ExpandedName, ExpandedEntity>,
}
impl EntityRegistry {
fn intern(&mut self, entity: ExpandedEntity) -> Result<ExpandedName, SchemaLiftError> {
let name = entity.name.clone();
match self.entities.get_mut(&name) {
None => {
self.entities.insert(name.clone(), entity);
Ok(name)
}
Some(existing) => {
let merged = merge_entities(existing, &entity)?;
*existing = merged;
Ok(name)
}
}
}
fn into_map(self) -> BTreeMap<ExpandedName, ExpandedEntity> {
self.entities
}
}
fn merge_entities(
a: &ExpandedEntity,
b: &ExpandedEntity,
) -> Result<ExpandedEntity, SchemaLiftError> {
let name = a.name.clone();
let description = merge_field(&a.description, &b.description, &name)?;
let dtype = merge_field(&a.dtype, &b.dtype, &name)?;
let validator = merge_field(&a.validator, &b.validator, &name)?;
let threshold = merge_threshold(&a.threshold, &b.threshold, &name)?;
Ok(ExpandedEntity {
name,
description,
dtype,
validator,
threshold,
})
}
fn merge_field<T: PartialEq + Clone>(
a: &Option<T>,
b: &Option<T>,
name: &ExpandedName,
) -> Result<Option<T>, SchemaLiftError> {
match (a, b) {
(None, None) => Ok(None),
(Some(v), None) => Ok(Some(v.clone())),
(None, Some(v)) => Ok(Some(v.clone())),
(Some(v1), Some(v2)) => {
if v1 == v2 {
Ok(Some(v1.clone()))
} else {
Err(SchemaLiftError::ConflictingEntityDefinition {
name: name.to_string(),
})
}
}
}
}
fn merge_threshold(
a: &Option<f64>,
b: &Option<f64>,
name: &ExpandedName,
) -> Result<Option<f64>, SchemaLiftError> {
match (a, b) {
(None, None) => Ok(None),
(Some(v), None) => Ok(Some(*v)),
(None, Some(v)) => Ok(Some(*v)),
(Some(v1), Some(v2)) => {
if (v1 - v2).abs() < f64::EPSILON {
Ok(Some(*v1))
} else {
Err(SchemaLiftError::ConflictingEntityDefinition {
name: name.to_string(),
})
}
}
}
}
fn lift_structure_property(
prop: ExpandedStructureProperty,
registry: &mut EntityRegistry,
) -> Result<LiftedStructureProperty, SchemaLiftError> {
let mut choices = Vec::with_capacity(prop.choices.len());
for entity in prop.choices {
choices.push(registry.intern(entity)?);
}
Ok(LiftedStructureProperty {
choices,
description: prop.description,
value: prop.value,
dtype: prop.dtype,
validator: prop.validator,
threshold: prop.threshold,
})
}
fn lift_json_structure(
js: ExpandedJsonStructure,
registry: &mut EntityRegistry,
) -> Result<LiftedJsonStructure, SchemaLiftError> {
let structure_name = js.name.clone();
let mut props = BTreeMap::new();
for (prop_name, prop) in js.props {
if props.contains_key(&prop_name) {
return Err(SchemaLiftError::DuplicateStructureProperty {
structure: structure_name.to_string(),
property: prop_name.to_string(),
});
}
props.insert(prop_name, lift_structure_property(prop, registry)?);
}
Ok(LiftedJsonStructure {
name: structure_name,
props,
})
}
fn lift_relation(
rel: ExpandedRelation,
registry: &mut EntityRegistry,
) -> Result<LiftedRelation, SchemaLiftError> {
match rel {
ExpandedRelation::EmptyAcquired { name, description } => {
Ok(LiftedRelation::EmptyAcquired { name, description })
}
ExpandedRelation::EntityAcquired {
name,
description,
head,
tail,
} => {
let head_name = registry.intern(*head)?;
let tail_name = registry.intern(*tail)?;
Ok(LiftedRelation::EntityAcquired {
name,
description,
head: head_name,
tail: tail_name,
})
}
}
}
fn lift_classification(
cls: ExpandedClassification,
registry: &mut EntityRegistry,
) -> Result<LiftedClassification, SchemaLiftError> {
let task_name = registry.intern(cls.task)?;
let mut labels = Vec::with_capacity(cls.labels.len());
for label in cls.labels {
labels.push(registry.intern(label)?);
}
for (_label_key, entity) in cls.label_descriptions {
registry.intern(entity)?;
}
Ok(LiftedClassification {
task: task_name,
labels,
threshold: cls.threshold,
multi_label: cls.multi_label,
})
}
impl TryFrom<ExpandedSchema> for LiftedSchema {
type Error = SchemaLiftError;
fn try_from(value: ExpandedSchema) -> Result<Self, Self::Error> {
let mut registry = EntityRegistry::default();
for entity in value.entities {
registry.intern(entity)?;
}
let mut json_structures = Vec::with_capacity(value.json_structures.len());
for js in value.json_structures {
json_structures.push(lift_json_structure(js, &mut registry)?);
}
let mut relations = Vec::with_capacity(value.relations.len());
for rel in value.relations {
relations.push(lift_relation(rel, &mut registry)?);
}
let mut classifications = Vec::with_capacity(value.classifications.len());
for cls in value.classifications {
classifications.push(lift_classification(cls, &mut registry)?);
}
Ok(LiftedSchema {
entities: registry.into_map(),
json_structures,
relations,
classifications,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expanded::ExpandedSchema;
use crate::normalized::{DType, NormalizedSchema};
#[test]
fn lifted_lifts_nested_entities_from_relations_and_classifications() {
let s = r#"
{
"entities": [
"gene::str::0.9::gene symbol"
],
"relations": [
{ "associated_with": { "head": "gene", "tail": "disease::str::0.8::disease entity" } }
],
"classifications": [
{
"task": "sentiment",
"labels": ["positive", "negative"],
"label_descriptions": {
"positive": "positive",
"negative": "negative"
}
}
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
assert!(
s4.entities
.contains_key(&ExpandedName::new("gene".to_string()))
);
assert!(
s4.entities
.contains_key(&ExpandedName::new("disease".to_string()))
);
assert!(
s4.entities
.contains_key(&ExpandedName::new("sentiment".to_string()))
);
assert!(
s4.entities
.contains_key(&ExpandedName::new("positive".to_string()))
);
assert!(
s4.entities
.contains_key(&ExpandedName::new("negative".to_string()))
);
assert_eq!(s4.relations.len(), 1);
match &s4.relations[0] {
LiftedRelation::EntityAcquired { head, tail, .. } => {
assert_eq!(head.as_str(), "gene");
assert_eq!(tail.as_str(), "disease");
}
other => panic!("unexpected relation: {other:?}"),
}
assert_eq!(s4.classifications.len(), 1);
let cls = &s4.classifications[0];
assert_eq!(cls.task.as_str(), "sentiment");
assert_eq!(cls.labels.len(), 2);
assert_eq!(cls.labels[0].as_str(), "positive");
assert_eq!(cls.labels[1].as_str(), "negative");
}
#[test]
fn lifted_lifts_structure_choices() {
let s = r#"
{
"relations": [
{ "contains": { "head": "patient", "tail": "record" } }
],
"json_structures": [
{
"name": "Patient Record",
"status": {
"choices": [
"active::str::0.7::active status",
"inactive::str::0.7::inactive status"
]
}
}
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
assert!(
s4.entities
.contains_key(&ExpandedName::new("active".to_string()))
);
assert!(
s4.entities
.contains_key(&ExpandedName::new("inactive".to_string()))
);
let js = &s4.json_structures[0];
let status_prop = js
.props
.get(&ExpandedName::new("status".to_string()))
.unwrap();
assert_eq!(status_prop.choices.len(), 2);
assert_eq!(status_prop.choices[0].as_str(), "active");
assert_eq!(status_prop.choices[1].as_str(), "inactive");
}
#[test]
fn lifted_rejects_conflicting_entity_definitions() {
let s3 = ExpandedSchema {
entities: vec![
ExpandedEntity {
name: ExpandedName::new("gene".to_string()),
dtype: Some(DType::String),
validator: None,
threshold: Some(0.5),
description: Some("first".to_string()),
},
ExpandedEntity {
name: ExpandedName::new("gene".to_string()),
dtype: Some(DType::String),
validator: None,
threshold: Some(0.9),
description: Some("second".to_string()),
},
],
json_structures: vec![],
relations: vec![],
classifications: vec![],
};
let err = LiftedSchema::try_from(s3).unwrap_err();
match err {
SchemaLiftError::ConflictingEntityDefinition { name } => {
assert_eq!(name, "gene");
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn lifted_deduplicates_identical_entity_definitions() {
let entity = ExpandedEntity {
name: ExpandedName::new("gene".to_string()),
dtype: Some(DType::String),
validator: None,
threshold: Some(0.5),
description: Some("gene symbol".to_string()),
};
let s3 = ExpandedSchema {
entities: vec![entity.clone(), entity],
json_structures: vec![],
relations: vec![],
classifications: vec![],
};
let s4 = LiftedSchema::try_from(s3).unwrap();
assert_eq!(s4.entities.len(), 1);
assert!(
s4.entities
.contains_key(&ExpandedName::new("gene".to_string()))
);
}
}