milli_core/prompt/
mod.rs

1mod context;
2mod document;
3pub(crate) mod error;
4mod fields;
5mod template_checker;
6
7use std::cell::RefCell;
8use std::convert::TryFrom;
9use std::fmt::Debug;
10use std::num::NonZeroUsize;
11
12use bumpalo::Bump;
13use document::ParseableDocument;
14use error::{NewPromptError, RenderPromptError};
15use fields::{BorrowedFields, OwnedFields};
16
17use self::context::Context;
18use self::document::Document;
19use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
20use crate::update::del_add::DelAdd;
21use crate::GlobalFieldsIdsMap;
22
23pub struct Prompt {
24    template: liquid::Template,
25    template_text: String,
26    max_bytes: Option<NonZeroUsize>,
27}
28
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct PromptData {
31    pub template: String,
32    pub max_bytes: Option<NonZeroUsize>,
33}
34
35impl From<Prompt> for PromptData {
36    fn from(value: Prompt) -> Self {
37        Self { template: value.template_text, max_bytes: value.max_bytes }
38    }
39}
40
41impl TryFrom<PromptData> for Prompt {
42    type Error = NewPromptError;
43
44    fn try_from(value: PromptData) -> Result<Self, Self::Error> {
45        Prompt::new(value.template, value.max_bytes)
46    }
47}
48
49impl Clone for Prompt {
50    fn clone(&self) -> Self {
51        let template_text = self.template_text.clone();
52        Self {
53            template: new_template(&template_text).unwrap(),
54            template_text,
55            max_bytes: self.max_bytes,
56        }
57    }
58}
59
60fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
61    liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
62}
63
64fn default_template() -> liquid::Template {
65    new_template(default_template_text()).unwrap()
66}
67
68fn default_template_text() -> &'static str {
69    "{% for field in fields %}\
70    {% if field.is_searchable and field.value != nil %}\
71    {{ field.name }}: {{ field.value }}\n\
72    {% endif %}\
73    {% endfor %}"
74}
75
76pub fn default_max_bytes() -> NonZeroUsize {
77    NonZeroUsize::new(400).unwrap()
78}
79
80impl Default for Prompt {
81    fn default() -> Self {
82        Self {
83            template: default_template(),
84            template_text: default_template_text().into(),
85            max_bytes: Some(default_max_bytes()),
86        }
87    }
88}
89
90impl Default for PromptData {
91    fn default() -> Self {
92        Self { template: default_template_text().into(), max_bytes: Some(default_max_bytes()) }
93    }
94}
95
96impl Prompt {
97    pub fn new(template: String, max_bytes: Option<NonZeroUsize>) -> Result<Self, NewPromptError> {
98        let this = Self {
99            template: liquid::ParserBuilder::with_stdlib()
100                .build()
101                .unwrap()
102                .parse(&template)
103                .map_err(NewPromptError::cannot_parse_template)?,
104            template_text: template,
105            max_bytes,
106        };
107
108        // render template with special object that's OK with `doc.*` and `fields.*`
109        this.template
110            .render(&template_checker::TemplateChecker)
111            .map_err(NewPromptError::invalid_fields_in_template)?;
112
113        Ok(this)
114    }
115
116    pub fn render_document<
117        'a,       // lifetime of the borrow of the document
118        'doc: 'a, // lifetime of the allocator, will live for an entire chunk of documents
119    >(
120        &self,
121        external_docid: &str,
122        document: impl crate::update::new::document::Document<'a> + Debug,
123        field_id_map: &RefCell<GlobalFieldsIdsMap>,
124        doc_alloc: &'doc Bump,
125    ) -> Result<&'doc str, RenderPromptError> {
126        let document = ParseableDocument::new(document, doc_alloc);
127        let fields = BorrowedFields::new(&document, field_id_map, doc_alloc);
128        let context = Context::new(&document, &fields);
129        let mut rendered = bumpalo::collections::Vec::with_capacity_in(
130            self.max_bytes.unwrap_or_else(default_max_bytes).get(),
131            doc_alloc,
132        );
133        self.template.render_to(&mut rendered, &context).map_err(|liquid_error| {
134            RenderPromptError::missing_context_with_external_docid(
135                external_docid.to_owned(),
136                liquid_error,
137            )
138        })?;
139        Ok(std::str::from_utf8(rendered.into_bump_slice())
140            .expect("render can only write UTF-8 because all inputs and processing preserve utf-8"))
141    }
142
143    pub fn render_kvdeladd(
144        &self,
145        document: &obkv::KvReaderU16,
146        side: DelAdd,
147        field_id_map: &FieldIdMapWithMetadata,
148    ) -> Result<String, RenderPromptError> {
149        let document = Document::new(document, side, field_id_map.as_fields_ids_map());
150        let fields = OwnedFields::new(&document, field_id_map);
151        let context = Context::new(&document, &fields);
152
153        let mut rendered =
154            self.template.render(&context).map_err(RenderPromptError::missing_context)?;
155        if let Some(max_bytes) = self.max_bytes {
156            truncate(&mut rendered, max_bytes.get());
157        }
158        Ok(rendered)
159    }
160}
161
162fn truncate(s: &mut String, max_bytes: usize) {
163    if max_bytes >= s.len() {
164        return;
165    }
166    for i in (0..=max_bytes).rev() {
167        if s.is_char_boundary(i) {
168            s.truncate(i);
169            break;
170        }
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use super::Prompt;
177    use crate::error::FaultSource;
178    use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
179    use crate::prompt::truncate;
180
181    #[test]
182    fn default_template() {
183        // does not panic
184        Prompt::default();
185    }
186
187    #[test]
188    fn empty_template() {
189        Prompt::new("".into(), None).unwrap();
190    }
191
192    #[test]
193    fn template_ok() {
194        Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None).unwrap();
195    }
196
197    #[test]
198    fn template_syntax() {
199        assert!(matches!(
200            Prompt::new("{{doc.title: {{doc.overview}}".into(), None),
201            Err(NewPromptError {
202                kind: NewPromptErrorKind::CannotParseTemplate(_),
203                fault: FaultSource::User
204            })
205        ));
206    }
207
208    #[test]
209    fn template_missing_doc() {
210        assert!(matches!(
211            Prompt::new("{{title}}: {{overview}}".into(), None),
212            Err(NewPromptError {
213                kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
214                fault: FaultSource::User
215            })
216        ));
217    }
218
219    #[test]
220    fn template_nested_doc() {
221        Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None).unwrap();
222    }
223
224    #[test]
225    fn template_fields() {
226        Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None).unwrap();
227    }
228
229    #[test]
230    fn template_fields_ok() {
231        Prompt::new(
232            "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
233            None,
234        )
235        .unwrap();
236    }
237
238    #[test]
239    fn template_fields_invalid() {
240        assert!(matches!(
241            // intentionally garbled field
242            Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None),
243            Err(NewPromptError {
244                kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
245                fault: FaultSource::User
246            })
247        ));
248    }
249
250    // todo: test truncation
251    #[test]
252    fn template_truncation() {
253        let mut s = "インテル ザー ビーグル".to_string();
254
255        truncate(&mut s, 42);
256        assert_eq!(s, "インテル ザー ビーグル");
257
258        assert_eq!(s.len(), 32);
259        truncate(&mut s, 32);
260        assert_eq!(s, "インテル ザー ビーグル");
261
262        truncate(&mut s, 31);
263        assert_eq!(s, "インテル ザー ビーグ");
264        truncate(&mut s, 30);
265        assert_eq!(s, "インテル ザー ビーグ");
266        truncate(&mut s, 28);
267        assert_eq!(s, "インテル ザー ビー");
268        truncate(&mut s, 26);
269        assert_eq!(s, "インテル ザー ビー");
270        truncate(&mut s, 25);
271        assert_eq!(s, "インテル ザー ビ");
272
273        assert_eq!("イ".len(), 3);
274        truncate(&mut s, 3);
275        assert_eq!(s, "イ");
276        truncate(&mut s, 2);
277        assert_eq!(s, "");
278    }
279}