ragit_pdl/
lib.rs

1use lazy_static::lazy_static;
2use ragit_fs::{extension, join, parent, read_bytes, read_string};
3use regex::bytes::Regex;
4use serde::Serialize;
5use serde_json::Value;
6
7mod error;
8mod image;
9mod message;
10mod role;
11mod schema;
12mod util;
13
14pub use error::{Error, JsonType};
15pub use image::ImageType;
16pub use message::{Message, MessageContent};
17pub use role::{PdlRole, Role};
18pub use schema::{Schema, SchemaParseError, parse_schema, render_pdl_schema};
19pub use util::{decode_base64, encode_base64};
20
21lazy_static! {
22    static ref MEDIA_RE: Regex = Regex::new(r"^media\((.+)\)$").unwrap();
23    static ref RAW_MEDIA_RE: Regex = Regex::new(r"^raw_media\(([a-zA-Z0-9]+):([^:]+)\)$").unwrap();
24}
25
26/// `parse_pdl` takes `tera::Context` as an input. If you're too lazy to
27/// construct a `tera::Context`, you can use this function. It converts
28/// a rust struct into a json object.
29pub fn into_context<T: Serialize>(v: &T) -> Result<tera::Context, Error> {
30    let v = serde_json::to_value(v)?;
31    let mut result = tera::Context::new();
32
33    match v {
34        Value::Object(object) => {
35            for (k, v) in object.iter() {
36                result.insert(k, v);
37            }
38
39            Ok(result)
40        },
41        _ => Err(Error::JsonTypeError {
42            expected: JsonType::Object,
43            got: (&v).into(),
44        }),
45    }
46}
47
48#[derive(Clone, Debug)]
49pub struct Pdl {
50    pub schema: Option<Schema>,
51    pub messages: Vec<Message>,
52}
53
54impl Pdl {
55    pub fn validate(&self) -> Result<(), Error> {
56        if self.messages.is_empty() {
57            return Err(Error::InvalidPdl(String::from("A pdl file is empty.")));
58        }
59
60        let mut after_user = false;
61        let mut after_assistant = false;
62
63        for (index, Message { role, .. }) in self.messages.iter().enumerate() {
64            match role {
65                Role::User => {
66                    if after_user {
67                        return Err(Error::InvalidPdl(String::from("<|user|> appeared twice in a row.")));
68                    }
69
70                    after_user = true;
71                    after_assistant = false;
72                },
73                Role::Assistant => {
74                    if after_assistant {
75                        return Err(Error::InvalidPdl(String::from("<|assistant|> appeared twice in a row.")));
76                    }
77
78                    after_user = false;
79                    after_assistant = true;
80                },
81                Role::System => {
82                    if index != 0 {
83                        return Err(Error::InvalidPdl(String::from("<|system|> must appear at top.")));
84                    }
85                },
86                Role::Reasoning => {},  // TODO
87            }
88        }
89
90        match self.messages.last() {
91            Some(Message { role: Role::Assistant, .. }) => {
92                return Err(Error::InvalidPdl(String::from("A pdl file ends with <|assistant|>.")));
93            },
94            _ => {},
95        }
96
97        Ok(())
98    }
99}
100
101pub fn parse_pdl_from_file(
102    path: &str,
103    context: &tera::Context,
104
105    // If it's not set, it would never return `Err`.
106    strict_mode: bool,
107) -> Result<Pdl, Error> {
108    parse_pdl(
109        &read_string(path)?,
110        context,
111        &parent(path)?,
112        strict_mode,
113    )
114}
115
116pub fn parse_pdl(
117    s: &str,
118    context: &tera::Context,
119    curr_dir: &str,
120
121    // If it's not set, it would never return `Err`.
122    strict_mode: bool,
123) -> Result<Pdl, Error> {
124    let mut renderer = tera::Tera::default();
125    renderer.autoescape_on(vec!["__tera_one_off"]);
126    renderer.set_escape_fn(escape_pdl_tokens);
127
128    let tera_rendered = match renderer.render_str(s, context) {
129        Ok(t) => t,
130        Err(e) => if strict_mode {
131            return Err(e.into());
132        } else {
133            s.to_string()
134        },
135    };
136
137    let mut messages = vec![];
138    let mut schema = None;
139    let mut curr_role = None;
140    let mut line_buffer = vec![];
141
142    // simple hack: Adding this line to the content makes the code
143    // handle the last turn correctly. Since this fake turn is empty,
144    // it will be removed later.
145    let last_line = "<|assistant|>";
146
147    for line in tera_rendered.lines().chain(std::iter::once(last_line)) {
148        let trimmed = line.trim();
149
150        // maybe a turn-separator
151        if trimmed.starts_with("<|") && trimmed.ends_with("|>") && trimmed.len() > 4 {
152            match trimmed.to_ascii_lowercase().get(2..(trimmed.len() - 2)).unwrap() {
153                t @ ("user" | "system" | "assistant" | "schema" | "reasoning") => {
154                    if !line_buffer.is_empty() || curr_role.is_some() {
155                        match curr_role {
156                            Some(PdlRole::Schema) => match parse_schema(&line_buffer.join("\n")) {
157                                Ok(s) => {
158                                    if schema.is_some() && strict_mode {
159                                        return Err(Error::InvalidPdl(String::from("<|schema|> appeared multiple times.")));
160                                    }
161
162                                    schema = Some(s);
163                                },
164                                Err(e) => {
165                                    if strict_mode {
166                                        return Err(e.into());
167                                    }
168                                },
169                            },
170                            // reasoning tokens are not fed to llm contexts
171                            Some(PdlRole::Reasoning) => {},
172                            _ => {
173                                // there must be lots of unnecessary newlines due to the nature of the format
174                                // let's just trim them away
175                                let raw_contents = line_buffer.join("\n");
176                                let raw_contents = raw_contents.trim();
177
178                                let role = match curr_role {
179                                    Some(role) => role,
180                                    None => {
181                                        if raw_contents.is_empty() {
182                                            curr_role = Some(PdlRole::from(t));
183                                            line_buffer = vec![];
184                                            continue;
185                                        }
186
187                                        if strict_mode {
188                                            return Err(Error::RoleMissing);
189                                        }
190
191                                        PdlRole::System
192                                    },
193                                };
194
195                                match into_message_contents(&raw_contents, curr_dir) {
196                                    Ok(t) => {
197                                        messages.push(Message {
198                                            role: role.into(),
199                                            content: t,
200                                        });
201                                    },
202                                    Err(e) => {
203                                        if strict_mode {
204                                            return Err(e);
205                                        }
206
207                                        else {
208                                            messages.push(Message {
209                                                role: role.into(),
210                                                content: vec![MessageContent::String(raw_contents.to_string())],
211                                            });
212                                        }
213                                    },
214                                }
215                            },
216                        }
217                    }
218
219                    curr_role = Some(PdlRole::from(t));
220                    line_buffer = vec![];
221                    continue;
222                },
223                t => {
224                    if strict_mode && t.chars().all(|c| c.is_ascii_alphabetic()) {
225                        return Err(Error::InvalidTurnSeparator(t.to_string()));
226                    }
227
228                    line_buffer.push(line.to_string());
229                },
230            }
231        }
232
233        else {
234            line_buffer.push(line.to_string());
235        }
236    }
237
238    if let Some(Message { content, .. }) = messages.last() {
239        if content.is_empty() {
240            messages.pop().unwrap();
241        }
242    }
243
244    let result = Pdl {
245        schema,
246        messages,
247    };
248
249    if strict_mode {
250        result.validate()?;
251    }
252
253    Ok(result)
254}
255
256pub fn escape_pdl_tokens(s: &str) -> String {
257    s.replace("&", "&amp;").replace("|>", "|&gt;").replace("<|", "&lt;|")
258}
259
260pub fn unescape_pdl_tokens(s: &str) -> String {  // TODO: use `Cow` type
261    s.replace("&lt;", "<").replace("&gt;", ">").replace("&amp;", "&")
262}
263
264fn into_message_contents(s: &str, curr_dir: &str) -> Result<Vec<MessageContent>, Error> {
265    let bytes = s.as_bytes().iter().map(|b| *b).collect::<Vec<_>>();
266    let mut index = 0;
267    let mut result = vec![];
268    let mut string_buffer = vec![];
269
270    loop {
271        match bytes.get(index) {
272            Some(b'<') => match try_parse_inline_block(&bytes, index, curr_dir) {
273                Ok(Some((image_type, bytes, new_index))) => {
274                    if !string_buffer.is_empty() {
275                        match String::from_utf8(string_buffer.clone()) {
276                            Ok(s) => {
277                                result.push(MessageContent::String(unescape_pdl_tokens(&s)));
278                            },
279                            Err(e) => {
280                                return Err(e.into());
281                            },
282                        }
283                    }
284
285                    result.push(MessageContent::Image { image_type, bytes });
286                    index = new_index;
287                    string_buffer = vec![];
288                    continue;
289                },
290                Ok(None) => {
291                    string_buffer.push(b'<');
292                },
293                Err(e) => {
294                    return Err(e);
295                },
296            },
297            Some(b) => {
298                string_buffer.push(*b);
299            },
300            None => {
301                if !string_buffer.is_empty() {
302                    match String::from_utf8(string_buffer) {
303                        Ok(s) => {
304                            result.push(MessageContent::String(unescape_pdl_tokens(&s)));
305                        },
306                        Err(e) => {
307                            return Err(e.into());
308                        },
309                    }
310                }
311
312                break;
313            },
314        }
315
316        index += 1;
317    }
318
319    Ok(result)
320}
321
322// 1. It returns `Ok(Some(_))` if it's a valid inline block.
323// 2. It returns `Ok(None)` if it's not an inline block.
324// 3. It returns `Err(_)` if it's an inline block, but there's an error (syntax error, image type error, file error, ...).
325fn try_parse_inline_block(bytes: &[u8], index: usize, curr_dir: &str) -> Result<Option<(ImageType, Vec<u8>, usize)>, Error> {
326    match try_get_pdl_token(bytes, index) {
327        Some((token, new_index)) => {
328            let media_re = &MEDIA_RE;
329            let raw_media_re = &RAW_MEDIA_RE;
330
331            if let Some(cap) = raw_media_re.captures(token) {
332                let image_type = String::from_utf8_lossy(&cap[1]).to_string();
333                let image_bytes = String::from_utf8_lossy(&cap[2]).to_string();
334
335                Ok(Some((ImageType::from_extension(&image_type)?, decode_base64(&image_bytes)?, new_index)))
336            }
337
338            else if let Some(cap) = media_re.captures(token) {
339                let path = &cap[1];
340                let file = join(curr_dir, &String::from_utf8_lossy(path).to_string())?;
341                Ok(Some((ImageType::from_extension(&extension(&file)?.unwrap_or(String::new()))?, read_bytes(&file)?, new_index)))
342            }
343
344            else {
345                Err(Error::InvalidInlineBlock)
346            }
347        },
348
349        // not an inline block at all
350        None => Ok(None),
351    }
352}
353
354fn try_get_pdl_token(bytes: &[u8], mut index: usize) -> Option<(&[u8], usize)> {
355    let old_index = index;
356
357    match (bytes.get(index), bytes.get(index + 1)) {
358        (Some(b'<'), Some(b'|')) => {
359            index += 2;
360
361            loop {
362                match (bytes.get(index), bytes.get(index + 1)) {
363                    (Some(b'|'), Some(b'>')) => {
364                        return Some((&bytes[(old_index + 2)..index], index + 2));
365                    },
366                    (_, Some(b'|')) => {
367                        index += 1;
368                    },
369                    (_, None) => {
370                        return None;
371                    },
372                    _ => {
373                        index += 2;
374                    },
375                }
376            }
377        },
378        _ => None,
379    }
380}
381
382#[cfg(test)]
383mod tests {
384    use crate::{
385        ImageType,
386        Message,
387        MessageContent,
388        Pdl,
389        Role,
390        decode_base64,
391        parse_pdl,
392        parse_pdl_from_file,
393    };
394    use ragit_fs::{
395        WriteMode,
396        join,
397        temp_dir,
398        write_string,
399    };
400
401    // more thorough test suites are in `tests/`
402    #[test]
403    fn messages_from_file_test() {
404        let tmp_path = join(
405            &temp_dir().unwrap(),
406            "test_messages.tera",
407        ).unwrap();
408
409        write_string(
410            &tmp_path,
411"
412<|system|>
413
414You're a code helper.
415
416<|user|>
417
418Write me a sudoku-solver.
419
420
421",
422            WriteMode::CreateOrTruncate,
423        ).unwrap();
424
425        let Pdl { messages, schema } = parse_pdl_from_file(
426            &tmp_path,
427            &tera::Context::new(),
428            true,  // strict mode
429        ).unwrap();
430
431        assert_eq!(
432            messages,
433            vec![
434                Message {
435                    role: Role::System,
436                    content: vec![
437                        MessageContent::String(String::from("You're a code helper.")),
438                    ],
439                },
440                Message {
441                    role: Role::User,
442                    content: vec![
443                        MessageContent::String(String::from("Write me a sudoku-solver.")),
444                    ],
445                },
446            ],
447        );
448        assert_eq!(
449            schema,
450            None,
451        );
452    }
453
454    #[test]
455    fn media_content_test() {
456        let Pdl { messages, schema } = parse_pdl(
457"
458<|user|>
459
460<|raw_media(png:HiMyNameIsBaehyunsol)|>
461",
462            &tera::Context::new(),
463            ".",  // there's no `<|media|>`
464            true,  // strict mode
465        ).unwrap();
466
467        assert_eq!(
468            messages,
469            vec![
470                Message {
471                    role: Role::User,
472                    content: vec![
473                        MessageContent::Image {
474                            image_type: ImageType::Png,
475                            bytes: decode_base64("HiMyNameIsBaehyunsol").unwrap(),
476                        },
477                    ],
478                },
479            ],
480        );
481        assert_eq!(
482            schema,
483            None,
484        );
485    }
486}