use std::collections::HashMap;
pub type ComputeFn = Box<dyn Fn() -> Option<String> + Send + Sync>;
pub struct SystemPromptSection {
pub name: String,
pub compute: ComputeFn,
pub cache_break: bool,
}
pub fn system_prompt_section(name: &str, compute: ComputeFn) -> SystemPromptSection {
SystemPromptSection {
name: name.to_string(),
compute,
cache_break: false,
}
}
pub fn dangerous_uncached_system_prompt_section(
name: &str,
compute: ComputeFn,
_reason: &str,
) -> SystemPromptSection {
SystemPromptSection {
name: name.to_string(),
compute,
cache_break: true,
}
}
pub fn resolve_system_prompt_sections(
sections: &[SystemPromptSection],
cache: &mut HashMap<String, Option<String>>,
) -> Vec<Option<String>> {
sections
.iter()
.map(|s| {
if !s.cache_break {
if let Some(cached) = cache.get(&s.name) {
return cached.clone();
}
}
let value = (s.compute)();
cache.insert(s.name.clone(), value.clone());
value
})
.collect()
}
pub fn clear_system_prompt_sections(_cache: &mut HashMap<String, Option<String>>) {
_cache.clear();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_system_prompt_section() {
let compute = Box::new(|| Some("test prompt".to_string()));
let section = system_prompt_section("test", compute);
assert_eq!(section.name, "test");
assert!(!section.cache_break);
}
#[test]
fn test_uncached_section() {
let compute = Box::new(|| Some("test prompt".to_string()));
let section =
dangerous_uncached_system_prompt_section("test", compute, "needs fresh value");
assert_eq!(section.name, "test");
assert!(section.cache_break);
}
#[test]
fn test_resolve_with_cache() {
let compute = Box::new(|| Some("computed value".to_string()));
let section = system_prompt_section("test", compute);
let mut cache = HashMap::new();
let results = resolve_system_prompt_sections(&[section], &mut cache);
assert_eq!(results.len(), 1);
assert_eq!(results[0], Some("computed value".to_string()));
}
#[test]
fn test_cache_hit() {
let compute = Box::new(|| Some("new value".to_string()));
let section = system_prompt_section("test", compute);
let mut cache = HashMap::new();
cache.insert("test".to_string(), Some("cached value".to_string()));
let results = resolve_system_prompt_sections(&[section], &mut cache);
assert_eq!(results[0], Some("cached value".to_string()));
}
}