1mod context;
2mod document;
3pub(crate) mod error;
4mod fields;
5mod template_checker;
6
7use std::cell::RefCell;
8use std::convert::TryFrom;
9use std::fmt::Debug;
10use std::num::NonZeroUsize;
11
12use bumpalo::Bump;
13use document::ParseableDocument;
14use error::{NewPromptError, RenderPromptError};
15use fields::{BorrowedFields, OwnedFields};
16
17use self::context::Context;
18use self::document::Document;
19use crate::fields_ids_map::metadata::FieldIdMapWithMetadata;
20use crate::update::del_add::DelAdd;
21use crate::GlobalFieldsIdsMap;
22
23pub struct Prompt {
24 template: liquid::Template,
25 template_text: String,
26 max_bytes: Option<NonZeroUsize>,
27}
28
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct PromptData {
31 pub template: String,
32 pub max_bytes: Option<NonZeroUsize>,
33}
34
35impl From<Prompt> for PromptData {
36 fn from(value: Prompt) -> Self {
37 Self { template: value.template_text, max_bytes: value.max_bytes }
38 }
39}
40
41impl TryFrom<PromptData> for Prompt {
42 type Error = NewPromptError;
43
44 fn try_from(value: PromptData) -> Result<Self, Self::Error> {
45 Prompt::new(value.template, value.max_bytes)
46 }
47}
48
49impl Clone for Prompt {
50 fn clone(&self) -> Self {
51 let template_text = self.template_text.clone();
52 Self {
53 template: new_template(&template_text).unwrap(),
54 template_text,
55 max_bytes: self.max_bytes,
56 }
57 }
58}
59
60fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
61 liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
62}
63
64fn default_template() -> liquid::Template {
65 new_template(default_template_text()).unwrap()
66}
67
68fn default_template_text() -> &'static str {
69 "{% for field in fields %}\
70 {% if field.is_searchable and field.value != nil %}\
71 {{ field.name }}: {{ field.value }}\n\
72 {% endif %}\
73 {% endfor %}"
74}
75
76pub fn default_max_bytes() -> NonZeroUsize {
77 NonZeroUsize::new(400).unwrap()
78}
79
80impl Default for Prompt {
81 fn default() -> Self {
82 Self {
83 template: default_template(),
84 template_text: default_template_text().into(),
85 max_bytes: Some(default_max_bytes()),
86 }
87 }
88}
89
90impl Default for PromptData {
91 fn default() -> Self {
92 Self { template: default_template_text().into(), max_bytes: Some(default_max_bytes()) }
93 }
94}
95
96impl Prompt {
97 pub fn new(template: String, max_bytes: Option<NonZeroUsize>) -> Result<Self, NewPromptError> {
98 let this = Self {
99 template: liquid::ParserBuilder::with_stdlib()
100 .build()
101 .unwrap()
102 .parse(&template)
103 .map_err(NewPromptError::cannot_parse_template)?,
104 template_text: template,
105 max_bytes,
106 };
107
108 this.template
110 .render(&template_checker::TemplateChecker)
111 .map_err(NewPromptError::invalid_fields_in_template)?;
112
113 Ok(this)
114 }
115
116 pub fn render_document<
117 'a, 'doc: 'a, >(
120 &self,
121 external_docid: &str,
122 document: impl crate::update::new::document::Document<'a> + Debug,
123 field_id_map: &RefCell<GlobalFieldsIdsMap>,
124 doc_alloc: &'doc Bump,
125 ) -> Result<&'doc str, RenderPromptError> {
126 let document = ParseableDocument::new(document, doc_alloc);
127 let fields = BorrowedFields::new(&document, field_id_map, doc_alloc);
128 let context = Context::new(&document, &fields);
129 let mut rendered = bumpalo::collections::Vec::with_capacity_in(
130 self.max_bytes.unwrap_or_else(default_max_bytes).get(),
131 doc_alloc,
132 );
133 self.template.render_to(&mut rendered, &context).map_err(|liquid_error| {
134 RenderPromptError::missing_context_with_external_docid(
135 external_docid.to_owned(),
136 liquid_error,
137 )
138 })?;
139 Ok(std::str::from_utf8(rendered.into_bump_slice())
140 .expect("render can only write UTF-8 because all inputs and processing preserve utf-8"))
141 }
142
143 pub fn render_kvdeladd(
144 &self,
145 document: &obkv::KvReaderU16,
146 side: DelAdd,
147 field_id_map: &FieldIdMapWithMetadata,
148 ) -> Result<String, RenderPromptError> {
149 let document = Document::new(document, side, field_id_map.as_fields_ids_map());
150 let fields = OwnedFields::new(&document, field_id_map);
151 let context = Context::new(&document, &fields);
152
153 let mut rendered =
154 self.template.render(&context).map_err(RenderPromptError::missing_context)?;
155 if let Some(max_bytes) = self.max_bytes {
156 truncate(&mut rendered, max_bytes.get());
157 }
158 Ok(rendered)
159 }
160}
161
162fn truncate(s: &mut String, max_bytes: usize) {
163 if max_bytes >= s.len() {
164 return;
165 }
166 for i in (0..=max_bytes).rev() {
167 if s.is_char_boundary(i) {
168 s.truncate(i);
169 break;
170 }
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use super::Prompt;
177 use crate::error::FaultSource;
178 use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
179 use crate::prompt::truncate;
180
181 #[test]
182 fn default_template() {
183 Prompt::default();
185 }
186
187 #[test]
188 fn empty_template() {
189 Prompt::new("".into(), None).unwrap();
190 }
191
192 #[test]
193 fn template_ok() {
194 Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None).unwrap();
195 }
196
197 #[test]
198 fn template_syntax() {
199 assert!(matches!(
200 Prompt::new("{{doc.title: {{doc.overview}}".into(), None),
201 Err(NewPromptError {
202 kind: NewPromptErrorKind::CannotParseTemplate(_),
203 fault: FaultSource::User
204 })
205 ));
206 }
207
208 #[test]
209 fn template_missing_doc() {
210 assert!(matches!(
211 Prompt::new("{{title}}: {{overview}}".into(), None),
212 Err(NewPromptError {
213 kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
214 fault: FaultSource::User
215 })
216 ));
217 }
218
219 #[test]
220 fn template_nested_doc() {
221 Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None).unwrap();
222 }
223
224 #[test]
225 fn template_fields() {
226 Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None).unwrap();
227 }
228
229 #[test]
230 fn template_fields_ok() {
231 Prompt::new(
232 "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
233 None,
234 )
235 .unwrap();
236 }
237
238 #[test]
239 fn template_fields_invalid() {
240 assert!(matches!(
241 Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None),
243 Err(NewPromptError {
244 kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
245 fault: FaultSource::User
246 })
247 ));
248 }
249
250 #[test]
252 fn template_truncation() {
253 let mut s = "インテル ザー ビーグル".to_string();
254
255 truncate(&mut s, 42);
256 assert_eq!(s, "インテル ザー ビーグル");
257
258 assert_eq!(s.len(), 32);
259 truncate(&mut s, 32);
260 assert_eq!(s, "インテル ザー ビーグル");
261
262 truncate(&mut s, 31);
263 assert_eq!(s, "インテル ザー ビーグ");
264 truncate(&mut s, 30);
265 assert_eq!(s, "インテル ザー ビーグ");
266 truncate(&mut s, 28);
267 assert_eq!(s, "インテル ザー ビー");
268 truncate(&mut s, 26);
269 assert_eq!(s, "インテル ザー ビー");
270 truncate(&mut s, 25);
271 assert_eq!(s, "インテル ザー ビ");
272
273 assert_eq!("イ".len(), 3);
274 truncate(&mut s, 3);
275 assert_eq!(s, "イ");
276 truncate(&mut s, 2);
277 assert_eq!(s, "");
278 }
279}