ie-schema 0.1.5

A flexible schema specification and parser for information extraction tasks.
Documentation
use crate::normalized::ExpandedName;
use crate::task_plan::{
    ClassificationTaskPlan, EntityTaskPlan, PlannedTask, RelationTaskPlan, StructureTaskPlan,
    TaskPlan,
};
use serde::Serialize;
use std::convert::TryFrom;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum PromptSpecialToken {
    Prompt,    // [P]
    Entity,    // [E]
    Child,     // [C]
    Label,     // [L]
    Separator, // [SEP]
}

impl PromptSpecialToken {
    pub fn as_str(&self) -> &'static str {
        match self {
            Self::Prompt => "[P]",
            Self::Entity => "[E]",
            Self::Child => "[C]",
            Self::Label => "[L]",
            Self::Separator => "[SEP]",
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum PromptAtom {
    Special(PromptSpecialToken),
    Text(String),
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub enum PromptTaskKind {
    Entity,
    Relation,
    Structure,
    Classification,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PromptTaskPlan {
    pub kind: PromptTaskKind,
    pub name: ExpandedName,
    pub atoms: Vec<PromptAtom>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
pub struct PromptPlan {
    pub tasks: Vec<PromptTaskPlan>,
    pub flat_prompt: Vec<PromptAtom>,
}

#[derive(Debug, thiserror::Error)]
pub enum PromptPlanError {
    #[error("empty entity task")]
    EmptyEntityTask,

    #[error("empty structure task: {name}")]
    EmptyStructureTask { name: String },

    #[error("empty classification task: {name}")]
    EmptyClassificationTask { name: String },
}

fn entity_task_to_prompt(task: &EntityTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
    if task.entities.is_empty() {
        return Err(PromptPlanError::EmptyEntityTask);
    }

    let mut atoms = Vec::new();
    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
    atoms.push(PromptAtom::Text("entities".to_string()));

    for entity in &task.entities {
        atoms.push(PromptAtom::Special(PromptSpecialToken::Entity));
        atoms.push(PromptAtom::Text(entity.to_string()));
    }

    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));

    Ok(PromptTaskPlan {
        kind: PromptTaskKind::Entity,
        name: ExpandedName::new("entities".to_string()),
        atoms,
    })
}

fn relation_task_to_prompt(task: &RelationTaskPlan) -> PromptTaskPlan {
    let atoms = vec![
        PromptAtom::Special(PromptSpecialToken::Prompt),
        PromptAtom::Text(task.relation.to_string()),
        PromptAtom::Special(PromptSpecialToken::Child),
        PromptAtom::Text(task.head.to_string()),
        PromptAtom::Special(PromptSpecialToken::Child),
        PromptAtom::Text(task.tail.to_string()),
        PromptAtom::Special(PromptSpecialToken::Separator),
    ];

    PromptTaskPlan {
        kind: PromptTaskKind::Relation,
        name: task.relation.clone(),
        atoms,
    }
}

fn structure_task_to_prompt(task: &StructureTaskPlan) -> Result<PromptTaskPlan, PromptPlanError> {
    if task.children.is_empty() {
        return Err(PromptPlanError::EmptyStructureTask {
            name: task.structure.to_string(),
        });
    }

    let mut atoms = Vec::new();
    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
    atoms.push(PromptAtom::Text(task.structure.to_string()));

    for child in &task.children {
        atoms.push(PromptAtom::Special(PromptSpecialToken::Child));
        atoms.push(PromptAtom::Text(child.property.to_string()));
    }

    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));

    Ok(PromptTaskPlan {
        kind: PromptTaskKind::Structure,
        name: task.structure.clone(),
        atoms,
    })
}

fn classification_task_to_prompt(
    task: &ClassificationTaskPlan,
) -> Result<PromptTaskPlan, PromptPlanError> {
    if task.labels.is_empty() {
        return Err(PromptPlanError::EmptyClassificationTask {
            name: task.task.to_string(),
        });
    }

    let mut atoms = Vec::new();
    atoms.push(PromptAtom::Special(PromptSpecialToken::Prompt));
    atoms.push(PromptAtom::Text(task.task.to_string()));

    for label in &task.labels {
        atoms.push(PromptAtom::Special(PromptSpecialToken::Label));
        atoms.push(PromptAtom::Text(label.to_string()));
    }

    atoms.push(PromptAtom::Special(PromptSpecialToken::Separator));

    Ok(PromptTaskPlan {
        kind: PromptTaskKind::Classification,
        name: task.task.clone(),
        atoms,
    })
}

impl TryFrom<TaskPlan> for PromptPlan {
    type Error = PromptPlanError;

    fn try_from(value: TaskPlan) -> Result<Self, Self::Error> {
        let mut tasks = Vec::new();

        for task in &value.tasks {
            let planned = match task {
                PlannedTask::Entity(x) => entity_task_to_prompt(x)?,
                PlannedTask::Relation(x) => relation_task_to_prompt(x),
                PlannedTask::Structure(x) => structure_task_to_prompt(x)?,
                PlannedTask::Classification(x) => classification_task_to_prompt(x)?,
            };
            tasks.push(planned);
        }

        let mut flat_prompt = Vec::new();
        for task in &tasks {
            flat_prompt.extend(task.atoms.clone());
        }

        Ok(Self { tasks, flat_prompt })
    }
}

impl PromptPlan {
    pub fn render_debug_string(&self) -> String {
        self.flat_prompt
            .iter()
            .map(|atom| match atom {
                PromptAtom::Special(tok) => tok.as_str().to_string(),
                PromptAtom::Text(s) => s.clone(),
            })
            .collect::<Vec<_>>()
            .join(" ")
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::expanded::ExpandedSchema;
    use crate::lifted::LiftedSchema;
    use crate::normalized::NormalizedSchema;
    use crate::task_plan::TaskPlan;

    #[test]
    fn prompt_plan_renders_expected_debug_string() {
        let s = r#"
        {
            "entities": ["gene", "disease", "sentiment", "positive", "negative"],
            "relations": [
                { "associated_with": { "head": "gene", "tail": "disease" } }
            ],
            "classifications": [
                {
                    "task": "sentiment",
                    "labels": ["positive", "negative"]
                }
            ]
        }
        "#;

        let s2 = NormalizedSchema::from_json_str(s).unwrap();
        let s3 = ExpandedSchema::try_from(s2).unwrap();
        let s4 = LiftedSchema::try_from(s3).unwrap();
        let tp = TaskPlan::try_from(s4).unwrap();
        let pp = PromptPlan::try_from(tp).unwrap();

        let rendered = pp.render_debug_string();
        assert!(rendered.contains("[P] entities [E] disease"));
        assert!(rendered.contains("[P] associated_with [C] gene [C] disease [SEP]"));
        assert!(rendered.contains("[P] sentiment [L] positive [L] negative [SEP]"));
    }
}