1use crate::normalized::ExpandedName;
2use crate::task_plan::{
3 ClassificationTaskPlan, EntityTaskPlan, PlannedTask, RelationTaskPlan, StructureTaskPlan,
4 TaskPlan,
5};
6use serde::Serialize;
7use std::convert::TryFrom;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
10pub enum PromptSpecialToken {
11 Prompt, Entity, Child, Label, Separator, }
17
18impl PromptSpecialToken {
19 pub fn as_str(&self) -> &'static str {
20 match self {
21 Self::Prompt => "[P]",
22 Self::Entity => "[E]",
23 Self::Child => "[C]",
24 Self::Label => "[L]",
25 Self::Separator => "[SEP]",
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
31pub enum PromptAtom {
32 Special(PromptSpecialToken),
33 Text(String),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub enum PromptTaskKind {
38 Entity,
39 Relation,
40 Structure,
41 Classification,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
45pub struct PromptTaskPlan {
46 pub kind: PromptTaskKind,
47 pub name: ExpandedName,
48 pub atoms: Vec<PromptAtom>,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
52pub struct PromptPlan {
53 pub tasks: Vec<PromptTaskPlan>,
54 pub flat_prompt: Vec<PromptAtom>,
55}
56
57#[derive(Debug, thiserror::Error)]
58pub enum PromptPlanError {
59 #[error("empty entity task")]
60 EmptyEntityTask,
61
62 #[error("empty structure task: {name}")]
63 EmptyStructureTask { name: String },
64
65 #[error("empty classification task: {name}")]
66 EmptyClassificationTask { name: String },
67}
68
69fn entity_task_to_prompt(task: &EntityTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
70 if task.entities.is_empty() {
71 return Err(PromptPlanError::EmptyEntityTask);
72 }
73
74 let mut atoms = Vec::new();
75 atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
76 atoms.push(PromptAtom::Text("entities".to_string()));
77
78 for entity in &task.entities {
79 atoms.push(PromptAtom::Special(PromptSpecialToken::Entity));
80 atoms.push(PromptAtom::Text(entity.to_string()));
81 }
82
83 atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
84
85 Ok(PromptTaskPlan {
86 kind: PromptTaskKind::Entity,
87 name: ExpandedName::new("entities".to_string()),
88 atoms,
89 })
90}
91
92fn relation_task_to_prompt(task: &RelationTaskPlan) -> PromptTaskPlan {
93 let atoms = vec![
94 PromptAtom::Special(PromptSpecialToken::Prompt),
95 PromptAtom::Text(task.relation.to_string()),
96 PromptAtom::Special(PromptSpecialToken::Child),
97 PromptAtom::Text(task.head.to_string()),
98 PromptAtom::Special(PromptSpecialToken::Child),
99 PromptAtom::Text(task.tail.to_string()),
100 PromptAtom::Special(PromptSpecialToken::Separator),
101 ];
102
103 PromptTaskPlan {
104 kind: PromptTaskKind::Relation,
105 name: task.relation.clone(),
106 atoms,
107 }
108}
109
110fn structure_task_to_prompt(task: &StructureTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
111 if task.children.is_empty() {
112 return Err(PromptPlanError::EmptyStructureTask {
113 name: task.structure.to_string(),
114 });
115 }
116
117 let mut atoms = Vec::new();
118 atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
119 atoms.push(PromptAtom::Text(task.structure.to_string()));
120
121 for child in &task.children {
122 atoms.push(PromptAtom::Special(PromptSpecialToken::Child));
123 atoms.push(PromptAtom::Text(child.property.to_string()));
124 }
125
126 atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
127
128 Ok(PromptTaskPlan {
129 kind: PromptTaskKind::Structure,
130 name: task.structure.clone(),
131 atoms,
132 })
133}
134
135fn classification_task_to_prompt(
136 task: &ClassificationTaskPlan,
137) -> Result<PromptTaskPlan, PromptPlanError> {
138 if task.labels.is_empty() {
139 return Err(PromptPlanError::EmptyClassificationTask {
140 name: task.task.to_string(),
141 });
142 }
143
144 let mut atoms = Vec::new();
145 atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
146 atoms.push(PromptAtom::Text(task.task.to_string()));
147
148 for label in &task.labels {
149 atoms.push(PromptAtom::Special(PromptSpecialToken::Label));
150 atoms.push(PromptAtom::Text(label.to_string()));
151 }
152
153 atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
154
155 Ok(PromptTaskPlan {
156 kind: PromptTaskKind::Classification,
157 name: task.task.clone(),
158 atoms,
159 })
160}
161
162impl TryFrom<TaskPlan> for PromptPlan {
163 type Error = PromptPlanError;
164
165 fn try_from(value: TaskPlan) -> Result<Self, Self::Error> {
166 let mut tasks = Vec::new();
167
168 for task in &value.tasks {
169 let planned = match task {
170 PlannedTask::Entity(x) => entity_task_to_prompt(x)?,
171 PlannedTask::Relation(x) => relation_task_to_prompt(x),
172 PlannedTask::Structure(x) => structure_task_to_prompt(x)?,
173 PlannedTask::Classification(x) => classification_task_to_prompt(x)?,
174 };
175 tasks.push(planned);
176 }
177
178 let mut flat_prompt = Vec::new();
179 for task in &tasks {
180 flat_prompt.extend(task.atoms.clone());
181 }
182
183 Ok(Self { tasks, flat_prompt })
184 }
185}
186
187impl PromptPlan {
188 pub fn render_debug_string(&self) -> String {
189 self.flat_prompt
190 .iter()
191 .map(|atom| match atom {
192 PromptAtom::Special(tok) => tok.as_str().to_string(),
193 PromptAtom::Text(s) => s.clone(),
194 })
195 .collect::<Vec<_>>()
196 .join(" ")
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::expanded::ExpandedSchema;
204 use crate::lifted::LiftedSchema;
205 use crate::normalized::NormalizedSchema;
206 use crate::task_plan::TaskPlan;
207
208 #[test]
209 fn prompt_plan_renders_expected_debug_string() {
210 let s = r#"
211 {
212 "entities": ["gene", "disease", "sentiment", "positive", "negative"],
213 "relations": [
214 { "associated_with": { "head": "gene", "tail": "disease" } }
215 ],
216 "classifications": [
217 {
218 "task": "sentiment",
219 "labels": ["positive", "negative"]
220 }
221 ]
222 }
223 "#;
224
225 let s2 = NormalizedSchema::from_json_str(s).unwrap();
226 let s3 = ExpandedSchema::try_from(s2).unwrap();
227 let s4 = LiftedSchema::try_from(s3).unwrap();
228 let tp = TaskPlan::try_from(s4).unwrap();
229 let pp = PromptPlan::try_from(tp).unwrap();
230
231 let rendered = pp.render_debug_string();
232 assert!(rendered.contains("[P] entities [E] disease"));
233 assert!(rendered.contains("[P] associated_with [C] gene [C] disease [SEP]"));
234 assert!(rendered.contains("[P] sentiment [L] positive [L] negative [SEP]"));
235 }
236}