use async_trait::async_trait;
use futures_util::Stream;
use std::pin::Pin;
use crate::core::runnables::{Runnable, RunnableConfig};
use super::base::{BaseOutputParser, OutputParserError, OutputParserResult};
pub struct JsonOutputParser {
partial: bool,
}
impl JsonOutputParser {
pub fn new() -> Self {
Self { partial: false }
}
pub fn new_partial() -> Self {
Self { partial: true }
}
fn extract_json_str<'a>(&self, text: &'a str) -> OutputParserResult<&'a str> {
let text = text.trim();
if let Some(start) = text.find("```json") {
let content = &text[start + 7..];
if let Some(end) = content.find("```") {
return Ok(content[..end].trim());
}
}
if let Some(start) = text.find("```") {
let content = &text[start + 3..];
let content = content.trim();
let skip_to_newline = content.find('\n').unwrap_or(0);
let content = &content[skip_to_newline..];
if let Some(end) = content.find("```") {
return Ok(content[..end].trim());
}
}
Ok(text)
}
}
impl Default for JsonOutputParser {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BaseOutputParser<serde_json::Value> for JsonOutputParser {
async fn parse(&self, text: &str) -> OutputParserResult<serde_json::Value> {
let json_str = self.extract_json_str(text)?;
if self.partial {
self.parse_partial_json(json_str)
} else {
serde_json::from_str(json_str).map_err(|e| {
OutputParserError::JsonError(format!(
"JSON 解析失败(位置 {}:{}):{},输入:{}",
e.line(),
e.column(),
e,
&json_str[..std::cmp::min(200, json_str.len())]
))
})
}
}
fn get_format_instructions(&self) -> String {
"请使用 JSON 格式输出,例如:{\"key\": \"value\"}。确保 JSON 是合法的。".to_string()
}
}
impl JsonOutputParser {
fn parse_partial_json(&self, text: &str) -> OutputParserResult<serde_json::Value> {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
return Ok(value);
}
let repaired = self.repair_partial_json(text);
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&repaired) {
return Ok(value);
}
Err(OutputParserError::JsonError(format!(
"部分 JSON 解析失败:{}",
&text[..std::cmp::min(200, text.len())]
)))
}
fn repair_partial_json(&self, text: &str) -> String {
let mut repaired = text.trim().to_string();
if let Some(stripped) = Self::strip_incomplete_token(&repaired) {
repaired = stripped;
}
let open_braces = repaired.matches('{').count();
let close_braces = repaired.matches('}').count();
for _ in close_braces..open_braces {
repaired.push('}');
}
let open_brackets = repaired.matches('[').count();
let close_brackets = repaired.matches(']').count();
for _ in close_brackets..open_brackets {
repaired.push(']');
}
let mut chars = repaired.chars().rev().peekable();
let mut in_string = false;
let mut truncate_at = repaired.len();
while let Some(c) = chars.next() {
if c == '"' {
if chars.peek() == Some(&'\\') {
continue;
}
in_string = !in_string;
}
if in_string && c == '\n' {
truncate_at = repaired.len() - chars.count() - 1;
break;
}
}
if in_string && truncate_at < repaired.len() {
repaired.truncate(truncate_at);
repaired.push('"');
}
repaired
}
fn strip_incomplete_token(s: &str) -> Option<String> {
let trimmed = s.trim_end();
if trimmed.len() < s.len() {
return None;
}
if trimmed.ends_with('"') {
return None; }
None
}
}
#[async_trait]
impl Runnable<String, serde_json::Value> for JsonOutputParser {
type Error = OutputParserError;
async fn invoke(&self, input: String, _config: Option<RunnableConfig>) -> Result<serde_json::Value, Self::Error> {
self.parse(&input).await
}
async fn stream(
&self,
input: String,
_config: Option<RunnableConfig>,
) -> Result<Pin<Box<dyn Stream<Item = Result<serde_json::Value, Self::Error>> + Send>>, Self::Error> {
let result = self.parse(&input).await?;
let stream = futures_util::stream::once(async move { Ok(result) });
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_json_parser_standard_obj() {
let parser = JsonOutputParser::new();
let result = parser.parse(r#"{"name": "Rust", "year": 2015}"#).await.unwrap();
assert_eq!(result["name"], "Rust");
assert_eq!(result["year"], 2015);
}
#[tokio::test]
async fn test_json_parser_from_markdown_block() {
let parser = JsonOutputParser::new();
let input = "以下是结果:\n```json\n{\"status\": \"ok\"}\n```\n";
let result = parser.parse(input).await.unwrap();
assert_eq!(result["status"], "ok");
}
#[tokio::test]
async fn test_json_parser_array() {
let parser = JsonOutputParser::new();
let result = parser.parse("[1, 2, 3]").await.unwrap();
assert_eq!(result[0], 1);
assert_eq!(result[2], 3);
}
#[tokio::test]
async fn test_json_parser_invalid_json() {
let parser = JsonOutputParser::new();
let result = parser.parse("{invalid}").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_json_parser_format_instructions() {
let parser = JsonOutputParser::new();
let instructions = parser.get_format_instructions();
assert!(!instructions.is_empty());
}
#[tokio::test]
async fn test_json_parser_invoke_runnable() {
let parser = JsonOutputParser::new();
let result = parser.invoke(r#"{"key": "value"}"#.to_string(), None).await.unwrap();
assert_eq!(result["key"], "value");
}
#[tokio::test]
async fn test_json_parser_partial_success() {
let parser = JsonOutputParser::new_partial();
let result = parser.parse(r#"{"a": 1}"#).await.unwrap();
assert_eq!(result["a"], 1);
}
}