use crate::token_plan::{TokenPlan, TokenRole};
use burn::prelude::*;
use serde::Serialize;
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PromptTaskTensorSpan {
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TensorPlanMetadata {
pub prompt_len: usize,
pub input_len: usize,
pub total_len: usize,
pub prompt_task_spans: Vec<PromptTaskTensorSpan>,
}
#[derive(Debug)]
pub struct TensorPlan<B: Backend> {
pub input_ids: Tensor<B, 2, Int>,
pub attention_mask: Tensor<B, 2, Int>,
pub token_role_ids: Tensor<B, 2, Int>,
pub prompt_task_spans: Tensor<B, 2, Int>,
pub metadata: TensorPlanMetadata,
}
#[derive(Debug, thiserror::Error)]
pub enum TensorPlanError {
#[error("sequence too long for i64 conversion")]
SequenceTooLong,
#[error("invalid prompt task span: start={start}, end={end}, total={total}")]
InvalidPromptTaskSpan {
start: usize,
end: usize,
total: usize,
},
}
fn role_to_id(role: TokenRole) -> i64 {
match role {
TokenRole::PromptSpecial => 1,
TokenRole::PromptText => 2,
TokenRole::InputText => 3,
}
}
impl<B: Backend> TensorPlan<B> {
pub fn from_token_plan(
token_plan: &TokenPlan,
device: &B::Device,
) -> Result<Self, TensorPlanError> {
for task in &token_plan.prompt_tasks {
if task.token_span.start > task.token_span.end
|| task.token_span.end > token_plan.prompt_len
{
return Err(TensorPlanError::InvalidPromptTaskSpan {
start: task.token_span.start,
end: task.token_span.end,
total: token_plan.prompt_len,
});
}
}
let input_ids_data: Vec<i64> = token_plan
.combined_token_ids
.iter()
.map(|&x| x as i64)
.collect();
let attention_mask_data: Vec<i64> = vec![1; token_plan.combined_token_ids.len()];
let token_role_ids_data: Vec<i64> = token_plan
.combined_roles
.iter()
.copied()
.map(role_to_id)
.collect();
let mut prompt_task_spans_flat = Vec::with_capacity(token_plan.prompt_tasks.len() * 2);
let mut prompt_task_spans_meta = Vec::with_capacity(token_plan.prompt_tasks.len());
for task in &token_plan.prompt_tasks {
prompt_task_spans_flat.push(task.token_span.start as i64);
prompt_task_spans_flat.push(task.token_span.end as i64);
prompt_task_spans_meta.push(PromptTaskTensorSpan {
start: task.token_span.start,
end: task.token_span.end,
});
}
let input_ids = Tensor::<B, 1, Int>::from_data(input_ids_data.as_slice(), device)
.reshape([1, token_plan.total_len]);
let attention_mask = Tensor::<B, 1, Int>::from_data(attention_mask_data.as_slice(), device)
.reshape([1, token_plan.total_len]);
let token_role_ids = Tensor::<B, 1, Int>::from_data(token_role_ids_data.as_slice(), device)
.reshape([1, token_plan.total_len]);
let prompt_task_spans =
Tensor::<B, 1, Int>::from_data(prompt_task_spans_flat.as_slice(), device)
.reshape([token_plan.prompt_tasks.len(), 2]);
let metadata = TensorPlanMetadata {
prompt_len: token_plan.prompt_len,
input_len: token_plan.input_len,
total_len: token_plan.total_len,
prompt_task_spans: prompt_task_spans_meta,
};
Ok(Self {
input_ids,
attention_mask,
token_role_ids,
prompt_task_spans,
metadata,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expanded::ExpandedSchema;
use crate::lifted::LiftedSchema;
use crate::normalized::NormalizedSchema;
use crate::prompt_plan::PromptPlan;
use crate::task_plan::TaskPlan;
use crate::token_plan::{TokenPlan, TokenizerLike};
use burn::backend::NdArray;
use std::collections::BTreeMap;
struct MockTokenizer {
vocab: BTreeMap<String, u32>,
}
impl MockTokenizer {
fn new() -> Self {
let mut vocab = BTreeMap::new();
vocab.insert("[P]".to_string(), 1);
vocab.insert("[E]".to_string(), 2);
vocab.insert("[C]".to_string(), 3);
vocab.insert("[L]".to_string(), 4);
vocab.insert("[SEP]".to_string(), 5);
Self { vocab }
}
}
impl TokenizerLike for MockTokenizer {
fn tokenize(&self, text: &str) -> Vec<u32> {
text.split_whitespace()
.map(|s| s.len() as u32 + 100)
.collect()
}
fn token_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
}
#[test]
fn tensor_plan_builds_burn_tensors() {
type B = NdArray;
let s = r#"
{
"entities": ["gene", "disease"],
"relations": [
{ "associated_with": { "head": "gene", "tail": "disease" } }
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
let tp = TaskPlan::try_from(s4).unwrap();
let pp = PromptPlan::try_from(tp).unwrap();
let tokenizer = MockTokenizer::new();
let token_plan =
TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tokenizer)
.unwrap();
let device = Default::default();
let tensor_plan = TensorPlan::<B>::from_token_plan(&token_plan, &device).unwrap();
assert_eq!(tensor_plan.metadata.total_len, token_plan.total_len);
assert_eq!(tensor_plan.metadata.prompt_len, token_plan.prompt_len);
assert_eq!(tensor_plan.metadata.input_len, token_plan.input_len);
assert_eq!(
tensor_plan.metadata.prompt_task_spans.len(),
token_plan.prompt_tasks.len()
);
}
}