ie-schema 0.1.5

A flexible schema specification and parser for information extraction tasks.
Documentation
use crate::token_plan::{TokenPlan, TokenRole};
use burn::prelude::*;
use serde::Serialize;

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PromptTaskTensorSpan {
    pub start: usize,
    pub end: usize,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TensorPlanMetadata {
    pub prompt_len: usize,
    pub input_len: usize,
    pub total_len: usize,
    pub prompt_task_spans: Vec<PromptTaskTensorSpan>,
}

#[derive(Debug)]
pub struct TensorPlan<B: Backend> {
    pub input_ids: Tensor<B, 2, Int>,
    pub attention_mask: Tensor<B, 2, Int>,
    pub token_role_ids: Tensor<B, 2, Int>,

    /// Shape: [num_prompt_tasks, 2]
    /// Each row is [start, end)
    pub prompt_task_spans: Tensor<B, 2, Int>,

    pub metadata: TensorPlanMetadata,
}

#[derive(Debug, thiserror::Error)]
pub enum TensorPlanError {
    #[error("sequence too long for i64 conversion")]
    SequenceTooLong,

    #[error("invalid prompt task span: start={start}, end={end}, total={total}")]
    InvalidPromptTaskSpan {
        start: usize,
        end: usize,
        total: usize,
    },
}

fn role_to_id(role: TokenRole) -> i64 {
    match role {
        TokenRole::PromptSpecial => 1,
        TokenRole::PromptText => 2,
        TokenRole::InputText => 3,
    }
}

impl<B: Backend> TensorPlan<B> {
    pub fn from_token_plan(
        token_plan: &TokenPlan,
        device: &B::Device,
    ) -> Result<Self, TensorPlanError> {
        for task in &token_plan.prompt_tasks {
            if task.token_span.start > task.token_span.end
                || task.token_span.end > token_plan.prompt_len
            {
                return Err(TensorPlanError::InvalidPromptTaskSpan {
                    start: task.token_span.start,
                    end: task.token_span.end,
                    total: token_plan.prompt_len,
                });
            }
        }

        let input_ids_data: Vec<i64> = token_plan
            .combined_token_ids
            .iter()
            .map(|&x| x as i64)
            .collect();

        let attention_mask_data: Vec<i64> = vec![1; token_plan.combined_token_ids.len()];

        let token_role_ids_data: Vec<i64> = token_plan
            .combined_roles
            .iter()
            .copied()
            .map(role_to_id)
            .collect();

        let mut prompt_task_spans_flat = Vec::with_capacity(token_plan.prompt_tasks.len() * 2);
        let mut prompt_task_spans_meta = Vec::with_capacity(token_plan.prompt_tasks.len());

        for task in &token_plan.prompt_tasks {
            prompt_task_spans_flat.push(task.token_span.start as i64);
            prompt_task_spans_flat.push(task.token_span.end as i64);

            prompt_task_spans_meta.push(PromptTaskTensorSpan {
                start: task.token_span.start,
                end: task.token_span.end,
            });
        }

        let input_ids = Tensor::<B, 1, Int>::from_data(input_ids_data.as_slice(), device)
            .reshape([1, token_plan.total_len]);
        let attention_mask = Tensor::<B, 1, Int>::from_data(attention_mask_data.as_slice(), device)
            .reshape([1, token_plan.total_len]);
        let token_role_ids = Tensor::<B, 1, Int>::from_data(token_role_ids_data.as_slice(), device)
            .reshape([1, token_plan.total_len]);

        let prompt_task_spans =
            Tensor::<B, 1, Int>::from_data(prompt_task_spans_flat.as_slice(), device)
                .reshape([token_plan.prompt_tasks.len(), 2]);

        let metadata = TensorPlanMetadata {
            prompt_len: token_plan.prompt_len,
            input_len: token_plan.input_len,
            total_len: token_plan.total_len,
            prompt_task_spans: prompt_task_spans_meta,
        };

        Ok(Self {
            input_ids,
            attention_mask,
            token_role_ids,
            prompt_task_spans,
            metadata,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::expanded::ExpandedSchema;
    use crate::lifted::LiftedSchema;
    use crate::normalized::NormalizedSchema;
    use crate::prompt_plan::PromptPlan;
    use crate::task_plan::TaskPlan;
    use crate::token_plan::{TokenPlan, TokenizerLike};
    use burn::backend::NdArray;
    use std::collections::BTreeMap;

    struct MockTokenizer {
        vocab: BTreeMap<String, u32>,
    }

    impl MockTokenizer {
        fn new() -> Self {
            let mut vocab = BTreeMap::new();
            vocab.insert("[P]".to_string(), 1);
            vocab.insert("[E]".to_string(), 2);
            vocab.insert("[C]".to_string(), 3);
            vocab.insert("[L]".to_string(), 4);
            vocab.insert("[SEP]".to_string(), 5);
            Self { vocab }
        }
    }

    impl TokenizerLike for MockTokenizer {
        fn tokenize(&self, text: &str) -> Vec<u32> {
            text.split_whitespace()
                .map(|s| s.len() as u32 + 100)
                .collect()
        }

        fn token_id(&self, token: &str) -> Option<u32> {
            self.vocab.get(token).copied()
        }
    }

    #[test]
    fn tensor_plan_builds_burn_tensors() {
        type B = NdArray;

        let s = r#"
        {
            "entities": ["gene", "disease"],
            "relations": [
                { "associated_with": { "head": "gene", "tail": "disease" } }
            ]
        }
        "#;

        let s2 = NormalizedSchema::from_json_str(s).unwrap();
        let s3 = ExpandedSchema::try_from(s2).unwrap();
        let s4 = LiftedSchema::try_from(s3).unwrap();
        let tp = TaskPlan::try_from(s4).unwrap();
        let pp = PromptPlan::try_from(tp).unwrap();

        let tokenizer = MockTokenizer::new();
        let token_plan =
            TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tokenizer)
                .unwrap();

        let device = Default::default();
        let tensor_plan = TensorPlan::<B>::from_token_plan(&token_plan, &device).unwrap();

        assert_eq!(tensor_plan.metadata.total_len, token_plan.total_len);
        assert_eq!(tensor_plan.metadata.prompt_len, token_plan.prompt_len);
        assert_eq!(tensor_plan.metadata.input_len, token_plan.input_len);
        assert_eq!(
            tensor_plan.metadata.prompt_task_spans.len(),
            token_plan.prompt_tasks.len()
        );
    }
}