use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use cognis_core::error::{CognisError, Result};
use cognis_core::language_models::BaseChatModel;
use cognis_core::messages::{HumanMessage, Message, SystemMessage};
use cognis_core::output_parsers::OutputParser;
use cognis_core::runnables::base::Runnable;
use cognis_core::runnables::config::RunnableConfig;
pub struct OutputFixingParser {
parser: Box<dyn OutputParser>,
llm: Arc<dyn BaseChatModel>,
}
pub struct OutputFixingParserBuilder {
parser: Option<Box<dyn OutputParser>>,
llm: Option<Arc<dyn BaseChatModel>>,
}
impl OutputFixingParserBuilder {
pub fn parser(mut self, parser: impl OutputParser + 'static) -> Self {
self.parser = Some(Box::new(parser));
self
}
pub fn llm(mut self, llm: Arc<dyn BaseChatModel>) -> Self {
self.llm = Some(llm);
self
}
pub fn build(self) -> OutputFixingParser {
OutputFixingParser {
parser: self.parser.expect("parser is required"),
llm: self.llm.expect("llm is required"),
}
}
}
impl OutputFixingParser {
pub fn builder() -> OutputFixingParserBuilder {
OutputFixingParserBuilder {
parser: None,
llm: None,
}
}
pub fn new(parser: impl OutputParser + 'static, llm: Arc<dyn BaseChatModel>) -> Self {
Self {
parser: Box::new(parser),
llm,
}
}
async fn fix_output(&self, malformed: &str, error: &CognisError) -> Result<Value> {
let format_instructions = self.parser.get_format_instructions().unwrap_or_default();
let system_msg = Message::System(SystemMessage::new(
"You are a helpful assistant that fixes malformed output. \
Return ONLY the corrected output with no additional text.",
));
let user_content = format!(
"The following output failed to parse:\n\n```\n{}\n```\n\n\
Error: {}\n\n\
Please fix the output so it conforms to the expected format.\n\n\
{}",
malformed, error, format_instructions
);
let user_msg = Message::Human(HumanMessage::new(&user_content));
let ai_msg = self
.llm
.invoke_messages(&[system_msg, user_msg], None)
.await?;
let fixed_text = ai_msg.base.content.text();
self.parser.parse(&fixed_text)
}
}
impl OutputParser for OutputFixingParser {
fn parse(&self, text: &str) -> Result<Value> {
self.parser.parse(text)
}
fn get_format_instructions(&self) -> Option<String> {
self.parser.get_format_instructions()
}
fn parser_type(&self) -> &str {
"output_fixing_parser"
}
}
#[async_trait]
impl Runnable for OutputFixingParser {
fn name(&self) -> &str {
"OutputFixingParser"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
let text = match &input {
Value::String(s) => s.clone(),
other => other.to_string(),
};
match self.parser.parse(&text) {
Ok(v) => Ok(v),
Err(e) => self.fix_output(&text, &e).await,
}
}
}