use std::marker::PhantomData;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use cognis_core::output_parsers::{JsonParser, OutputParser};
use cognis_core::{Message, Result, Runnable, RunnableConfig};
use crate::client::Client;
pub struct StructuredClient<T> {
client: Client,
schema_json: String,
parser: JsonParser<T>,
_t: PhantomData<fn() -> T>,
}
impl<T> StructuredClient<T>
where
T: JsonSchema + DeserializeOwned + Send + 'static,
{
pub fn new(client: Client) -> Self {
let schema = schemars::schema_for!(T);
let schema_json =
serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string());
Self {
client,
schema_json,
parser: JsonParser::new(),
_t: PhantomData,
}
}
pub fn schema(&self) -> &str {
&self.schema_json
}
fn instructions(&self) -> String {
format!(
"Reply with a single JSON object matching this JSON Schema. \
Do not include any text before or after the JSON. Do not wrap \
the JSON in markdown code fences.\n\nSchema:\n{}",
self.schema_json
)
}
}
#[async_trait]
impl<T> Runnable<Vec<Message>, T> for StructuredClient<T>
where
T: JsonSchema + DeserializeOwned + Send + 'static,
{
async fn invoke(&self, mut input: Vec<Message>, _: RunnableConfig) -> Result<T> {
input.insert(0, Message::system(self.instructions()));
let reply = self.client.invoke(input).await?;
self.parser.parse(reply.content())
}
fn name(&self) -> &str {
"StructuredClient"
}
}
impl Client {
pub fn with_structured_output<T>(self) -> StructuredClient<T>
where
T: JsonSchema + DeserializeOwned + Send + 'static,
{
StructuredClient::new(self)
}
}