use crate::messages::AssistantMessage;
const VALID_JSON_ESCAPES: &[char] = &['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u'];
fn is_control_character(ch: char) -> bool {
ch as u32 <= 0x1F
}
fn escape_control_character(ch: char) -> String {
match ch {
'\u{0008}' => "\\b".to_string(),
'\u{000C}' => "\\f".to_string(),
'\n' => "\\n".to_string(),
'\r' => "\\r".to_string(),
'\t' => "\\t".to_string(),
_ => format!("\\u{:04x}", ch as u32),
}
}
pub fn repair_json(json: &str) -> String {
let mut repaired = String::with_capacity(json.len());
let mut in_string = false;
let chars: Vec<char> = json.chars().collect();
let len = chars.len();
let mut index = 0;
while index < len {
let ch = chars[index];
if !in_string {
repaired.push(ch);
if ch == '"' {
in_string = true;
}
index += 1;
continue;
}
if ch == '"' {
repaired.push(ch);
in_string = false;
index += 1;
continue;
}
if ch == '\\' {
if index + 1 >= len {
repaired.push_str("\\\\");
index += 1;
continue;
}
let next_ch = chars[index + 1];
if next_ch == 'u' {
let unicode_digits: String = chars[index + 2..std::cmp::min(index + 6, len)]
.iter()
.collect();
if unicode_digits.len() == 4
&& unicode_digits.chars().all(|c| c.is_ascii_hexdigit())
{
repaired.push_str(&format!("\\u{}", unicode_digits));
index += 6;
continue;
}
}
if VALID_JSON_ESCAPES.contains(&next_ch) {
repaired.push('\\');
repaired.push(next_ch);
index += 2;
continue;
}
repaired.push_str("\\\\");
index += 1;
continue;
}
if is_control_character(ch) {
repaired.push_str(&escape_control_character(ch));
} else {
repaired.push(ch);
}
index += 1;
}
repaired
}
pub fn parse_json_with_repair<T: serde::de::DeserializeOwned>(
json: &str,
) -> Result<T, serde_json::Error> {
match serde_json::from_str(json) {
Ok(result) => Ok(result),
Err(original_error) => {
let repaired = repair_json(json);
if repaired != json {
match serde_json::from_str(&repaired) {
Ok(result) => Ok(result),
Err(_) => Err(original_error),
}
} else {
Err(original_error)
}
}
}
}
pub fn parse_streaming_json<T: serde::de::DeserializeOwned + Default>(
json: &str,
) -> T {
let trimmed = json.trim();
if trimmed.is_empty() {
return T::default();
}
if let Ok(result) = serde_json::from_str(trimmed) {
return result;
}
if let Ok(result) = parse_json_with_repair(trimmed) {
return result;
}
if let Some(result) = parse_partial_json(trimmed) {
return result;
}
let repaired = repair_json(trimmed);
if repaired != trimmed {
if let Some(result) = parse_partial_json(&repaired) {
return result;
}
}
T::default()
}
fn parse_partial_json<T: serde::de::DeserializeOwned>(json: &str) -> Option<T> {
let trimmed = json.trim();
if !trimmed.starts_with('{') && !trimmed.starts_with('[') {
return None;
}
let _close_char = if trimmed.starts_with('{') {
'}'
} else {
']'
};
let _open_char = if trimmed.starts_with('{') {
'{'
} else {
'['
};
let mut depth = 0;
let mut in_string = false;
let mut last_valid_close = None;
let bytes = trimmed.as_bytes();
for (i, &b) in bytes.iter().enumerate() {
if in_string {
if b == b'"' {
in_string = false;
} else if b == b'\\' {
continue;
}
continue;
}
match b {
b'"' => in_string = true,
b'{' | b'[' => depth += 1,
b'}' | b']' => {
depth -= 1;
if depth == 0 {
last_valid_close = Some(i);
}
}
_ => {}
}
}
if let Some(pos) = last_valid_close {
let candidate = &trimmed[..=pos];
if let Ok(result) = serde_json::from_str(candidate) {
return Some(result);
}
}
None
}
pub fn parse_sse_data<T: serde::de::DeserializeOwned + Default>(
line: &str,
) -> Option<T> {
let line = line.trim();
if !line.starts_with("data: ") {
return None;
}
let data = &line[6..];
if data.is_empty() || data == "[DONE]" {
return None;
}
Some(parse_streaming_json(data))
}
pub fn extract_error_message(message: &AssistantMessage) -> String {
message
.error_message
.clone()
.unwrap_or_else(|| "Unknown error".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq, Default)]
struct TestObj {
name: String,
value: Option<i64>,
}
#[test]
fn test_repair_json_valid() {
let json = r#"{"name": "test"}"#;
assert_eq!(repair_json(json), json);
}
#[test]
fn test_repair_json_control_chars() {
let json = "{\"name\": \"hello\nworld\"}";
let repaired = repair_json(json);
assert!(repaired.contains("\\n"));
assert!(!repaired.contains("hello\nworld"));
}
#[test]
fn test_repair_json_tab() {
let json = "{\"name\": \"hello\tworld\"}";
let repaired = repair_json(json);
assert!(repaired.contains("\\t"));
}
#[test]
fn test_repair_json_invalid_escape() {
let json = r#"{"name": "hello\qworld"}"#;
let repaired = repair_json(json);
assert!(repaired.contains("\\\\q") || repaired.contains(r#"\\q"#));
}
#[test]
fn test_repair_json_trailing_backslash() {
let json = r#"{"name": "test\"#;
let repaired = repair_json(json);
assert!(repaired.contains("\\\\"));
}
#[test]
fn test_repair_json_valid_escapes_preserved() {
let json = r#"{"name": "hello\nworld"}"#;
let repaired = repair_json(json);
assert_eq!(repaired, json);
}
#[test]
fn test_repair_json_unicode_escape_preserved() {
let json = r#"{"name": "\u0041"}"#;
let repaired = repair_json(json);
assert_eq!(repaired, json);
}
#[test]
fn test_parse_json_with_repair_valid() {
let result: TestObj = parse_json_with_repair(r#"{"name": "test", "value": 42}"#).unwrap();
assert_eq!(result.name, "test");
assert_eq!(result.value, Some(42));
}
#[test]
fn test_parse_json_with_repair_control_chars() {
let json = "{\"name\": \"hello\nworld\"}";
let result: TestObj = parse_json_with_repair(json).unwrap();
assert_eq!(result.name, "hello\nworld");
}
#[test]
fn test_parse_streaming_json_valid() {
let result: TestObj = parse_streaming_json(r#"{"name": "test"}"#);
assert_eq!(result.name, "test");
}
#[test]
fn test_parse_streaming_json_empty() {
let result: TestObj = parse_streaming_json("");
assert_eq!(result, TestObj::default());
}
#[test]
fn test_parse_streaming_json_whitespace() {
let result: TestObj = parse_streaming_json(" ");
assert_eq!(result, TestObj::default());
}
#[test]
fn test_parse_streaming_json_partial() {
let result: TestObj = parse_streaming_json(r#"{"name": "test"}, "extra""#);
assert_eq!(result.name, "test");
}
#[test]
fn test_parse_sse_data_valid() {
let result: TestObj = parse_sse_data(r#"data: {"name": "test"}"#).unwrap();
assert_eq!(result.name, "test");
}
#[test]
fn test_parse_sse_data_done() {
let result: Option<TestObj> = parse_sse_data("data: [DONE]");
assert!(result.is_none());
}
#[test]
fn test_parse_sse_data_not_data_line() {
let result: Option<TestObj> = parse_sse_data("event: message");
assert!(result.is_none());
}
#[test]
fn test_parse_sse_data_empty_data() {
let result: Option<TestObj> = parse_sse_data("data: ");
assert!(result.is_none());
}
#[test]
fn test_escape_control_character_special() {
assert_eq!(escape_control_character('\n'), "\\n");
assert_eq!(escape_control_character('\r'), "\\r");
assert_eq!(escape_control_character('\t'), "\\t");
assert_eq!(escape_control_character('\u{0008}'), "\\b");
assert_eq!(escape_control_character('\u{000C}'), "\\f");
}
#[test]
fn test_escape_control_character_generic() {
assert_eq!(escape_control_character('\u{0001}'), "\\u0001");
assert_eq!(escape_control_character('\u{001F}'), "\\u001f");
}
}