use crate::normalized::ExpandedName;
use crate::task_plan::{
ClassificationTaskPlan, EntityTaskPlan, PlannedTask, RelationTaskPlan, StructureTaskPlan,
TaskPlan,
};
use serde::Serialize;
use std::convert::TryFrom;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum PromptSpecialToken {
Prompt, Entity, Child, Label, Separator, }
impl PromptSpecialToken {
pub fn as_str(&self) -> &'static str {
match self {
Self::Prompt => "[P]",
Self::Entity => "[E]",
Self::Child => "[C]",
Self::Label => "[L]",
Self::Separator => "[SEP]",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum PromptAtom {
Special(PromptSpecialToken),
Text(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum PromptTaskKind {
Entity,
Relation,
Structure,
Classification,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PromptTaskPlan {
pub kind: PromptTaskKind,
pub name: ExpandedName,
pub atoms: Vec<PromptAtom>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
pub struct PromptPlan {
pub tasks: Vec<PromptTaskPlan>,
pub flat_prompt: Vec<PromptAtom>,
}
#[derive(Debug, thiserror::Error)]
pub enum PromptPlanError {
#[error("empty entity task")]
EmptyEntityTask,
#[error("empty structure task: {name}")]
EmptyStructureTask { name: String },
#[error("empty classification task: {name}")]
EmptyClassificationTask { name: String },
}
fn entity_task_to_prompt(task: &EntityTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
if task.entities.is_empty() {
return Err(PromptPlanError::EmptyEntityTask);
}
let mut atoms = Vec::new();
atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
atoms.push(PromptAtom::Text("entities".to_string()));
for entity in &task.entities {
atoms.push(PromptAtom::Special(PromptSpecialToken::Entity));
atoms.push(PromptAtom::Text(entity.to_string()));
}
atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
Ok(PromptTaskPlan {
kind: PromptTaskKind::Entity,
name: ExpandedName::new("entities".to_string()),
atoms,
})
}
fn relation_task_to_prompt(task: &RelationTaskPlan) -> PromptTaskPlan {
let atoms = vec![
PromptAtom::Special(PromptSpecialToken::Prompt),
PromptAtom::Text(task.relation.to_string()),
PromptAtom::Special(PromptSpecialToken::Child),
PromptAtom::Text(task.head.to_string()),
PromptAtom::Special(PromptSpecialToken::Child),
PromptAtom::Text(task.tail.to_string()),
PromptAtom::Special(PromptSpecialToken::Separator),
];
PromptTaskPlan {
kind: PromptTaskKind::Relation,
name: task.relation.clone(),
atoms,
}
}
fn structure_task_to_prompt(task: &StructureTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
if task.children.is_empty() {
return Err(PromptPlanError::EmptyStructureTask {
name: task.structure.to_string(),
});
}
let mut atoms = Vec::new();
atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
atoms.push(PromptAtom::Text(task.structure.to_string()));
for child in &task.children {
atoms.push(PromptAtom::Special(PromptSpecialToken::Child));
atoms.push(PromptAtom::Text(child.property.to_string()));
}
atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
Ok(PromptTaskPlan {
kind: PromptTaskKind::Structure,
name: task.structure.clone(),
atoms,
})
}
fn classification_task_to_prompt(
task: &ClassificationTaskPlan,
) -> Result<PromptTaskPlan, PromptPlanError> {
if task.labels.is_empty() {
return Err(PromptPlanError::EmptyClassificationTask {
name: task.task.to_string(),
});
}
let mut atoms = Vec::new();
atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
atoms.push(PromptAtom::Text(task.task.to_string()));
for label in &task.labels {
atoms.push(PromptAtom::Special(PromptSpecialToken::Label));
atoms.push(PromptAtom::Text(label.to_string()));
}
atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));
Ok(PromptTaskPlan {
kind: PromptTaskKind::Classification,
name: task.task.clone(),
atoms,
})
}
impl TryFrom<TaskPlan> for PromptPlan {
type Error = PromptPlanError;
fn try_from(value: TaskPlan) -> Result<Self, Self::Error> {
let mut tasks = Vec::new();
for task in &value.tasks {
let planned = match task {
PlannedTask::Entity(x) => entity_task_to_prompt(x)?,
PlannedTask::Relation(x) => relation_task_to_prompt(x),
PlannedTask::Structure(x) => structure_task_to_prompt(x)?,
PlannedTask::Classification(x) => classification_task_to_prompt(x)?,
};
tasks.push(planned);
}
let mut flat_prompt = Vec::new();
for task in &tasks {
flat_prompt.extend(task.atoms.clone());
}
Ok(Self { tasks, flat_prompt })
}
}
impl PromptPlan {
pub fn render_debug_string(&self) -> String {
self.flat_prompt
.iter()
.map(|atom| match atom {
PromptAtom::Special(tok) => tok.as_str().to_string(),
PromptAtom::Text(s) => s.clone(),
})
.collect::<Vec<_>>()
.join(" ")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expanded::ExpandedSchema;
use crate::lifted::LiftedSchema;
use crate::normalized::NormalizedSchema;
use crate::task_plan::TaskPlan;
#[test]
fn prompt_plan_renders_expected_debug_string() {
let s = r#"
{
"entities": ["gene", "disease", "sentiment", "positive", "negative"],
"relations": [
{ "associated_with": { "head": "gene", "tail": "disease" } }
],
"classifications": [
{
"task": "sentiment",
"labels": ["positive", "negative"]
}
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
let tp = TaskPlan::try_from(s4).unwrap();
let pp = PromptPlan::try_from(tp).unwrap();
let rendered = pp.render_debug_string();
assert!(rendered.contains("[P] entities [E] disease"));
assert!(rendered.contains("[P] associated_with [C] gene [C] disease [SEP]"));
assert!(rendered.contains("[P] sentiment [L] positive [L] negative [SEP]"));
}
}