Skip to main content

ie_schema/
prompt_plan.rs

1use crate::normalized::ExpandedName;
2use crate::task_plan::{
3    ClassificationTaskPlan, EntityTaskPlan, PlannedTask, RelationTaskPlan, StructureTaskPlan,
4    TaskPlan,
5};
6use serde::Serialize;
7use std::convert::TryFrom;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
10pub enum PromptSpecialToken {
11    Prompt,    // [P]
12    Entity,    // [E]
13    Child,     // [C]
14    Label,     // [L]
15    Separator, // [SEP]
16}
17
18impl PromptSpecialToken {
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            Self::Prompt => "[P]",
22            Self::Entity => "[E]",
23            Self::Child => "[C]",
24            Self::Label => "[L]",
25            Self::Separator => "[SEP]",
26        }
27    }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
31pub enum PromptAtom {
32    Special(PromptSpecialToken),
33    Text(String),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub enum PromptTaskKind {
38    Entity,
39    Relation,
40    Structure,
41    Classification,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
45pub struct PromptTaskPlan {
46    pub kind: PromptTaskKind,
47    pub name: ExpandedName,
48    pub atoms: Vec<PromptAtom>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
52pub struct PromptPlan {
53    pub tasks: Vec<PromptTaskPlan>,
54    pub flat_prompt: Vec<PromptAtom>,
55}
56
57#[derive(Debug, thiserror::Error)]
58pub enum PromptPlanError {
59    #[error("empty entity task")]
60    EmptyEntityTask,
61
62    #[error("empty structure task: {name}")]
63    EmptyStructureTask { name: String },
64
65    #[error("empty classification task: {name}")]
66    EmptyClassificationTask { name: String },
67}
68
69fn entity_task_to_prompt(task: &EntityTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
70    if task.entities.is_empty() {
71        return Err(PromptPlanError::EmptyEntityTask);
72    }
73
74    let mut atoms = Vec::new();
75    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
76    atoms.push(PromptAtom::Text("entities".to_string()));
77
78    for entity in &task.entities {
79        atoms.push(PromptAtom::Special(PromptSpecialToken::Entity));
80        atoms.push(PromptAtom::Text(entity.to_string()));
81    }
82
83    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
84
85    Ok(PromptTaskPlan {
86        kind: PromptTaskKind::Entity,
87        name: ExpandedName::new("entities".to_string()),
88        atoms,
89    })
90}
91
92fn relation_task_to_prompt(task: &RelationTaskPlan) -> PromptTaskPlan {
93    let atoms = vec![
94        PromptAtom::Special(PromptSpecialToken::Prompt),
95        PromptAtom::Text(task.relation.to_string()),
96        PromptAtom::Special(PromptSpecialToken::Child),
97        PromptAtom::Text(task.head.to_string()),
98        PromptAtom::Special(PromptSpecialToken::Child),
99        PromptAtom::Text(task.tail.to_string()),
100        PromptAtom::Special(PromptSpecialToken::Separator),
101    ];
102
103    PromptTaskPlan {
104        kind: PromptTaskKind::Relation,
105        name: task.relation.clone(),
106        atoms,
107    }
108}
109
110fn structure_task_to_prompt(task: &StructureTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
111    if task.children.is_empty() {
112        return Err(PromptPlanError::EmptyStructureTask {
113            name: task.structure.to_string(),
114        });
115    }
116
117    let mut atoms = Vec::new();
118    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
119    atoms.push(PromptAtom::Text(task.structure.to_string()));
120
121    for child in &task.children {
122        atoms.push(PromptAtom::Special(PromptSpecialToken::Child));
123        atoms.push(PromptAtom::Text(child.property.to_string()));
124    }
125
126    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
127
128    Ok(PromptTaskPlan {
129        kind: PromptTaskKind::Structure,
130        name: task.structure.clone(),
131        atoms,
132    })
133}
134
135fn classification_task_to_prompt(
136    task: &ClassificationTaskPlan,
137) -> Result<PromptTaskPlan, PromptPlanError> {
138    if task.labels.is_empty() {
139        return Err(PromptPlanError::EmptyClassificationTask {
140            name: task.task.to_string(),
141        });
142    }
143
144    let mut atoms = Vec::new();
145    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
146    atoms.push(PromptAtom::Text(task.task.to_string()));
147
148    for label in &task.labels {
149        atoms.push(PromptAtom::Special(PromptSpecialToken::Label));
150        atoms.push(PromptAtom::Text(label.to_string()));
151    }
152
153    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
154
155    Ok(PromptTaskPlan {
156        kind: PromptTaskKind::Classification,
157        name: task.task.clone(),
158        atoms,
159    })
160}
161
162impl TryFrom<TaskPlan> for PromptPlan {
163    type Error = PromptPlanError;
164
165    fn try_from(value: TaskPlan) -> Result<Self, Self::Error> {
166        let mut tasks = Vec::new();
167
168        for task in &value.tasks {
169            let planned = match task {
170                PlannedTask::Entity(x) => entity_task_to_prompt(x)?,
171                PlannedTask::Relation(x) => relation_task_to_prompt(x),
172                PlannedTask::Structure(x) => structure_task_to_prompt(x)?,
173                PlannedTask::Classification(x) => classification_task_to_prompt(x)?,
174            };
175            tasks.push(planned);
176        }
177
178        let mut flat_prompt = Vec::new();
179        for task in &tasks {
180            flat_prompt.extend(task.atoms.clone());
181        }
182
183        Ok(Self { tasks, flat_prompt })
184    }
185}
186
187impl PromptPlan {
188    pub fn render_debug_string(&self) -> String {
189        self.flat_prompt
190            .iter()
191            .map(|atom| match atom {
192                PromptAtom::Special(tok) => tok.as_str().to_string(),
193                PromptAtom::Text(s) => s.clone(),
194            })
195            .collect::<Vec<_>>()
196            .join(" ")
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::expanded::ExpandedSchema;
204    use crate::lifted::LiftedSchema;
205    use crate::normalized::NormalizedSchema;
206    use crate::task_plan::TaskPlan;
207
208    #[test]
209    fn prompt_plan_renders_expected_debug_string() {
210        let s = r#"
211        {
212            "entities": ["gene", "disease", "sentiment", "positive", "negative"],
213            "relations": [
214                { "associated_with": { "head": "gene", "tail": "disease" } }
215            ],
216            "classifications": [
217                {
218                    "task": "sentiment",
219                    "labels": ["positive", "negative"]
220                }
221            ]
222        }
223        "#;
224
225        let s2 = NormalizedSchema::from_json_str(s).unwrap();
226        let s3 = ExpandedSchema::try_from(s2).unwrap();
227        let s4 = LiftedSchema::try_from(s3).unwrap();
228        let tp = TaskPlan::try_from(s4).unwrap();
229        let pp = PromptPlan::try_from(tp).unwrap();
230
231        let rendered = pp.render_debug_string();
232        assert!(rendered.contains("[P] entities [E] disease"));
233        assert!(rendered.contains("[P] associated_with [C] gene [C] disease [SEP]"));
234        assert!(rendered.contains("[P] sentiment [L] positive [L] negative [SEP]"));
235    }
236}