1#![forbid(unsafe_code)]
2
3use std::collections::HashSet;
6
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9
10mod lenient_json;
11mod validate;
12
13pub use lenient_json::parse_lenient_json_value;
14#[cfg(feature = "derive")]
15pub use rustia_macros::LLMData;
16pub use serde;
17pub use serde_json;
18pub use validate::{
19 IValidation, IValidationError, TagRuntime, Validate, apply_tags, join_index_path,
20 join_object_path, merge_prefixed_errors, prepend_path,
21};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct LlmJsonParseError {
26 pub path: String,
27 pub expected: String,
28 pub description: String,
29}
30
31#[derive(Debug, Clone, PartialEq)]
33pub enum LlmJsonParseResult<T> {
34 Success {
35 data: T,
36 },
37 Failure {
38 data: Option<serde_json::Value>,
39 input: String,
40 errors: Vec<LlmJsonParseError>,
41 },
42}
43
44pub trait LLMData: Validate + serde::Serialize + DeserializeOwned + Sized {
46 fn parse(input: &str) -> LlmJsonParseResult<Self> {
49 match parse_lenient_json_value(input) {
50 LlmJsonParseResult::Success { data } => {
51 match validate_with_parse_coercion::<Self>(data) {
52 CoercionValidation::Success { data, .. } => {
53 LlmJsonParseResult::Success { data }
54 }
55 CoercionValidation::Failure { value, errors } => LlmJsonParseResult::Failure {
56 data: Some(value),
57 input: input.to_owned(),
58 errors: map_validation_errors(errors),
59 },
60 }
61 }
62 LlmJsonParseResult::Failure {
63 data,
64 input,
65 mut errors,
66 } => {
67 let data = data.map(|value| match validate_with_parse_coercion::<Self>(value) {
68 CoercionValidation::Success { value, .. } => value,
69 CoercionValidation::Failure {
70 value: coerced,
71 errors: validation_errors,
72 } => {
73 errors.extend(map_validation_errors(validation_errors));
74 coerced
75 }
76 });
77
78 LlmJsonParseResult::Failure {
79 data,
80 input,
81 errors,
82 }
83 }
84 }
85 }
86
87 fn stringify(&self) -> Result<String, serde_json::Error> {
89 serde_json::to_string(self)
90 }
91}
92
93#[doc(hidden)]
94pub mod __private {
95 pub use crate::validate::__private::*;
96}
97
98const MAX_PARSE_COERCION_ROUNDS: usize = 16;
99
100enum CoercionValidation<T> {
101 Success {
102 data: T,
103 value: Value,
104 },
105 Failure {
106 value: Value,
107 errors: Vec<IValidationError>,
108 },
109}
110
111enum JsonPathSegment {
112 Key(String),
113 Index(usize),
114}
115
116fn validate_with_parse_coercion<T>(mut value: Value) -> CoercionValidation<T>
117where
118 T: Validate,
119{
120 for _ in 0..MAX_PARSE_COERCION_ROUNDS {
121 match T::validate(value.clone()) {
122 IValidation::Success { data } => return CoercionValidation::Success { data, value },
123 IValidation::Failure { errors, .. } => {
124 if !coerce_value_from_errors(&mut value, &errors) {
125 return CoercionValidation::Failure { value, errors };
126 }
127 }
128 }
129 }
130
131 match T::validate(value.clone()) {
132 IValidation::Success { data } => CoercionValidation::Success { data, value },
133 IValidation::Failure { errors, .. } => CoercionValidation::Failure { value, errors },
134 }
135}
136
137fn coerce_value_from_errors(value: &mut Value, errors: &[IValidationError]) -> bool {
138 let mut changed = false;
139 let mut seen = HashSet::new();
140
141 for error in errors {
142 if !seen.insert(error.path.clone()) {
143 continue;
144 }
145 let Some(path) = parse_validation_path(&error.path) else {
146 continue;
147 };
148 if coerce_stringified_path(value, &path) {
149 changed = true;
150 }
151 }
152
153 changed
154}
155
156fn coerce_stringified_path(root: &mut Value, path: &[JsonPathSegment]) -> bool {
157 let Some(target) = value_mut_on_path(root, path) else {
158 return false;
159 };
160
161 let raw = match target {
162 Value::String(raw) => raw.clone(),
163 _ => return false,
164 };
165
166 let Some(coerced) = parse_stringified_non_string(&raw) else {
167 return false;
168 };
169 *target = coerced;
170 true
171}
172
173fn parse_stringified_non_string(raw: &str) -> Option<Value> {
174 let mut cursor = raw.to_owned();
175 for _ in 0..MAX_PARSE_COERCION_ROUNDS {
176 let LlmJsonParseResult::Success { data } = parse_lenient_json_value(&cursor) else {
177 return None;
178 };
179 match data {
180 Value::String(next) => {
181 if next == cursor {
182 return None;
183 }
184 cursor = next;
185 }
186 other => return Some(other),
187 }
188 }
189 None
190}
191
192fn value_mut_on_path<'a>(root: &'a mut Value, path: &[JsonPathSegment]) -> Option<&'a mut Value> {
193 let mut cursor = root;
194
195 for segment in path {
196 match segment {
197 JsonPathSegment::Key(key) => {
198 cursor = cursor.as_object_mut()?.get_mut(key)?;
199 }
200 JsonPathSegment::Index(index) => {
201 cursor = cursor.as_array_mut()?.get_mut(*index)?;
202 }
203 }
204 }
205
206 Some(cursor)
207}
208
209fn parse_validation_path(path: &str) -> Option<Vec<JsonPathSegment>> {
210 let mut chars = path.chars().peekable();
211
212 for expected in "$input".chars() {
213 if chars.next()? != expected {
214 return None;
215 }
216 }
217
218 let mut output = Vec::new();
219
220 while let Some(ch) = chars.peek().copied() {
221 match ch {
222 '.' => {
223 chars.next();
224 let mut key = String::new();
225 while let Some(next) = chars.peek().copied() {
226 if matches!(next, '.' | '[') {
227 break;
228 }
229 key.push(next);
230 chars.next();
231 }
232 if key.is_empty() {
233 continue;
234 }
235 output.push(JsonPathSegment::Key(key));
236 }
237 '[' => {
238 chars.next();
239 match chars.peek().copied() {
240 Some('"') => {
241 chars.next();
242 let mut key = String::new();
243 let mut escaped = false;
244
245 for next in chars.by_ref() {
246 if escaped {
247 key.push(next);
248 escaped = false;
249 continue;
250 }
251 if next == '\\' {
252 escaped = true;
253 continue;
254 }
255 if next == '"' {
256 break;
257 }
258 key.push(next);
259 }
260
261 if escaped || chars.next() != Some(']') {
262 return None;
263 }
264 output.push(JsonPathSegment::Key(key));
265 }
266 Some(next) if next.is_ascii_digit() => {
267 let mut digits = String::new();
268 while let Some(digit) = chars.peek().copied() {
269 if !digit.is_ascii_digit() {
270 break;
271 }
272 digits.push(digit);
273 chars.next();
274 }
275 if chars.next() != Some(']') {
276 return None;
277 }
278 let index = digits.parse::<usize>().ok()?;
279 output.push(JsonPathSegment::Index(index));
280 }
281 _ => return None,
282 }
283 }
284 _ => return None,
285 }
286 }
287
288 Some(output)
289}
290
291fn map_validation_errors(errors: Vec<IValidationError>) -> Vec<LlmJsonParseError> {
292 errors
293 .into_iter()
294 .map(|error| LlmJsonParseError {
295 path: error.path,
296 expected: error.expected,
297 description: error
298 .description
299 .unwrap_or_else(|| "validation failed".to_owned()),
300 })
301 .collect()
302}