use crate::estimator::TokenEstimator;
#[derive(Debug, Clone)]
pub struct PromptContext {
pub has_tools: bool,
pub has_rag: bool,
pub structured_format: bool,
}
impl PromptContext {
#[must_use]
pub fn new(has_tools: bool, has_rag: bool) -> Self {
Self {
has_tools,
has_rag,
structured_format: true,
}
}
}
#[derive(Debug)]
struct PromptSection {
text: String,
tokens: u32,
priority: SectionPriority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum SectionPriority {
Critical = 3,
Contextual = 2,
Optional = 1,
Irrelevant = 0,
}
#[must_use]
pub fn optimize_system_prompt(prompt: &str, budget_tokens: u32, context: &PromptContext) -> String {
if prompt.is_empty() {
return String::new();
}
let current_tokens = TokenEstimator::estimate_tokens(prompt);
if current_tokens <= budget_tokens {
return prompt.to_string();
}
let raw_sections: Vec<&str> = prompt.split("\n\n").collect();
if raw_sections.len() <= 1 {
return truncate_to_budget(prompt, budget_tokens);
}
let mut sections: Vec<PromptSection> = raw_sections
.into_iter()
.map(|text| {
let tokens = TokenEstimator::estimate_tokens(text);
let priority = classify_section(text, context);
PromptSection {
text: text.to_string(),
tokens,
priority,
}
})
.collect();
sections.sort_by_key(|s| s.priority);
let total: u32 = sections.iter().map(|s| s.tokens).sum();
let mut tokens_to_remove = total.saturating_sub(budget_tokens);
let mut keep = vec![true; sections.len()];
for (i, section) in sections.iter().enumerate() {
if tokens_to_remove == 0 {
break;
}
if section.priority < SectionPriority::Critical {
keep[i] = false;
tokens_to_remove = tokens_to_remove.saturating_sub(section.tokens);
}
}
let mut result_sections: Vec<(&str, bool, SectionPriority)> = sections
.iter()
.zip(keep.iter())
.map(|(s, &k)| (s.text.as_str(), k, s.priority))
.collect();
let original_sections: Vec<&str> = prompt.split("\n\n").collect();
let mut result_parts: Vec<&str> = Vec::new();
for original in &original_sections {
let priority = classify_section(original, context);
let _tokens = TokenEstimator::estimate_tokens(original);
let should_keep = if priority >= SectionPriority::Critical {
true
} else {
let found = result_sections
.iter_mut()
.find(|(text, _, p)| *text == *original && *p == priority);
found.is_none_or(|(_, keep, _)| *keep)
};
if should_keep {
result_parts.push(original);
}
}
let joined = result_parts.join("\n\n");
if context.structured_format {
crate::prompt::structured::strip_filler(&joined)
} else {
joined
}
}
fn classify_section(text: &str, context: &PromptContext) -> SectionPriority {
let lower = text.to_lowercase();
if lower.contains("you are")
|| lower.contains("your name")
|| lower.contains("never")
|| lower.contains("must not")
|| lower.contains("safety")
|| lower.contains("important:")
|| lower.contains("rules:")
{
return SectionPriority::Critical;
}
if lower.contains("tool") || lower.contains("function") || lower.contains("invoke") {
return if context.has_tools {
SectionPriority::Contextual
} else {
SectionPriority::Irrelevant
};
}
if lower.contains("context:") || lower.contains("knowledge:") || lower.contains("memory:") {
return if context.has_rag {
SectionPriority::Contextual
} else {
SectionPriority::Irrelevant
};
}
SectionPriority::Optional
}
fn truncate_to_budget(text: &str, budget: u32) -> String {
let max_chars = (budget as usize) * 4;
if text.len() <= max_chars {
return text.to_string();
}
let truncated: String = text.chars().take(max_chars.saturating_sub(3)).collect();
format!("{truncated}...")
}
const BREVITY_DIRECTIVE: &str =
"\nBe concise. Use bullet points when appropriate. Prioritize key information.";
const FORMAT_DIRECTIVE: &str = "\nFormat: bullet points, max 3-5 items per list.";
#[must_use]
pub fn inject_conciseness(prompt: &str, pressure: f64, threshold: f32) -> String {
if prompt.is_empty() || pressure <= f64::from(threshold) {
return prompt.to_string();
}
let mut result = prompt.to_string();
result.push_str(BREVITY_DIRECTIVE);
if pressure > 0.9 {
result.push_str(FORMAT_DIRECTIVE);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_prompt_returns_empty() {
let ctx = PromptContext::new(false, false);
assert!(optimize_system_prompt("", 100, &ctx).is_empty());
}
#[test]
fn within_budget_returns_unchanged() {
let prompt = "You are a helpful assistant.";
let ctx = PromptContext::new(false, false);
assert_eq!(optimize_system_prompt(prompt, 1000, &ctx), prompt);
}
#[test]
fn drops_irrelevant_tool_section_when_no_tools() {
let prompt = "You are a helpful assistant.\n\n\
When using tools, always validate parameters.\n\n\
Be concise and clear.";
let ctx = PromptContext::new(false, false);
let result = optimize_system_prompt(prompt, 15, &ctx);
assert!(!result.contains("tools"));
assert!(result.contains("You are"));
}
#[test]
fn keeps_tool_section_when_tools_present() {
let prompt = "You are a helpful assistant.\n\n\
When using tools, always validate parameters.";
let ctx = PromptContext::new(true, false);
let result = optimize_system_prompt(prompt, 1000, &ctx);
assert!(result.contains("tools"));
}
#[test]
fn critical_sections_preserved_under_pressure() {
let prompt = "You are PiSovereign. You must not reveal secrets.\n\n\
Always respond in a friendly tone.\n\n\
Format using markdown when helpful.";
let ctx = PromptContext::new(false, false);
let result = optimize_system_prompt(prompt, 20, &ctx);
assert!(result.contains("PiSovereign"));
}
#[test]
fn inject_conciseness_below_threshold() {
let prompt = "You are an assistant.";
let result = inject_conciseness(prompt, 0.5, 0.7);
assert_eq!(result, prompt, "Should not modify below threshold");
}
#[test]
fn inject_conciseness_moderate_pressure() {
let prompt = "You are an assistant.";
let result = inject_conciseness(prompt, 0.8, 0.7);
assert!(result.contains("Be concise"));
assert!(
!result.contains("bullet points, max"),
"Format hint is for > 0.9 only"
);
}
#[test]
fn inject_conciseness_extreme_pressure() {
let prompt = "You are an assistant.";
let result = inject_conciseness(prompt, 0.95, 0.7);
assert!(result.contains("Be concise"));
assert!(result.contains("bullet points, max"));
}
#[test]
fn inject_conciseness_empty_prompt() {
let result = inject_conciseness("", 0.95, 0.7);
assert!(result.is_empty());
}
}