1use crate::lifted::{LiftedClassification, LiftedJsonStructure, LiftedRelation, LiftedSchema};
2use crate::normalized::ExpandedName;
3use serde::Serialize;
4use std::collections::BTreeMap;
5use std::convert::TryFrom;
6
7#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
22pub enum TaskKind {
23 Entity,
24 Relation,
25 Structure,
26 Classification,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
30pub struct EntityTaskPlan {
31 pub entities: Vec<ExpandedName>,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
36pub struct RelationTaskPlan {
37 pub relation: ExpandedName,
39
40 pub head: ExpandedName,
42
43 pub tail: ExpandedName,
45
46 pub description: Option<String>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
51pub struct StructureChildPlan {
52 pub property: ExpandedName,
54
55 pub choices: Vec<ExpandedName>,
57
58 pub description: Option<String>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
63pub struct StructureTaskPlan {
64 pub structure: ExpandedName,
66
67 pub children: Vec<StructureChildPlan>,
69}
70
71#[derive(Debug, Clone, PartialEq, Serialize)]
72pub struct ClassificationTaskPlan {
73 pub task: ExpandedName,
75
76 pub labels: Vec<ExpandedName>,
78
79 pub threshold: Option<f64>,
80 pub multi_label: bool,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize)]
84pub enum PlannedTask {
85 Entity(EntityTaskPlan),
86 Relation(RelationTaskPlan),
87 Structure(StructureTaskPlan),
88 Classification(ClassificationTaskPlan),
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Default)]
92pub struct TaskPlan {
93 pub entities: BTreeMap<ExpandedName, TaskEntityDef>,
95
96 pub tasks: Vec<PlannedTask>,
98}
99
100#[derive(Debug, Clone, PartialEq, Serialize)]
101pub struct TaskEntityDef {
102 pub name: ExpandedName,
103 pub description: Option<String>,
104 pub threshold: Option<f64>,
105 pub dtype: Option<String>,
106}
107
108#[derive(Debug, thiserror::Error)]
109pub enum TaskPlanError {
110 #[error("relation task missing acquired endpoints: {name}")]
111 RelationMissingEndpoints { name: String },
112
113 #[error("referenced entity not found in registry: {name}")]
114 MissingEntity { name: String },
115
116 #[error("duplicate structure child {child} in structure {structure}")]
117 DuplicateStructureChild { structure: String, child: String },
118}
119
120fn dtype_to_string(dtype: &super::normalized::DType) -> String {
121 match dtype {
122 super::normalized::DType::String => "string".to_string(),
123 super::normalized::DType::Int => "int".to_string(),
124 super::normalized::DType::Float => "float".to_string(),
125 super::normalized::DType::Bool => "bool".to_string(),
126 }
127}
128
129fn build_entity_registry(schema: &LiftedSchema) -> BTreeMap<ExpandedName, TaskEntityDef> {
130 schema
131 .entities
132 .iter()
133 .map(|(name, entity)| {
134 (
135 name.clone(),
136 TaskEntityDef {
137 name: name.clone(),
138 description: entity.description.clone(),
139 threshold: entity.threshold,
140 dtype: entity.dtype.as_ref().map(dtype_to_string),
141 },
142 )
143 })
144 .collect()
145}
146
147fn ensure_entity_exists(
148 registry: &BTreeMap<ExpandedName, TaskEntityDef>,
149 name: &ExpandedName,
150) -> Result<(), TaskPlanError> {
151 if registry.contains_key(name) {
152 Ok(())
153 } else {
154 Err(TaskPlanError::MissingEntity {
155 name: name.to_string(),
156 })
157 }
158}
159
160fn entity_task_from_schema(
161 schema: &LiftedSchema,
162 registry: &BTreeMap<ExpandedName, TaskEntityDef>,
163) -> Result<Option<PlannedTask>, TaskPlanError> {
164 if schema.entities.is_empty() {
165 return Ok(None);
166 }
167
168 let mut entities: Vec<ExpandedName> = schema.entities.keys().cloned().collect();
169 entities.sort();
170
171 for entity in &entities {
172 ensure_entity_exists(registry, entity)?;
173 }
174
175 Ok(Some(PlannedTask::Entity(EntityTaskPlan { entities })))
176}
177
178fn relation_task_from_relation(
179 rel: &LiftedRelation,
180 registry: &BTreeMap<ExpandedName, TaskEntityDef>,
181) -> Result<PlannedTask, TaskPlanError> {
182 match rel {
183 LiftedRelation::EmptyAcquired { name, description } => {
184 Err(TaskPlanError::RelationMissingEndpoints {
185 name: format!(
186 "{}{}",
187 name,
188 description
189 .as_ref()
190 .map(|d| format!(" ({d})"))
191 .unwrap_or_default()
192 ),
193 })
194 }
195 LiftedRelation::EntityAcquired {
196 name,
197 description,
198 head,
199 tail,
200 } => {
201 ensure_entity_exists(registry, head)?;
202 ensure_entity_exists(registry, tail)?;
203
204 Ok(PlannedTask::Relation(RelationTaskPlan {
205 relation: name.clone(),
206 head: head.clone(),
207 tail: tail.clone(),
208 description: description.clone(),
209 }))
210 }
211 }
212}
213
214fn structure_task_from_structure(
215 js: &LiftedJsonStructure,
216 registry: &BTreeMap<ExpandedName, TaskEntityDef>,
217) -> Result<PlannedTask, TaskPlanError> {
218 let mut children = Vec::with_capacity(js.props.len());
219
220 for (property, prop) in &js.props {
221 for choice in &prop.choices {
222 ensure_entity_exists(registry, choice)?;
223 }
224
225 children.push(StructureChildPlan {
226 property: property.clone(),
227 choices: prop.choices.clone(),
228 description: prop.description.clone(),
229 });
230 }
231
232 Ok(PlannedTask::Structure(StructureTaskPlan {
233 structure: js.name.clone(),
234 children,
235 }))
236}
237
238fn classification_task_from_classification(
239 cls: &LiftedClassification,
240 registry: &BTreeMap<ExpandedName, TaskEntityDef>,
241) -> Result<PlannedTask, TaskPlanError> {
242 ensure_entity_exists(registry, &cls.task)?;
243 for label in &cls.labels {
244 ensure_entity_exists(registry, label)?;
245 }
246
247 Ok(PlannedTask::Classification(ClassificationTaskPlan {
248 task: cls.task.clone(),
249 labels: cls.labels.clone(),
250 threshold: cls.threshold,
251 multi_label: cls.multi_label,
252 }))
253}
254
255impl TryFrom<LiftedSchema> for TaskPlan {
256 type Error = TaskPlanError;
257
258 fn try_from(schema: LiftedSchema) -> Result<Self, Self::Error> {
259 let registry = build_entity_registry(&schema);
260
261 let mut tasks = Vec::new();
262
263 if let Some(entity_task) = entity_task_from_schema(&schema, ®istry)? {
264 tasks.push(entity_task);
265 }
266
267 for rel in &schema.relations {
268 tasks.push(relation_task_from_relation(rel, ®istry)?);
269 }
270
271 for js in &schema.json_structures {
272 tasks.push(structure_task_from_structure(js, ®istry)?);
273 }
274
275 for cls in &schema.classifications {
276 tasks.push(classification_task_from_classification(cls, ®istry)?);
277 }
278
279 Ok(TaskPlan {
280 entities: registry,
281 tasks,
282 })
283 }
284}
285
286impl TaskPlan {
287 pub fn entity_tasks(&self) -> impl Iterator<Item = &EntityTaskPlan> {
288 self.tasks.iter().filter_map(|t| match t {
289 PlannedTask::Entity(x) => Some(x),
290 _ => None,
291 })
292 }
293
294 pub fn relation_tasks(&self) -> impl Iterator<Item = &RelationTaskPlan> {
295 self.tasks.iter().filter_map(|t| match t {
296 PlannedTask::Relation(x) => Some(x),
297 _ => None,
298 })
299 }
300
301 pub fn structure_tasks(&self) -> impl Iterator<Item = &StructureTaskPlan> {
302 self.tasks.iter().filter_map(|t| match t {
303 PlannedTask::Structure(x) => Some(x),
304 _ => None,
305 })
306 }
307
308 pub fn classification_tasks(&self) -> impl Iterator<Item = &ClassificationTaskPlan> {
309 self.tasks.iter().filter_map(|t| match t {
310 PlannedTask::Classification(x) => Some(x),
311 _ => None,
312 })
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::expanded::ExpandedSchema;
320 use crate::lifted::LiftedSchema;
321 use crate::normalized::NormalizedSchema;
322
323 #[test]
324 fn task_plan_builds_entity_relation_structure_and_classification_tasks() {
325 let s = r#"
326 {
327 "entities": [
328 "gene::str::0.9::gene symbol",
329 "disease::str::0.8::disease entity",
330 "patient",
331 "record",
332 "positive",
333 "negative",
334 "sentiment"
335 ],
336 "relations": [
337 { "associated_with": { "head": "gene", "tail": "disease" } }
338 ],
339 "json_structures": [
340 {
341 "name": "Patient Record",
342 "status": {
343 "choices": ["positive", "negative"]
344 }
345 }
346 ],
347 "classifications": [
348 {
349 "task": "sentiment",
350 "labels": ["positive", "negative"],
351 "multi_label": false
352 }
353 ]
354 }
355 "#;
356
357 let s2 = NormalizedSchema::from_json_str(s).unwrap();
358 let s3 = ExpandedSchema::try_from(s2).unwrap();
359 let s4 = LiftedSchema::try_from(s3).unwrap();
360 let plan = TaskPlan::try_from(s4).unwrap();
361
362 assert_eq!(plan.entity_tasks().count(), 1);
363 assert_eq!(plan.relation_tasks().count(), 1);
364 assert_eq!(plan.structure_tasks().count(), 1);
365 assert_eq!(plan.classification_tasks().count(), 1);
366 }
367
368 #[test]
369 fn task_plan_relation_requires_registered_entities() {
370 let s4 = LiftedSchema {
371 entities: BTreeMap::new(),
372 json_structures: vec![],
373 relations: vec![LiftedRelation::EntityAcquired {
374 name: ExpandedName::new("associated_with".to_string()),
375 description: None,
376 head: ExpandedName::new("gene".to_string()),
377 tail: ExpandedName::new("disease".to_string()),
378 }],
379 classifications: vec![],
380 };
381
382 let err = TaskPlan::try_from(s4).unwrap_err();
383 match err {
384 TaskPlanError::MissingEntity { name } => assert_eq!(name, "gene"),
385 other => panic!("unexpected error: {other:?}"),
386 }
387 }
388
389 #[test]
390 fn task_plan_classification_uses_entity_refs() {
391 let s = r#"
392 {
393 "entities": ["sentiment", "positive", "negative"],
394 "classifications": [
395 {
396 "task": "sentiment",
397 "labels": ["positive", "negative"],
398 "multi_label": true,
399 "threshold": 0.6
400 }
401 ]
402 }
403 "#;
404
405 let s2 = NormalizedSchema::from_json_str(s).unwrap();
406 let s3 = ExpandedSchema::try_from(s2).unwrap();
407 let s4 = LiftedSchema::try_from(s3).unwrap();
408 let plan = TaskPlan::try_from(s4).unwrap();
409
410 let cls = plan.classification_tasks().next().unwrap();
411 assert_eq!(cls.task.as_str(), "sentiment");
412 assert_eq!(cls.labels.len(), 2);
413 assert_eq!(cls.labels[0].as_str(), "positive");
414 assert_eq!(cls.labels[1].as_str(), "negative");
415 assert_eq!(cls.threshold, Some(0.6));
416 assert!(cls.multi_label);
417 }
418}