Skip to main content

ie_schema/
lifted.rs

1use crate::expanded::{
2    ExpandedClassification, ExpandedEntity, ExpandedJsonStructure, ExpandedRelation,
3    ExpandedSchema, ExpandedStructureProperty,
4};
5use crate::normalized::{DType, ExpandedName};
6use serde::Serialize;
7use std::collections::BTreeMap;
8use std::convert::TryFrom;
9
10/// Lifted schema.
11///
12/// Goals:
13/// - move all entity definitions into a single top-level entity map
14/// - replace nested entity payloads with entity-name references
15/// - preserve structure / relation / classification topology
16
17#[derive(Debug, Clone, PartialEq, Serialize)]
18pub struct LiftedStructureProperty {
19    pub choices: Vec<ExpandedName>,
20    pub description: Option<String>,
21    pub value: Option<String>,
22    pub dtype: Option<DType>,
23    pub validator: Option<super::normalized::Validator>,
24    pub threshold: Option<f64>,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize)]
28pub struct LiftedJsonStructure {
29    pub name: ExpandedName,
30    pub props: BTreeMap<ExpandedName, LiftedStructureProperty>,
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize)]
34pub enum LiftedRelation {
35    EmptyAcquired {
36        name: ExpandedName,
37        description: Option<String>,
38    },
39    EntityAcquired {
40        name: ExpandedName,
41        description: Option<String>,
42        head: ExpandedName,
43        tail: ExpandedName,
44    },
45}
46
47#[derive(Debug, Clone, PartialEq, Serialize)]
48pub struct LiftedClassification {
49    pub task: ExpandedName,
50    pub labels: Vec<ExpandedName>,
51    pub threshold: Option<f64>,
52    pub multi_label: bool,
53}
54
55#[derive(Debug, Clone, PartialEq, Serialize, Default)]
56pub struct LiftedSchema {
57    pub entities: BTreeMap<ExpandedName, ExpandedEntity>,
58    pub json_structures: Vec<LiftedJsonStructure>,
59    pub relations: Vec<LiftedRelation>,
60    pub classifications: Vec<LiftedClassification>,
61}
62
63#[derive(Debug, thiserror::Error)]
64pub enum SchemaLiftError {
65    #[error("conflicting entity definitions for {name}")]
66    ConflictingEntityDefinition { name: String },
67
68    #[error("duplicate property key in structure {structure}: {property}")]
69    DuplicateStructureProperty { structure: String, property: String },
70
71    #[error("duplicate label description key in classification {task}: {label}")]
72    DuplicateLabelDescription { task: String, label: String },
73}
74
75#[derive(Debug, Default)]
76struct EntityRegistry {
77    entities: BTreeMap<ExpandedName, ExpandedEntity>,
78}
79
80impl EntityRegistry {
81    fn intern(&mut self, entity: ExpandedEntity) -> Result<ExpandedName, SchemaLiftError> {
82        let name = entity.name.clone();
83
84        match self.entities.get_mut(&name) {
85            None => {
86                self.entities.insert(name.clone(), entity);
87                Ok(name)
88            }
89            Some(existing) => {
90                let merged = merge_entities(existing, &entity)?;
91                *existing = merged;
92                Ok(name)
93            }
94        }
95    }
96
97    fn into_map(self) -> BTreeMap<ExpandedName, ExpandedEntity> {
98        self.entities
99    }
100}
101
102fn merge_entities(
103    a: &ExpandedEntity,
104    b: &ExpandedEntity,
105) -> Result<ExpandedEntity, SchemaLiftError> {
106    let name = a.name.clone();
107    let description = merge_field(&a.description, &b.description, &name)?;
108    let dtype = merge_field(&a.dtype, &b.dtype, &name)?;
109    let validator = merge_field(&a.validator, &b.validator, &name)?;
110    let threshold = merge_threshold(&a.threshold, &b.threshold, &name)?;
111
112    Ok(ExpandedEntity {
113        name,
114        description,
115        dtype,
116        validator,
117        threshold,
118    })
119}
120
121fn merge_field<T: PartialEq + Clone>(
122    a: &Option<T>,
123    b: &Option<T>,
124    name: &ExpandedName,
125) -> Result<Option<T>, SchemaLiftError> {
126    match (a, b) {
127        (None, None) => Ok(None),
128        (Some(v), None) => Ok(Some(v.clone())),
129        (None, Some(v)) => Ok(Some(v.clone())),
130        (Some(v1), Some(v2)) => {
131            if v1 == v2 {
132                Ok(Some(v1.clone()))
133            } else {
134                Err(SchemaLiftError::ConflictingEntityDefinition {
135                    name: name.to_string(),
136                })
137            }
138        }
139    }
140}
141
142fn merge_threshold(
143    a: &Option<f64>,
144    b: &Option<f64>,
145    name: &ExpandedName,
146) -> Result<Option<f64>, SchemaLiftError> {
147    match (a, b) {
148        (None, None) => Ok(None),
149        (Some(v), None) => Ok(Some(*v)),
150        (None, Some(v)) => Ok(Some(*v)),
151        (Some(v1), Some(v2)) => {
152            if (v1 - v2).abs() < f64::EPSILON {
153                Ok(Some(*v1))
154            } else {
155                Err(SchemaLiftError::ConflictingEntityDefinition {
156                    name: name.to_string(),
157                })
158            }
159        }
160    }
161}
162
163fn lift_structure_property(
164    prop: ExpandedStructureProperty,
165    registry: &mut EntityRegistry,
166) -> Result<LiftedStructureProperty, SchemaLiftError> {
167    let mut choices = Vec::with_capacity(prop.choices.len());
168    for entity in prop.choices {
169        choices.push(registry.intern(entity)?);
170    }
171
172    Ok(LiftedStructureProperty {
173        choices,
174        description: prop.description,
175        value: prop.value,
176        dtype: prop.dtype,
177        validator: prop.validator,
178        threshold: prop.threshold,
179    })
180}
181
182fn lift_json_structure(
183    js: ExpandedJsonStructure,
184    registry: &mut EntityRegistry,
185) -> Result<LiftedJsonStructure, SchemaLiftError> {
186    let structure_name = js.name.clone();
187    let mut props = BTreeMap::new();
188
189    for (prop_name, prop) in js.props {
190        if props.contains_key(&prop_name) {
191            return Err(SchemaLiftError::DuplicateStructureProperty {
192                structure: structure_name.to_string(),
193                property: prop_name.to_string(),
194            });
195        }
196
197        props.insert(prop_name, lift_structure_property(prop, registry)?);
198    }
199
200    Ok(LiftedJsonStructure {
201        name: structure_name,
202        props,
203    })
204}
205
206fn lift_relation(
207    rel: ExpandedRelation,
208    registry: &mut EntityRegistry,
209) -> Result<LiftedRelation, SchemaLiftError> {
210    match rel {
211        ExpandedRelation::EmptyAcquired { name, description } => {
212            Ok(LiftedRelation::EmptyAcquired { name, description })
213        }
214        ExpandedRelation::EntityAcquired {
215            name,
216            description,
217            head,
218            tail,
219        } => {
220            let head_name = registry.intern(*head)?;
221            let tail_name = registry.intern(*tail)?;
222            Ok(LiftedRelation::EntityAcquired {
223                name,
224                description,
225                head: head_name,
226                tail: tail_name,
227            })
228        }
229    }
230}
231
232fn lift_classification(
233    cls: ExpandedClassification,
234    registry: &mut EntityRegistry,
235) -> Result<LiftedClassification, SchemaLiftError> {
236    let task_name = registry.intern(cls.task)?;
237
238    let mut labels = Vec::with_capacity(cls.labels.len());
239    for label in cls.labels {
240        labels.push(registry.intern(label)?);
241    }
242
243    // Optional enrichment only; no retention in LiftedSchema.
244    for (_label_key, entity) in cls.label_descriptions {
245        registry.intern(entity)?;
246    }
247
248    Ok(LiftedClassification {
249        task: task_name,
250        labels,
251        threshold: cls.threshold,
252        multi_label: cls.multi_label,
253    })
254}
255
256impl TryFrom<ExpandedSchema> for LiftedSchema {
257    type Error = SchemaLiftError;
258
259    fn try_from(value: ExpandedSchema) -> Result<Self, Self::Error> {
260        let mut registry = EntityRegistry::default();
261
262        // First seed the registry with top-level entities.
263        for entity in value.entities {
264            registry.intern(entity)?;
265        }
266
267        let mut json_structures = Vec::with_capacity(value.json_structures.len());
268        for js in value.json_structures {
269            json_structures.push(lift_json_structure(js, &mut registry)?);
270        }
271
272        let mut relations = Vec::with_capacity(value.relations.len());
273        for rel in value.relations {
274            relations.push(lift_relation(rel, &mut registry)?);
275        }
276
277        let mut classifications = Vec::with_capacity(value.classifications.len());
278        for cls in value.classifications {
279            classifications.push(lift_classification(cls, &mut registry)?);
280        }
281
282        Ok(LiftedSchema {
283            entities: registry.into_map(),
284            json_structures,
285            relations,
286            classifications,
287        })
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::expanded::ExpandedSchema;
295    use crate::normalized::{DType, NormalizedSchema};
296
297    #[test]
298    fn lifted_lifts_nested_entities_from_relations_and_classifications() {
299        let s = r#"
300        {
301            "entities": [
302                "gene::str::0.9::gene symbol"
303            ],
304            "relations": [
305                { "associated_with": { "head": "gene", "tail": "disease::str::0.8::disease entity" } }
306            ],
307            "classifications": [
308                {
309                    "task": "sentiment",
310                    "labels": ["positive", "negative"],
311                    "label_descriptions": {
312                        "positive": "positive",
313                        "negative": "negative"
314                    }
315                }
316            ]
317        }
318        "#;
319
320        let s2 = NormalizedSchema::from_json_str(s).unwrap();
321        let s3 = ExpandedSchema::try_from(s2).unwrap();
322        let s4 = LiftedSchema::try_from(s3).unwrap();
323
324        assert!(
325            s4.entities
326                .contains_key(&ExpandedName::new("gene".to_string()))
327        );
328        assert!(
329            s4.entities
330                .contains_key(&ExpandedName::new("disease".to_string()))
331        );
332        assert!(
333            s4.entities
334                .contains_key(&ExpandedName::new("sentiment".to_string()))
335        );
336        assert!(
337            s4.entities
338                .contains_key(&ExpandedName::new("positive".to_string()))
339        );
340        assert!(
341            s4.entities
342                .contains_key(&ExpandedName::new("negative".to_string()))
343        );
344
345        assert_eq!(s4.relations.len(), 1);
346        match &s4.relations[0] {
347            LiftedRelation::EntityAcquired { head, tail, .. } => {
348                assert_eq!(head.as_str(), "gene");
349                assert_eq!(tail.as_str(), "disease");
350            }
351            other => panic!("unexpected relation: {other:?}"),
352        }
353
354        assert_eq!(s4.classifications.len(), 1);
355        let cls = &s4.classifications[0];
356        assert_eq!(cls.task.as_str(), "sentiment");
357        assert_eq!(cls.labels.len(), 2);
358        assert_eq!(cls.labels[0].as_str(), "positive");
359        assert_eq!(cls.labels[1].as_str(), "negative");
360    }
361
362    #[test]
363    fn lifted_lifts_structure_choices() {
364        let s = r#"
365        {
366            "relations": [
367                { "contains": { "head": "patient", "tail": "record" } }
368            ],
369            "json_structures": [
370                {
371                    "name": "Patient Record",
372                    "status": {
373                        "choices": [
374                            "active::str::0.7::active status",
375                            "inactive::str::0.7::inactive status"
376                        ]
377                    }
378                }
379            ]
380        }
381        "#;
382
383        let s2 = NormalizedSchema::from_json_str(s).unwrap();
384        let s3 = ExpandedSchema::try_from(s2).unwrap();
385        let s4 = LiftedSchema::try_from(s3).unwrap();
386
387        assert!(
388            s4.entities
389                .contains_key(&ExpandedName::new("active".to_string()))
390        );
391        assert!(
392            s4.entities
393                .contains_key(&ExpandedName::new("inactive".to_string()))
394        );
395
396        let js = &s4.json_structures[0];
397        let status_prop = js
398            .props
399            .get(&ExpandedName::new("status".to_string()))
400            .unwrap();
401
402        assert_eq!(status_prop.choices.len(), 2);
403        assert_eq!(status_prop.choices[0].as_str(), "active");
404        assert_eq!(status_prop.choices[1].as_str(), "inactive");
405    }
406
407    #[test]
408    fn lifted_rejects_conflicting_entity_definitions() {
409        let s3 = ExpandedSchema {
410            entities: vec![
411                ExpandedEntity {
412                    name: ExpandedName::new("gene".to_string()),
413                    dtype: Some(DType::String),
414                    validator: None,
415                    threshold: Some(0.5),
416                    description: Some("first".to_string()),
417                },
418                ExpandedEntity {
419                    name: ExpandedName::new("gene".to_string()),
420                    dtype: Some(DType::String),
421                    validator: None,
422                    threshold: Some(0.9),
423                    description: Some("second".to_string()),
424                },
425            ],
426            json_structures: vec![],
427            relations: vec![],
428            classifications: vec![],
429        };
430
431        let err = LiftedSchema::try_from(s3).unwrap_err();
432        match err {
433            SchemaLiftError::ConflictingEntityDefinition { name } => {
434                assert_eq!(name, "gene");
435            }
436            other => panic!("unexpected error: {other:?}"),
437        }
438    }
439
440    #[test]
441    fn lifted_deduplicates_identical_entity_definitions() {
442        let entity = ExpandedEntity {
443            name: ExpandedName::new("gene".to_string()),
444            dtype: Some(DType::String),
445            validator: None,
446            threshold: Some(0.5),
447            description: Some("gene symbol".to_string()),
448        };
449
450        let s3 = ExpandedSchema {
451            entities: vec![entity.clone(), entity],
452            json_structures: vec![],
453            relations: vec![],
454            classifications: vec![],
455        };
456
457        let s4 = LiftedSchema::try_from(s3).unwrap();
458        assert_eq!(s4.entities.len(), 1);
459        assert!(
460            s4.entities
461                .contains_key(&ExpandedName::new("gene".to_string()))
462        );
463    }
464}