1use lazy_static::lazy_static;
2use ragit_fs::{extension, join, parent, read_bytes, read_string};
3use regex::bytes::Regex;
4use serde::Serialize;
5use serde_json::Value;
6
7mod error;
8mod image;
9mod message;
10mod role;
11mod schema;
12mod util;
13
14pub use error::{Error, JsonType};
15pub use image::ImageType;
16pub use message::{Message, MessageContent};
17pub use role::{PdlRole, Role};
18pub use schema::{Schema, SchemaParseError, parse_schema, render_pdl_schema};
19pub use util::{decode_base64, encode_base64};
20
21lazy_static! {
22 static ref MEDIA_RE: Regex = Regex::new(r"^media\((.+)\)$").unwrap();
23 static ref RAW_MEDIA_RE: Regex = Regex::new(r"^raw_media\(([a-zA-Z0-9]+):([^:]+)\)$").unwrap();
24}
25
26pub fn into_context<T: Serialize>(v: &T) -> Result<tera::Context, Error> {
30 let v = serde_json::to_value(v)?;
31 let mut result = tera::Context::new();
32
33 match v {
34 Value::Object(object) => {
35 for (k, v) in object.iter() {
36 result.insert(k, v);
37 }
38
39 Ok(result)
40 },
41 _ => Err(Error::JsonTypeError {
42 expected: JsonType::Object,
43 got: (&v).into(),
44 }),
45 }
46}
47
48#[derive(Clone, Debug)]
49pub struct Pdl {
50 pub schema: Option<Schema>,
51 pub messages: Vec<Message>,
52}
53
54impl Pdl {
55 pub fn validate(&self) -> Result<(), Error> {
56 if self.messages.is_empty() {
57 return Err(Error::InvalidPdl(String::from("A pdl file is empty.")));
58 }
59
60 let mut after_user = false;
61 let mut after_assistant = false;
62
63 for (index, Message { role, .. }) in self.messages.iter().enumerate() {
64 match role {
65 Role::User => {
66 if after_user {
67 return Err(Error::InvalidPdl(String::from("<|user|> appeared twice in a row.")));
68 }
69
70 after_user = true;
71 after_assistant = false;
72 },
73 Role::Assistant => {
74 if after_assistant {
75 return Err(Error::InvalidPdl(String::from("<|assistant|> appeared twice in a row.")));
76 }
77
78 after_user = false;
79 after_assistant = true;
80 },
81 Role::System => {
82 if index != 0 {
83 return Err(Error::InvalidPdl(String::from("<|system|> must appear at top.")));
84 }
85 },
86 Role::Reasoning => {}, }
88 }
89
90 match self.messages.last() {
91 Some(Message { role: Role::Assistant, .. }) => {
92 return Err(Error::InvalidPdl(String::from("A pdl file ends with <|assistant|>.")));
93 },
94 _ => {},
95 }
96
97 Ok(())
98 }
99}
100
101pub fn parse_pdl_from_file(
102 path: &str,
103 context: &tera::Context,
104
105 strict_mode: bool,
107) -> Result<Pdl, Error> {
108 parse_pdl(
109 &read_string(path)?,
110 context,
111 &parent(path)?,
112 strict_mode,
113 )
114}
115
116pub fn parse_pdl(
117 s: &str,
118 context: &tera::Context,
119 curr_dir: &str,
120
121 strict_mode: bool,
123) -> Result<Pdl, Error> {
124 let mut renderer = tera::Tera::default();
125 renderer.autoescape_on(vec!["__tera_one_off"]);
126 renderer.set_escape_fn(escape_pdl_tokens);
127
128 let tera_rendered = match renderer.render_str(s, context) {
129 Ok(t) => t,
130 Err(e) => if strict_mode {
131 return Err(e.into());
132 } else {
133 s.to_string()
134 },
135 };
136
137 let mut messages = vec![];
138 let mut schema = None;
139 let mut curr_role = None;
140 let mut line_buffer = vec![];
141
142 let last_line = "<|assistant|>";
146
147 for line in tera_rendered.lines().chain(std::iter::once(last_line)) {
148 let trimmed = line.trim();
149
150 if trimmed.starts_with("<|") && trimmed.ends_with("|>") && trimmed.len() > 4 {
152 match trimmed.to_ascii_lowercase().get(2..(trimmed.len() - 2)).unwrap() {
153 t @ ("user" | "system" | "assistant" | "schema" | "reasoning") => {
154 if !line_buffer.is_empty() || curr_role.is_some() {
155 match curr_role {
156 Some(PdlRole::Schema) => match parse_schema(&line_buffer.join("\n")) {
157 Ok(s) => {
158 if schema.is_some() && strict_mode {
159 return Err(Error::InvalidPdl(String::from("<|schema|> appeared multiple times.")));
160 }
161
162 schema = Some(s);
163 },
164 Err(e) => {
165 if strict_mode {
166 return Err(e.into());
167 }
168 },
169 },
170 Some(PdlRole::Reasoning) => {},
172 _ => {
173 let raw_contents = line_buffer.join("\n");
176 let raw_contents = raw_contents.trim();
177
178 let role = match curr_role {
179 Some(role) => role,
180 None => {
181 if raw_contents.is_empty() {
182 curr_role = Some(PdlRole::from(t));
183 line_buffer = vec![];
184 continue;
185 }
186
187 if strict_mode {
188 return Err(Error::RoleMissing);
189 }
190
191 PdlRole::System
192 },
193 };
194
195 match into_message_contents(&raw_contents, curr_dir) {
196 Ok(t) => {
197 messages.push(Message {
198 role: role.into(),
199 content: t,
200 });
201 },
202 Err(e) => {
203 if strict_mode {
204 return Err(e);
205 }
206
207 else {
208 messages.push(Message {
209 role: role.into(),
210 content: vec![MessageContent::String(raw_contents.to_string())],
211 });
212 }
213 },
214 }
215 },
216 }
217 }
218
219 curr_role = Some(PdlRole::from(t));
220 line_buffer = vec![];
221 continue;
222 },
223 t => {
224 if strict_mode && t.chars().all(|c| c.is_ascii_alphabetic()) {
225 return Err(Error::InvalidTurnSeparator(t.to_string()));
226 }
227
228 line_buffer.push(line.to_string());
229 },
230 }
231 }
232
233 else {
234 line_buffer.push(line.to_string());
235 }
236 }
237
238 if let Some(Message { content, .. }) = messages.last() {
239 if content.is_empty() {
240 messages.pop().unwrap();
241 }
242 }
243
244 let result = Pdl {
245 schema,
246 messages,
247 };
248
249 if strict_mode {
250 result.validate()?;
251 }
252
253 Ok(result)
254}
255
256pub fn escape_pdl_tokens(s: &str) -> String {
257 s.replace("&", "&").replace("|>", "|>").replace("<|", "<|")
258}
259
260pub fn unescape_pdl_tokens(s: &str) -> String { s.replace("<", "<").replace(">", ">").replace("&", "&")
262}
263
264fn into_message_contents(s: &str, curr_dir: &str) -> Result<Vec<MessageContent>, Error> {
265 let bytes = s.as_bytes().iter().map(|b| *b).collect::<Vec<_>>();
266 let mut index = 0;
267 let mut result = vec![];
268 let mut string_buffer = vec![];
269
270 loop {
271 match bytes.get(index) {
272 Some(b'<') => match try_parse_inline_block(&bytes, index, curr_dir) {
273 Ok(Some((image_type, bytes, new_index))) => {
274 if !string_buffer.is_empty() {
275 match String::from_utf8(string_buffer.clone()) {
276 Ok(s) => {
277 result.push(MessageContent::String(unescape_pdl_tokens(&s)));
278 },
279 Err(e) => {
280 return Err(e.into());
281 },
282 }
283 }
284
285 result.push(MessageContent::Image { image_type, bytes });
286 index = new_index;
287 string_buffer = vec![];
288 continue;
289 },
290 Ok(None) => {
291 string_buffer.push(b'<');
292 },
293 Err(e) => {
294 return Err(e);
295 },
296 },
297 Some(b) => {
298 string_buffer.push(*b);
299 },
300 None => {
301 if !string_buffer.is_empty() {
302 match String::from_utf8(string_buffer) {
303 Ok(s) => {
304 result.push(MessageContent::String(unescape_pdl_tokens(&s)));
305 },
306 Err(e) => {
307 return Err(e.into());
308 },
309 }
310 }
311
312 break;
313 },
314 }
315
316 index += 1;
317 }
318
319 Ok(result)
320}
321
322fn try_parse_inline_block(bytes: &[u8], index: usize, curr_dir: &str) -> Result<Option<(ImageType, Vec<u8>, usize)>, Error> {
326 match try_get_pdl_token(bytes, index) {
327 Some((token, new_index)) => {
328 let media_re = &MEDIA_RE;
329 let raw_media_re = &RAW_MEDIA_RE;
330
331 if let Some(cap) = raw_media_re.captures(token) {
332 let image_type = String::from_utf8_lossy(&cap[1]).to_string();
333 let image_bytes = String::from_utf8_lossy(&cap[2]).to_string();
334
335 Ok(Some((ImageType::from_extension(&image_type)?, decode_base64(&image_bytes)?, new_index)))
336 }
337
338 else if let Some(cap) = media_re.captures(token) {
339 let path = &cap[1];
340 let file = join(curr_dir, &String::from_utf8_lossy(path).to_string())?;
341 Ok(Some((ImageType::from_extension(&extension(&file)?.unwrap_or(String::new()))?, read_bytes(&file)?, new_index)))
342 }
343
344 else {
345 Err(Error::InvalidInlineBlock)
346 }
347 },
348
349 None => Ok(None),
351 }
352}
353
354fn try_get_pdl_token(bytes: &[u8], mut index: usize) -> Option<(&[u8], usize)> {
355 let old_index = index;
356
357 match (bytes.get(index), bytes.get(index + 1)) {
358 (Some(b'<'), Some(b'|')) => {
359 index += 2;
360
361 loop {
362 match (bytes.get(index), bytes.get(index + 1)) {
363 (Some(b'|'), Some(b'>')) => {
364 return Some((&bytes[(old_index + 2)..index], index + 2));
365 },
366 (_, Some(b'|')) => {
367 index += 1;
368 },
369 (_, None) => {
370 return None;
371 },
372 _ => {
373 index += 2;
374 },
375 }
376 }
377 },
378 _ => None,
379 }
380}
381
382#[cfg(test)]
383mod tests {
384 use crate::{
385 ImageType,
386 Message,
387 MessageContent,
388 Pdl,
389 Role,
390 decode_base64,
391 parse_pdl,
392 parse_pdl_from_file,
393 };
394 use ragit_fs::{
395 WriteMode,
396 join,
397 temp_dir,
398 write_string,
399 };
400
401 #[test]
403 fn messages_from_file_test() {
404 let tmp_path = join(
405 &temp_dir().unwrap(),
406 "test_messages.tera",
407 ).unwrap();
408
409 write_string(
410 &tmp_path,
411"
412<|system|>
413
414You're a code helper.
415
416<|user|>
417
418Write me a sudoku-solver.
419
420
421",
422 WriteMode::CreateOrTruncate,
423 ).unwrap();
424
425 let Pdl { messages, schema } = parse_pdl_from_file(
426 &tmp_path,
427 &tera::Context::new(),
428 true, ).unwrap();
430
431 assert_eq!(
432 messages,
433 vec![
434 Message {
435 role: Role::System,
436 content: vec![
437 MessageContent::String(String::from("You're a code helper.")),
438 ],
439 },
440 Message {
441 role: Role::User,
442 content: vec![
443 MessageContent::String(String::from("Write me a sudoku-solver.")),
444 ],
445 },
446 ],
447 );
448 assert_eq!(
449 schema,
450 None,
451 );
452 }
453
454 #[test]
455 fn media_content_test() {
456 let Pdl { messages, schema } = parse_pdl(
457"
458<|user|>
459
460<|raw_media(png:HiMyNameIsBaehyunsol)|>
461",
462 &tera::Context::new(),
463 ".", true, ).unwrap();
466
467 assert_eq!(
468 messages,
469 vec![
470 Message {
471 role: Role::User,
472 content: vec![
473 MessageContent::Image {
474 image_type: ImageType::Png,
475 bytes: decode_base64("HiMyNameIsBaehyunsol").unwrap(),
476 },
477 ],
478 },
479 ],
480 );
481 assert_eq!(
482 schema,
483 None,
484 );
485 }
486}