cognis_core/output_parsers/
json.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use serde::de::DeserializeOwned;
7
8use crate::output_parsers::OutputParser;
9use crate::runnable::{Runnable, RunnableConfig};
10use crate::{CognisError, Result};
11
12#[derive(Debug, Clone, Copy)]
18pub struct JsonParser<T> {
19 _t: PhantomData<fn() -> T>,
20}
21
22impl<T> Default for JsonParser<T> {
23 fn default() -> Self {
24 Self { _t: PhantomData }
25 }
26}
27
28impl<T> JsonParser<T> {
29 pub fn new() -> Self {
31 Self::default()
32 }
33}
34
35impl<T> OutputParser<T> for JsonParser<T>
36where
37 T: DeserializeOwned + Send + 'static,
38{
39 fn parse(&self, text: &str) -> Result<T> {
40 let cleaned = strip_code_fence(text.trim());
41 serde_json::from_str(cleaned)
42 .map_err(|e| CognisError::Serialization(format!("json parse: {e}")))
43 }
44
45 fn format_instructions(&self) -> Option<String> {
46 Some(
47 "Reply with a single JSON object. Do not include any text before \
48 or after the JSON. Do not wrap the JSON in markdown code fences."
49 .to_string(),
50 )
51 }
52}
53
54#[async_trait]
55impl<T> Runnable<String, T> for JsonParser<T>
56where
57 T: DeserializeOwned + Send + 'static,
58{
59 async fn invoke(&self, input: String, _: RunnableConfig) -> Result<T> {
60 OutputParser::parse(self, &input)
61 }
62 fn name(&self) -> &str {
63 "JsonParser"
64 }
65}
66
67fn strip_code_fence(s: &str) -> &str {
68 let s = s.trim();
69 let s = s
70 .strip_prefix("```json")
71 .or_else(|| s.strip_prefix("```"))
72 .unwrap_or(s);
73 let s = s.trim_start_matches('\n').trim_start();
74 s.strip_suffix("```").unwrap_or(s).trim_end()
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use serde::Deserialize;
81
82 #[derive(Debug, Deserialize, PartialEq)]
83 struct Foo {
84 bar: String,
85 baz: u32,
86 }
87
88 #[tokio::test]
89 async fn parses_plain_json() {
90 let p: JsonParser<Foo> = JsonParser::new();
91 let out = p
92 .invoke(r#"{"bar":"hi","baz":7}"#.into(), RunnableConfig::default())
93 .await
94 .unwrap();
95 assert_eq!(
96 out,
97 Foo {
98 bar: "hi".into(),
99 baz: 7
100 }
101 );
102 }
103
104 #[test]
105 fn strips_json_code_fence() {
106 let p: JsonParser<Foo> = JsonParser::new();
107 let out = p.parse("```json\n{\"bar\":\"hi\",\"baz\":7}\n```").unwrap();
108 assert_eq!(out.baz, 7);
109 }
110
111 #[test]
112 fn strips_plain_code_fence() {
113 let p: JsonParser<Foo> = JsonParser::new();
114 let out = p.parse("```\n{\"bar\":\"hi\",\"baz\":1}\n```").unwrap();
115 assert_eq!(out.baz, 1);
116 }
117
118 #[test]
119 fn invalid_json_errors_with_serialization_kind() {
120 let p: JsonParser<Foo> = JsonParser::new();
121 let err = p.parse("not json").unwrap_err();
122 assert!(matches!(err, CognisError::Serialization(_)));
123 }
124
125 #[test]
126 fn format_instructions_set() {
127 let p: JsonParser<Foo> = JsonParser::new();
128 assert!(OutputParser::format_instructions(&p).is_some());
129 }
130}