1use crate::prompt_plan::{PromptAtom, PromptPlan, PromptTaskKind, PromptTaskPlan};
2use serde::Serialize;
3use std::ops::Range;
4
5pub trait TokenizerLike {
6 fn tokenize(&self, text: &str) -> Vec<u32>;
7 fn token_id(&self, token: &str) -> Option<u32>;
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
11pub enum TokenRole {
12 PromptSpecial,
13 PromptText,
14 InputText,
15}
16
17#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
18pub struct TokenSpan {
19 pub start: usize,
20 pub end: usize,
21}
22
23impl TokenSpan {
24 pub fn range(&self) -> Range<usize> {
25 self.start..self.end
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
30pub struct TokenizedPromptTask {
31 pub kind: PromptTaskKind,
32 pub name: String,
33 pub token_span: TokenSpan,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
37pub struct TokenPlan {
38 pub prompt_token_ids: Vec<u32>,
39 pub prompt_roles: Vec<TokenRole>,
40 pub prompt_tasks: Vec<TokenizedPromptTask>,
41
42 pub input_token_ids: Vec<u32>,
43 pub input_roles: Vec<TokenRole>,
44
45 pub combined_token_ids: Vec<u32>,
46 pub combined_roles: Vec<TokenRole>,
47
48 pub prompt_len: usize,
49 pub input_len: usize,
50 pub total_len: usize,
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum TokenPlanError {
55 #[error("missing special token in tokenizer vocab: {token}")]
56 MissingSpecialToken { token: String },
57
58 #[error("prompt roles length mismatch")]
59 PromptRoleLengthMismatch,
60
61 #[error("input roles length mismatch")]
62 InputRoleLengthMismatch,
63}
64
65fn encode_atom<T: TokenizerLike>(
66 tokenizer: &T,
67 atom: &PromptAtom,
68) -> Result<(Vec<u32>, Vec<TokenRole>), TokenPlanError> {
69 match atom {
70 PromptAtom::Special(tok) => {
71 let raw = tok.as_str();
72 let id =
73 tokenizer
74 .token_id(raw)
75 .ok_or_else(|| TokenPlanError::MissingSpecialToken {
76 token: raw.to_string(),
77 })?;
78 Ok((vec![id], vec![TokenRole::PromptSpecial]))
79 }
80 PromptAtom::Text(text) => {
81 let ids = tokenizer.tokenize(text);
82 let roles = vec![TokenRole::PromptText; ids.len()];
83 Ok((ids, roles))
84 }
85 }
86}
87
88fn encode_prompt_task<T: TokenizerLike>(
89 tokenizer: &T,
90 task: &PromptTaskPlan,
91 offset: usize,
92) -> Result<(Vec<u32>, Vec<TokenRole>, TokenizedPromptTask), TokenPlanError> {
93 let mut ids = Vec::new();
94 let mut roles = Vec::new();
95
96 for atom in &task.atoms {
97 let (atom_ids, atom_roles) = encode_atom(tokenizer, atom)?;
98 ids.extend(atom_ids);
99 roles.extend(atom_roles);
100 }
101
102 let span = TokenSpan {
103 start: offset,
104 end: offset + ids.len(),
105 };
106
107 let tokenized_task = TokenizedPromptTask {
108 kind: task.kind.clone(),
109 name: task.name.to_string(),
110 token_span: span,
111 };
112
113 Ok((ids, roles, tokenized_task))
114}
115
116impl TokenPlan {
117 pub fn from_prompt_plan<T: TokenizerLike>(
118 prompt_plan: &PromptPlan,
119 input_text: &str,
120 tokenizer: &T,
121 ) -> Result<Self, TokenPlanError> {
122 let mut prompt_token_ids = Vec::new();
123 let mut prompt_roles = Vec::new();
124 let mut prompt_tasks = Vec::new();
125
126 for task in &prompt_plan.tasks {
127 let offset = prompt_token_ids.len();
128 let (ids, roles, tokenized_task) = encode_prompt_task(tokenizer, task, offset)?;
129 prompt_token_ids.extend(ids);
130 prompt_roles.extend(roles);
131 prompt_tasks.push(tokenized_task);
132 }
133
134 if prompt_token_ids.len() != prompt_roles.len() {
135 return Err(TokenPlanError::PromptRoleLengthMismatch);
136 }
137
138 let input_token_ids = tokenizer.tokenize(input_text);
139 let input_roles = vec![TokenRole::InputText; input_token_ids.len()];
140
141 if input_token_ids.len() != input_roles.len() {
142 return Err(TokenPlanError::InputRoleLengthMismatch);
143 }
144
145 let mut combined_token_ids =
146 Vec::with_capacity(prompt_token_ids.len() + input_token_ids.len());
147 combined_token_ids.extend(prompt_token_ids.iter().copied());
148 combined_token_ids.extend(input_token_ids.iter().copied());
149
150 let mut combined_roles = Vec::with_capacity(prompt_roles.len() + input_roles.len());
151 combined_roles.extend(prompt_roles.iter().copied());
152 combined_roles.extend(input_roles.iter().copied());
153
154 Ok(Self {
155 prompt_len: prompt_token_ids.len(),
156 input_len: input_token_ids.len(),
157 total_len: combined_token_ids.len(),
158 prompt_token_ids,
159 prompt_roles,
160 prompt_tasks,
161 input_token_ids,
162 input_roles,
163 combined_token_ids,
164 combined_roles,
165 })
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::expanded::ExpandedSchema;
173 use crate::lifted::LiftedSchema;
174 use crate::normalized::NormalizedSchema;
175 use crate::prompt_plan::PromptPlan;
176 use crate::task_plan::TaskPlan;
177 use std::collections::BTreeMap;
178
179 struct MockTokenizer {
180 vocab: BTreeMap<String, u32>,
181 }
182
183 impl MockTokenizer {
184 fn new() -> Self {
185 let mut vocab = BTreeMap::new();
186 vocab.insert("[P]".to_string(), 1);
187 vocab.insert("[E]".to_string(), 2);
188 vocab.insert("[C]".to_string(), 3);
189 vocab.insert("[L]".to_string(), 4);
190 vocab.insert("[SEP]".to_string(), 5);
191 Self { vocab }
192 }
193 }
194
195 impl TokenizerLike for MockTokenizer {
196 fn tokenize(&self, text: &str) -> Vec<u32> {
197 text.split_whitespace()
198 .map(|s| s.len() as u32 + 100)
199 .collect()
200 }
201
202 fn token_id(&self, token: &str) -> Option<u32> {
203 self.vocab.get(token).copied()
204 }
205 }
206
207 #[test]
208 fn token_plan_builds_combined_sequence() {
209 let s = r#"
210 {
211 "entities": ["gene", "disease"],
212 "relations": [
213 { "associated_with": { "head": "gene", "tail": "disease" } }
214 ]
215 }
216 "#;
217
218 let s2 = NormalizedSchema::from_json_str(s).unwrap();
219 let s3 = ExpandedSchema::try_from(s2).unwrap();
220 let s4 = LiftedSchema::try_from(s3).unwrap();
221 let tp = TaskPlan::try_from(s4).unwrap();
222 let pp = PromptPlan::try_from(tp).unwrap();
223
224 let tok = MockTokenizer::new();
225 let plan =
226 TokenPlan::from_prompt_plan(&pp, "gene is associated with disease", &tok).unwrap();
227
228 assert!(plan.prompt_len > 0);
229 assert!(plan.input_len > 0);
230 assert_eq!(plan.total_len, plan.prompt_len + plan.input_len);
231 assert_eq!(plan.combined_token_ids.len(), plan.combined_roles.len());
232 }
233}