1use crate::expanded::{
2 ExpandedClassification, ExpandedEntity, ExpandedJsonStructure, ExpandedRelation,
3 ExpandedSchema, ExpandedStructureProperty,
4};
5use crate::normalized::{DType, ExpandedName};
6use serde::Serialize;
7use std::collections::BTreeMap;
8use std::convert::TryFrom;
9
10#[derive(Debug, Clone, PartialEq, Serialize)]
18pub struct LiftedStructureProperty {
19 pub choices: Vec<ExpandedName>,
20 pub description: Option<String>,
21 pub value: Option<String>,
22 pub dtype: Option<DType>,
23 pub validator: Option<super::normalized::Validator>,
24 pub threshold: Option<f64>,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize)]
28pub struct LiftedJsonStructure {
29 pub name: ExpandedName,
30 pub props: BTreeMap<ExpandedName, LiftedStructureProperty>,
31}
32
33#[derive(Debug, Clone, PartialEq, Serialize)]
34pub enum LiftedRelation {
35 EmptyAcquired {
36 name: ExpandedName,
37 description: Option<String>,
38 },
39 EntityAcquired {
40 name: ExpandedName,
41 description: Option<String>,
42 head: ExpandedName,
43 tail: ExpandedName,
44 },
45}
46
47#[derive(Debug, Clone, PartialEq, Serialize)]
48pub struct LiftedClassification {
49 pub task: ExpandedName,
50 pub labels: Vec<ExpandedName>,
51 pub threshold: Option<f64>,
52 pub multi_label: bool,
53}
54
55#[derive(Debug, Clone, PartialEq, Serialize, Default)]
56pub struct LiftedSchema {
57 pub entities: BTreeMap<ExpandedName, ExpandedEntity>,
58 pub json_structures: Vec<LiftedJsonStructure>,
59 pub relations: Vec<LiftedRelation>,
60 pub classifications: Vec<LiftedClassification>,
61}
62
63#[derive(Debug, thiserror::Error)]
64pub enum SchemaLiftError {
65 #[error("conflicting entity definitions for {name}")]
66 ConflictingEntityDefinition { name: String },
67
68 #[error("duplicate property key in structure {structure}: {property}")]
69 DuplicateStructureProperty { structure: String, property: String },
70
71 #[error("duplicate label description key in classification {task}: {label}")]
72 DuplicateLabelDescription { task: String, label: String },
73}
74
75#[derive(Debug, Default)]
76struct EntityRegistry {
77 entities: BTreeMap<ExpandedName, ExpandedEntity>,
78}
79
80impl EntityRegistry {
81 fn intern(&mut self, entity: ExpandedEntity) -> Result<ExpandedName, SchemaLiftError> {
82 let name = entity.name.clone();
83
84 match self.entities.get_mut(&name) {
85 None => {
86 self.entities.insert(name.clone(), entity);
87 Ok(name)
88 }
89 Some(existing) => {
90 let merged = merge_entities(existing, &entity)?;
91 *existing = merged;
92 Ok(name)
93 }
94 }
95 }
96
97 fn into_map(self) -> BTreeMap<ExpandedName, ExpandedEntity> {
98 self.entities
99 }
100}
101
102fn merge_entities(
103 a: &ExpandedEntity,
104 b: &ExpandedEntity,
105) -> Result<ExpandedEntity, SchemaLiftError> {
106 let name = a.name.clone();
107 let description = merge_field(&a.description, &b.description, &name)?;
108 let dtype = merge_field(&a.dtype, &b.dtype, &name)?;
109 let validator = merge_field(&a.validator, &b.validator, &name)?;
110 let threshold = merge_threshold(&a.threshold, &b.threshold, &name)?;
111
112 Ok(ExpandedEntity {
113 name,
114 description,
115 dtype,
116 validator,
117 threshold,
118 })
119}
120
121fn merge_field<T: PartialEq + Clone>(
122 a: &Option<T>,
123 b: &Option<T>,
124 name: &ExpandedName,
125) -> Result<Option<T>, SchemaLiftError> {
126 match (a, b) {
127 (None, None) => Ok(None),
128 (Some(v), None) => Ok(Some(v.clone())),
129 (None, Some(v)) => Ok(Some(v.clone())),
130 (Some(v1), Some(v2)) => {
131 if v1 == v2 {
132 Ok(Some(v1.clone()))
133 } else {
134 Err(SchemaLiftError::ConflictingEntityDefinition {
135 name: name.to_string(),
136 })
137 }
138 }
139 }
140}
141
142fn merge_threshold(
143 a: &Option<f64>,
144 b: &Option<f64>,
145 name: &ExpandedName,
146) -> Result<Option<f64>, SchemaLiftError> {
147 match (a, b) {
148 (None, None) => Ok(None),
149 (Some(v), None) => Ok(Some(*v)),
150 (None, Some(v)) => Ok(Some(*v)),
151 (Some(v1), Some(v2)) => {
152 if (v1 - v2).abs() < f64::EPSILON {
153 Ok(Some(*v1))
154 } else {
155 Err(SchemaLiftError::ConflictingEntityDefinition {
156 name: name.to_string(),
157 })
158 }
159 }
160 }
161}
162
163fn lift_structure_property(
164 prop: ExpandedStructureProperty,
165 registry: &mut EntityRegistry,
166) -> Result<LiftedStructureProperty, SchemaLiftError> {
167 let mut choices = Vec::with_capacity(prop.choices.len());
168 for entity in prop.choices {
169 choices.push(registry.intern(entity)?);
170 }
171
172 Ok(LiftedStructureProperty {
173 choices,
174 description: prop.description,
175 value: prop.value,
176 dtype: prop.dtype,
177 validator: prop.validator,
178 threshold: prop.threshold,
179 })
180}
181
182fn lift_json_structure(
183 js: ExpandedJsonStructure,
184 registry: &mut EntityRegistry,
185) -> Result<LiftedJsonStructure, SchemaLiftError> {
186 let structure_name = js.name.clone();
187 let mut props = BTreeMap::new();
188
189 for (prop_name, prop) in js.props {
190 if props.contains_key(&prop_name) {
191 return Err(SchemaLiftError::DuplicateStructureProperty {
192 structure: structure_name.to_string(),
193 property: prop_name.to_string(),
194 });
195 }
196
197 props.insert(prop_name, lift_structure_property(prop, registry)?);
198 }
199
200 Ok(LiftedJsonStructure {
201 name: structure_name,
202 props,
203 })
204}
205
206fn lift_relation(
207 rel: ExpandedRelation,
208 registry: &mut EntityRegistry,
209) -> Result<LiftedRelation, SchemaLiftError> {
210 match rel {
211 ExpandedRelation::EmptyAcquired { name, description } => {
212 Ok(LiftedRelation::EmptyAcquired { name, description })
213 }
214 ExpandedRelation::EntityAcquired {
215 name,
216 description,
217 head,
218 tail,
219 } => {
220 let head_name = registry.intern(*head)?;
221 let tail_name = registry.intern(*tail)?;
222 Ok(LiftedRelation::EntityAcquired {
223 name,
224 description,
225 head: head_name,
226 tail: tail_name,
227 })
228 }
229 }
230}
231
232fn lift_classification(
233 cls: ExpandedClassification,
234 registry: &mut EntityRegistry,
235) -> Result<LiftedClassification, SchemaLiftError> {
236 let task_name = registry.intern(cls.task)?;
237
238 let mut labels = Vec::with_capacity(cls.labels.len());
239 for label in cls.labels {
240 labels.push(registry.intern(label)?);
241 }
242
243 for (_label_key, entity) in cls.label_descriptions {
245 registry.intern(entity)?;
246 }
247
248 Ok(LiftedClassification {
249 task: task_name,
250 labels,
251 threshold: cls.threshold,
252 multi_label: cls.multi_label,
253 })
254}
255
256impl TryFrom<ExpandedSchema> for LiftedSchema {
257 type Error = SchemaLiftError;
258
259 fn try_from(value: ExpandedSchema) -> Result<Self, Self::Error> {
260 let mut registry = EntityRegistry::default();
261
262 for entity in value.entities {
264 registry.intern(entity)?;
265 }
266
267 let mut json_structures = Vec::with_capacity(value.json_structures.len());
268 for js in value.json_structures {
269 json_structures.push(lift_json_structure(js, &mut registry)?);
270 }
271
272 let mut relations = Vec::with_capacity(value.relations.len());
273 for rel in value.relations {
274 relations.push(lift_relation(rel, &mut registry)?);
275 }
276
277 let mut classifications = Vec::with_capacity(value.classifications.len());
278 for cls in value.classifications {
279 classifications.push(lift_classification(cls, &mut registry)?);
280 }
281
282 Ok(LiftedSchema {
283 entities: registry.into_map(),
284 json_structures,
285 relations,
286 classifications,
287 })
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::expanded::ExpandedSchema;
295 use crate::normalized::{DType, NormalizedSchema};
296
297 #[test]
298 fn lifted_lifts_nested_entities_from_relations_and_classifications() {
299 let s = r#"
300 {
301 "entities": [
302 "gene::str::0.9::gene symbol"
303 ],
304 "relations": [
305 { "associated_with": { "head": "gene", "tail": "disease::str::0.8::disease entity" } }
306 ],
307 "classifications": [
308 {
309 "task": "sentiment",
310 "labels": ["positive", "negative"],
311 "label_descriptions": {
312 "positive": "positive",
313 "negative": "negative"
314 }
315 }
316 ]
317 }
318 "#;
319
320 let s2 = NormalizedSchema::from_json_str(s).unwrap();
321 let s3 = ExpandedSchema::try_from(s2).unwrap();
322 let s4 = LiftedSchema::try_from(s3).unwrap();
323
324 assert!(
325 s4.entities
326 .contains_key(&ExpandedName::new("gene".to_string()))
327 );
328 assert!(
329 s4.entities
330 .contains_key(&ExpandedName::new("disease".to_string()))
331 );
332 assert!(
333 s4.entities
334 .contains_key(&ExpandedName::new("sentiment".to_string()))
335 );
336 assert!(
337 s4.entities
338 .contains_key(&ExpandedName::new("positive".to_string()))
339 );
340 assert!(
341 s4.entities
342 .contains_key(&ExpandedName::new("negative".to_string()))
343 );
344
345 assert_eq!(s4.relations.len(), 1);
346 match &s4.relations[0] {
347 LiftedRelation::EntityAcquired { head, tail, .. } => {
348 assert_eq!(head.as_str(), "gene");
349 assert_eq!(tail.as_str(), "disease");
350 }
351 other => panic!("unexpected relation: {other:?}"),
352 }
353
354 assert_eq!(s4.classifications.len(), 1);
355 let cls = &s4.classifications[0];
356 assert_eq!(cls.task.as_str(), "sentiment");
357 assert_eq!(cls.labels.len(), 2);
358 assert_eq!(cls.labels[0].as_str(), "positive");
359 assert_eq!(cls.labels[1].as_str(), "negative");
360 }
361
362 #[test]
363 fn lifted_lifts_structure_choices() {
364 let s = r#"
365 {
366 "relations": [
367 { "contains": { "head": "patient", "tail": "record" } }
368 ],
369 "json_structures": [
370 {
371 "name": "Patient Record",
372 "status": {
373 "choices": [
374 "active::str::0.7::active status",
375 "inactive::str::0.7::inactive status"
376 ]
377 }
378 }
379 ]
380 }
381 "#;
382
383 let s2 = NormalizedSchema::from_json_str(s).unwrap();
384 let s3 = ExpandedSchema::try_from(s2).unwrap();
385 let s4 = LiftedSchema::try_from(s3).unwrap();
386
387 assert!(
388 s4.entities
389 .contains_key(&ExpandedName::new("active".to_string()))
390 );
391 assert!(
392 s4.entities
393 .contains_key(&ExpandedName::new("inactive".to_string()))
394 );
395
396 let js = &s4.json_structures[0];
397 let status_prop = js
398 .props
399 .get(&ExpandedName::new("status".to_string()))
400 .unwrap();
401
402 assert_eq!(status_prop.choices.len(), 2);
403 assert_eq!(status_prop.choices[0].as_str(), "active");
404 assert_eq!(status_prop.choices[1].as_str(), "inactive");
405 }
406
407 #[test]
408 fn lifted_rejects_conflicting_entity_definitions() {
409 let s3 = ExpandedSchema {
410 entities: vec![
411 ExpandedEntity {
412 name: ExpandedName::new("gene".to_string()),
413 dtype: Some(DType::String),
414 validator: None,
415 threshold: Some(0.5),
416 description: Some("first".to_string()),
417 },
418 ExpandedEntity {
419 name: ExpandedName::new("gene".to_string()),
420 dtype: Some(DType::String),
421 validator: None,
422 threshold: Some(0.9),
423 description: Some("second".to_string()),
424 },
425 ],
426 json_structures: vec![],
427 relations: vec![],
428 classifications: vec![],
429 };
430
431 let err = LiftedSchema::try_from(s3).unwrap_err();
432 match err {
433 SchemaLiftError::ConflictingEntityDefinition { name } => {
434 assert_eq!(name, "gene");
435 }
436 other => panic!("unexpected error: {other:?}"),
437 }
438 }
439
440 #[test]
441 fn lifted_deduplicates_identical_entity_definitions() {
442 let entity = ExpandedEntity {
443 name: ExpandedName::new("gene".to_string()),
444 dtype: Some(DType::String),
445 validator: None,
446 threshold: Some(0.5),
447 description: Some("gene symbol".to_string()),
448 };
449
450 let s3 = ExpandedSchema {
451 entities: vec![entity.clone(), entity],
452 json_structures: vec![],
453 relations: vec![],
454 classifications: vec![],
455 };
456
457 let s4 = LiftedSchema::try_from(s3).unwrap();
458 assert_eq!(s4.entities.len(), 1);
459 assert!(
460 s4.entities
461 .contains_key(&ExpandedName::new("gene".to_string()))
462 );
463 }
464}