Skip to main content

plato_tile_prompt/
lib.rs

1//! plato-tile-prompt — Tile Context → LLM Prompt Assembly
2//!
3//! Takes scored tiles and assembles them into a prompt for LLM inference.
4//! Handles budget management, deadband injection, and format selection.
5//!
6//! ## Why
7//! A model doesn't receive raw tiles. It receives a prompt assembled from
8//! relevant tiles, formatted for its context window, with deadband warnings
9//! injected when approaching negative space.
10//!
11//! ## API
12//! ```ignore
13//! let config = PromptConfig::default();
14//! let (prompt, stats) = PromptAssembler::build(&scored_tiles, query, &config);
15//! ```
16
17/// A scored tile ready for prompt assembly.
18#[derive(Debug, Clone)]
19pub struct ScoredTile {
20    pub id: String,
21    pub question: String,
22    pub answer: String,
23    pub domain: String,
24    pub score: f64,
25    pub priority: Priority,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum Priority { P0, P1, P2 }
30impl Default for Priority { fn default() -> Self { Priority::P2 } }
31
32/// Prompt assembly configuration.
33#[derive(Debug, Clone)]
34pub struct PromptConfig {
35    /// Maximum tokens for the assembled prompt (approximate, uses chars/4).
36    pub max_tokens: usize,
37    /// Include deadband warnings in the prompt.
38    pub inject_deadband: bool,
39    /// Format style for tiles.
40    pub format: TileFormat,
41    /// System prompt prefix.
42    pub system_prefix: String,
43    /// Whether to include domain tags.
44    pub include_domain: bool,
45}
46
47impl Default for PromptConfig {
48    fn default() -> Self {
49        PromptConfig {
50            max_tokens: 4096,
51            inject_deadband: true,
52            format: TileFormat::Structured,
53            system_prefix: String::new(),
54            include_domain: true,
55        }
56    }
57}
58
59#[derive(Debug, Clone, Copy)]
60pub enum TileFormat {
61    /// Q: ... A: ... format
62    Structured,
63    /// Markdown with headers
64    Markdown,
65    /// JSON array
66    Json,
67    /// Compact: id: score | question → answer
68    Compact,
69}
70
71/// The prompt assembler.
72pub struct PromptAssembler;
73
74impl PromptAssembler {
75    /// Build a prompt from scored tiles.
76    /// Returns (prompt, stats) where stats tracks what was included/excluded.
77    pub fn build(tiles: &[ScoredTile], query: &str, config: &PromptConfig) -> (String, BuildStats) {
78        let mut stats = BuildStats::default();
79        let mut parts = Vec::new();
80
81        // System prefix
82        if !config.system_prefix.is_empty() {
83            parts.push(config.system_prefix.clone());
84            stats.system_tokens = estimate_tokens(&config.system_prefix);
85        }
86
87        // Sort by priority (P0 first) then score descending
88        let mut sorted: Vec<&ScoredTile> = tiles.iter().collect();
89        sorted.sort_by(|a, b| {
90            let pa = match a.priority { Priority::P0 => 0, Priority::P1 => 1, Priority::P2 => 2 };
91            let pb = match b.priority { Priority::P0 => 0, Priority::P1 => 1, Priority::P2 => 2 };
92            pa.cmp(&pb).then(b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal))
93        });
94
95        let budget = config.max_tokens.saturating_sub(stats.system_tokens);
96        let mut used = 0usize;
97
98        for tile in &sorted {
99            let formatted = Self::format_tile(tile, config);
100            let tile_tokens = estimate_tokens(&formatted);
101            if used + tile_tokens > budget {
102                stats.excluded += 1;
103                continue;
104            }
105            parts.push(formatted);
106            used += tile_tokens;
107            stats.tiles_included += 1;
108            stats.tile_tokens += tile_tokens;
109            match tile.priority {
110                Priority::P0 => stats.p0_count += 1,
111                Priority::P1 => stats.p1_count += 1,
112                Priority::P2 => stats.p2_count += 1,
113            }
114        }
115
116        // Deadband injection
117        if config.inject_deadband {
118            let deadband = Self::deadband_section(tiles, &stats);
119            if !deadband.is_empty() {
120                let db_tokens = estimate_tokens(&deadband);
121                parts.push(deadband);
122                stats.deadband_tokens = db_tokens;
123            }
124        }
125
126        // Query suffix
127        let query_line = format!("\n\nQuery: {}", query);
128        stats.query_tokens = estimate_tokens(&query_line);
129        parts.push(query_line);
130
131        stats.total_tokens = stats.system_tokens + stats.tile_tokens + stats.deadband_tokens + stats.query_tokens;
132
133        (parts.join("\n\n"), stats)
134    }
135
136    fn format_tile(tile: &ScoredTile, config: &PromptConfig) -> String {
137        match config.format {
138            TileFormat::Structured => {
139                let domain_tag = if config.include_domain { format!("[{}]", tile.domain) } else { String::new() };
140                format!("{}Q: {}\nA: {}", domain_tag, tile.question, tile.answer)
141            }
142            TileFormat::Markdown => {
143                let domain_tag = if config.include_domain { format!(" ({})", tile.domain) } else { String::new() };
144                format!("## {}{}\n\n{}", tile.question, domain_tag, tile.answer)
145            }
146            TileFormat::Json => {
147                format!(r#"{{"id":"{}","q":"{}","a":"{}","score":{:.3}}}"#, tile.id, tile.question, tile.answer, tile.score)
148            }
149            TileFormat::Compact => {
150                format!("{}: {:.2} | {} → {}", tile.id, tile.score, tile.question, tile.answer)
151            }
152        }
153    }
154
155    /// Generate deadband warnings for domains NOT covered by included tiles.
156    fn deadband_section(tiles: &[ScoredTile], stats: &BuildStats) -> String {
157        if stats.tiles_included == 0 { return String::new(); }
158        let covered_domains: std::collections::HashSet<&str> =
159            tiles.iter().take(stats.tiles_included).map(|t| t.domain.as_str()).collect();
160        if covered_domains.is_empty() { return String::new(); }
161
162        // Count P0 negatives (tiles that SHOULD be here but aren't)
163        let p0_negatives: Vec<&str> = tiles.iter()
164            .filter(|t| t.priority == Priority::P0 && !covered_domains.contains(t.domain.as_str()))
165            .map(|t| t.domain.as_str())
166            .collect();
167
168        if p0_negatives.is_empty() { return String::new(); }
169
170        let unique: std::collections::HashSet<&str> = p0_negatives.into_iter().collect();
171        let warnings: Vec<String> = unique.iter().map(|d| format!("- ⚠️ P0 gap: no coverage for [{}]", d)).collect();
172        format!("## Deadband Warnings\n\n{}", warnings.join("\n"))
173    }
174}
175
176/// Build statistics.
177#[derive(Debug, Clone, Default)]
178pub struct BuildStats {
179    pub tiles_included: usize,
180    pub excluded: usize,
181    pub p0_count: usize,
182    pub p1_count: usize,
183    pub p2_count: usize,
184    pub system_tokens: usize,
185    pub tile_tokens: usize,
186    pub deadband_tokens: usize,
187    pub query_tokens: usize,
188    pub total_tokens: usize,
189}
190
191fn estimate_tokens(text: &str) -> usize {
192    text.len() / 4 // rough: 1 token ≈ 4 chars
193}
194
195fn make_tile(id: &str, q: &str, a: &str, domain: &str, score: f64, priority: Priority) -> ScoredTile {
196    ScoredTile { id: id.into(), question: q.into(), answer: a.into(), domain: domain.into(), score, priority }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_basic_assembly() {
205        let tiles = vec![
206            make_tile("t1", "What is PLATO?", "Training pipeline.", "plato", 0.9, Priority::P2),
207        ];
208        let config = PromptConfig::default();
209        let (prompt, stats) = PromptAssembler::build(&tiles, "tell me about PLATO", &config);
210        assert!(prompt.contains("What is PLATO?"));
211        assert!(prompt.contains("Training pipeline."));
212        assert!(prompt.contains("tell me about PLATO"));
213        assert_eq!(stats.tiles_included, 1);
214        assert_eq!(stats.excluded, 0);
215    }
216
217    #[test]
218    fn test_priority_sorting() {
219        let tiles = vec![
220            make_tile("p2", "Low priority", "Answer P2", "misc", 0.5, Priority::P2),
221            make_tile("p1", "Medium priority", "Answer P1", "safety", 0.7, Priority::P1),
222            make_tile("p0", "High priority", "Answer P0", "critical", 0.3, Priority::P0),
223        ];
224        let config = PromptConfig::default();
225        let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
226        // P0 should appear first despite lowest score
227        let p0_pos = prompt.find("High priority").unwrap();
228        let p1_pos = prompt.find("Medium priority").unwrap();
229        let p2_pos = prompt.find("Low priority").unwrap();
230        assert!(p0_pos < p1_pos);
231        assert!(p1_pos < p2_pos);
232        assert_eq!(stats.p0_count, 1);
233        assert_eq!(stats.p1_count, 1);
234        assert_eq!(stats.p2_count, 1);
235    }
236
237    #[test]
238    fn test_budget_exclusion() {
239        let tiles = vec![
240            make_tile("big", &"A".repeat(20000), &"B".repeat(20000), "x", 0.9, Priority::P2),
241            make_tile("small", "Q?", "A.", "y", 0.5, Priority::P2),
242        ];
243        let config = PromptConfig { max_tokens: 100, ..Default::default() };
244        let (_, stats) = PromptAssembler::build(&tiles, "test", &config);
245        assert!(stats.excluded >= 1);
246        assert!(stats.tiles_included >= 1);
247    }
248
249    #[test]
250    fn test_deadband_injection() {
251        let tiles = vec![
252            make_tile("safe", "Safe topic", "Safe answer", "safe_domain", 0.9, Priority::P2),
253        ];
254        let config = PromptConfig { inject_deadband: true, ..Default::default() };
255        let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
256        // No P0 gaps → no deadband warnings
257        assert_eq!(stats.deadband_tokens, 0);
258    }
259
260    #[test]
261    fn test_deadband_p0_gap() {
262        // Small budget excludes the P0 tile, creating a gap
263        let tiles = vec![
264            make_tile("safe", "Safe topic question here", "Safe answer here", "safe_domain", 0.9, Priority::P2),
265            make_tile("gap", "Critical safety question that is very long and will not fit in the tiny budget window provided here for testing purposes", "Critical answer also very long", "unsafe_domain", 0.1, Priority::P0),
266        ];
267        let config = PromptConfig { max_tokens: 50, inject_deadband: true, ..Default::default() };
268        let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
269        // P0 tile excluded by budget → deadband warning should fire
270        assert!(stats.excluded >= 1);
271        assert!(stats.deadband_tokens > 0);
272        assert!(prompt.contains("Deadband Warnings"));
273    }
274
275    #[test]
276    fn test_no_deadband_when_disabled() {
277        let tiles = vec![
278            make_tile("gap", "Missing", "None", "gap_domain", 0.1, Priority::P0),
279        ];
280        let config = PromptConfig { inject_deadband: false, ..Default::default() };
281        let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
282        assert_eq!(stats.deadband_tokens, 0);
283        assert!(!prompt.contains("Deadband"));
284    }
285
286    #[test]
287    fn test_structured_format() {
288        let tiles = vec![make_tile("t1", "Q?", "A.", "dom", 0.9, Priority::P2)];
289        let config = PromptConfig { format: TileFormat::Structured, include_domain: true, ..Default::default() };
290        let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
291        assert!(prompt.contains("[dom]"));
292        assert!(prompt.contains("Q: Q?"));
293        assert!(prompt.contains("A: A."));
294    }
295
296    #[test]
297    fn test_markdown_format() {
298        let tiles = vec![make_tile("t1", "What is flux?", "Bytecode runtime.", "flux", 0.9, Priority::P2)];
299        let config = PromptConfig { format: TileFormat::Markdown, include_domain: true, ..Default::default() };
300        let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
301        assert!(prompt.contains("## What is flux?"));
302        assert!(prompt.contains("(flux)"));
303    }
304
305    #[test]
306    fn test_compact_format() {
307        let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.85, Priority::P2)];
308        let config = PromptConfig { format: TileFormat::Compact, ..Default::default() };
309        let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
310        assert!(prompt.contains("t1: 0.85"));
311        assert!(prompt.contains("Q? → A."));
312    }
313
314    #[test]
315    fn test_json_format() {
316        let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.9, Priority::P2)];
317        let config = PromptConfig { format: TileFormat::Json, ..Default::default() };
318        let (prompt, _) = PromptAssembler::build(&tiles, "test", &config);
319        assert!(prompt.contains(r#""id":"t1""#));
320        assert!(prompt.contains(r#""score":0.900"#));
321    }
322
323    #[test]
324    fn test_system_prefix() {
325        let tiles = vec![make_tile("t1", "Q?", "A.", "d", 0.9, Priority::P2)];
326        let config = PromptConfig { system_prefix: "You are a helpful PLATO assistant.".into(), ..Default::default() };
327        let (prompt, stats) = PromptAssembler::build(&tiles, "test", &config);
328        assert!(prompt.starts_with("You are a helpful PLATO assistant."));
329        assert!(stats.system_tokens > 0);
330    }
331
332    #[test]
333    fn test_empty_tiles() {
334        let config = PromptConfig::default();
335        let (prompt, stats) = PromptAssembler::build(&[], "test", &config);
336        assert!(prompt.contains("Query: test"));
337        assert_eq!(stats.tiles_included, 0);
338    }
339
340    #[test]
341    fn test_token_accounting() {
342        let tiles = vec![
343            make_tile("t1", "Short Q", "Short A", "d", 0.9, Priority::P2),
344        ];
345        let config = PromptConfig::default();
346        let (_, stats) = PromptAssembler::build(&tiles, "test query", &config);
347        assert!(stats.total_tokens > 0);
348        assert_eq!(stats.total_tokens, stats.system_tokens + stats.tile_tokens + stats.deadband_tokens + stats.query_tokens);
349    }
350}