1use std::marker::PhantomData;
11
12use async_trait::async_trait;
13use schemars::JsonSchema;
14use serde::de::DeserializeOwned;
15
16use cognis_core::output_parsers::{JsonParser, OutputParser};
17use cognis_core::{Message, Result, Runnable, RunnableConfig};
18
19use crate::client::Client;
20
21pub struct StructuredClient<T> {
25 client: Client,
26 schema_json: String,
27 parser: JsonParser<T>,
28 _t: PhantomData<fn() -> T>,
29}
30
31impl<T> StructuredClient<T>
32where
33 T: JsonSchema + DeserializeOwned + Send + 'static,
34{
35 pub fn new(client: Client) -> Self {
37 let schema = schemars::schema_for!(T);
38 let schema_json =
39 serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string());
40 Self {
41 client,
42 schema_json,
43 parser: JsonParser::new(),
44 _t: PhantomData,
45 }
46 }
47
48 pub fn schema(&self) -> &str {
50 &self.schema_json
51 }
52
53 fn instructions(&self) -> String {
54 format!(
55 "Reply with a single JSON object matching this JSON Schema. \
56 Do not include any text before or after the JSON. Do not wrap \
57 the JSON in markdown code fences.\n\nSchema:\n{}",
58 self.schema_json
59 )
60 }
61}
62
63#[async_trait]
64impl<T> Runnable<Vec<Message>, T> for StructuredClient<T>
65where
66 T: JsonSchema + DeserializeOwned + Send + 'static,
67{
68 async fn invoke(&self, mut input: Vec<Message>, _: RunnableConfig) -> Result<T> {
69 input.insert(0, Message::system(self.instructions()));
71 let reply = self.client.invoke(input).await?;
72 self.parser.parse(reply.content())
73 }
74
75 fn name(&self) -> &str {
76 "StructuredClient"
77 }
78}
79
80impl Client {
81 pub fn with_structured_output<T>(self) -> StructuredClient<T>
87 where
88 T: JsonSchema + DeserializeOwned + Send + 'static,
89 {
90 StructuredClient::new(self)
91 }
92}