Skip to main content

ie_schema/
task_plan.rs

1use crate::lifted::{LiftedClassification, LiftedJsonStructure, LiftedRelation, LiftedSchema};
2use crate::normalized::ExpandedName;
3use serde::Serialize;
4use std::collections::BTreeMap;
5use std::convert::TryFrom;
6
7/// Semantic planning layer between LiftedSchema and prompt/token generation.
8///
9/// Responsibilities:
10/// - compile lifted schema into explicit task units
11/// - preserve deterministic ordering
12/// - keep task types separate
13/// - avoid tokenizer / tensor concerns
14///
15/// Non-responsibilities:
16/// - vocabulary lookup
17/// - token IDs
18/// - string formatting for final prompts
19/// - tensor construction
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
22pub enum TaskKind {
23    Entity,
24    Relation,
25    Structure,
26    Classification,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
30pub struct EntityTaskPlan {
31    /// Entities to extract directly from text.
32    pub entities: Vec<ExpandedName>,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
36pub struct RelationTaskPlan {
37    /// Relation/task name.
38    pub relation: ExpandedName,
39
40    /// Head entity reference.
41    pub head: ExpandedName,
42
43    /// Tail entity reference.
44    pub tail: ExpandedName,
45
46    /// Optional human-readable description carried forward for prompt rendering.
47    pub description: Option<String>,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
51pub struct StructureChildPlan {
52    /// Property name in the structure.
53    pub property: ExpandedName,
54
55    /// Optional choices for closed-set properties.
56    pub choices: Vec<ExpandedName>,
57
58    /// Optional description to support later prompt rendering.
59    pub description: Option<String>,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
63pub struct StructureTaskPlan {
64    /// Structure/task name.
65    pub structure: ExpandedName,
66
67    /// Ordered child/property definitions.
68    pub children: Vec<StructureChildPlan>,
69}
70
71#[derive(Debug, Clone, PartialEq, Serialize)]
72pub struct ClassificationTaskPlan {
73    /// Classification task entity.
74    pub task: ExpandedName,
75
76    /// Allowed label entities.
77    pub labels: Vec<ExpandedName>,
78
79    pub threshold: Option<f64>,
80    pub multi_label: bool,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize)]
84pub enum PlannedTask {
85    Entity(EntityTaskPlan),
86    Relation(RelationTaskPlan),
87    Structure(StructureTaskPlan),
88    Classification(ClassificationTaskPlan),
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Default)]
92pub struct TaskPlan {
93    /// Canonical entity registry from LiftedSchema.
94    pub entities: BTreeMap<ExpandedName, TaskEntityDef>,
95
96    /// Planned tasks in deterministic execution / rendering order.
97    pub tasks: Vec<PlannedTask>,
98}
99
100#[derive(Debug, Clone, PartialEq, Serialize)]
101pub struct TaskEntityDef {
102    pub name: ExpandedName,
103    pub description: Option<String>,
104    pub threshold: Option<f64>,
105    pub dtype: Option<String>,
106}
107
108#[derive(Debug, thiserror::Error)]
109pub enum TaskPlanError {
110    #[error("relation task missing acquired endpoints: {name}")]
111    RelationMissingEndpoints { name: String },
112
113    #[error("referenced entity not found in registry: {name}")]
114    MissingEntity { name: String },
115
116    #[error("duplicate structure child {child} in structure {structure}")]
117    DuplicateStructureChild { structure: String, child: String },
118}
119
120fn dtype_to_string(dtype: &super::normalized::DType) -> String {
121    match dtype {
122        super::normalized::DType::String => "string".to_string(),
123        super::normalized::DType::Int => "int".to_string(),
124        super::normalized::DType::Float => "float".to_string(),
125        super::normalized::DType::Bool => "bool".to_string(),
126    }
127}
128
129fn build_entity_registry(schema: &LiftedSchema) -> BTreeMap<ExpandedName, TaskEntityDef> {
130    schema
131        .entities
132        .iter()
133        .map(|(name, entity)| {
134            (
135                name.clone(),
136                TaskEntityDef {
137                    name: name.clone(),
138                    description: entity.description.clone(),
139                    threshold: entity.threshold,
140                    dtype: entity.dtype.as_ref().map(dtype_to_string),
141                },
142            )
143        })
144        .collect()
145}
146
147fn ensure_entity_exists(
148    registry: &BTreeMap<ExpandedName, TaskEntityDef>,
149    name: &ExpandedName,
150) -> Result<(), TaskPlanError> {
151    if registry.contains_key(name) {
152        Ok(())
153    } else {
154        Err(TaskPlanError::MissingEntity {
155            name: name.to_string(),
156        })
157    }
158}
159
160fn entity_task_from_schema(
161    schema: &LiftedSchema,
162    registry: &BTreeMap<ExpandedName, TaskEntityDef>,
163) -> Result<Option<PlannedTask>, TaskPlanError> {
164    if schema.entities.is_empty() {
165        return Ok(None);
166    }
167
168    let mut entities: Vec<ExpandedName> = schema.entities.keys().cloned().collect();
169    entities.sort();
170
171    for entity in &entities {
172        ensure_entity_exists(registry, entity)?;
173    }
174
175    Ok(Some(PlannedTask::Entity(EntityTaskPlan { entities })))
176}
177
178fn relation_task_from_relation(
179    rel: &LiftedRelation,
180    registry: &BTreeMap<ExpandedName, TaskEntityDef>,
181) -> Result<PlannedTask, TaskPlanError> {
182    match rel {
183        LiftedRelation::EmptyAcquired { name, description } => {
184            Err(TaskPlanError::RelationMissingEndpoints {
185                name: format!(
186                    "{}{}",
187                    name,
188                    description
189                        .as_ref()
190                        .map(|d| format!(" ({d})"))
191                        .unwrap_or_default()
192                ),
193            })
194        }
195        LiftedRelation::EntityAcquired {
196            name,
197            description,
198            head,
199            tail,
200        } => {
201            ensure_entity_exists(registry, head)?;
202            ensure_entity_exists(registry, tail)?;
203
204            Ok(PlannedTask::Relation(RelationTaskPlan {
205                relation: name.clone(),
206                head: head.clone(),
207                tail: tail.clone(),
208                description: description.clone(),
209            }))
210        }
211    }
212}
213
214fn structure_task_from_structure(
215    js: &LiftedJsonStructure,
216    registry: &BTreeMap<ExpandedName, TaskEntityDef>,
217) -> Result<PlannedTask, TaskPlanError> {
218    let mut children = Vec::with_capacity(js.props.len());
219
220    for (property, prop) in &js.props {
221        for choice in &prop.choices {
222            ensure_entity_exists(registry, choice)?;
223        }
224
225        children.push(StructureChildPlan {
226            property: property.clone(),
227            choices: prop.choices.clone(),
228            description: prop.description.clone(),
229        });
230    }
231
232    Ok(PlannedTask::Structure(StructureTaskPlan {
233        structure: js.name.clone(),
234        children,
235    }))
236}
237
238fn classification_task_from_classification(
239    cls: &LiftedClassification,
240    registry: &BTreeMap<ExpandedName, TaskEntityDef>,
241) -> Result<PlannedTask, TaskPlanError> {
242    ensure_entity_exists(registry, &cls.task)?;
243    for label in &cls.labels {
244        ensure_entity_exists(registry, label)?;
245    }
246
247    Ok(PlannedTask::Classification(ClassificationTaskPlan {
248        task: cls.task.clone(),
249        labels: cls.labels.clone(),
250        threshold: cls.threshold,
251        multi_label: cls.multi_label,
252    }))
253}
254
255impl TryFrom<LiftedSchema> for TaskPlan {
256    type Error = TaskPlanError;
257
258    fn try_from(schema: LiftedSchema) -> Result<Self, Self::Error> {
259        let registry = build_entity_registry(&schema);
260
261        let mut tasks = Vec::new();
262
263        if let Some(entity_task) = entity_task_from_schema(&schema, &registry)? {
264            tasks.push(entity_task);
265        }
266
267        for rel in &schema.relations {
268            tasks.push(relation_task_from_relation(rel, &registry)?);
269        }
270
271        for js in &schema.json_structures {
272            tasks.push(structure_task_from_structure(js, &registry)?);
273        }
274
275        for cls in &schema.classifications {
276            tasks.push(classification_task_from_classification(cls, &registry)?);
277        }
278
279        Ok(TaskPlan {
280            entities: registry,
281            tasks,
282        })
283    }
284}
285
286impl TaskPlan {
287    pub fn entity_tasks(&self) -> impl Iterator<Item = &EntityTaskPlan> {
288        self.tasks.iter().filter_map(|t| match t {
289            PlannedTask::Entity(x) => Some(x),
290            _ => None,
291        })
292    }
293
294    pub fn relation_tasks(&self) -> impl Iterator<Item = &RelationTaskPlan> {
295        self.tasks.iter().filter_map(|t| match t {
296            PlannedTask::Relation(x) => Some(x),
297            _ => None,
298        })
299    }
300
301    pub fn structure_tasks(&self) -> impl Iterator<Item = &StructureTaskPlan> {
302        self.tasks.iter().filter_map(|t| match t {
303            PlannedTask::Structure(x) => Some(x),
304            _ => None,
305        })
306    }
307
308    pub fn classification_tasks(&self) -> impl Iterator<Item = &ClassificationTaskPlan> {
309        self.tasks.iter().filter_map(|t| match t {
310            PlannedTask::Classification(x) => Some(x),
311            _ => None,
312        })
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::expanded::ExpandedSchema;
320    use crate::lifted::LiftedSchema;
321    use crate::normalized::NormalizedSchema;
322
323    #[test]
324    fn task_plan_builds_entity_relation_structure_and_classification_tasks() {
325        let s = r#"
326        {
327            "entities": [
328                "gene::str::0.9::gene symbol",
329                "disease::str::0.8::disease entity",
330                "patient",
331                "record",
332                "positive",
333                "negative",
334                "sentiment"
335            ],
336            "relations": [
337                { "associated_with": { "head": "gene", "tail": "disease" } }
338            ],
339            "json_structures": [
340                {
341                    "name": "Patient Record",
342                    "status": {
343                        "choices": ["positive", "negative"]
344                    }
345                }
346            ],
347            "classifications": [
348                {
349                    "task": "sentiment",
350                    "labels": ["positive", "negative"],
351                    "multi_label": false
352                }
353            ]
354        }
355        "#;
356
357        let s2 = NormalizedSchema::from_json_str(s).unwrap();
358        let s3 = ExpandedSchema::try_from(s2).unwrap();
359        let s4 = LiftedSchema::try_from(s3).unwrap();
360        let plan = TaskPlan::try_from(s4).unwrap();
361
362        assert_eq!(plan.entity_tasks().count(), 1);
363        assert_eq!(plan.relation_tasks().count(), 1);
364        assert_eq!(plan.structure_tasks().count(), 1);
365        assert_eq!(plan.classification_tasks().count(), 1);
366    }
367
368    #[test]
369    fn task_plan_relation_requires_registered_entities() {
370        let s4 = LiftedSchema {
371            entities: BTreeMap::new(),
372            json_structures: vec![],
373            relations: vec![LiftedRelation::EntityAcquired {
374                name: ExpandedName::new("associated_with".to_string()),
375                description: None,
376                head: ExpandedName::new("gene".to_string()),
377                tail: ExpandedName::new("disease".to_string()),
378            }],
379            classifications: vec![],
380        };
381
382        let err = TaskPlan::try_from(s4).unwrap_err();
383        match err {
384            TaskPlanError::MissingEntity { name } => assert_eq!(name, "gene"),
385            other => panic!("unexpected error: {other:?}"),
386        }
387    }
388
389    #[test]
390    fn task_plan_classification_uses_entity_refs() {
391        let s = r#"
392        {
393            "entities": ["sentiment", "positive", "negative"],
394            "classifications": [
395                {
396                    "task": "sentiment",
397                    "labels": ["positive", "negative"],
398                    "multi_label": true,
399                    "threshold": 0.6
400                }
401            ]
402        }
403        "#;
404
405        let s2 = NormalizedSchema::from_json_str(s).unwrap();
406        let s3 = ExpandedSchema::try_from(s2).unwrap();
407        let s4 = LiftedSchema::try_from(s3).unwrap();
408        let plan = TaskPlan::try_from(s4).unwrap();
409
410        let cls = plan.classification_tasks().next().unwrap();
411        assert_eq!(cls.task.as_str(), "sentiment");
412        assert_eq!(cls.labels.len(), 2);
413        assert_eq!(cls.labels[0].as_str(), "positive");
414        assert_eq!(cls.labels[1].as_str(), "negative");
415        assert_eq!(cls.threshold, Some(0.6));
416        assert!(cls.multi_label);
417    }
418}