1use crate::error::Error;
4use serde_json::Value;
5use std::collections::HashSet;
6use std::fmt::{Debug, Display};
7use std::str::FromStr;
8
9mod code_fence;
10mod parse;
11mod parse_value;
12mod task_list;
13
14pub use code_fence::try_extract_code_fence;
15pub use parse::{SchemaParseError, parse_schema};
16use parse_value::{JsonMatch, extract_jsonish_literal};
17pub use task_list::{count_task_list_elements, try_extract_task_list};
18
19#[cfg(test)]
20mod tests;
21
22#[derive(Clone, Debug, PartialEq)]
28pub enum SchemaType {
29 Integer,
30 Float,
31 String,
32 Array(Option<Box<Schema>>),
33 Boolean,
34 Object(Vec<(String, Schema)>),
35 Null,
36 Yesno,
37 Code,
38
39 TaskList,
42}
43
44impl SchemaType {
45 pub fn type_name(&self) -> &'static str {
47 match self {
48 SchemaType::Integer => "integer",
49 SchemaType::Float => "float",
50 SchemaType::String => "string",
51 SchemaType::Array(_) => "array",
52 SchemaType::Boolean => "boolean",
53 SchemaType::Object(_) => "object",
54 SchemaType::Null => "null",
55 SchemaType::Yesno => "yes or no",
56 SchemaType::Code => "code",
57 SchemaType::TaskList => "markdown task list",
58 }
59 }
60
61 pub fn is_number(&self) -> bool {
62 match self {
63 SchemaType::Integer
64 | SchemaType::Float => true,
65 _ => false,
66 }
67 }
68
69 pub fn is_array(&self) -> bool {
70 matches!(self, SchemaType::Array(_))
71 }
72
73 pub fn unwrap_keys(&self) -> Vec<String> {
74 match self {
75 SchemaType::Object(obj) => obj.iter().map(|(key, _)| key.to_string()).collect(),
76 _ => panic!(),
77 }
78 }
79
80 fn one_word(&self) -> &'static str {
83 match self {
84 SchemaType::Integer
85 | SchemaType::Float => "numeric value",
86 SchemaType::String => "string value",
87 SchemaType::Array(_)
88 | SchemaType::Object(_) => "json value",
89 SchemaType::Boolean => "boolean value",
90 SchemaType::Null => "null value",
91 SchemaType::Yesno => "yes/no value",
92 SchemaType::Code => "code block",
93 SchemaType::TaskList => "markdown task list",
94 }
95 }
96}
97
98#[derive(Clone, Debug)]
99pub enum SchemaError {
100 RangeError {
102 s1: String, s2: String, s3: String, s4: Option<String>, },
107 MissingKeys(Vec<String>),
108 UnnecessaryKeys(Vec<String>),
109 ErrorInObject {
110 key: String,
111 error: Box<SchemaError>,
112 },
113 ErrorInArray {
114 index: usize,
115 error: Box<SchemaError>,
116 },
117 TypeError {
118 expected: SchemaType,
119 got: SchemaType,
120 },
121}
122
123impl SchemaError {
124 pub fn prettify(&self, schema: &Schema) -> String {
128 match self {
129 SchemaError::RangeError { s1, s2, s3, s4 } => format!(
130 "Your output is too {s1}. Make sure that the output {s2} {s3}.{}",
131 if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
132 ),
133 SchemaError::MissingKeys(keys) => {
134 let schema_keys = schema.unwrap_keys();
135
136 format!(
137 "Your output is missing {}: {}. Make sure that your output contains {} key{}: {}",
138 if keys.len() == 1 { "a field" } else { "fields "},
139 keys.join(", "),
140 schema_keys.len(),
141 if schema_keys.len() == 1 { "" } else { "s" },
142 schema_keys.join(", "),
143 )
144 },
145 SchemaError::UnnecessaryKeys(keys) => {
146 let schema_keys = schema.unwrap_keys();
147
148 format!(
149 "Your output has {}unnecessary key{}: {}. Make sure that the output contains {}key{}: {}",
150 if keys.len() == 1 { "an " } else { "" },
151 if keys.len() == 1 { "" } else { "s" },
152 keys.join(", "),
153 if schema_keys.len() == 1 { "a " } else { "" },
154 if schema_keys.len() == 1 { "" } else { "s" },
155 schema_keys.join(", "),
156 )
157 },
158 SchemaError::ErrorInObject { key, error } => match error.as_ref() {
159 SchemaError::RangeError { s1, s2, s3, s4 } => format!(
160 "Field `{key}` of your output is too {s1}. Make sure that the field {s2} {s3}.{}",
161 if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
162 ),
163 SchemaError::TypeError { expected, got } => format!(
164 "Field `{key}` of your output has a wrong type. Make sure that the field is `{}`, not `{}`.",
165 expected.type_name(),
166 got.type_name(),
167 ),
168 _ => String::from("Please make sure that your output has a correct schema."),
171 },
172 SchemaError::ErrorInArray { index, error } => match error.as_ref() {
173 SchemaError::RangeError { s1, s2, s3, s4 } => format!(
174 "The {} value of your output is too {s1}. Make sure that the value {s2} {s3}.{}",
175 match index {
176 0 => String::from("first"),
177 1 => String::from("second"),
178 2 => String::from("third"),
179 3 => String::from("forth"),
180 4 => String::from("fifth"),
181 n => format!("{}th", n + 1),
182 },
183 if let Some(s4) = s4 { format!(" Currently, it has {s4}.") } else { String::new() },
184 ),
185 SchemaError::TypeError { expected, got } => format!(
186 "The {} value of your output has a wrong type. Make sure all the elements are `{}`, not `{}`.",
187 match index {
188 0 => String::from("first"),
189 1 => String::from("second"),
190 2 => String::from("third"),
191 3 => String::from("forth"),
192 4 => String::from("fifth"),
193 n => format!("{}th", n + 1),
194 },
195 expected.type_name(),
196 got.type_name(),
197 ),
198 _ => String::from("Please make sure that your output has a correct schema."),
201 },
202 SchemaError::TypeError { expected, got } => format!(
203 "Your output has a wrong type. It has to be `{}`, not `{}`.",
204 expected.type_name(),
205 got.type_name(),
206 ),
207 }
208 }
209}
210
211#[derive(Clone, Debug, PartialEq)]
212pub struct Schema {
213 r#type: SchemaType,
214 constraint: Option<Constraint>,
215}
216
217impl Schema {
218 pub fn validate(&self, s: &str) -> Result<Value, String> {
221 let extracted_text = self.extract_text(s)?;
222 let v = match serde_json::from_str::<Value>(&extracted_text) {
223 Ok(v) => v,
224 Err(_) => {
225 return Err(format!("I cannot parse your output. Please make sure that your output contains a valid {} with valid data.", self.one_word()));
226 },
227 };
228
229 self.validate_value(&v).map_err(|e| e.prettify(self))?;
230 Ok(v)
231 }
232
233 fn validate_value(&self, v: &Value) -> Result<(), SchemaError> {
234 match (&self.r#type, v) {
235 (SchemaType::Integer, Value::Number(n)) => match n.as_i64() {
236 Some(n) => {
237 check_range(SchemaType::Integer, &self.constraint, n)?;
238 Ok(())
239 },
240 None => Err(SchemaError::TypeError {
241 expected: SchemaType::Integer,
242 got: SchemaType::Float,
243 }),
244 },
245 (SchemaType::Float, Value::Number(n)) => match n.as_f64() {
246 Some(n) => {
247 check_range(SchemaType::Float, &self.constraint, n)?;
248 Ok(())
249 },
250 None => unreachable!(),
251 },
252 (ty @ (SchemaType::String | SchemaType::Code), Value::String(s)) => {
253 check_range(ty.clone(), &self.constraint, s.len())?;
254 Ok(())
255 },
256 (SchemaType::Array(schema), Value::Array(v)) => {
257 if let Some(schema) = schema {
258 for (index, e) in v.iter().enumerate() {
259 if let Err(e) = schema.validate_value(e) {
260 return Err(SchemaError::ErrorInArray { index, error: Box::new(e) });
261 }
262 }
263 }
264
265 check_range(SchemaType::Array(None), &self.constraint, v.len())?;
266 Ok(())
267 },
268 (SchemaType::Object(obj_schema), Value::Object(obj)) => {
269 let mut keys_in_schema = HashSet::with_capacity(obj_schema.len());
270 let mut missing_keys = vec![];
271 let mut unnecessary_keys = vec![];
272
273 for (k, v_schema) in obj_schema.iter() {
274 keys_in_schema.insert(k);
275
276 match obj.get(k) {
277 Some(v) => match v_schema.validate_value(v) {
278 Ok(_) => {},
279 Err(e) => {
280 return Err(SchemaError::ErrorInObject {
281 key: k.to_string(),
282 error: Box::new(e),
283 });
284 },
285 },
286 None => {
287 missing_keys.push(k.to_string());
288 },
289 }
290 }
291
292 for k in obj.keys() {
293 if !keys_in_schema.contains(k) {
294 unnecessary_keys.push(k.to_string());
295 }
296 }
297
298 if !missing_keys.is_empty() {
299 Err(SchemaError::MissingKeys(missing_keys))
300 }
301
302 else if !unnecessary_keys.is_empty() {
303 Err(SchemaError::UnnecessaryKeys(unnecessary_keys))
304 }
305
306 else {
307 Ok(())
308 }
309 },
310 (SchemaType::TaskList, Value::String(s)) => {
311 check_range(SchemaType::TaskList, &self.constraint, count_task_list_elements(s))?;
312 Ok(())
313 },
314 (SchemaType::Boolean | SchemaType::Yesno, Value::Bool(_)) => Ok(()),
315 (t1, t2) => Err(SchemaError::TypeError {
316 expected: t1.clone(),
317 got: get_schema_type(t2),
318 }),
319 }
320 }
321
322 fn extract_text(&self, s: &str) -> Result<String, String> {
327 match &self.r#type {
328 SchemaType::Boolean | SchemaType::Yesno => {
329 let s = s.to_ascii_lowercase();
330 let t = if self.r#type == SchemaType::Boolean { s.contains("true")} else { s.contains("yes") };
331 let f = if self.r#type == SchemaType::Boolean { s.contains("false")} else { s.contains("no") };
332
333 match (t, f) {
334 (true, false) => Ok(String::from("true")),
335 (false, true) => Ok(String::from("false")),
336 (true, true) => if self.r#type == SchemaType::Boolean {
337 Err(String::from("Your output contains both `true` and `false`. Please be specific."))
338 } else {
339 Err(String::from("Just say yes or no."))
340 },
341 (false, false) => if self.r#type == SchemaType::Boolean {
342 Err(format!("I cannot find `boolean` in your output. Please make sure that your output contains a valid {}.", self.one_word()))
343 } else {
344 Err(String::from("Just say yes or no."))
345 },
346 }
347 },
348 SchemaType::Null => {
349 let low = s.to_ascii_lowercase();
350
351 if low == "null" || low == "none" {
352 Ok(String::from("null"))
353 }
354
355 else {
356 Err(format!("{s:?} is not null."))
357 }
358 },
359 SchemaType::String => Ok(format!("{s:?}")),
360 SchemaType::Code => Ok(format!("{:?}", try_extract_code_fence(s)?)),
361 SchemaType::TaskList => Ok(format!("{:?}", try_extract_task_list(s)?)),
362 SchemaType::Integer | SchemaType::Float
363 | SchemaType::Array(_) | SchemaType::Object(_) => {
364 let mut jsonish_literals = extract_jsonish_literal(s);
365
366 match jsonish_literals.get_matches(&self.r#type) {
367 JsonMatch::NoMatch => Err(format!("I cannot find `{}` in your output. Please make sure that your output contains a valid {}.", self.type_name(), self.one_word())),
368 JsonMatch::MultipleMatches => Err(format!("I see more than 1 candidates that look like `{}`. I don't know which one to choose. Please give me just one `{}`.", self.type_name(), self.type_name())),
369 JsonMatch::Match(s) => Ok(s.to_string()),
370 JsonMatch::ExpectedIntegerGotFloat(s) => Err(format!("I want an integer, but I can only find a float literal: `{s}`. Could you give me an integer literal?")),
371 }
372 },
373 }
374 }
375
376 pub fn default_integer() -> Self {
377 Schema {
378 r#type: SchemaType::Integer,
379 constraint: None,
380 }
381 }
382
383 pub fn integer_between(min: Option<i128>, max: Option<i128>) -> Self {
385 Schema {
386 r#type: SchemaType::Integer,
387 constraint: Some(Constraint {
388 min: min.map(|n| n.to_string()),
389 max: max.map(|n| n.to_string()),
390 }),
391 }
392 }
393
394 pub fn default_float() -> Self {
395 Schema {
396 r#type: SchemaType::Float,
397 constraint: None,
398 }
399 }
400
401 pub fn default_string() -> Self {
402 Schema {
403 r#type: SchemaType::String,
404 constraint: None,
405 }
406 }
407
408 pub fn string_length_between(min: Option<usize>, max: Option<usize>) -> Self {
410 Schema {
411 r#type: SchemaType::String,
412 constraint: Some(Constraint {
413 min: min.map(|n| n.to_string()),
414 max: max.map(|n| n.to_string()),
415 }),
416 }
417 }
418
419 pub fn default_array(r#type: Option<Schema>) -> Self {
420 Schema {
421 r#type: SchemaType::Array(r#type.map(|t| Box::new(t))),
422 constraint: None,
423 }
424 }
425
426 pub fn default_boolean() -> Self {
427 Schema {
428 r#type: SchemaType::Boolean,
429 constraint: None,
430 }
431 }
432
433 pub fn default_yesno() -> Self {
434 Schema {
435 r#type: SchemaType::Yesno,
436 constraint: None,
437 }
438 }
439
440 pub fn default_code() -> Self {
441 Schema {
442 r#type: SchemaType::Code,
443 constraint: None,
444 }
445 }
446
447 pub fn default_task_list() -> Self {
448 Schema {
449 r#type: SchemaType::TaskList,
450 constraint: None,
451 }
452 }
453
454 pub fn add_constraint(&mut self, constraint: Constraint) {
455 debug_assert!(self.constraint.is_none());
456 self.constraint = Some(constraint);
457 }
458
459 pub fn validate_constraint(&self) -> Result<(), SchemaParseError> {
460 match (&self.r#type, &self.constraint) {
461 (ty @ (SchemaType::Integer | SchemaType::Array(_) | SchemaType::String | SchemaType::TaskList | SchemaType::Code), Some(constraint)) => {
462 let mut min_ = i64::MIN;
463 let mut max_ = i64::MAX;
464
465 if let Some(min) = &constraint.min {
466 match min.parse::<i64>() {
467 Ok(n) => { min_ = n; },
468 Err(_) => {
469 return Err(SchemaParseError::InvalidConstraint(format!("{min:?} is not a valid integer.")));
470 },
471 }
472 }
473
474 if let Some(max) = &constraint.max {
475 match max.parse::<i64>() {
476 Ok(n) => { max_ = n; },
477 Err(_) => {
478 return Err(SchemaParseError::InvalidConstraint(format!("{max:?} is not a valid integer.")));
479 },
480 }
481 }
482
483 if min_ > max_ {
484 return Err(SchemaParseError::InvalidConstraint(format!("`min` ({min_}) is greater than `max` ({max_}).")));
485 }
486
487 if matches!(ty, SchemaType::String) || matches!(ty, SchemaType::Array(_)) {
488 if constraint.min.is_some() && min_ < 0 {
489 return Err(SchemaParseError::InvalidConstraint(format!("`min` is supposed to be a positive integer, but is {min_}")));
490 }
491
492 if constraint.max.is_some() && max_ < 0 {
493 return Err(SchemaParseError::InvalidConstraint(format!("`max` is supposed to be a positive integer, but is {max_}")));
494 }
495 }
496
497 Ok(())
498 },
499 (SchemaType::Float, Some(constraint)) => {
500 let mut min_ = f64::MIN;
501 let mut max_ = f64::MAX;
502
503 if let Some(min) = &constraint.min {
504 match min.parse::<f64>() {
505 Ok(n) => { min_ = n; },
506 Err(_) => {
507 return Err(SchemaParseError::InvalidConstraint(format!("{min:?} is not a valid number.")));
508 },
509 }
510 }
511
512 if let Some(max) = &constraint.max {
513 match max.parse::<f64>() {
514 Ok(n) => { max_ = n; },
515 Err(_) => {
516 return Err(SchemaParseError::InvalidConstraint(format!("{max:?} is not a valid number.")));
517 },
518 }
519 }
520
521 if min_ > max_ {
522 return Err(SchemaParseError::InvalidConstraint(format!("`min` ({min_}) is greater than `max` ({max_}).")));
523 }
524
525 Ok(())
526 },
527 (ty @ (SchemaType::Null | SchemaType::Boolean | SchemaType::Object(_) | SchemaType::Yesno), Some(constraint)) => {
528 if constraint.min.is_some() {
529 Err(SchemaParseError::InvalidConstraint(format!(
530 "Type `{}` cannot have constraint `min`",
531 ty.type_name(),
532 )))
533 }
534
535 else if constraint.max.is_some() {
536 Err(SchemaParseError::InvalidConstraint(format!(
537 "Type `{}` cannot have constraint `max`",
538 ty.type_name(),
539 )))
540 }
541
542 else {
543 Ok(())
544 }
545 },
546 (_, None) => Ok(()),
547 }
548 }
549
550 pub fn type_name(&self) -> &'static str {
551 self.r#type.type_name()
552 }
553
554 pub fn unwrap_keys(&self) -> Vec<String> {
555 self.r#type.unwrap_keys()
556 }
557
558 fn one_word(&self) -> &'static str {
561 self.r#type.one_word()
562 }
563}
564
565pub fn render_pdl_schema(
573 schema: &Schema,
574
575 value: &Value,
577) -> Result<String, Error> {
578 let s = match (&schema.r#type, value) {
579 (SchemaType::Code, Value::String(s)) => s.to_string(),
580 (SchemaType::TaskList, Value::String(s)) => s.to_string(),
581 (SchemaType::Yesno, Value::Bool(b)) => if *b {
582 String::from("yes")
583 } else {
584 String::from("no")
585 },
586 _ => serde_json::to_string_pretty(value)?,
587 };
588
589 Ok(s)
590}
591
592#[derive(Clone, Debug, Default, PartialEq)]
594pub struct Constraint {
595 min: Option<String>,
599 max: Option<String>,
600}
601
602fn get_schema_type(v: &Value) -> SchemaType {
603 match v {
604 Value::Number(n) => {
605 if n.is_i64() {
606 SchemaType::Integer
607 }
608
609 else {
610 SchemaType::Float
611 }
612 },
613 Value::String(_) => SchemaType::String,
614 Value::Array(_) => SchemaType::Array(None),
615 Value::Object(_) => SchemaType::Object(vec![]),
616 Value::Bool(_) => SchemaType::Boolean,
617 Value::Null => SchemaType::Null,
618 }
619}
620
621fn check_range<T: PartialOrd + FromStr + ToString + Display>(schema: SchemaType, constraint: &Option<Constraint>, n: T) -> Result<(), SchemaError> where <T as FromStr>::Err: Debug {
622 if let Some(constraint) = constraint {
624 if let Constraint { min: Some(min), .. } = &constraint {
625 let min = min.parse::<T>().unwrap();
626
627 if n < min {
628 return Err(SchemaError::RangeError {
629 s1: String::from(if schema.is_number() { "small" } else { "short" }),
630 s2: String::from(if schema.is_number() { "is at least" } else { "has at least" }),
631 s3: if schema.is_number() { min.to_string() } else if schema.is_array() { format!("{min} elements") } else { format!("{min} characters") },
632 s4: if schema.is_number() { None } else if schema.is_array() { Some(format!("{n} elements")) } else { Some(format!("{n} characters")) },
633 });
634 }
635 }
636
637 if let Constraint { max: Some(max), .. } = &constraint {
638 let max = max.parse::<T>().unwrap();
639
640 if n > max {
641 return Err(SchemaError::RangeError {
642 s1: String::from(if schema.is_number() { "big" } else { "long" }),
643 s2: String::from(if schema.is_number() { "is at most" } else { "has at most" }),
644 s3: if schema.is_number() { max.to_string() } else if schema.is_array() { format!("{max} elements") } else { format!("{max} characters") },
645 s4: if schema.is_number() { None } else if schema.is_array() { Some(format!("{n} elements")) } else { Some(format!("{n} characters")) },
646 });
647 }
648 }
649 }
650
651 Ok(())
652}