use crate::prompt_plan::{PromptAtom, PromptPlan, PromptTaskKind, PromptTaskPlan};
use serde::Serialize;
use std::ops::Range;
pub trait TokenizerLike {
fn tokenize(&self, text: &str) -> Vec<u32>;
fn token_id(&self, token: &str) -> Option<u32>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
pub enum TokenRole {
PromptSpecial,
PromptText,
InputText,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenSpan {
pub start: usize,
pub end: usize,
}
impl TokenSpan {
pub fn range(&self) -> Range<usize> {
self.start..self.end
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenizedPromptTask {
pub kind: PromptTaskKind,
pub name: String,
pub token_span: TokenSpan,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct TokenPlan {
pub prompt_token_ids: Vec<u32>,
pub prompt_roles: Vec<TokenRole>,
pub prompt_tasks: Vec<TokenizedPromptTask>,
pub input_token_ids: Vec<u32>,
pub input_roles: Vec<TokenRole>,
pub combined_token_ids: Vec<u32>,
pub combined_roles: Vec<TokenRole>,
pub prompt_len: usize,
pub input_len: usize,
pub total_len: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum TokenPlanError {
#[error("missing special token in tokenizer vocab: {token}")]
MissingSpecialToken { token: String },
#[error("prompt roles length mismatch")]
PromptRoleLengthMismatch,
#[error("input roles length mismatch")]
InputRoleLengthMismatch,
}
fn encode_atom<T: TokenizerLike>(
tokenizer: &T,
atom: &PromptAtom,
) -> Result<(Vec<u32>, Vec<TokenRole>), TokenPlanError> {
match atom {
PromptAtom::Special(tok) => {
let raw = tok.as_str();
let id =
tokenizer
.token_id(raw)
.ok_or_else(|| TokenPlanError::MissingSpecialToken {
token: raw.to_string(),
})?;
Ok((vec![id], vec![TokenRole::PromptSpecial]))
}
PromptAtom::Text(text) => {
let ids = tokenizer.tokenize(text);
let roles = vec![TokenRole::PromptText; ids.len()];
Ok((ids, roles))
}
}
}
fn encode_prompt_task<T: TokenizerLike>(
tokenizer: &T,
task: &PromptTaskPlan,
offset: usize,
) -> Result<(Vec<u32>, Vec<TokenRole>, TokenizedPromptTask), TokenPlanError> {
let mut ids = Vec::new();
let mut roles = Vec::new();
for atom in &task.atoms {
let (atom_ids, atom_roles) = encode_atom(tokenizer, atom)?;
ids.extend(atom_ids);
roles.extend(atom_roles);
}
let span = TokenSpan {
start: offset,
end: offset + ids.len(),
};
let tokenized_task = TokenizedPromptTask {
kind: task.kind.clone(),
name: task.name.to_string(),
token_span: span,
};
Ok((ids, roles, tokenized_task))
}
impl TokenPlan {
pub fn from_prompt_plan<T: TokenizerLike>(
prompt_plan: &PromptPlan,
input_text: &str,
tokenizer: &T,
) -> Result<Self, TokenPlanError> {
let mut prompt_token_ids = Vec::new();
let mut prompt_roles = Vec::new();
let mut prompt_tasks = Vec::new();
for task in &prompt_plan.tasks {
let offset = prompt_token_ids.len();
let (ids, roles, tokenized_task) = encode_prompt_task(tokenizer, task, offset)?;
prompt_token_ids.extend(ids);
prompt_roles.extend(roles);
prompt_tasks.push(tokenized_task);
}
if prompt_token_ids.len() != prompt_roles.len() {
return Err(TokenPlanError::PromptRoleLengthMismatch);
}
let input_token_ids = tokenizer.tokenize(input_text);
let input_roles = vec![TokenRole::InputText; input_token_ids.len()];
if input_token_ids.len() != input_roles.len() {
return Err(TokenPlanError::InputRoleLengthMismatch);
}
let mut combined_token_ids =
Vec::with_capacity(prompt_token_ids.len() + input_token_ids.len());
combined_token_ids.extend(prompt_token_ids.iter().copied());
combined_token_ids.extend(input_token_ids.iter().copied());
let mut combined_roles = Vec::with_capacity(prompt_roles.len() + input_roles.len());
combined_roles.extend(prompt_roles.iter().copied());
combined_roles.extend(input_roles.iter().copied());
Ok(Self {
prompt_len: prompt_token_ids.len(),
input_len: input_token_ids.len(),
total_len: combined_token_ids.len(),
prompt_token_ids,
prompt_roles,
prompt_tasks,
input_token_ids,
input_roles,
combined_token_ids,
combined_roles,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expanded::ExpandedSchema;
use crate::lifted::LiftedSchema;
use crate::normalized::NormalizedSchema;
use crate::prompt_plan::PromptPlan;
use crate::task_plan::TaskPlan;
use std::collections::BTreeMap;
struct MockTokenizer {
vocab: BTreeMap<String, u32>,
}
impl MockTokenizer {
fn new() -> Self {
let mut vocab = BTreeMap::new();
vocab.insert("[P]".to_string(), 1);
vocab.insert("[E]".to_string(), 2);
vocab.insert("[C]".to_string(), 3);
vocab.insert("[L]".to_string(), 4);
vocab.insert("[SEP]".to_string(), 5);
Self { vocab }
}
}
impl TokenizerLike for MockTokenizer {
fn tokenize(&self, text: &str) -> Vec<u32> {
text.split_whitespace()
.map(|s| s.len() as u32 + 100)
.collect()
}
fn token_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
}
#[test]
fn token_plan_builds_combined_sequence() {
let s = r#"
{
"entities": ["gene", "disease"],
"relations": [
{ "associated_with": { "head": "gene", "tail": "disease" } }
]
}
"#;
let s2 = NormalizedSchema::from_json_str(s).unwrap();
let s3 = ExpandedSchema::try_from(s2).unwrap();
let s4 = LiftedSchema::try_from(s3).unwrap();
let tp = TaskPlan::try_from(s4).unwrap();
let pp = PromptPlan::try_from(tp).unwrap();
let tok = MockTokenizer::new();
let plan =
TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tok).unwrap();
assert!(plan.prompt_len > 0);
assert!(plan.input_len > 0);
assert_eq!(plan.total_len, plan.prompt_len + plan.input_len);
assert_eq!(plan.combined_token_ids.len(), plan.combined_roles.len());
}
}