Skip to main content

lash_sansio/
prompt.rs

1use std::hash::{Hash, Hasher};
2use std::sync::{Arc, Mutex};
3
4use crate::{ExecutionMode, PromptContext, PromptContribution, PromptTemplate};
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
7pub struct PromptFingerprint(u64);
8
9impl PromptFingerprint {
10    fn from_hashable(value: impl Hash) -> Self {
11        let mut hasher = std::collections::hash_map::DefaultHasher::new();
12        value.hash(&mut hasher);
13        Self(hasher.finish())
14    }
15
16    fn write(self, state: &mut impl Hasher) {
17        self.0.hash(state);
18    }
19}
20
21#[derive(Clone, Debug)]
22pub struct PromptContributionSet {
23    contributions: Arc<Vec<PromptContribution>>,
24    fingerprint: PromptFingerprint,
25}
26
27impl PromptContributionSet {
28    pub fn new(contributions: Vec<PromptContribution>) -> Self {
29        let contributions = Arc::new(merge_prompt_contributions(contributions));
30        let fingerprint = fingerprint_contributions(&contributions);
31        Self {
32            contributions,
33            fingerprint,
34        }
35    }
36
37    pub fn empty() -> Self {
38        Self::new(Vec::new())
39    }
40
41    pub fn as_arc(&self) -> Arc<Vec<PromptContribution>> {
42        Arc::clone(&self.contributions)
43    }
44
45    pub fn as_slice(&self) -> &[PromptContribution] {
46        &self.contributions
47    }
48
49    pub fn fingerprint(&self) -> PromptFingerprint {
50        self.fingerprint
51    }
52}
53
54impl Default for PromptContributionSet {
55    fn default() -> Self {
56        Self::empty()
57    }
58}
59
60#[derive(Clone, Debug)]
61pub struct PromptBuildInput {
62    pub mode: ExecutionMode,
63    pub template: PromptTemplate,
64    pub template_fingerprint: PromptFingerprint,
65    pub execution_prompt: Arc<str>,
66    pub execution_prompt_fingerprint: PromptFingerprint,
67    pub tool_names: Arc<Vec<String>>,
68    pub tool_names_fingerprint: PromptFingerprint,
69    pub omitted_tool_count: usize,
70    pub contributions: PromptContributionSet,
71}
72
73#[derive(Clone, Debug)]
74pub struct PreparedPrompt {
75    pub context: PromptContext,
76    pub system_prompt: Arc<str>,
77}
78
79/// Single-slot memo for the rendered system prompt, keyed by a hash of
80/// the inputs. Most consecutive turns in a session pass identical
81/// inputs (template, mode-preamble-derived `execution_prompt` /
82/// `tool_names`, plus context contributions), so a one-slot cache hits
83/// repeatedly and avoids the section-by-section `Vec<String>::join`
84/// work in `PromptTemplate::render`.
85#[derive(Default)]
86pub struct PromptCache {
87    inner: Mutex<Option<(u64, Arc<str>)>>,
88}
89
90impl PromptCache {
91    pub fn new() -> Self {
92        Self::default()
93    }
94}
95
96impl std::fmt::Debug for PromptCache {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        f.debug_struct("PromptCache").finish_non_exhaustive()
99    }
100}
101
102pub fn build_prompt(input: PromptBuildInput) -> PreparedPrompt {
103    build_prompt_cached(input, None)
104}
105
106pub fn build_prompt_cached(input: PromptBuildInput, cache: Option<&PromptCache>) -> PreparedPrompt {
107    let context = PromptContext {
108        mode: input.mode.clone(),
109        execution_prompt: Arc::clone(&input.execution_prompt),
110        tool_names: Arc::clone(&input.tool_names),
111        omitted_tool_count: input.omitted_tool_count,
112        contributions: input.contributions.as_arc(),
113    };
114    let key = cache.map(|_| hash_prompt_inputs(&input, &context));
115    if let (Some(cache), Some(key)) = (cache, key)
116        && let Some(cached) = cache.inner.lock().ok().and_then(|guard| {
117            guard
118                .as_ref()
119                .filter(|(k, _)| *k == key)
120                .map(|(_, v)| Arc::clone(v))
121        })
122    {
123        return PreparedPrompt {
124            context,
125            system_prompt: cached,
126        };
127    }
128    let system_prompt: Arc<str> = Arc::from(input.template.render(&context));
129    if let (Some(cache), Some(key)) = (cache, key)
130        && let Ok(mut guard) = cache.inner.lock()
131    {
132        *guard = Some((key, Arc::clone(&system_prompt)));
133    }
134    PreparedPrompt {
135        context,
136        system_prompt,
137    }
138}
139
140pub fn prompt_template_fingerprint(template: &PromptTemplate) -> PromptFingerprint {
141    PromptFingerprint::from_hashable(template)
142}
143
144pub fn prompt_text_fingerprint(text: &str) -> PromptFingerprint {
145    PromptFingerprint::from_hashable(text)
146}
147
148pub fn prompt_tool_names_fingerprint(tool_names: &[String]) -> PromptFingerprint {
149    PromptFingerprint::from_hashable(tool_names)
150}
151
152fn hash_prompt_inputs(input: &PromptBuildInput, context: &PromptContext) -> u64 {
153    let mut hasher = std::collections::hash_map::DefaultHasher::new();
154    input.template_fingerprint.write(&mut hasher);
155    context.mode.hash(&mut hasher);
156    input.execution_prompt_fingerprint.write(&mut hasher);
157    input.tool_names_fingerprint.write(&mut hasher);
158    context.omitted_tool_count.hash(&mut hasher);
159    input.contributions.fingerprint().write(&mut hasher);
160    hasher.finish()
161}
162
163fn fingerprint_contributions(contributions: &[PromptContribution]) -> PromptFingerprint {
164    let mut hasher = std::collections::hash_map::DefaultHasher::new();
165    for contribution in contributions {
166        contribution.slot.hash(&mut hasher);
167        contribution.priority.hash(&mut hasher);
168        contribution.title.hash(&mut hasher);
169        contribution.content.hash(&mut hasher);
170    }
171    PromptFingerprint(hasher.finish())
172}
173
174fn merge_prompt_contributions(contributions: Vec<PromptContribution>) -> Vec<PromptContribution> {
175    let mut merged = contributions
176        .into_iter()
177        .filter_map(normalize_contribution)
178        .collect::<Vec<_>>();
179
180    merged.sort_by(|left, right| {
181        slot_order(left.slot)
182            .cmp(&slot_order(right.slot))
183            .then(left.priority.cmp(&right.priority))
184            .then_with(|| left.title.cmp(&right.title))
185            .then_with(|| left.content.cmp(&right.content))
186    });
187
188    // Duplicates are adjacent after the sort, so `dedup_by` on &str
189    // refs drops them without cloning anything.
190    merged.dedup_by(|a, b| {
191        slot_order(a.slot) == slot_order(b.slot)
192            && a.priority == b.priority
193            && a.title.as_deref() == b.title.as_deref()
194            && a.content == b.content
195    });
196    merged
197}
198
199fn normalize_contribution(mut contribution: PromptContribution) -> Option<PromptContribution> {
200    contribution.content = Arc::from(contribution.content.trim());
201    if contribution.content.is_empty() {
202        return None;
203    }
204    contribution.title = contribution
205        .title
206        .as_deref()
207        .map(str::trim)
208        .filter(|title| !title.is_empty())
209        .map(Arc::from);
210    Some(contribution)
211}
212
213fn slot_order(slot: crate::PromptSlot) -> usize {
214    match slot {
215        crate::PromptSlot::Intro => 0,
216        crate::PromptSlot::Execution => 1,
217        crate::PromptSlot::CliAutonomousIntro => 2,
218        crate::PromptSlot::CliAutonomousExecution => 3,
219        crate::PromptSlot::CliRlmExecution => 4,
220        crate::PromptSlot::Guidance => 5,
221        crate::PromptSlot::ProjectInstructions => 6,
222        crate::PromptSlot::RuntimeContext => 7,
223        crate::PromptSlot::Environment => 8,
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::{
231        PromptBuiltin, PromptContribution, PromptLayer, PromptSlot, PromptTemplate,
232        PromptTemplateEntry, PromptTemplateSection, default_prompt_template, resolve_prompt_layers,
233    };
234
235    fn input(
236        template: PromptTemplate,
237        execution_prompt: &str,
238        tool_names: Vec<String>,
239        omitted_tool_count: usize,
240        contributions: Vec<PromptContribution>,
241    ) -> PromptBuildInput {
242        let execution_prompt: Arc<str> = Arc::from(execution_prompt);
243        let tool_names = Arc::new(tool_names);
244        PromptBuildInput {
245            mode: crate::ExecutionMode::standard(),
246            template_fingerprint: prompt_template_fingerprint(&template),
247            template,
248            execution_prompt_fingerprint: prompt_text_fingerprint(&execution_prompt),
249            execution_prompt,
250            tool_names_fingerprint: prompt_tool_names_fingerprint(&tool_names),
251            tool_names,
252            omitted_tool_count,
253            contributions: PromptContributionSet::new(contributions),
254        }
255    }
256
257    #[test]
258    fn build_prompt_renders_template_from_merged_context() {
259        let prepared = build_prompt(input(
260            default_prompt_template(),
261            "Use tools.",
262            vec!["read_file".to_string()],
263            0,
264            vec![
265                PromptContribution::guidance("Repo", "Follow repo rules."),
266                PromptContribution::guidance("Repo", "Follow repo rules."),
267                PromptContribution::project_instructions("Be careful."),
268            ],
269        ));
270
271        assert!(prepared.system_prompt.contains("Use tools."));
272        assert!(prepared.system_prompt.contains("Follow repo rules."));
273        assert!(prepared.system_prompt.contains("Be careful."));
274        assert_eq!(prepared.context.contributions.len(), 2);
275    }
276
277    #[test]
278    fn build_prompt_cached_reuses_arc_on_identical_inputs() {
279        let cache = PromptCache::new();
280        let inputs = || {
281            input(
282                default_prompt_template(),
283                "Use tools.",
284                vec!["read_file".to_string()],
285                0,
286                vec![PromptContribution::guidance("Repo", "Follow repo rules.")],
287            )
288        };
289        let first = build_prompt_cached(inputs(), Some(&cache));
290        let second = build_prompt_cached(inputs(), Some(&cache));
291        assert!(Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
292    }
293
294    #[test]
295    fn build_prompt_cached_renders_again_when_inputs_change() {
296        let cache = PromptCache::new();
297        let first = build_prompt_cached(
298            input(
299                default_prompt_template(),
300                "Use tools.",
301                vec!["read_file".to_string()],
302                0,
303                vec![],
304            ),
305            Some(&cache),
306        );
307        let second = build_prompt_cached(
308            input(
309                default_prompt_template(),
310                "Use other tools.",
311                vec!["read_file".to_string()],
312                0,
313                vec![],
314            ),
315            Some(&cache),
316        );
317        assert!(!Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
318        assert_ne!(first.system_prompt, second.system_prompt);
319    }
320
321    fn template_with_text(text: &str) -> PromptTemplate {
322        PromptTemplate::new(vec![PromptTemplateSection::untitled(vec![
323            PromptTemplateEntry::text(text),
324            PromptTemplateEntry::builtin(PromptBuiltin::ExecutionInstructions),
325        ])])
326    }
327
328    fn content(contributions: &[PromptContribution]) -> Vec<&str> {
329        contributions
330            .iter()
331            .map(|contribution| contribution.content.as_ref())
332            .collect()
333    }
334
335    #[test]
336    fn prompt_layers_use_later_template() {
337        let core = PromptLayer::with_template(template_with_text("core"));
338        let session = PromptLayer::with_template(template_with_text("session"));
339        let resolved = resolve_prompt_layers([&core, &session]);
340
341        let rendered = resolved.template.render(&PromptContext {
342            mode: crate::ExecutionMode::standard(),
343            execution_prompt: Arc::from("execute"),
344            ..PromptContext::default()
345        });
346        assert!(rendered.contains("session"));
347        assert!(!rendered.contains("core"));
348    }
349
350    #[test]
351    fn prompt_layers_append_inherited_slot_content() {
352        let core =
353            PromptLayer::new().with_contribution(PromptContribution::guidance("Core", "core"));
354        let session = PromptLayer::new()
355            .with_contribution(PromptContribution::guidance("Session", "session"));
356
357        let resolved = resolve_prompt_layers([&core, &session]);
358        assert_eq!(content(&resolved.contributions), vec!["core", "session"]);
359    }
360
361    #[test]
362    fn prompt_layers_clear_one_slot_without_touching_others() {
363        let core = PromptLayer::new()
364            .with_contribution(PromptContribution::guidance("Guide", "guide"))
365            .with_contribution(PromptContribution::project_instructions("project"));
366        let session = PromptLayer::new().with_cleared_slot(PromptSlot::Guidance);
367
368        let resolved = resolve_prompt_layers([&core, &session]);
369        assert_eq!(content(&resolved.contributions), vec!["project"]);
370    }
371
372    #[test]
373    fn prompt_layers_replace_slot_and_normalize_contribution_slot() {
374        let core =
375            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
376        let session = PromptLayer::new().with_replaced_slot(
377            PromptSlot::Guidance,
378            [PromptContribution::project_instructions("new")],
379        );
380
381        let resolved = resolve_prompt_layers([&core, &session]);
382        assert_eq!(content(&resolved.contributions), vec!["new"]);
383        assert_eq!(resolved.contributions[0].slot, PromptSlot::Guidance);
384    }
385
386    #[test]
387    fn prompt_layers_allow_later_append_after_replace() {
388        let core =
389            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
390        let session = PromptLayer::new().with_replaced_slot(
391            PromptSlot::Guidance,
392            [PromptContribution::guidance("New", "new")],
393        );
394        let turn =
395            PromptLayer::new().with_contribution(PromptContribution::guidance("Turn", "turn"));
396
397        let resolved = resolve_prompt_layers([&core, &session, &turn]);
398        assert_eq!(content(&resolved.contributions), vec!["new", "turn"]);
399    }
400}