1#[derive(Debug, Clone)]
19pub struct ScoredTile {
20 pub id: String,
21 pub question: String,
22 pub answer: String,
23 pub domain: String,
24 pub score: f64,
25 pub priority: Priority,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum Priority { P0, P1, P2 }
30impl Default for Priority { fn default() -> Self { Priority::P2 } }
31
32#[derive(Debug, Clone)]
34pub struct PromptConfig {
35 pub max_tokens: usize,
37 pub inject_deadband: bool,
39 pub format: TileFormat,
41 pub system_prefix: String,
43 pub include_domain: bool,
45}
46
47impl Default for PromptConfig {
48 fn default() -> Self {
49 PromptConfig {
50 max_tokens: 4096,
51 inject_deadband: true,
52 format: TileFormat::Structured,
53 system_prefix: String::new(),
54 include_domain: true,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
60pub enum TileFormat {
61 Structured,
63 Markdown,
65 Json,
67 Compact,
69}
70
71pub struct PromptAssembler;
73
74impl PromptAssembler {
75 pub fn build(tiles: &[ScoredTile], query: &str, config: &PromptConfig) -> (String, BuildStats) {
78 let mut stats = BuildStats::default();
79 let mut parts = Vec::new();
80
81 if !config.system_prefix.is_empty() {
83 parts.push(config.system_prefix.clone());
84 stats.system_tokens = estimate_tokens(&config.system_prefix);
85 }
86
87 let mut sorted: Vec<&ScoredTile> = tiles.iter().collect();
89 sorted.sort_by(|a, b| {
90 let pa = match a.priority { Priority::P0 => 0, Priority::P1 => 1, Priority::P2 => 2 };
91 let pb = match b.priority { Priority::P0 => 0, Priority::P1 => 1, Priority::P2 => 2 };
92 pa.cmp(&pb).then(b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal))
93 });
94
95 let budget = config.max_tokens.saturating_sub(stats.system_tokens);
96 let mut used = 0usize;
97
98 for tile in &sorted {
99 let formatted = Self::format_tile(tile, config);
100 let tile_tokens = estimate_tokens(&formatted);
101 if used + tile_tokens > budget {
102 stats.excluded += 1;
103 continue;
104 }
105 parts.push(formatted);
106 used += tile_tokens;
107 stats.tiles_included += 1;
108 stats.tile_tokens += tile_tokens;
109 match tile.priority {
110 Priority::P0 => stats.p0_count += 1,
111 Priority::P1 => stats.p1_count += 1,
112 Priority::P2 => stats.p2_count += 1,
113 }
114 }
115
116 if config.inject_deadband {
118 let deadband = Self::deadband_section(tiles, &stats);
119 if !deadband.is_empty() {
120 let db_tokens = estimate_tokens(&deadband);
121 parts.push(deadband);
122 stats.deadband_tokens = db_tokens;
123 }
124 }
125
126 let query_line = format!("\n\nQuery: {}", query);
128 stats.query_tokens = estimate_tokens(&query_line);
129 parts.push(query_line);
130
131 stats.total_tokens = stats.system_tokens + stats.tile_tokens + stats.deadband_tokens + stats.query_tokens;
132
133 (parts.join("\n\n"), stats)
134 }
135
136 fn format_tile(tile: &ScoredTile, config: &PromptConfig) -> String {
137 match config.format {
138 TileFormat::Structured => {
139 let domain_tag = if config.include_domain { format!("[{}]", tile.domain) } else { String::new() };
140 format!("{}Q: {}\nA: {}", domain_tag, tile.question, tile.answer)
141 }
142 TileFormat::Markdown => {
143 let domain_tag = if config.include_domain { format!(" ({})", tile.domain) } else { String::new() };
144 format!("## {}{}\n\n{}", tile.question, domain_tag, tile.answer)
145 }
146 TileFormat::Json => {
147 format!(r#"{{"id":"{}","q":"{}","a":"{}","score":{:.3}}}"#, tile.id, tile.question, tile.answer, tile.score)
148 }
149 TileFormat::Compact => {
150 format!("{}: {:.2} | {} → {}", tile.id, tile.score, tile.question, tile.answer)
151 }
152 }
153 }
154
155 fn deadband_section(tiles: &[ScoredTile], stats: &BuildStats) -> String {
157 if stats.tiles_included == 0 { return String::new(); }
158 let covered_domains: std::collections::HashSet<&str> =
159 tiles.iter().take(stats.tiles_included).map(|t| t.domain.as_str()).collect();
160 if covered_domains.is_empty() { return String::new(); }
161
162 let p0_negatives: Vec<&str> = tiles.iter()
164 .filter(|t| t.priority == Priority::P0 && !covered_domains.contains(t.domain.as_str()))
165 .map(|t| t.domain.as_str())
166 .collect();
167
168 if p0_negatives.is_empty() { return String::new(); }
169
170 let unique: std::collections::HashSet<&str> = p0_negatives.into_iter().collect();
171 let warnings: Vec<String> = unique.iter().map(|d| format!("- ⚠️ P0 gap: no coverage for [{}]", d)).collect();
172 format!("## Deadband Warnings\n\n{}", warnings.join("\n"))
173 }
174}
175
176#[derive(Debug, Clone, Default)]
178pub struct BuildStats {
179 pub tiles_included: usize,
180 pub excluded: usize,
181 pub p0_count: usize,
182 pub p1_count: usize,
183 pub p2_count: usize,
184 pub system_tokens: usize,
185 pub tile_tokens: usize,
186 pub deadband_tokens: usize,
187 pub query_tokens: usize,
188 pub total_tokens: usize,
189}
190
191fn estimate_tokens(text: &str) -> usize {
192 text.len() / 4 }
194
195fn make_tile(id: &str, q: &str, a: &str, domain: &str, score: f64, priority: Priority) -> ScoredTile {
196 ScoredTile { id: id.into(), question: q.into(), answer: a.into(), domain: domain.into(), score, priority }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_basic_assembly() {
205 let tiles = vec![
206 make_tile("t1", "What is PLATO?", "Training pipeline.", "plato", 0.9, Priority::P2),
207 ];
208 let config = PromptConfig::default();
209 let (prompt, stats) = PromptAssembler::build(&tiles, "tell me about PLATO", &config);
210 assert!(prompt.contains("What is PLATO?"));
211 assert!(prompt.contains("Training pipeline."));
212 assert!(prompt.contains("tell me about PLATO"));
213 assert_eq!(stats.tiles_included, 1);
214 assert_eq!(stats.excluded, 0);
215 }
216
217 #[test]
218 fn test_priority_sorting() {
219 let tiles = vec![
220 make_tile("p2", "Low priority", "Answer P2", "misc", 0.5, Priority::P2),
221 make_tile("p1", "Medium priority", "Answer P1", "safety", 0.7, Priority::P1),
222 make_tile("p0", "High priority", "Answer P0", "critical", 0.3, Priority::P0),
223 ];
224 let config = PromptConfig::default();
225 let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
226 let p0_pos = prompt.find("High priority").unwrap();
228 let p1_pos = prompt.find("Medium priority").unwrap();
229 let p2_pos = prompt.find("Low priority").unwrap();
230 assert!(p0_pos < p1_pos);
231 assert!(p1_pos < p2_pos);
232 assert_eq!(stats.p0_count, 1);
233 assert_eq!(stats.p1_count, 1);
234 assert_eq!(stats.p2_count, 1);
235 }
236
237 #[test]
238 fn test_budget_exclusion() {
239 let tiles = vec![
240 make_tile("big", &"A".repeat(20000), &"B".repeat(20000), "x", 0.9, Priority::P2),
241 make_tile("small", "Q?", "A.", "y", 0.5, Priority::P2),
242 ];
243 let config = PromptConfig { max_tokens: 100, ..Default::default() };
244 let (_, stats) = PromptAssembler::build(&tiles, "test", &config);
245 assert!(stats.excluded >= 1);
246 assert!(stats.tiles_included >= 1);
247 }
248
249 #[test]
250 fn test_deadband_injection() {
251 let tiles = vec![
252 make_tile("safe", "Safe topic", "Safe answer", "safe_domain", 0.9, Priority::P2),
253 ];
254 let config = PromptConfig { inject_deadband: true, ..Default::default() };
255 let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
256 assert_eq!(stats.deadband_tokens, 0);
258 }
259
260 #[test]
261 fn test_deadband_p0_gap() {
262 let tiles = vec![
264 make_tile("safe", "Safe topic question here", "Safe answer here", "safe_domain", 0.9, Priority::P2),
265 make_tile("gap", "Critical safety question that is very long and will not fit in the tiny budget window provided here for testing purposes", "Critical answer also very long", "unsafe_domain", 0.1, Priority::P0),
266 ];
267 let config = PromptConfig { max_tokens: 50, inject_deadband: true, ..Default::default() };
268 let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
269 assert!(stats.excluded >= 1);
271 assert!(stats.deadband_tokens > 0);
272 assert!(prompt.contains("Deadband Warnings"));
273 }
274
275 #[test]
276 fn test_no_deadband_when_disabled() {
277 let tiles = vec![
278 make_tile("gap", "Missing", "None", "gap_domain", 0.1, Priority::P0),
279 ];
280 let config = PromptConfig { inject_deadband: false, ..Default::default() };
281 let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
282 assert_eq!(stats.deadband_tokens, 0);
283 assert!(!prompt.contains("Deadband"));
284 }
285
286 #[test]
287 fn test_structured_format() {
288 let tiles = vec![make_tile("t1", "Q?", "A.", "dom", 0.9, Priority::P2)];
289 let config = PromptConfig { format: TileFormat::Structured, include_domain: true, ..Default::default() };
290 let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
291 assert!(prompt.contains("[dom]"));
292 assert!(prompt.contains("Q: Q?"));
293 assert!(prompt.contains("A: A."));
294 }
295
296 #[test]
297 fn test_markdown_format() {
298 let tiles = vec![make_tile("t1", "What is flux?", "Bytecode runtime.", "flux", 0.9, Priority::P2)];
299 let config = PromptConfig { format: TileFormat::Markdown, include_domain: true, ..Default::default() };
300 let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
301 assert!(prompt.contains("## What is flux?"));
302 assert!(prompt.contains("(flux)"));
303 }
304
305 #[test]
306 fn test_compact_format() {
307 let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.85, Priority::P2)];
308 let config = PromptConfig { format: TileFormat::Compact, ..Default::default() };
309 let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
310 assert!(prompt.contains("t1: 0.85"));
311 assert!(prompt.contains("Q? → A."));
312 }
313
314 #[test]
315 fn test_json_format() {
316 let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.9, Priority::P2)];
317 let config = PromptConfig { format: TileFormat::Json, ..Default::default() };
318 let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
319 assert!(prompt.contains(r#""id":"t1""#));
320 assert!(prompt.contains(r#""score":0.900"#));
321 }
322
323 #[test]
324 fn test_system_prefix() {
325 let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.9, Priority::P2)];
326 let config = PromptConfig { system_prefix: "You are a helpful PLATO assistant.".into(), ..Default::default() };
327 let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
328 assert!(prompt.starts_with("You are a helpful PLATO assistant."));
329 assert!(stats.system_tokens > 0);
330 }
331
332 #[test]
333 fn test_empty_tiles() {
334 let config = PromptConfig::default();
335 let (prompt, stats) = PromptAssembler::build(&[], "test", &config);
336 assert!(prompt.contains("Query: test"));
337 assert_eq!(stats.tiles_included, 0);
338 }
339
340 #[test]
341 fn test_token_accounting() {
342 let tiles = vec![
343 make_tile("t1", "Short Q", "Short A", "d", 0.9, Priority::P2),
344 ];
345 let config = PromptConfig::default();
346 let (_, stats) = PromptAssembler::build(&tiles, "test query", &config);
347 assert!(stats.total_tokens > 0);
348 assert_eq!(stats.total_tokens, stats.system_tokens + stats.tile_tokens + stats.deadband_tokens + stats.query_tokens);
349 }
350}