prompt_graph_core/
templates.rs

1use std::collections::HashMap;
2
3/// This is a wasm-compatible implementation of how we handle templates for prompts
4/// I made the decision to implement this in order to avoid needing to build equivalents for multiple platforms.
5use handlebars::{Handlebars, Path, Template};
6use handlebars::template::{Parameter, TemplateElement};
7use serde_json::{Map, Value};
8use serde_json::value::{Map as JsonMap};
9
10
11use anyhow::{Result};
12use crate::proto::serialized_value::Val;
13use crate::proto::{ChangeValue, PromptLibraryRecord, SerializedValue, SerializedValueArray, SerializedValueObject};
14
15
16// https://github.com/microsoft/guidance
17
18// TODO: support accessing a library of prompts injected as partials
19
20// https://github.com/sunng87/handlebars-rust/blob/23ca8d76bee783bf72f627b4c4995d1d11008d17/src/template.rs#L963
21// self.handlebars.register_template_string(name, template).unwrap();
22
23/// Verify that the template and included query paths are valid
24pub fn validate_template(template_str: &str, _query_paths: Vec<Vec<String>>) {
25    // let mut handlebars = Handlebars::new();
26    let template = Template::compile(template_str).unwrap();
27    let mut reference_paths = Vec::new();
28    traverse_ast(&template, &mut reference_paths, vec![]);
29    println!("{:?}", reference_paths);
30    // TODO: check that all query paths are satisfied by this template
31    // handlebars.register_template("test", template).unwrap();
32}
33
34#[derive(Debug, Clone)]
35struct ContextBlock {
36    name: Parameter,
37    params: Vec<Parameter>,
38}
39
40/// Traverse over every partial template in a Template (which can be a set of template partials) and validate that each
41/// partial template can be matched to a either 1) some template type that Handlebars recognizes
42/// or 2) a query path that can pull data out of the event log
43fn traverse_ast(template: &Template, reference_paths: &mut Vec<(Path, Vec<ContextBlock>)>, context: Vec<ContextBlock>) {
44    for el in &template.elements {
45        match el {
46            TemplateElement::RawString(_) => {}
47            TemplateElement::HtmlExpression(helper_block) |
48            TemplateElement::Expression(helper_block) |
49            TemplateElement::HelperBlock(helper_block) => {
50                let deref = *(helper_block.clone());
51                let _params = &deref.params;
52                match &deref.name {
53                    Parameter::Name(_name) => {
54                        // println!("name, {:?} - params {:?}", name, params);
55                        // reference_paths.push((None, context.clone()));
56                    }
57                    Parameter::Path(path) => {
58                        reference_paths.push((path.clone(), context.clone()));
59                    }
60                    Parameter::Literal(_) => {
61                    }
62                    Parameter::Subexpression(_) => {}
63                }
64                if let Some(next_template) = deref.template {
65                    let mut ctx = context.clone();
66                    ctx.extend(vec![ContextBlock {
67                        name: deref.name.clone(),
68                        params: deref.params.clone(),
69                    }]);
70                    traverse_ast(&next_template, reference_paths, ctx);
71                }
72            }
73            TemplateElement::DecoratorExpression(_) => {}
74            TemplateElement::DecoratorBlock(_) => {}
75            TemplateElement::PartialExpression(_) => {}
76            TemplateElement::PartialBlock(_) => {}
77            TemplateElement::Comment(_) => {}
78        }
79    }
80}
81
82fn convert_template_to_prompt() {
83
84}
85
86fn infer_query_from_template() {
87
88}
89
90fn extract_roles_from_template() {
91
92}
93
94/// Recursively flatten a SerializedValue into a set of key paths and values
95pub fn flatten_value_keys(sval: SerializedValue, current_path: Vec<String>) -> Vec<(Vec<String>, Val)> {
96    let mut flattened = vec![];
97    match sval.val {
98        Some(Val::Object(a)) => {
99            for (key, value) in &a.values {
100                let mut path = current_path.clone();
101                path.push(key.clone());
102                flattened.extend(flatten_value_keys(value.clone(), path));
103            }
104        }
105        None => {},
106        x @ _ => { flattened.push((current_path.clone(), x.unwrap())) }
107    }
108    flattened
109}
110
111// TODO: fix the conversion to numbers
112/// Convert a SerializedValue into a serde_json::Value
113pub fn serialized_value_to_json_value(sval: &SerializedValue) -> Value {
114    match &sval.val {
115        Some(Val::Float(f)) => { Value::Number(f.to_string().parse().unwrap()) }
116        Some(Val::Number(n)) => { Value::Number(n.to_string().parse().unwrap()) }
117        Some(Val::String(s)) => { Value::String(s.to_string()) }
118        Some(Val::Boolean(b)) => { Value::Bool(*b) }
119        Some(Val::Array(a)) => {
120            Value::Array(a.values.iter().map(|v| serialized_value_to_json_value(v)).collect())
121        }
122        Some(Val::Object(a)) => {
123            Value::Object(a.values.iter().map(|(k, v)| (k.clone(), serialized_value_to_json_value(v))).collect())
124        }
125        _ => { Value::Null }
126    }
127}
128
129/// Convert a serde_json::Value into a SerializedValue
130pub fn json_value_to_serialized_value(jval: &Value) -> SerializedValue {
131    SerializedValue {
132        val: match jval {
133            Value::Number(n) => {
134                if n.is_i64() {
135                    Some(Val::Number(n.as_i64().unwrap() as i32))
136                } else if n.is_f64() {
137                    Some(Val::Float(n.as_f64().unwrap() as f32))
138                } else {
139                    panic!("Invalid number value")
140                }
141            }
142            Value::String(s) => Some(Val::String(s.clone())),
143            Value::Bool(b) => Some(Val::Boolean(*b)),
144            Value::Array(a) => {
145                Some(Val::Array(SerializedValueArray{ values: a.iter().map(|v| json_value_to_serialized_value(v)).collect()}))
146            }
147            Value::Object(o) => {
148                let mut map = HashMap::new();
149                for (k, v) in o {
150                    map.insert(k.clone(), json_value_to_serialized_value(v));
151                }
152                Some(Val::Object(SerializedValueObject{ values: map }))
153            }
154            Value::Null => None,
155            _ => panic!("Invalid value type"),
156        },
157    }
158}
159
160
161
162/// Recursively convert a path and value into a JSON map, where the path is split into nested keys that map down to the value
163fn query_path_to_json(path: &[String], val: &SerializedValue) -> Option<Map<String, Value>> {
164    let mut map = JsonMap::new();
165    if let Some((head, tail)) = path.split_first() {
166        if tail.is_empty() {
167            map.insert(head.clone(), serialized_value_to_json_value(val));
168        } else {
169            if let Some(created) = query_path_to_json(tail, val) {
170                map.insert(head.clone(), Value::Object(created));
171            }
172        }
173        Some(map)
174    } else {
175        None
176    }
177}
178
179/// Merge two JSON maps together, where the second map takes precedence over the first
180fn merge(a: &mut Value, b: Value) {
181    if let Value::Object(a) = a {
182        if let Value::Object(b) = b {
183            for (k, v) in b {
184                if v.is_null() {
185                    a.remove(&k);
186                }
187                else {
188                    merge(a.entry(k).or_insert(Value::Null), v);
189                }
190            }
191
192            return;
193        }
194    }
195
196    *a = b;
197}
198
199/// Convert a set of query paths into a JSON map where each path is split into nested keys that map down to their respective value
200fn query_paths_to_json(query_paths: &Vec<ChangeValue>) -> Value {
201    let mut m = Value::Object(JsonMap::new());
202    for change_value in query_paths {
203        let path = &change_value.path.as_ref().unwrap().address;
204        let val = &change_value.value.as_ref().unwrap();
205        if let Some(created) = query_path_to_json(path, val) {
206            merge(&mut m, Value::Object(created));
207        }
208        // Allow using unresolved paths as keys
209        if let Some((last, _)) = path.split_last() {
210            if let Some(created) = query_path_to_json(&vec![last.clone()], val) {
211                merge(&mut m, Value::Object(created));
212            }
213        }
214    }
215    m
216}
217
218// TODO: remove these unwraps
219// TODO: add an argument for passing a set of partials
220// TODO: implement block helpers for User and System prompts
221
222
223/// Render a template string, placing in partials (names that map to prompts in the prompt library) and values from the query paths
224/// as records of changes that are made to the event log
225pub fn render_template_prompt(template_str: &str, query_paths: &Vec<ChangeValue>, partials: &HashMap<String, PromptLibraryRecord>) -> Result<String> {
226    let mut reg = Handlebars::new();
227    for (name, prompt) in partials.iter() {
228        reg.register_partial(name, prompt.record.as_ref().unwrap().template.as_str()).unwrap();
229    }
230    reg.register_template_string("tpl_1", template_str).unwrap();
231    reg.register_escape_fn(handlebars::no_escape);
232    let render = reg.render("tpl_1", &query_paths_to_json(query_paths)).unwrap();
233    Ok(render)
234}
235
236
237#[cfg(test)]
238mod tests {
239    use serde_json::json;
240    use crate::create_change_value;
241    use crate::proto::UpsertPromptLibraryRecord;
242    use super::*;
243    use crate::templates::validate_template;
244
245    #[test]
246    fn test_generating_json_map_from_paths() {
247        assert_eq!(query_paths_to_json(&vec![
248            create_change_value(
249                vec![String::from("user"), String::from("name")],
250                Some(Val::String(String::from("John"))),
251                0
252            ),
253        ]), json!({
254            "name": "John",
255            "user": {
256                "name": "John",
257            }})
258        );
259
260        assert_eq!(query_paths_to_json(&vec![
261            create_change_value(
262                vec![String::from("user"), String::from("name")],
263                Some(Val::String(String::from("John"))),
264                0
265            ),
266            create_change_value(
267                vec![String::from("user"), String::from("last_name")],
268                Some(Val::String(String::from("Panhuyzen"))),
269                0
270            )
271        ]), json!({
272                "name": "John",
273                "last_name": "Panhuyzen",
274            "user": {
275                "name": "John",
276                "last_name": "Panhuyzen"
277            }})
278        );
279    }
280
281    #[test]
282    fn test_template_validation() {
283        validate_template(
284            "Hello, {{name}}! {{user.name}}",
285            vec![vec!["user".to_string(), "name".to_string()]],
286        );
287    }
288
289    #[test]
290    fn test_template_validation_eval_context() {
291        validate_template(
292            "{{#with user}} {{name}} {{/with}}",
293            vec![vec!["user".to_string(), "name".to_string()]],
294        );
295    }
296
297    #[test]
298    fn test_template_validation_eval_context_each() {
299        validate_template(
300            "{{#each users}} {{name}} {{/each}}",
301            vec![vec!["user".to_string(), "name".to_string()]],
302        );
303    }
304
305    #[test]
306    fn test_guidance_style_system_prompts() {
307        validate_template(
308            "\
309                {{#system}}
310                You are a helpful assistant. {{value}}
311                {{/system}}
312                {{#user}}
313                    test
314                {{/user}}
315                {{#assistant}}
316                    test
317                {{/assistant}}
318            ",
319            vec![vec!["user".to_string(), "name".to_string()]],
320        );
321    }
322
323    #[test]
324    fn test_rendering_template() {
325        let rendered = render_template_prompt(
326            &"Basic template {{user.name}}",
327            &vec![
328                create_change_value(
329                    vec![String::from("user"), String::from("name")],
330                    Some(Val::String(String::from("John"))),
331                    0
332                ),
333                create_change_value(
334                    vec![String::from("user"), String::from("last_name")],
335                    Some(Val::String(String::from("Panhuyzen"))),
336                    0
337                )
338            ],
339            &HashMap::new()
340        );
341        assert_eq!(rendered.unwrap(), "Basic template John");
342    }
343
344    #[test]
345    fn test_rendering_template_clean_resolve() {
346        let rendered = render_template_prompt(
347            &"Basic template {{name}} {{last_name}}",
348            &vec![
349                create_change_value(
350                    vec![String::from("user"), String::from("name")],
351                    Some(Val::String(String::from("John"))),
352                    0
353                ),
354                create_change_value(
355                    vec![String::from("user"), String::from("last_name")],
356                    Some(Val::String(String::from("Panhuyzen"))),
357                    0
358                )
359            ],
360            &HashMap::new()
361        );
362        assert_eq!(rendered.unwrap(), "Basic template John Panhuyzen");
363    }
364
365    #[test]
366    fn test_rendering_template_with_partial() {
367        let mut partials = HashMap::new();
368        partials.insert("part".to_string(), PromptLibraryRecord {
369            record: Some(UpsertPromptLibraryRecord {
370                template: "[{{user.name}} inside partial]".to_string(),
371                name: "part".to_string(),
372                id: "0".to_string(),
373                description: None,
374            }),
375            version_counter: 0,
376        });
377
378        let rendered = render_template_prompt(
379            &"Basic template {{> part}}",
380            &vec![
381                create_change_value(
382                    vec![String::from("user"), String::from("name")],
383                    Some(Val::String(String::from("John"))),
384                    0
385                ),
386                create_change_value(
387                    vec![String::from("user"), String::from("last_name")],
388                    Some(Val::String(String::from("Panhuyzen"))),
389                    0
390                )
391            ],
392            &partials
393        );
394        assert_eq!(rendered.unwrap(), "Basic template [John inside partial]");
395    }
396
397
398}