Skip to main content

lash_sansio/
prompt.rs

1use std::hash::{Hash, Hasher};
2use std::sync::{Arc, Mutex};
3
4use crate::{PromptContribution, PromptTemplate};
5
6/// Process-local cache identity for prompt inputs.
7///
8/// This intentionally uses Rust's default hasher because it only keys in-memory
9/// prompt caches inside one process. Do not persist or compare it across
10/// processes.
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
12pub struct PromptFingerprint(u64);
13
14impl PromptFingerprint {
15    fn from_hashable(value: impl Hash) -> Self {
16        let mut hasher = std::collections::hash_map::DefaultHasher::new();
17        value.hash(&mut hasher);
18        Self(hasher.finish())
19    }
20
21    fn write(self, state: &mut impl Hasher) {
22        self.0.hash(state);
23    }
24}
25
26#[derive(Clone, Debug)]
27pub struct PromptContributionSet {
28    contributions: Arc<Vec<PromptContribution>>,
29    fingerprint: PromptFingerprint,
30}
31
32impl PromptContributionSet {
33    pub fn new(contributions: Vec<PromptContribution>) -> Self {
34        let contributions = Arc::new(merge_prompt_contributions(contributions));
35        let fingerprint = fingerprint_contributions(&contributions);
36        Self {
37            contributions,
38            fingerprint,
39        }
40    }
41
42    pub fn empty() -> Self {
43        Self::new(Vec::new())
44    }
45
46    pub fn as_arc(&self) -> Arc<Vec<PromptContribution>> {
47        Arc::clone(&self.contributions)
48    }
49
50    pub fn as_slice(&self) -> &[PromptContribution] {
51        &self.contributions
52    }
53
54    pub fn fingerprint(&self) -> PromptFingerprint {
55        self.fingerprint
56    }
57}
58
59impl Default for PromptContributionSet {
60    fn default() -> Self {
61        Self::empty()
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct PromptBuildInput {
67    pub template: PromptTemplate,
68    pub template_fingerprint: PromptFingerprint,
69    pub execution_prompt: Arc<str>,
70    pub execution_prompt_fingerprint: PromptFingerprint,
71    pub tool_names: Arc<Vec<String>>,
72    pub tool_names_fingerprint: PromptFingerprint,
73    pub contributions: PromptContributionSet,
74}
75
76#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
77pub struct PromptContext {
78    #[serde(default)]
79    pub execution_prompt: Arc<str>,
80    pub tool_names: Arc<Vec<String>>,
81    pub contributions: Arc<Vec<PromptContribution>>,
82}
83
84impl PromptContext {
85    pub fn has_tool(&self, tool_name: &str) -> bool {
86        self.tool_names.iter().any(|name| name == tool_name)
87    }
88}
89
90#[derive(Clone, Debug)]
91pub struct PreparedPrompt {
92    pub context: PromptContext,
93    pub system_prompt: Arc<str>,
94}
95
96/// Single-slot memo for the rendered system prompt, keyed by a hash of
97/// the inputs. Most consecutive turns in a session pass identical
98/// inputs (template, turn-driver-preamble-derived `execution_prompt` /
99/// `tool_names`, plus context contributions), so a one-slot cache hits
100/// repeatedly and avoids the section-by-section `Vec<String>::join`
101/// work in `PromptTemplate::render`.
102#[derive(Default)]
103pub struct PromptCache {
104    inner: Mutex<Option<(u64, Arc<str>)>>,
105}
106
107impl PromptCache {
108    pub fn new() -> Self {
109        Self::default()
110    }
111
112    pub fn clear(&self) {
113        if let Ok(mut guard) = self.inner.lock() {
114            *guard = None;
115        }
116    }
117}
118
119impl std::fmt::Debug for PromptCache {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("PromptCache").finish_non_exhaustive()
122    }
123}
124
125pub fn build_prompt(input: PromptBuildInput) -> PreparedPrompt {
126    build_prompt_cached(input, None)
127}
128
129pub fn build_prompt_cached(input: PromptBuildInput, cache: Option<&PromptCache>) -> PreparedPrompt {
130    let context = PromptContext {
131        execution_prompt: Arc::clone(&input.execution_prompt),
132        tool_names: Arc::clone(&input.tool_names),
133        contributions: input.contributions.as_arc(),
134    };
135    let key = cache.map(|_| hash_prompt_inputs(&input));
136    if let (Some(cache), Some(key)) = (cache, key)
137        && let Some(cached) = cache.inner.lock().ok().and_then(|guard| {
138            guard
139                .as_ref()
140                .filter(|(k, _)| *k == key)
141                .map(|(_, v)| Arc::clone(v))
142        })
143    {
144        return PreparedPrompt {
145            context,
146            system_prompt: cached,
147        };
148    }
149    let system_prompt: Arc<str> = Arc::from(input.template.render(&context));
150    if let (Some(cache), Some(key)) = (cache, key)
151        && let Ok(mut guard) = cache.inner.lock()
152    {
153        *guard = Some((key, Arc::clone(&system_prompt)));
154    }
155    PreparedPrompt {
156        context,
157        system_prompt,
158    }
159}
160
161pub fn prompt_template_fingerprint(template: &PromptTemplate) -> PromptFingerprint {
162    PromptFingerprint::from_hashable(template)
163}
164
165pub fn prompt_text_fingerprint(text: &str) -> PromptFingerprint {
166    PromptFingerprint::from_hashable(text)
167}
168
169pub fn prompt_tool_names_fingerprint(tool_names: &[String]) -> PromptFingerprint {
170    PromptFingerprint::from_hashable(tool_names)
171}
172
173fn hash_prompt_inputs(input: &PromptBuildInput) -> u64 {
174    let mut hasher = std::collections::hash_map::DefaultHasher::new();
175    input.template_fingerprint.write(&mut hasher);
176    input.execution_prompt_fingerprint.write(&mut hasher);
177    input.tool_names_fingerprint.write(&mut hasher);
178    input.contributions.fingerprint().write(&mut hasher);
179    hasher.finish()
180}
181
182fn fingerprint_contributions(contributions: &[PromptContribution]) -> PromptFingerprint {
183    let mut hasher = std::collections::hash_map::DefaultHasher::new();
184    for contribution in contributions {
185        contribution.slot.hash(&mut hasher);
186        contribution.priority.hash(&mut hasher);
187        contribution.title.hash(&mut hasher);
188        contribution.content.hash(&mut hasher);
189    }
190    PromptFingerprint(hasher.finish())
191}
192
193fn merge_prompt_contributions(contributions: Vec<PromptContribution>) -> Vec<PromptContribution> {
194    let mut merged = contributions
195        .into_iter()
196        .filter_map(normalize_contribution)
197        .collect::<Vec<_>>();
198
199    merged.sort_by(|left, right| {
200        slot_order(left.slot)
201            .cmp(&slot_order(right.slot))
202            .then(left.priority.cmp(&right.priority))
203            .then_with(|| left.title.cmp(&right.title))
204            .then_with(|| left.content.cmp(&right.content))
205    });
206
207    // Duplicates are adjacent after the sort, so `dedup_by` on &str
208    // refs drops them without cloning anything.
209    merged.dedup_by(|a, b| {
210        slot_order(a.slot) == slot_order(b.slot)
211            && a.priority == b.priority
212            && a.title.as_deref() == b.title.as_deref()
213            && a.content == b.content
214    });
215    merged
216}
217
218fn normalize_contribution(mut contribution: PromptContribution) -> Option<PromptContribution> {
219    contribution.content = Arc::from(contribution.content.trim());
220    if contribution.content.is_empty() {
221        return None;
222    }
223    contribution.title = contribution
224        .title
225        .as_deref()
226        .map(str::trim)
227        .filter(|title| !title.is_empty())
228        .map(Arc::from);
229    Some(contribution)
230}
231
232fn slot_order(slot: crate::PromptSlot) -> usize {
233    match slot {
234        crate::PromptSlot::Intro => 0,
235        crate::PromptSlot::Execution => 1,
236        crate::PromptSlot::Guidance => 2,
237        crate::PromptSlot::ProjectInstructions => 3,
238        crate::PromptSlot::RuntimeContext => 4,
239        crate::PromptSlot::Environment => 5,
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use crate::{
247        PromptBuiltin, PromptContribution, PromptLayer, PromptSlot, PromptTemplate,
248        PromptTemplateEntry, PromptTemplateSection, default_prompt_template, resolve_prompt_layers,
249    };
250
251    fn input(
252        template: PromptTemplate,
253        execution_prompt: &str,
254        tool_names: Vec<String>,
255        contributions: Vec<PromptContribution>,
256    ) -> PromptBuildInput {
257        let execution_prompt: Arc<str> = Arc::from(execution_prompt);
258        let tool_names = Arc::new(tool_names);
259        PromptBuildInput {
260            template_fingerprint: prompt_template_fingerprint(&template),
261            template,
262            execution_prompt_fingerprint: prompt_text_fingerprint(&execution_prompt),
263            execution_prompt,
264            tool_names_fingerprint: prompt_tool_names_fingerprint(&tool_names),
265            tool_names,
266            contributions: PromptContributionSet::new(contributions),
267        }
268    }
269
270    #[test]
271    fn build_prompt_renders_template_from_merged_context() {
272        let prepared = build_prompt(input(
273            default_prompt_template(),
274            "Use tools.",
275            vec!["read_file".to_string()],
276            vec![
277                PromptContribution::guidance("Repo", "Follow repo rules."),
278                PromptContribution::guidance("Repo", "Follow repo rules."),
279                PromptContribution::project_instructions("Be careful."),
280            ],
281        ));
282
283        assert!(prepared.system_prompt.contains("Use tools."));
284        assert!(prepared.system_prompt.contains("Follow repo rules."));
285        assert!(prepared.system_prompt.contains("Be careful."));
286        assert_eq!(prepared.context.contributions.len(), 2);
287    }
288
289    #[test]
290    fn build_prompt_cached_reuses_arc_on_identical_inputs() {
291        let cache = PromptCache::new();
292        let inputs = || {
293            input(
294                default_prompt_template(),
295                "Use tools.",
296                vec!["read_file".to_string()],
297                vec![PromptContribution::guidance("Repo", "Follow repo rules.")],
298            )
299        };
300        let first = build_prompt_cached(inputs(), Some(&cache));
301        let second = build_prompt_cached(inputs(), Some(&cache));
302        assert!(Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
303    }
304
305    #[test]
306    fn build_prompt_cached_renders_again_when_inputs_change() {
307        let cache = PromptCache::new();
308        let first = build_prompt_cached(
309            input(
310                default_prompt_template(),
311                "Use tools.",
312                vec!["read_file".to_string()],
313                vec![],
314            ),
315            Some(&cache),
316        );
317        let second = build_prompt_cached(
318            input(
319                default_prompt_template(),
320                "Use other tools.",
321                vec!["read_file".to_string()],
322                vec![],
323            ),
324            Some(&cache),
325        );
326        assert!(!Arc::ptr_eq(&first.system_prompt, &second.system_prompt));
327        assert_ne!(first.system_prompt, second.system_prompt);
328    }
329
330    fn template_with_text(text: &str) -> PromptTemplate {
331        PromptTemplate::new(vec![PromptTemplateSection::untitled(vec![
332            PromptTemplateEntry::text(text),
333            PromptTemplateEntry::builtin(PromptBuiltin::ExecutionInstructions),
334        ])])
335    }
336
337    fn content(contributions: &[PromptContribution]) -> Vec<&str> {
338        contributions
339            .iter()
340            .map(|contribution| contribution.content.as_ref())
341            .collect()
342    }
343
344    #[test]
345    fn prompt_layers_use_later_template() {
346        let core = PromptLayer::with_template(template_with_text("core"));
347        let session = PromptLayer::with_template(template_with_text("session"));
348        let resolved = resolve_prompt_layers([&core, &session]);
349
350        let rendered = resolved.template.render(&PromptContext {
351            execution_prompt: Arc::from("execute"),
352            ..PromptContext::default()
353        });
354        assert!(rendered.contains("session"));
355        assert!(!rendered.contains("core"));
356    }
357
358    #[test]
359    fn prompt_layers_append_inherited_slot_content() {
360        let core =
361            PromptLayer::new().with_contribution(PromptContribution::guidance("Core", "core"));
362        let session = PromptLayer::new()
363            .with_contribution(PromptContribution::guidance("Session", "session"));
364
365        let resolved = resolve_prompt_layers([&core, &session]);
366        assert_eq!(content(&resolved.contributions), vec!["core", "session"]);
367    }
368
369    #[test]
370    fn prompt_layers_clear_one_slot_without_touching_others() {
371        let core = PromptLayer::new()
372            .with_contribution(PromptContribution::guidance("Guide", "guide"))
373            .with_contribution(PromptContribution::project_instructions("project"));
374        let session = PromptLayer::new().with_cleared_slot(PromptSlot::Guidance);
375
376        let resolved = resolve_prompt_layers([&core, &session]);
377        assert_eq!(content(&resolved.contributions), vec!["project"]);
378    }
379
380    #[test]
381    fn prompt_layers_replace_slot_and_normalize_contribution_slot() {
382        let core =
383            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
384        let session = PromptLayer::new().with_replaced_slot(
385            PromptSlot::Guidance,
386            [PromptContribution::project_instructions("new")],
387        );
388
389        let resolved = resolve_prompt_layers([&core, &session]);
390        assert_eq!(content(&resolved.contributions), vec!["new"]);
391        assert_eq!(resolved.contributions[0].slot, PromptSlot::Guidance);
392    }
393
394    #[test]
395    fn prompt_layers_allow_later_append_after_replace() {
396        let core =
397            PromptLayer::new().with_contribution(PromptContribution::guidance("Guide", "old"));
398        let session = PromptLayer::new().with_replaced_slot(
399            PromptSlot::Guidance,
400            [PromptContribution::guidance("New", "new")],
401        );
402        let turn =
403            PromptLayer::new().with_contribution(PromptContribution::guidance("Turn", "turn"));
404
405        let resolved = resolve_prompt_layers([&core, &session, &turn]);
406        assert_eq!(content(&resolved.contributions), vec!["new", "turn"]);
407    }
408}