use crate::lifted::{LiftedClassification, LiftedJsonStructure, LiftedRelation, LiftedSchema};
use crate::normalized::ExpandedName;
use serde::Serialize;
use std::collections::BTreeMap;
use std::convert::TryFrom;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum TaskKind {
Entity,
Relation,
Structure,
Classification,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct EntityTaskPlan {
pub entities: Vec<ExpandedName>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct RelationTaskPlan {
pub relation: ExpandedName,
pub head: ExpandedName,
pub tail: ExpandedName,
pub description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct StructureChildPlan {
pub property: ExpandedName,
pub choices: Vec<ExpandedName>,
pub description: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct StructureTaskPlan {
pub structure: ExpandedName,
pub children: Vec<StructureChildPlan>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct ClassificationTaskPlan {
pub task: ExpandedName,
pub labels: Vec<ExpandedName>,
pub threshold: Option<f64>,
pub multi_label: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum PlannedTask {
Entity(EntityTaskPlan),
Relation(RelationTaskPlan),
Structure(StructureTaskPlan),
Classification(ClassificationTaskPlan),
}
#[derive(Debug, Clone, PartialEq, Serialize, Default)]
pub struct TaskPlan {
pub entities: BTreeMap<ExpandedName, TaskEntityDef>,
pub tasks: Vec<PlannedTask>,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct TaskEntityDef {
pub name: ExpandedName,
pub description: Option<String>,
pub threshold: Option<f64>,
pub dtype: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum TaskPlanError {
#[error("relation task missing acquired endpoints: {name}")]
RelationMissingEndpoints { name: String },
#[error("referenced entity not found in registry: {name}")]
MissingEntity { name: String },
#[error("duplicate structure child {child} in structure {structure}")]
DuplicateStructureChild { structure: String, child: String },
}
fn dtype_to_string(dtype: &super::normalized::DType) -> String {
match dtype {
super::normalized::DType::String => "string".to_string(),
super::normalized::DType::Int => "int".to_string(),
super::normalized::DType::Float => "float".to_string(),
super::normalized::DType::Bool => "bool".to_string(),
}
}
fn build_entity_registry(schema: &LiftedSchema) -> BTreeMap<ExpandedName, TaskEntityDef> {
schema
.entities
.iter()
.map(|(name, entity)| {
(
name.clone(),
TaskEntityDef {
name: name.clone(),
description: entity.description.clone(),
threshold: entity.threshold,
dtype: entity.dtype.as_ref().map(dtype_to_string),
},
)
})
.collect()
}
fn ensure_entity_exists(
registry: &BTreeMap<ExpandedName, TaskEntityDef>,
name: &ExpandedName,
) -> Result<(), TaskPlanError> {
if registry.contains_key(name) {
Ok(())
} else {
Err(TaskPlanError::MissingEntity {
name: name.to_string(),
})
}
}
fn entity_task_from_schema(
schema: &LiftedSchema,
registry: &BTreeMap<ExpandedName, TaskEntityDef>,
) -> Result<Option<PlannedTask>, TaskPlanError> {
if schema.entities.is_empty() {
return Ok(None);
}
let mut entities: Vec<ExpandedName> = schema.entities.keys().cloned().collect();
entities.sort();
for entity in &entities {
ensure_entity_exists(registry, entity)?;
}
Ok(Some(PlannedTask::Entity(EntityTaskPlan { entities })))
}
fn relation_task_from_relation(
rel: &LiftedRelation,
registry: &BTreeMap<ExpandedName, TaskEntityDef>,
) -> Result<PlannedTask, TaskPlanError> {
match rel {
LiftedRelation::EmptyAcquired { name, description } => {
Err(TaskPlanError::RelationMissingEndpoints {
name: format!(
"{}{}",
name,
description
.as_ref()
.map(|d| format!(" ({d})"))
.unwrap_or_default()
),
})
}
LiftedRelation::EntityAcquired {
name,
description,
head,
tail,
} => {
ensure_entity_exists(registry, head)?;
ensure_entity_exists(registry, tail)?;
Ok(PlannedTask::Relation(RelationTaskPlan {
relation: name.clone(),
head: head.clone(),
tail: tail.clone(),
description: description.clone(),
}))
}
}
}
fn structure_task_from_structure(
js: &LiftedJsonStructure,
registry: &BTreeMap<ExpandedName, TaskEntityDef>,
) -> Result<PlannedTask, TaskPlanError> {
let mut children = Vec::with_capacity(js.props.len());
for (property, prop) in &js.props {
for choice in &prop.choices {
ensure_entity_exists(registry, choice)?;
}
children.push(StructureChildPlan {
property: property.clone(),
choices: prop.choices.clone(),
description: prop.description.clone(),
});
}
Ok(PlannedTask::Structure(StructureTaskPlan {
structure: js.name.clone(),
children,
}))
}
fn classification_task_from_classification(
cls: &LiftedClassification,
registry: &BTreeMap<ExpandedName, TaskEntityDef>,
) -> Result<PlannedTask, TaskPlanError> {
ensure_entity_exists(registry, &cls.task)?;
for label in &cls.labels {
ensure_entity_exists(registry, label)?;
}
Ok(PlannedTask::Classification(ClassificationTaskPlan {
task: cls.task.clone(),
labels: cls.labels.clone(),
threshold: cls.threshold,
multi_label: cls.multi_label,
}))
}
impl TryFrom<LiftedSchema> for TaskPlan {
type Error = TaskPlanError;
fn try_from(schema: LiftedSchema) -> Result<Self, Self::Error> {
let registry = build_entity_registry(&schema);
let mut tasks = Vec::new();
if let Some(entity_task) = entity_task_from_schema(&schema, ®istry)? {
tasks.push(entity_task);
}
for rel in &schema.relations {
tasks.push(relation_task_from_relation(rel, ®istry)?);
}
for js in &schema.json_structures {
tasks.push(structure_task_from_structure(js, ®istry)?);
}
for cls in &schema.classifications {
tasks.push(classification_task_from_classification(cls, ®istry)?);
}
Ok(TaskPlan {
entities: registry,
tasks,
})
}
}
impl TaskPlan {
pub fn entity_tasks(&self) -> impl Iterator<Item = &EntityTaskPlan> {
self.tasks.iter().filter_map(|t| match t {
PlannedTask::Entity(x) => Some(x),
_ => None,
})
}
pub fn relation_tasks(&self) -> impl Iterator<Item = &RelationTaskPlan> {
self.tasks.iter().filter_map(|t| match t {
PlannedTask::Relation(x) => Some(x),
_ => None,
})
}
pub fn structure_tasks(&self) -> impl Iterator<Item = &StructureTaskPlan> {
self.tasks.iter().filter_map(|t| match t {
PlannedTask::Structure(x) => Some(x),
_ => None,
})
}
pub fn classification_tasks(&self) -> impl Iterator<Item = &ClassificationTaskPlan> {
self.tasks.iter().filter_map(|t| match t {
PlannedTask::Classification(x) => Some(x),
_ => None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expanded::ExpandedSchema;
use crate::lifted::LiftedSchema;
use crate::normalized::NormalizedSchema;
#[test]
fn task_plan_builds_entity_relation_structure_and_classification_tasks() {
let s = r#"
{
"entities": [
"gene::str::0.9::gene symbol",
"disease::str::0.8::disease entity",
"patient",
"record",
"positive",
"negative",
"sentiment"
],
"relations": [
{ "associated_with": { "head": "gene", "tail": "disease" } }
],
"json_structures": [
{
"name": "Patient Record",
"status": {
"choices": ["positive", "negative"]
}
}
],
"classifications": [
{
"task": "sentiment",
"labels": ["positive", "negative"],
"multi_label": false
}
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
let plan = TaskPlan::try_from(s4).unwrap();
assert_eq!(plan.entity_tasks().count(), 1);
assert_eq!(plan.relation_tasks().count(), 1);
assert_eq!(plan.structure_tasks().count(), 1);
assert_eq!(plan.classification_tasks().count(), 1);
}
#[test]
fn task_plan_relation_requires_registered_entities() {
let s4 = LiftedSchema {
entities: BTreeMap::new(),
json_structures: vec![],
relations: vec![LiftedRelation::EntityAcquired {
name: ExpandedName::new("associated_with".to_string()),
description: None,
head: ExpandedName::new("gene".to_string()),
tail: ExpandedName::new("disease".to_string()),
}],
classifications: vec![],
};
let err = TaskPlan::try_from(s4).unwrap_err();
match err {
TaskPlanError::MissingEntity { name } => assert_eq!(name, "gene"),
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn task_plan_classification_uses_entity_refs() {
let s = r#"
{
"entities": ["sentiment", "positive", "negative"],
"classifications": [
{
"task": "sentiment",
"labels": ["positive", "negative"],
"multi_label": true,
"threshold": 0.6
}
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
let plan = TaskPlan::try_from(s4).unwrap();
let cls = plan.classification_tasks().next().unwrap();
assert_eq!(cls.task.as_str(), "sentiment");
assert_eq!(cls.labels.len(), 2);
assert_eq!(cls.labels[0].as_str(), "positive");
assert_eq!(cls.labels[1].as_str(), "negative");
assert_eq!(cls.threshold, Some(0.6));
assert!(cls.multi_label);
}
}