Skip to main content

rustia/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Serde-based LLM JSON utilities and data trait for `rustia`.
4
5use std::collections::HashSet;
6
7use serde::de::DeserializeOwned;
8use serde_json::Value;
9
10mod lenient_json;
11mod validate;
12
13pub use lenient_json::parse_lenient_json_value;
14#[cfg(feature = "derive")]
15pub use rustia_macros::LLMData;
16pub use serde;
17pub use serde_json;
18pub use validate::{
19    IValidation, IValidationError, TagRuntime, Validate, apply_tags, join_index_path,
20    join_object_path, merge_prefixed_errors, prepend_path,
21};
22
23/// Detailed parsing error emitted by the lenient parser or serde validation.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct LlmJsonParseError {
26    pub path: String,
27    pub expected: String,
28    pub description: String,
29}
30
31/// Result of LLM JSON parsing.
32#[derive(Debug, Clone, PartialEq)]
33pub enum LlmJsonParseResult<T> {
34    Success {
35        data: T,
36    },
37    Failure {
38        data: Option<serde_json::Value>,
39        input: String,
40        errors: Vec<LlmJsonParseError>,
41    },
42}
43
44/// Trait for Serde-powered LLM data parsing/validation/stringification.
45pub trait LLMData: Validate + serde::Serialize + DeserializeOwned + Sized {
46    /// Parse raw LLM output using rustia's lenient JSON parser and then
47    /// validate the shape through serde deserialization.
48    fn parse(input: &str) -> LlmJsonParseResult<Self> {
49        match parse_lenient_json_value(input) {
50            LlmJsonParseResult::Success { data } => {
51                match validate_with_parse_coercion::<Self>(data) {
52                    CoercionValidation::Success { data, .. } => {
53                        LlmJsonParseResult::Success { data }
54                    }
55                    CoercionValidation::Failure { value, errors } => LlmJsonParseResult::Failure {
56                        data: Some(value),
57                        input: input.to_owned(),
58                        errors: map_validation_errors(errors),
59                    },
60                }
61            }
62            LlmJsonParseResult::Failure {
63                data,
64                input,
65                mut errors,
66            } => {
67                let data = data.map(|value| match validate_with_parse_coercion::<Self>(value) {
68                    CoercionValidation::Success { value, .. } => value,
69                    CoercionValidation::Failure {
70                        value: coerced,
71                        errors: validation_errors,
72                    } => {
73                        errors.extend(map_validation_errors(validation_errors));
74                        coerced
75                    }
76                });
77
78                LlmJsonParseResult::Failure {
79                    data,
80                    input,
81                    errors,
82                }
83            }
84        }
85    }
86
87    /// Serialize into compact JSON text.
88    fn stringify(&self) -> Result<String, serde_json::Error> {
89        serde_json::to_string(self)
90    }
91}
92
93#[doc(hidden)]
94pub mod __private {
95    pub use crate::validate::__private::*;
96}
97
98const MAX_PARSE_COERCION_ROUNDS: usize = 16;
99
100enum CoercionValidation<T> {
101    Success {
102        data: T,
103        value: Value,
104    },
105    Failure {
106        value: Value,
107        errors: Vec<IValidationError>,
108    },
109}
110
111enum JsonPathSegment {
112    Key(String),
113    Index(usize),
114}
115
116fn validate_with_parse_coercion<T>(mut value: Value) -> CoercionValidation<T>
117where
118    T: Validate,
119{
120    for _ in 0..MAX_PARSE_COERCION_ROUNDS {
121        match T::validate(value.clone()) {
122            IValidation::Success { data } => return CoercionValidation::Success { data, value },
123            IValidation::Failure { errors, .. } => {
124                if !coerce_value_from_errors(&mut value, &errors) {
125                    return CoercionValidation::Failure { value, errors };
126                }
127            }
128        }
129    }
130
131    match T::validate(value.clone()) {
132        IValidation::Success { data } => CoercionValidation::Success { data, value },
133        IValidation::Failure { errors, .. } => CoercionValidation::Failure { value, errors },
134    }
135}
136
137fn coerce_value_from_errors(value: &mut Value, errors: &[IValidationError]) -> bool {
138    let mut changed = false;
139    let mut seen = HashSet::new();
140
141    for error in errors {
142        if !seen.insert(error.path.clone()) {
143            continue;
144        }
145        let Some(path) = parse_validation_path(&error.path) else {
146            continue;
147        };
148        if coerce_stringified_path(value, &path) {
149            changed = true;
150        }
151    }
152
153    changed
154}
155
156fn coerce_stringified_path(root: &mut Value, path: &[JsonPathSegment]) -> bool {
157    let Some(target) = value_mut_on_path(root, path) else {
158        return false;
159    };
160
161    let raw = match target {
162        Value::String(raw) => raw.clone(),
163        _ => return false,
164    };
165
166    let Some(coerced) = parse_stringified_non_string(&raw) else {
167        return false;
168    };
169    *target = coerced;
170    true
171}
172
173fn parse_stringified_non_string(raw: &str) -> Option<Value> {
174    let mut cursor = raw.to_owned();
175    for _ in 0..MAX_PARSE_COERCION_ROUNDS {
176        let LlmJsonParseResult::Success { data } = parse_lenient_json_value(&cursor) else {
177            return None;
178        };
179        match data {
180            Value::String(next) => {
181                if next == cursor {
182                    return None;
183                }
184                cursor = next;
185            }
186            other => return Some(other),
187        }
188    }
189    None
190}
191
192fn value_mut_on_path<'a>(root: &'a mut Value, path: &[JsonPathSegment]) -> Option<&'a mut Value> {
193    let mut cursor = root;
194
195    for segment in path {
196        match segment {
197            JsonPathSegment::Key(key) => {
198                cursor = cursor.as_object_mut()?.get_mut(key)?;
199            }
200            JsonPathSegment::Index(index) => {
201                cursor = cursor.as_array_mut()?.get_mut(*index)?;
202            }
203        }
204    }
205
206    Some(cursor)
207}
208
209fn parse_validation_path(path: &str) -> Option<Vec<JsonPathSegment>> {
210    let mut chars = path.chars().peekable();
211
212    for expected in "$input".chars() {
213        if chars.next()? != expected {
214            return None;
215        }
216    }
217
218    let mut output = Vec::new();
219
220    while let Some(ch) = chars.peek().copied() {
221        match ch {
222            '.' => {
223                chars.next();
224                let mut key = String::new();
225                while let Some(next) = chars.peek().copied() {
226                    if matches!(next, '.' | '[') {
227                        break;
228                    }
229                    key.push(next);
230                    chars.next();
231                }
232                if key.is_empty() {
233                    continue;
234                }
235                output.push(JsonPathSegment::Key(key));
236            }
237            '[' => {
238                chars.next();
239                match chars.peek().copied() {
240                    Some('"') => {
241                        chars.next();
242                        let mut key = String::new();
243                        let mut escaped = false;
244
245                        for next in chars.by_ref() {
246                            if escaped {
247                                key.push(next);
248                                escaped = false;
249                                continue;
250                            }
251                            if next == '\\' {
252                                escaped = true;
253                                continue;
254                            }
255                            if next == '"' {
256                                break;
257                            }
258                            key.push(next);
259                        }
260
261                        if escaped || chars.next() != Some(']') {
262                            return None;
263                        }
264                        output.push(JsonPathSegment::Key(key));
265                    }
266                    Some(next) if next.is_ascii_digit() => {
267                        let mut digits = String::new();
268                        while let Some(digit) = chars.peek().copied() {
269                            if !digit.is_ascii_digit() {
270                                break;
271                            }
272                            digits.push(digit);
273                            chars.next();
274                        }
275                        if chars.next() != Some(']') {
276                            return None;
277                        }
278                        let index = digits.parse::<usize>().ok()?;
279                        output.push(JsonPathSegment::Index(index));
280                    }
281                    _ => return None,
282                }
283            }
284            _ => return None,
285        }
286    }
287
288    Some(output)
289}
290
291fn map_validation_errors(errors: Vec<IValidationError>) -> Vec<LlmJsonParseError> {
292    errors
293        .into_iter()
294        .map(|error| LlmJsonParseError {
295            path: error.path,
296            expected: error.expected,
297            description: error
298                .description
299                .unwrap_or_else(|| "validation failed".to_owned()),
300        })
301        .collect()
302}