ie-schema 0.1.5

A flexible schema specification and parser for information extraction tasks.
Documentation
use crate::prompt_plan::{PromptAtom, PromptPlan, PromptTaskKind, PromptTaskPlan};
use serde::Serialize;
use std::ops::Range;

pub trait TokenizerLike {
    fn tokenize(&self, text: &str) -> Vec<u32>;
    fn token_id(&self, token: &str) -> Option<u32>;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum TokenRole {
    PromptSpecial,
    PromptText,
    InputText,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenSpan {
    pub start: usize,
    pub end: usize,
}

impl TokenSpan {
    pub fn range(&self) -> Range<usize> {
        self.start..self.end
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenizedPromptTask {
    pub kind: PromptTaskKind,
    pub name: String,
    pub token_span: TokenSpan,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenPlan {
    pub prompt_token_ids: Vec<u32>,
    pub prompt_roles: Vec<TokenRole>,
    pub prompt_tasks: Vec<TokenizedPromptTask>,

    pub input_token_ids: Vec<u32>,
    pub input_roles: Vec<TokenRole>,

    pub combined_token_ids: Vec<u32>,
    pub combined_roles: Vec<TokenRole>,

    pub prompt_len: usize,
    pub input_len: usize,
    pub total_len: usize,
}

#[derive(Debug, thiserror::Error)]
pub enum TokenPlanError {
    #[error("missing special token in tokenizer vocab: {token}")]
    MissingSpecialToken { token: String },

    #[error("prompt roles length mismatch")]
    PromptRoleLengthMismatch,

    #[error("input roles length mismatch")]
    InputRoleLengthMismatch,
}

fn encode_atom<T: TokenizerLike>(
    tokenizer: &T,
    atom: &PromptAtom,
) -> Result<(Vec<u32>, Vec<TokenRole>), TokenPlanError> {
    match atom {
        PromptAtom::Special(tok) => {
            let raw = tok.as_str();
            let id =
                tokenizer
                    .token_id(raw)
                    .ok_or_else(|| TokenPlanError::MissingSpecialToken {
                        token: raw.to_string(),
                    })?;
            Ok((vec![id], vec![TokenRole::PromptSpecial]))
        }
        PromptAtom::Text(text) => {
            let ids = tokenizer.tokenize(text);
            let roles = vec![TokenRole::PromptText; ids.len()];
            Ok((ids, roles))
        }
    }
}

fn encode_prompt_task<T: TokenizerLike>(
    tokenizer: &T,
    task: &PromptTaskPlan,
    offset: usize,
) -> Result<(Vec<u32>, Vec<TokenRole>, TokenizedPromptTask), TokenPlanError> {
    let mut ids = Vec::new();
    let mut roles = Vec::new();

    for atom in &task.atoms {
        let (atom_ids, atom_roles) = encode_atom(tokenizer, atom)?;
        ids.extend(atom_ids);
        roles.extend(atom_roles);
    }

    let span = TokenSpan {
        start: offset,
        end: offset + ids.len(),
    };

    let tokenized_task = TokenizedPromptTask {
        kind: task.kind.clone(),
        name: task.name.to_string(),
        token_span: span,
    };

    Ok((ids, roles, tokenized_task))
}

impl TokenPlan {
    pub fn from_prompt_plan<T: TokenizerLike>(
        prompt_plan: &PromptPlan,
        input_text: &str,
        tokenizer: &T,
    ) -> Result<Self, TokenPlanError> {
        let mut prompt_token_ids = Vec::new();
        let mut prompt_roles = Vec::new();
        let mut prompt_tasks = Vec::new();

        for task in &prompt_plan.tasks {
            let offset = prompt_token_ids.len();
            let (ids, roles, tokenized_task) = encode_prompt_task(tokenizer, task, offset)?;
            prompt_token_ids.extend(ids);
            prompt_roles.extend(roles);
            prompt_tasks.push(tokenized_task);
        }

        if prompt_token_ids.len() != prompt_roles.len() {
            return Err(TokenPlanError::PromptRoleLengthMismatch);
        }

        let input_token_ids = tokenizer.tokenize(input_text);
        let input_roles = vec![TokenRole::InputText; input_token_ids.len()];

        if input_token_ids.len() != input_roles.len() {
            return Err(TokenPlanError::InputRoleLengthMismatch);
        }

        let mut combined_token_ids =
            Vec::with_capacity(prompt_token_ids.len() + input_token_ids.len());
        combined_token_ids.extend(prompt_token_ids.iter().copied());
        combined_token_ids.extend(input_token_ids.iter().copied());

        let mut combined_roles = Vec::with_capacity(prompt_roles.len() + input_roles.len());
        combined_roles.extend(prompt_roles.iter().copied());
        combined_roles.extend(input_roles.iter().copied());

        Ok(Self {
            prompt_len: prompt_token_ids.len(),
            input_len: input_token_ids.len(),
            total_len: combined_token_ids.len(),
            prompt_token_ids,
            prompt_roles,
            prompt_tasks,
            input_token_ids,
            input_roles,
            combined_token_ids,
            combined_roles,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::expanded::ExpandedSchema;
    use crate::lifted::LiftedSchema;
    use crate::normalized::NormalizedSchema;
    use crate::prompt_plan::PromptPlan;
    use crate::task_plan::TaskPlan;
    use std::collections::BTreeMap;

    struct MockTokenizer {
        vocab: BTreeMap<String, u32>,
    }

    impl MockTokenizer {
        fn new() -> Self {
            let mut vocab = BTreeMap::new();
            vocab.insert("[P]".to_string(), 1);
            vocab.insert("[E]".to_string(), 2);
            vocab.insert("[C]".to_string(), 3);
            vocab.insert("[L]".to_string(), 4);
            vocab.insert("[SEP]".to_string(), 5);
            Self { vocab }
        }
    }

    impl TokenizerLike for MockTokenizer {
        fn tokenize(&self, text: &str) -> Vec<u32> {
            text.split_whitespace()
                .map(|s| s.len() as u32 + 100)
                .collect()
        }

        fn token_id(&self, token: &str) -> Option<u32> {
            self.vocab.get(token).copied()
        }
    }

    #[test]
    fn token_plan_builds_combined_sequence() {
        let s = r#"
        {
            "entities": ["gene", "disease"],
            "relations": [
                { "associated_with": { "head": "gene", "tail": "disease" } }
            ]
        }
        "#;

        let s2 = NormalizedSchema::from_json_str(s).unwrap();
        let s3 = ExpandedSchema::try_from(s2).unwrap();
        let s4 = LiftedSchema::try_from(s3).unwrap();
        let tp = TaskPlan::try_from(s4).unwrap();
        let pp = PromptPlan::try_from(tp).unwrap();

        let tok = MockTokenizer::new();
        let plan =
            TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tok).unwrap();

        assert!(plan.prompt_len > 0);
        assert!(plan.input_len > 0);
        assert_eq!(plan.total_len, plan.prompt_len + plan.input_len);
        assert_eq!(plan.combined_token_ids.len(), plan.combined_roles.len());
    }
}