Skip to main content

ie_schema/
token_plan.rs

1use crate::prompt_plan::{PromptAtom, PromptPlan, PromptTaskKind, PromptTaskPlan};
2use serde::Serialize;
3use std::ops::Range;
4
5pub trait TokenizerLike {
6    fn tokenize(&self, text: &str) -> Vec<u32>;
7    fn token_id(&self, token: &str) -> Option<u32>;
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
11pub enum TokenRole {
12    PromptSpecial,
13    PromptText,
14    InputText,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
18pub struct TokenSpan {
19    pub start: usize,
20    pub end: usize,
21}
22
23impl TokenSpan {
24    pub fn range(&self) -> Range<usize> {
25        self.start..self.end
26    }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
30pub struct TokenizedPromptTask {
31    pub kind: PromptTaskKind,
32    pub name: String,
33    pub token_span: TokenSpan,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub struct TokenPlan {
38    pub prompt_token_ids: Vec<u32>,
39    pub prompt_roles: Vec<TokenRole>,
40    pub prompt_tasks: Vec<TokenizedPromptTask>,
41
42    pub input_token_ids: Vec<u32>,
43    pub input_roles: Vec<TokenRole>,
44
45    pub combined_token_ids: Vec<u32>,
46    pub combined_roles: Vec<TokenRole>,
47
48    pub prompt_len: usize,
49    pub input_len: usize,
50    pub total_len: usize,
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum TokenPlanError {
55    #[error("missing special token in tokenizer vocab: {token}")]
56    MissingSpecialToken { token: String },
57
58    #[error("prompt roles length mismatch")]
59    PromptRoleLengthMismatch,
60
61    #[error("input roles length mismatch")]
62    InputRoleLengthMismatch,
63}
64
65fn encode_atom<T: TokenizerLike>(
66    tokenizer: &T,
67    atom: &PromptAtom,
68) -> Result<(Vec<u32>, Vec<TokenRole>), TokenPlanError> {
69    match atom {
70        PromptAtom::Special(tok) => {
71            let raw = tok.as_str();
72            let id =
73                tokenizer
74                    .token_id(raw)
75                    .ok_or_else(|| TokenPlanError::MissingSpecialToken {
76                        token: raw.to_string(),
77                    })?;
78            Ok((vec![id], vec![TokenRole::PromptSpecial]))
79        }
80        PromptAtom::Text(text) => {
81            let ids = tokenizer.tokenize(text);
82            let roles = vec![TokenRole::PromptText; ids.len()];
83            Ok((ids, roles))
84        }
85    }
86}
87
88fn encode_prompt_task<T: TokenizerLike>(
89    tokenizer: &T,
90    task: &PromptTaskPlan,
91    offset: usize,
92) -> Result<(Vec<u32>, Vec<TokenRole>, TokenizedPromptTask), TokenPlanError> {
93    let mut ids = Vec::new();
94    let mut roles = Vec::new();
95
96    for atom in &task.atoms {
97        let (atom_ids, atom_roles) = encode_atom(tokenizer, atom)?;
98        ids.extend(atom_ids);
99        roles.extend(atom_roles);
100    }
101
102    let span = TokenSpan {
103        start: offset,
104        end: offset + ids.len(),
105    };
106
107    let tokenized_task = TokenizedPromptTask {
108        kind: task.kind.clone(),
109        name: task.name.to_string(),
110        token_span: span,
111    };
112
113    Ok((ids, roles, tokenized_task))
114}
115
116impl TokenPlan {
117    pub fn from_prompt_plan<T: TokenizerLike>(
118        prompt_plan: &PromptPlan,
119        input_text: &str,
120        tokenizer: &T,
121    ) -> Result<Self, TokenPlanError> {
122        let mut prompt_token_ids = Vec::new();
123        let mut prompt_roles = Vec::new();
124        let mut prompt_tasks = Vec::new();
125
126        for task in &prompt_plan.tasks {
127            let offset = prompt_token_ids.len();
128            let (ids, roles, tokenized_task) = encode_prompt_task(tokenizer, task, offset)?;
129            prompt_token_ids.extend(ids);
130            prompt_roles.extend(roles);
131            prompt_tasks.push(tokenized_task);
132        }
133
134        if prompt_token_ids.len() != prompt_roles.len() {
135            return Err(TokenPlanError::PromptRoleLengthMismatch);
136        }
137
138        let input_token_ids = tokenizer.tokenize(input_text);
139        let input_roles = vec![TokenRole::InputText; input_token_ids.len()];
140
141        if input_token_ids.len() != input_roles.len() {
142            return Err(TokenPlanError::InputRoleLengthMismatch);
143        }
144
145        let mut combined_token_ids =
146            Vec::with_capacity(prompt_token_ids.len() + input_token_ids.len());
147        combined_token_ids.extend(prompt_token_ids.iter().copied());
148        combined_token_ids.extend(input_token_ids.iter().copied());
149
150        let mut combined_roles = Vec::with_capacity(prompt_roles.len() + input_roles.len());
151        combined_roles.extend(prompt_roles.iter().copied());
152        combined_roles.extend(input_roles.iter().copied());
153
154        Ok(Self {
155            prompt_len: prompt_token_ids.len(),
156            input_len: input_token_ids.len(),
157            total_len: combined_token_ids.len(),
158            prompt_token_ids,
159            prompt_roles,
160            prompt_tasks,
161            input_token_ids,
162            input_roles,
163            combined_token_ids,
164            combined_roles,
165        })
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::expanded::ExpandedSchema;
173    use crate::lifted::LiftedSchema;
174    use crate::normalized::NormalizedSchema;
175    use crate::prompt_plan::PromptPlan;
176    use crate::task_plan::TaskPlan;
177    use std::collections::BTreeMap;
178
179    struct MockTokenizer {
180        vocab: BTreeMap<String, u32>,
181    }
182
183    impl MockTokenizer {
184        fn new() -> Self {
185            let mut vocab = BTreeMap::new();
186            vocab.insert("[P]".to_string(), 1);
187            vocab.insert("[E]".to_string(), 2);
188            vocab.insert("[C]".to_string(), 3);
189            vocab.insert("[L]".to_string(), 4);
190            vocab.insert("[SEP]".to_string(), 5);
191            Self { vocab }
192        }
193    }
194
195    impl TokenizerLike for MockTokenizer {
196        fn tokenize(&self, text: &str) -> Vec<u32> {
197            text.split_whitespace()
198                .map(|s| s.len() as u32 + 100)
199                .collect()
200        }
201
202        fn token_id(&self, token: &str) -> Option<u32> {
203            self.vocab.get(token).copied()
204        }
205    }
206
207    #[test]
208    fn token_plan_builds_combined_sequence() {
209        let s = r#"
210        {
211            "entities": ["gene", "disease"],
212            "relations": [
213                { "associated_with": { "head": "gene", "tail": "disease" } }
214            ]
215        }
216        "#;
217
218        let s2 = NormalizedSchema::from_json_str(s).unwrap();
219        let s3 = ExpandedSchema::try_from(s2).unwrap();
220        let s4 = LiftedSchema::try_from(s3).unwrap();
221        let tp = TaskPlan::try_from(s4).unwrap();
222        let pp = PromptPlan::try_from(tp).unwrap();
223
224        let tok = MockTokenizer::new();
225        let plan =
226            TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tok).unwrap();
227
228        assert!(plan.prompt_len > 0);
229        assert!(plan.input_len > 0);
230        assert_eq!(plan.total_len, plan.prompt_len + plan.input_len);
231        assert_eq!(plan.combined_token_ids.len(), plan.combined_roles.len());
232    }
233}