Skip to main content

lash_sansio/
prompt.rs

1use std::hash::{Hash, Hasher};
2use std::sync::{Arc, Mutex};
3
4use crate::{PromptContribution, PromptTemplate};
5
6/// Process-local cache identity for prompt inputs.
7///
8/// This intentionally uses Rust's default hasher because it only keys in-memory
9/// prompt caches inside one process. Do not persist or compare it across
10/// processes.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct PromptFingerprint(u64);
13
14impl PromptFingerprint {
15    fn from_hashable(value: impl Hash) -> Self {
16        let mut hasher = std::collections::hash_map::DefaultHasher::new();
17        value.hash(&mut hasher);
18        Self(hasher.finish())
19    }
20
21    fn write(self, state: &mut impl Hasher) {
22        self.0.hash(state);
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct PromptContributionSet {
28    contributions: Arc<Vec<PromptContribution>>,
29    fingerprint: PromptFingerprint,
30}
31
32impl PromptContributionSet {
33    pub fn new(contributions: Vec<PromptContribution>) -> Self {
34        let contributions = Arc::new(merge_prompt_contributions(contributions));
35        let fingerprint = fingerprint_contributions(&contributions);
36        Self {
37            contributions,
38            fingerprint,
39        }
40    }
41
42    pub fn empty() -> Self {
43        Self::new(Vec::new())
44    }
45
46    pub fn as_arc(&self) -> Arc<Vec<PromptContribution>> {
47        Arc::clone(&self.contributions)
48    }
49
50    pub fn as_slice(&self) -> &[PromptContribution] {
51        &self.contributions
52    }
53
54    pub fn fingerprint(&self) -> PromptFingerprint {
55        self.fingerprint
56    }
57}
58
59impl Default for PromptContributionSet {
60    fn default() -> Self {
61        Self::empty()
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct PromptBuildInput {
67    pub template: PromptTemplate,
68    pub template_fingerprint: PromptFingerprint,
69    pub execution_prompt: Arc<str>,
70    pub execution_prompt_fingerprint: PromptFingerprint,
71    pub tool_names: Arc<Vec<String>>,
72    pub tool_names_fingerprint: PromptFingerprint,
73    pub omitted_tool_count: usize,
74    pub contributions: PromptContributionSet,
75}
76
77#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
78pub struct PromptContext {
79    #[serde(default)]
80    pub execution_prompt: Arc<str>,
81    pub tool_names: Arc<Vec<String>>,
82    pub omitted_tool_count: usize,
83    pub contributions: Arc<Vec<PromptContribution>>,
84}
85
86impl PromptContext {
87    pub fn has_tool(&self, tool_name: &str) -> bool {
88        self.tool_names.iter().any(|name| name == tool_name)
89    }
90}
91
92#[derive(Clone, Debug)]
93pub struct PreparedPrompt {
94    pub context: PromptContext,
95    pub system_prompt: Arc<str>,
96}
97
98/// Single-slot memo for the rendered system prompt, keyed by a hash of
99/// the inputs. Most consecutive turns in a session pass identical
100/// inputs (template, turn-driver-preamble-derived `execution_prompt` /
101/// `tool_names`, plus context contributions), so a one-slot cache hits
102/// repeatedly and avoids the section-by-section `Vec<String>::join`
103/// work in `PromptTemplate::render`.
104#[derive(Default)]
105pub struct PromptCache {
106    inner: Mutex<Option<(u64, Arc<str>)>>,
107}
108
109impl PromptCache {
110    pub fn new() -> Self {
111        Self::default()
112    }
113
114    pub fn clear(&self) {
115        if let Ok(mut guard) = self.inner.lock() {
116            *guard = None;
117        }
118    }
119}
120
121impl std::fmt::Debug for PromptCache {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("PromptCache").finish_non_exhaustive()
124    }
125}
126
127pub fn build_prompt(input: PromptBuildInput) -> PreparedPrompt {
128    build_prompt_cached(input, None)
129}
130
131pub fn build_prompt_cached(input: PromptBuildInput, cache: Option<&PromptCache>) -> PreparedPrompt {
132    let context = PromptContext {
133        execution_prompt: Arc::clone(&input.execution_prompt),
134        tool_names: Arc::clone(&input.tool_names),
135        omitted_tool_count: input.omitted_tool_count,
136        contributions: input.contributions.as_arc(),
137    };
138    let key = cache.map(|_| hash_prompt_inputs(&input, &context));
139    if let (Some(cache), Some(key)) = (cache, key)
140        && let Some(cached) = cache.inner.lock().ok().and_then(|guard| {
141            guard
142                .as_ref()
143                .filter(|(k, _)| *k == key)
144                .map(|(_, v)| Arc::clone(v))
145        })
146    {
147        return PreparedPrompt {
148            context,
149            system_prompt: cached,
150        };
151    }
152    let system_prompt: Arc<str> = Arc::from(input.template.render(&context));
153    if let (Some(cache), Some(key)) = (cache, key)
154        && let Ok(mut guard) = cache.inner.lock()
155    {
156        *guard = Some((key, Arc::clone(&system_prompt)));
157    }
158    PreparedPrompt {
159        context,
160        system_prompt,
161    }
162}
163
164pub fn prompt_template_fingerprint(template: &PromptTemplate) -> PromptFingerprint {
165    PromptFingerprint::from_hashable(template)
166}
167
168pub fn prompt_text_fingerprint(text: &str) -> PromptFingerprint {
169    PromptFingerprint::from_hashable(text)
170}
171
172pub fn prompt_tool_names_fingerprint(tool_names: &[String]) -> PromptFingerprint {
173    PromptFingerprint::from_hashable(tool_names)
174}
175
176fn hash_prompt_inputs(input: &PromptBuildInput, context: &PromptContext) -> u64 {
177    let mut hasher = std::collections::hash_map::DefaultHasher::new();
178    input.template_fingerprint.write(&mut hasher);
179    input.execution_prompt_fingerprint.write(&mut hasher);
180    input.tool_names_fingerprint.write(&mut hasher);
181    context.omitted_tool_count.hash(&mut hasher);
182    input.contributions.fingerprint().write(&mut hasher);
183    hasher.finish()
184}
185
186fn fingerprint_contributions(contributions: &[PromptContribution]) -> PromptFingerprint {
187    let mut hasher = std::collections::hash_map::DefaultHasher::new();
188    for contribution in contributions {
189        contribution.slot.hash(&mut hasher);
190        contribution.priority.hash(&mut hasher);
191        contribution.title.hash(&mut hasher);
192        contribution.content.hash(&mut hasher);
193    }
194    PromptFingerprint(hasher.finish())
195}
196
197fn merge_prompt_contributions(contributions: Vec<PromptContribution>) -> Vec<PromptContribution> {
198    let mut merged = contributions
199        .into_iter()
200        .filter_map(normalize_contribution)
201        .collect::<Vec<_>>();
202
203    merged.sort_by(|left, right| {
204        slot_order(left.slot)
205            .cmp(&slot_order(right.slot))
206            .then(left.priority.cmp(&right.priority))
207            .then_with(|| left.title.cmp(&right.title))
208            .then_with(|| left.content.cmp(&right.content))
209    });
210
211    // Duplicates are adjacent after the sort, so `dedup_by` on &str
212    // refs drops them without cloning anything.
213    merged.dedup_by(|a, b| {
214        slot_order(a.slot) == slot_order(b.slot)
215            && a.priority == b.priority
216            && a.title.as_deref() == b.title.as_deref()
217            && a.content == b.content
218    });
219    merged
220}
221
222fn normalize_contribution(mut contribution: PromptContribution) -> Option<PromptContribution> {
223    contribution.content = Arc::from(contribution.content.trim());
224    if contribution.content.is_empty() {
225        return None;
226    }
227    contribution.title = contribution
228        .title
229        .as_deref()
230        .map(str::trim)
231        .filter(|title| !title.is_empty())
232        .map(Arc::from);
233    Some(contribution)
234}
235
236fn slot_order(slot: crate::PromptSlot) -> usize {
237    match slot {
238        crate::PromptSlot::Intro => 0,
239        crate::PromptSlot::Execution => 1,
240        crate::PromptSlot::Guidance => 2,
241        crate::PromptSlot::ProjectInstructions => 3,
242        crate::PromptSlot::RuntimeContext => 4,
243        crate::PromptSlot::Environment => 5,
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::{
251        PromptBuiltin, PromptContribution, PromptLayer, PromptSlot, PromptTemplate,
252        PromptTemplateEntry, PromptTemplateSection, default_prompt_template, resolve_prompt_layers,
253    };
254
255    fn input(
256        template: PromptTemplate,
257        execution_prompt: &str,
258        tool_names: Vec<String>,
259        omitted_tool_count: usize,
260        contributions: Vec<PromptContribution>,
261    ) -> PromptBuildInput {
262        let execution_prompt: Arc<str> = Arc::from(execution_prompt);
263        let tool_names = Arc::new(tool_names);
264        PromptBuildInput {
265            template_fingerprint: prompt_template_fingerprint(&template),
266            template,
267            execution_prompt_fingerprint: prompt_text_fingerprint(&execution_prompt),
268            execution_prompt,
269            tool_names_fingerprint: prompt_tool_names_fingerprint(&tool_names),
270            tool_names,
271            omitted_tool_count,
272            contributions: PromptContributionSet::new(contributions),
273        }
274    }
275
276    #[test]
277    fn build_prompt_renders_template_from_merged_context() {
278        let prepared = build_prompt(input(
279            default_prompt_template(),
280            "Use tools.",
281            vec!["read_file".to_string()],
282            0,
283            vec![
284                PromptContribution::guidance("Repo", "Follow repo rules."),
285                PromptContribution::guidance("Repo", "Follow repo rules."),
286                PromptContribution::project_instructions("Be careful."),
287            ],
288        ));
289
290        assert!(prepared.system_prompt.contains("Use tools."));
291        assert!(prepared.system_prompt.contains("Follow repo rules."));
292        assert!(prepared.system_prompt.contains("Be careful."));
293        assert_eq!(prepared.context.contributions.len(), 2);
294    }
295
296    #[test]
297    fn build_prompt_cached_reuses_arc_on_identical_inputs() {
298        let cache = PromptCache::new();
299        let inputs = || {
300            input(
301                default_prompt_template(),
302                "Use tools.",
303                vec!["read_file".to_string()],
304                0,
305                vec![PromptContribution::guidance("Repo", "Follow repo rules.")],
306            )
307        };
308        let first = build_prompt_cached(inputs(), Some(&cache));
309        let second = build_prompt_cached(inputs(), Some(&cache));
310        assert!(Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
311    }
312
313    #[test]
314    fn build_prompt_cached_renders_again_when_inputs_change() {
315        let cache = PromptCache::new();
316        let first = build_prompt_cached(
317            input(
318                default_prompt_template(),
319                "Use tools.",
320                vec!["read_file".to_string()],
321                0,
322                vec![],
323            ),
324            Some(&cache),
325        );
326        let second = build_prompt_cached(
327            input(
328                default_prompt_template(),
329                "Use other tools.",
330                vec!["read_file".to_string()],
331                0,
332                vec![],
333            ),
334            Some(&cache),
335        );
336        assert!(!Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
337        assert_ne!(first.system_prompt, second.system_prompt);
338    }
339
340    fn template_with_text(text: &str) -> PromptTemplate {
341        PromptTemplate::new(vec![PromptTemplateSection::untitled(vec![
342            PromptTemplateEntry::text(text),
343            PromptTemplateEntry::builtin(PromptBuiltin::ExecutionInstructions),
344        ])])
345    }
346
347    fn content(contributions: &[PromptContribution]) -> Vec<&str> {
348        contributions
349            .iter()
350            .map(|contribution| contribution.content.as_ref())
351            .collect()
352    }
353
354    #[test]
355    fn prompt_layers_use_later_template() {
356        let core = PromptLayer::with_template(template_with_text("core"));
357        let session = PromptLayer::with_template(template_with_text("session"));
358        let resolved = resolve_prompt_layers([&core, &session]);
359
360        let rendered = resolved.template.render(&PromptContext {
361            execution_prompt: Arc::from("execute"),
362            ..PromptContext::default()
363        });
364        assert!(rendered.contains("session"));
365        assert!(!rendered.contains("core"));
366    }
367
368    #[test]
369    fn prompt_layers_append_inherited_slot_content() {
370        let core =
371            PromptLayer::new().with_contribution(PromptContribution::guidance("Core", "core"));
372        let session = PromptLayer::new()
373            .with_contribution(PromptContribution::guidance("Session", "session"));
374
375        let resolved = resolve_prompt_layers([&core, &session]);
376        assert_eq!(content(&resolved.contributions), vec!["core", "session"]);
377    }
378
379    #[test]
380    fn prompt_layers_clear_one_slot_without_touching_others() {
381        let core = PromptLayer::new()
382            .with_contribution(PromptContribution::guidance("Guide", "guide"))
383            .with_contribution(PromptContribution::project_instructions("project"));
384        let session = PromptLayer::new().with_cleared_slot(PromptSlot::Guidance);
385
386        let resolved = resolve_prompt_layers([&core, &session]);
387        assert_eq!(content(&resolved.contributions), vec!["project"]);
388    }
389
390    #[test]
391    fn prompt_layers_replace_slot_and_normalize_contribution_slot() {
392        let core =
393            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
394        let session = PromptLayer::new().with_replaced_slot(
395            PromptSlot::Guidance,
396            [PromptContribution::project_instructions("new")],
397        );
398
399        let resolved = resolve_prompt_layers([&core, &session]);
400        assert_eq!(content(&resolved.contributions), vec!["new"]);
401        assert_eq!(resolved.contributions[0].slot, PromptSlot::Guidance);
402    }
403
404    #[test]
405    fn prompt_layers_allow_later_append_after_replace() {
406        let core =
407            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
408        let session = PromptLayer::new().with_replaced_slot(
409            PromptSlot::Guidance,
410            [PromptContribution::guidance("New", "new")],
411        );
412        let turn =
413            PromptLayer::new().with_contribution(PromptContribution::guidance("Turn", "turn"));
414
415        let resolved = resolve_prompt_layers([&core, &session, &turn]);
416        assert_eq!(content(&resolved.contributions), vec!["new", "turn"]);
417    }
418}