ragit_pdl/
schema.rs

1// First and foremost goal of schema validation is to give nice error messages to LLMs.
2
3use crate::error::Error;
4use serde_json::Value;
5use std::collections::HashSet;
6use std::fmt::{Debug, Display};
7use std::str::FromStr;
8
9mod code_fence;
10mod parse;
11mod parse_value;
12mod task_list;
13
14pub use code_fence::try_extract_code_fence;
15pub use parse::{SchemaParseError, parse_schema};
16use parse_value::{JsonMatch, extract_jsonish_literal};
17pub use task_list::{count_task_list_elements, try_extract_task_list};
18
19#[cfg(test)]
20mod tests;
21
22// After adding a non-json schema_type,
23//
24// 1. Make sure that the code compiles.
25// 2. Add a test case in `tests/`.
26// 3. Update `render_pdl_schema`.
27#[derive(Clone, Debug, PartialEq)]
28pub enum SchemaType {
29    Integer,
30    Float,
31    String,
32    Array(Option<Box<Schema>>),
33    Boolean,
34    Object(Vec<(String, Schema)>),
35    Null,
36    Yesno,
37    Code,
38
39    // https://github.github.com/gfm/#task-list-items-extension-
40    // https://github.com/baehyunsol/ragit/issues/17
41    TaskList,
42}
43
44impl SchemaType {
45    // LLMs will see this name (e.g. "I cannot find `array` in your output.")
46    pub fn type_name(&self) -> &'static str {
47        match self {
48            SchemaType::Integer => "integer",
49            SchemaType::Float => "float",
50            SchemaType::String => "string",
51            SchemaType::Array(_) => "array",
52            SchemaType::Boolean => "boolean",
53            SchemaType::Object(_) => "object",
54            SchemaType::Null => "null",
55            SchemaType::Yesno => "yes or no",
56            SchemaType::Code => "code",
57            SchemaType::TaskList => "markdown task list",
58        }
59    }
60
61    pub fn is_number(&self) -> bool {
62        match self {
63            SchemaType::Integer
64            | SchemaType::Float => true,
65            _ => false,
66        }
67    }
68
69    pub fn is_array(&self) -> bool {
70        matches!(self, SchemaType::Array(_))
71    }
72
73    pub fn unwrap_keys(&self) -> Vec<String> {
74        match self {
75            SchemaType::Object(obj) => obj.iter().map(|(key, _)| key.to_string()).collect(),
76            _ => panic!(),
77        }
78    }
79
80    // It's a word that describes this schema. It's used to
81    // format error messages: format!("your output doesn't contain a valid {}", self.one_word())
82    fn one_word(&self) -> &'static str {
83        match self {
84            SchemaType::Integer
85            | SchemaType::Float => "numeric value",
86            SchemaType::String => "string value",
87            SchemaType::Array(_)
88            | SchemaType::Object(_) => "json value",
89            SchemaType::Boolean => "boolean value",
90            SchemaType::Null => "null value",
91            SchemaType::Yesno => "yes/no value",
92            SchemaType::Code => "code block",
93            SchemaType::TaskList => "markdown task list",
94        }
95    }
96}
97
98#[derive(Clone, Debug)]
99pub enum SchemaError {
100    // _ is too (small | big | short | long). Make sure that _ (is at least | is at most | has at least | has at most) (N | N characters | N elements).
101    RangeError {
102        s1: String,  // small | big | short | long
103        s2: String,  // is at least | is as most | has at least | has at most
104        s3: String,  // N | N characters | N elements  (constraint)
105        s4: Option<String>,  // N characters | N elements (current)
106    },
107    MissingKeys(Vec<String>),
108    UnnecessaryKeys(Vec<String>),
109    ErrorInObject {
110        key: String,
111        error: Box<SchemaError>,
112    },
113    ErrorInArray {
114        index: usize,
115        error: Box<SchemaError>,
116    },
117    TypeError {
118        expected: SchemaType,
119        got: SchemaType,
120    },
121}
122
123impl SchemaError {
124    // This is an error message for LLMs, not for (human) programmers.
125    // It has to be short and readable english sentences, unlike
126    // compiler error messages.
127    pub fn prettify(&self, schema: &Schema) -> String {
128        match self {
129            SchemaError::RangeError { s1, s2, s3, s4 } => format!(
130                "Your output is too {s1}. Make sure that the output {s2} {s3}.{}",
131                if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
132            ),
133            SchemaError::MissingKeys(keys) => {
134                let schema_keys = schema.unwrap_keys();
135
136                format!(
137                    "Your output is missing {}: {}. Make sure that your output contains {} key{}: {}",
138                    if keys.len() == 1 { "a field" } else { "fields "},
139                    keys.join(", "),
140                    schema_keys.len(),
141                    if schema_keys.len() == 1 { "" } else { "s" },
142                    schema_keys.join(", "),
143                )
144            },
145            SchemaError::UnnecessaryKeys(keys) => {
146                let schema_keys = schema.unwrap_keys();
147
148                format!(
149                    "Your output has {}unnecessary key{}: {}. Make sure that the output contains {}key{}: {}",
150                    if keys.len() == 1 { "an " } else { "" },
151                    if keys.len() == 1 { "" } else { "s" },
152                    keys.join(", "),
153                    if schema_keys.len() == 1 { "a " } else { "" },
154                    if schema_keys.len() == 1 { "" } else { "s" },
155                    schema_keys.join(", "),
156                )
157            },
158            SchemaError::ErrorInObject { key, error } => match error.as_ref() {
159                SchemaError::RangeError { s1, s2, s3, s4 } => format!(
160                    "Field `{key}` of your output is too {s1}. Make sure that the field {s2} {s3}.{}",
161                    if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
162                ),
163                SchemaError::TypeError { expected, got } => format!(
164                    "Field `{key}` of your output has a wrong type. Make sure that the field is `{}`, not `{}`.",
165                    expected.type_name(),
166                    got.type_name(),
167                ),
168                // It assumes that the models can find the schema somewhere in prompts.
169                // TODO: better error messages in these cases
170                _ => String::from("Please make sure that your output has a correct schema."),
171            },
172            SchemaError::ErrorInArray { index, error } => match error.as_ref() {
173                SchemaError::RangeError { s1, s2, s3, s4 } => format!(
174                    "The {} value of your output is too {s1}. Make sure that the value {s2} {s3}.{}",
175                    match index {
176                        0 => String::from("first"),
177                        1 => String::from("second"),
178                        2 => String::from("third"),
179                        3 => String::from("forth"),
180                        4 => String::from("fifth"),
181                        n => format!("{}th", n + 1),
182                    },
183                    if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
184                ),
185                SchemaError::TypeError { expected, got } => format!(
186                    "The {} value of your output has a wrong type. Make sure all the elements are `{}`, not `{}`.",
187                    match index {
188                        0 => String::from("first"),
189                        1 => String::from("second"),
190                        2 => String::from("third"),
191                        3 => String::from("forth"),
192                        4 => String::from("fifth"),
193                        n => format!("{}th", n + 1),
194                    },
195                    expected.type_name(),
196                    got.type_name(),
197                ),
198                // It assumes that the models can find the schema somewhere in prompts.
199                // TODO: better error messages in these cases
200                _ => String::from("Please make sure that your output has a correct schema."),
201            },
202            SchemaError::TypeError { expected, got } => format!(
203                "Your output has a wrong type. It has to be `{}`, not `{}`.",
204                expected.type_name(),
205                got.type_name(),
206            ),
207        }
208    }
209}
210
211#[derive(Clone, Debug, PartialEq)]
212pub struct Schema {
213    r#type: SchemaType,
214    constraint: Option<Constraint>,
215}
216
217impl Schema {
218    // If it's `Ok(s)`, `s` is an evaluable json string.
219    // If it's `Err(e)`, `e` is an error message which must be sent to the LLM.
220    pub fn validate(&self, s: &str) -> Result<Value, String> {
221        let extracted_text = self.extract_text(s)?;
222        let v = match serde_json::from_str::<Value>(&extracted_text) {
223            Ok(v) => v,
224            Err(_) => {
225                return Err(format!("I cannot parse your output. Please make sure that your output contains a valid {} with valid data.", self.one_word()));
226            },
227        };
228
229        self.validate_value(&v).map_err(|e| e.prettify(self))?;
230        Ok(v)
231    }
232
233    fn validate_value(&self, v: &Value) -> Result<(), SchemaError> {
234        match (&self.r#type, v) {
235            (SchemaType::Integer, Value::Number(n)) => match n.as_i64() {
236                Some(n) => {
237                    check_range(SchemaType::Integer, &self.constraint, n)?;
238                    Ok(())
239                },
240                None => Err(SchemaError::TypeError {
241                    expected: SchemaType::Integer,
242                    got: SchemaType::Float,
243                }),
244            },
245            (SchemaType::Float, Value::Number(n)) => match n.as_f64() {
246                Some(n) => {
247                    check_range(SchemaType::Float, &self.constraint, n)?;
248                    Ok(())
249                },
250                None => unreachable!(),
251            },
252            (ty @ (SchemaType::String | SchemaType::Code), Value::String(s)) => {
253                check_range(ty.clone(), &self.constraint, s.len())?;
254                Ok(())
255            },
256            (SchemaType::Array(schema), Value::Array(v)) => {
257                if let Some(schema) = schema {
258                    for (index, e) in v.iter().enumerate() {
259                        if let Err(e) = schema.validate_value(e) {
260                            return Err(SchemaError::ErrorInArray { index, error: Box::new(e) });
261                        }
262                    }
263                }
264
265                check_range(SchemaType::Array(None), &self.constraint, v.len())?;
266                Ok(())
267            },
268            (SchemaType::Object(obj_schema), Value::Object(obj)) => {
269                let mut keys_in_schema = HashSet::with_capacity(obj_schema.len());
270                let mut missing_keys = vec![];
271                let mut unnecessary_keys = vec![];
272
273                for (k, v_schema) in obj_schema.iter() {
274                    keys_in_schema.insert(k);
275
276                    match obj.get(k) {
277                        Some(v) => match v_schema.validate_value(v) {
278                            Ok(_) => {},
279                            Err(e) => {
280                                return Err(SchemaError::ErrorInObject {
281                                    key: k.to_string(),
282                                    error: Box::new(e),
283                                });
284                            },
285                        },
286                        None => {
287                            missing_keys.push(k.to_string());
288                        },
289                    }
290                }
291
292                for k in obj.keys() {
293                    if !keys_in_schema.contains(k) {
294                        unnecessary_keys.push(k.to_string());
295                    }
296                }
297
298                if !missing_keys.is_empty() {
299                    Err(SchemaError::MissingKeys(missing_keys))
300                }
301
302                else if !unnecessary_keys.is_empty() {
303                    Err(SchemaError::UnnecessaryKeys(unnecessary_keys))
304                }
305
306                else {
307                    Ok(())
308                }
309            },
310            (SchemaType::TaskList, Value::String(s)) => {
311                check_range(SchemaType::TaskList, &self.constraint, count_task_list_elements(s))?;
312                Ok(())
313            },
314            (SchemaType::Boolean | SchemaType::Yesno, Value::Bool(_)) => Ok(()),
315            (t1, t2) => Err(SchemaError::TypeError {
316                expected: t1.clone(),
317                got: get_schema_type(t2),
318            }),
319        }
320    }
321
322    // It tries to extract a json value from a haystack.
323    // It raises an error if there are multiple candidates.
324    // It can be more generous than json's syntax (e.g. it allows `true` and `True`),
325    // but its return value must be a valid json.
326    fn extract_text(&self, s: &str) -> Result<String, String> {
327        match &self.r#type {
328            SchemaType::Boolean | SchemaType::Yesno => {
329                let s = s.to_ascii_lowercase();
330                let t = if self.r#type == SchemaType::Boolean { s.contains("true")} else { s.contains("yes") };
331                let f = if self.r#type == SchemaType::Boolean { s.contains("false")} else { s.contains("no") };
332
333                match (t, f) {
334                    (true, false) => Ok(String::from("true")),
335                    (false, true) => Ok(String::from("false")),
336                    (true, true) => if self.r#type == SchemaType::Boolean {
337                        Err(String::from("Your output contains both `true` and `false`. Please be specific."))
338                    } else {
339                        Err(String::from("Just say yes or no."))
340                    },
341                    (false, false) => if self.r#type == SchemaType::Boolean {
342                        Err(format!("I cannot find `boolean` in your output. Please make sure that your output contains a valid {}.", self.one_word()))
343                    } else {
344                        Err(String::from("Just say yes or no."))
345                    },
346                }
347            },
348            SchemaType::Null => {
349                let low = s.to_ascii_lowercase();
350
351                if low == "null" || low == "none" {
352                    Ok(String::from("null"))
353                }
354
355                else {
356                    Err(format!("{s:?} is not null."))
357                }
358            },
359            SchemaType::String => Ok(format!("{s:?}")),
360            SchemaType::Code => Ok(format!("{:?}", try_extract_code_fence(s)?)),
361            SchemaType::TaskList => Ok(format!("{:?}", try_extract_task_list(s)?)),
362            SchemaType::Integer | SchemaType::Float
363            | SchemaType::Array(_) | SchemaType::Object(_) => {
364                let mut jsonish_literals = extract_jsonish_literal(s);
365
366                match jsonish_literals.get_matches(&self.r#type) {
367                    JsonMatch::NoMatch => Err(format!("I cannot find `{}` in your output. Please make sure that your output contains a valid {}.", self.type_name(), self.one_word())),
368                    JsonMatch::MultipleMatches => Err(format!("I see more than 1 candidates that look like `{}`. I don't know which one to choose. Please give me just one `{}`.", self.type_name(), self.type_name())),
369                    JsonMatch::Match(s) => Ok(s.to_string()),
370                    JsonMatch::ExpectedIntegerGotFloat(s) => Err(format!("I want an integer, but I can only find a float literal: `{s}`. Could you give me an integer literal?")),
371                }
372            },
373        }
374    }
375
376    pub fn default_integer() -> Self {
377        Schema {
378            r#type: SchemaType::Integer,
379            constraint: None,
380        }
381    }
382
383    /// Both inclusive.
384    pub fn integer_between(min: Option<i128>, max: Option<i128>) -> Self {
385        Schema {
386            r#type: SchemaType::Integer,
387            constraint: Some(Constraint {
388                min: min.map(|n| n.to_string()),
389                max: max.map(|n| n.to_string()),
390            }),
391        }
392    }
393
394    pub fn default_float() -> Self {
395        Schema {
396            r#type: SchemaType::Float,
397            constraint: None,
398        }
399    }
400
401    pub fn default_string() -> Self {
402        Schema {
403            r#type: SchemaType::String,
404            constraint: None,
405        }
406    }
407
408    /// Both inclusive
409    pub fn string_length_between(min: Option<usize>, max: Option<usize>) -> Self {
410        Schema {
411            r#type: SchemaType::String,
412            constraint: Some(Constraint {
413                min: min.map(|n| n.to_string()),
414                max: max.map(|n| n.to_string()),
415            }),
416        }
417    }
418
419    pub fn default_array(r#type: Option<Schema>) -> Self {
420        Schema {
421            r#type: SchemaType::Array(r#type.map(|t| Box::new(t))),
422            constraint: None,
423        }
424    }
425
426    pub fn default_boolean() -> Self {
427        Schema {
428            r#type: SchemaType::Boolean,
429            constraint: None,
430        }
431    }
432
433    pub fn default_yesno() -> Self {
434        Schema {
435            r#type: SchemaType::Yesno,
436            constraint: None,
437        }
438    }
439
440    pub fn default_code() -> Self {
441        Schema {
442            r#type: SchemaType::Code,
443            constraint: None,
444        }
445    }
446
447    pub fn default_task_list() -> Self {
448        Schema {
449            r#type: SchemaType::TaskList,
450            constraint: None,
451        }
452    }
453
454    pub fn add_constraint(&mut self, constraint: Constraint) {
455        debug_assert!(self.constraint.is_none());
456        self.constraint = Some(constraint);
457    }
458
459    pub fn validate_constraint(&self) -> Result<(), SchemaParseError> {
460        match (&self.r#type, &self.constraint) {
461            (ty @ (SchemaType::Integer | SchemaType::Array(_) | SchemaType::String | SchemaType::TaskList | SchemaType::Code), Some(constraint)) => {
462                let mut min_ = i64::MIN;
463                let mut max_ = i64::MAX;
464
465                if let Some(min) = &constraint.min {
466                    match min.parse::<i64>() {
467                        Ok(n) => { min_ = n; },
468                        Err(_) => {
469                            return Err(SchemaParseError::InvalidConstraint(format!("{min:?} is not a valid integer.")));
470                        },
471                    }
472                }
473
474                if let Some(max) = &constraint.max {
475                    match max.parse::<i64>() {
476                        Ok(n) => { max_ = n; },
477                        Err(_) => {
478                            return Err(SchemaParseError::InvalidConstraint(format!("{max:?} is not a valid integer.")));
479                        },
480                    }
481                }
482
483                if min_ > max_ {
484                    return Err(SchemaParseError::InvalidConstraint(format!("`min` ({min_}) is greater than `max` ({max_}).")));
485                }
486
487                if matches!(ty, SchemaType::String) || matches!(ty, SchemaType::Array(_)) {
488                    if constraint.min.is_some() && min_ < 0 {
489                        return Err(SchemaParseError::InvalidConstraint(format!("`min` is supposed to be a positive integer, but is {min_}")));
490                    }
491
492                    if constraint.max.is_some() && max_ < 0 {
493                        return Err(SchemaParseError::InvalidConstraint(format!("`max` is supposed to be a positive integer, but is {max_}")));
494                    }
495                }
496
497                Ok(())
498            },
499            (SchemaType::Float, Some(constraint)) => {
500                let mut min_ = f64::MIN;
501                let mut max_ = f64::MAX;
502
503                if let Some(min) = &constraint.min {
504                    match min.parse::<f64>() {
505                        Ok(n) => { min_ = n; },
506                        Err(_) => {
507                            return Err(SchemaParseError::InvalidConstraint(format!("{min:?} is not a valid number.")));
508                        },
509                    }
510                }
511
512                if let Some(max) = &constraint.max {
513                    match max.parse::<f64>() {
514                        Ok(n) => { max_ = n; },
515                        Err(_) => {
516                            return Err(SchemaParseError::InvalidConstraint(format!("{max:?} is not a valid number.")));
517                        },
518                    }
519                }
520
521                if min_ > max_ {
522                    return Err(SchemaParseError::InvalidConstraint(format!("`min` ({min_}) is greater than `max` ({max_}).")));
523                }
524
525                Ok(())
526            },
527            (ty @ (SchemaType::Null | SchemaType::Boolean | SchemaType::Object(_) | SchemaType::Yesno), Some(constraint)) => {
528                if constraint.min.is_some() {
529                    Err(SchemaParseError::InvalidConstraint(format!(
530                        "Type `{}` cannot have constraint `min`",
531                        ty.type_name(),
532                    )))
533                }
534
535                else if constraint.max.is_some() {
536                    Err(SchemaParseError::InvalidConstraint(format!(
537                        "Type `{}` cannot have constraint `max`",
538                        ty.type_name(),
539                    )))
540                }
541
542                else {
543                    Ok(())
544                }
545            },
546            (_, None) => Ok(()),
547        }
548    }
549
550    pub fn type_name(&self) -> &'static str {
551        self.r#type.type_name()
552    }
553
554    pub fn unwrap_keys(&self) -> Vec<String> {
555        self.r#type.unwrap_keys()
556    }
557
558    // It's a word that describes this schema. It's used to
559    // format error messages: format!("your output doesn't contain a valid {}", self.one_word())
560    fn one_word(&self) -> &'static str {
561        self.r#type.one_word()
562    }
563}
564
565/// pdl schema is a bit unintuitive when you're using non-json schema.
566/// For example, schema `yesno` will become `Value::Bool`. If you naively
567/// convert this to a string, you'll get "true" or "false", not "yes" or "no".
568///
569/// Likewise, schema `code` will become `Value::String` whose content is the code.
570/// If you naively convert this to a string (using serde_json::to_string), you'll
571/// get something like "\"fn main() ...\"".
572pub fn render_pdl_schema(
573    schema: &Schema,
574
575    // Result of `Schema::validate`
576    value: &Value,
577) -> Result<String, Error> {
578    let s = match (&schema.r#type, value) {
579        (SchemaType::Code, Value::String(s)) => s.to_string(),
580        (SchemaType::TaskList, Value::String(s)) => s.to_string(),
581        (SchemaType::Yesno, Value::Bool(b)) => if *b {
582            String::from("yes")
583        } else {
584            String::from("no")
585        },
586        _ => serde_json::to_string_pretty(value)?,
587    };
588
589    Ok(s)
590}
591
592// union of all constraints
593#[derive(Clone, Debug, Default, PartialEq)]
594pub struct Constraint {
595    // for `Integer` and `Float`, these are min/max values
596    // for `String`, these are min/max char len
597    // for `Array`, these are min/max len
598    min: Option<String>,
599    max: Option<String>,
600}
601
602fn get_schema_type(v: &Value) -> SchemaType {
603    match v {
604        Value::Number(n) => {
605            if n.is_i64() {
606                SchemaType::Integer
607            }
608
609            else {
610                SchemaType::Float
611            }
612        },
613        Value::String(_) => SchemaType::String,
614        Value::Array(_) => SchemaType::Array(None),
615        Value::Object(_) => SchemaType::Object(vec![]),
616        Value::Bool(_) => SchemaType::Boolean,
617        Value::Null => SchemaType::Null,
618    }
619}
620
621fn check_range<T: PartialOrd + FromStr + ToString + Display>(schema: SchemaType, constraint: &Option<Constraint>, n: T) -> Result<(), SchemaError> where <T as FromStr>::Err: Debug {
622    // It's okay to unwrap values because `Constraint` is always validated at creation.
623    if let Some(constraint) = constraint {
624        if let Constraint { min: Some(min), .. } = &constraint {
625            let min = min.parse::<T>().unwrap();
626
627            if n < min {
628                return Err(SchemaError::RangeError {
629                    s1: String::from(if schema.is_number() { "small" } else { "short" }),
630                    s2: String::from(if schema.is_number() { "is at least" } else { "has at least" }),
631                    s3: if schema.is_number() { min.to_string() } else if schema.is_array() { format!("{min} elements") } else { format!("{min} characters") },
632                    s4: if schema.is_number() { None } else if schema.is_array() { Some(format!("{n} elements")) } else { Some(format!("{n} characters")) },
633                });
634            }
635        }
636
637        if let Constraint { max: Some(max), .. } = &constraint {
638            let max = max.parse::<T>().unwrap();
639
640            if n > max {
641                return Err(SchemaError::RangeError {
642                    s1: String::from(if schema.is_number() { "big" } else { "long" }),
643                    s2: String::from(if schema.is_number() { "is at most" } else { "has at most" }),
644                    s3: if schema.is_number() { max.to_string() } else if schema.is_array() { format!("{max} elements") } else { format!("{max} characters") },
645                    s4: if schema.is_number() { None } else if schema.is_array() { Some(format!("{n} elements")) } else { Some(format!("{n} characters")) },
646                });
647            }
648        }
649    }
650
651    Ok(())
652}