Skip to main content

cognis_core/output_parsers/
json.rs

1//! Typed JSON parser — deserializes LLM output into any `DeserializeOwned`.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::de::DeserializeOwned;
7
8use crate::output_parsers::OutputParser;
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12/// Parses JSON output into a typed value `T`.
13///
14/// Tolerates code-fenced output (` ```json ... ``` ` or plain ` ``` ... ``` `)
15/// by stripping the fences before parsing. Tolerates leading/trailing
16/// whitespace.
17#[derive(Debug, Clone, Copy)]
18pub struct JsonParser<T> {
19    _t: PhantomData<fn() -> T>,
20}
21
22impl<T> Default for JsonParser<T> {
23    fn default() -> Self {
24        Self { _t: PhantomData }
25    }
26}
27
28impl<T> JsonParser<T> {
29    /// Construct a `JsonParser<T>`.
30    pub fn new() -> Self {
31        Self::default()
32    }
33}
34
35impl<T> OutputParser<T> for JsonParser<T>
36where
37    T: DeserializeOwned + Send + 'static,
38{
39    fn parse(&self, text: &str) -> Result<T> {
40        let cleaned = strip_code_fence(text.trim());
41        serde_json::from_str(cleaned)
42            .map_err(|e| CognisError::Serialization(format!("json parse: {e}")))
43    }
44
45    fn format_instructions(&self) -> Option<String> {
46        Some(
47            "Reply with a single JSON object. Do not include any text before \
48             or after the JSON. Do not wrap the JSON in markdown code fences."
49                .to_string(),
50        )
51    }
52}
53
54#[async_trait]
55impl<T> Runnable<String, T> for JsonParser<T>
56where
57    T: DeserializeOwned + Send + 'static,
58{
59    async fn invoke(&self, input: String, _: RunnableConfig) -> Result<T> {
60        OutputParser::parse(self, &input)
61    }
62    fn name(&self) -> &str {
63        "JsonParser"
64    }
65}
66
67fn strip_code_fence(s: &str) -> &str {
68    let s = s.trim();
69    let s = s
70        .strip_prefix("```json")
71        .or_else(|| s.strip_prefix("```"))
72        .unwrap_or(s);
73    let s = s.trim_start_matches('\n').trim_start();
74    s.strip_suffix("```").unwrap_or(s).trim_end()
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use serde::Deserialize;
81
82    #[derive(Debug, Deserialize, PartialEq)]
83    struct Foo {
84        bar: String,
85        baz: u32,
86    }
87
88    #[tokio::test]
89    async fn parses_plain_json() {
90        let p: JsonParser<Foo> = JsonParser::new();
91        let out = p
92            .invoke(r#"{"bar":"hi","baz":7}"#.into(), RunnableConfig::default())
93            .await
94            .unwrap();
95        assert_eq!(
96            out,
97            Foo {
98                bar: "hi".into(),
99                baz: 7
100            }
101        );
102    }
103
104    #[test]
105    fn strips_json_code_fence() {
106        let p: JsonParser<Foo> = JsonParser::new();
107        let out = p.parse("```json\n{\"bar\":\"hi\",\"baz\":7}\n```").unwrap();
108        assert_eq!(out.baz, 7);
109    }
110
111    #[test]
112    fn strips_plain_code_fence() {
113        let p: JsonParser<Foo> = JsonParser::new();
114        let out = p.parse("```\n{\"bar\":\"hi\",\"baz\":1}\n```").unwrap();
115        assert_eq!(out.baz, 1);
116    }
117
118    #[test]
119    fn invalid_json_errors_with_serialization_kind() {
120        let p: JsonParser<Foo> = JsonParser::new();
121        let err = p.parse("not json").unwrap_err();
122        assert!(matches!(err, CognisError::Serialization(_)));
123    }
124
125    #[test]
126    fn format_instructions_set() {
127        let p: JsonParser<Foo> = JsonParser::new();
128        assert!(OutputParser::format_instructions(&p).is_some());
129    }
130}