Skip to main content

cognis_llm/
structured.rs

1//! Structured output: turn a `Client` into a `Runnable<Vec<Message>, T>` for
2//! any `T: JsonSchema + DeserializeOwned`.
3//!
4//! Produces a typed value by:
5//! 1. Generating a JSON Schema from `T` via `schemars`.
6//! 2. Appending instructions to the conversation telling the model to
7//!    reply with JSON matching the schema.
8//! 3. Parsing the assistant's reply with [`cognis_core::output_parsers::JsonParser`].
9
10use std::marker::PhantomData;
11
12use async_trait::async_trait;
13use schemars::JsonSchema;
14use serde::de::DeserializeOwned;
15
16use cognis_core::output_parsers::{JsonParser, OutputParser};
17use cognis_core::{Message, Result, Runnable, RunnableConfig};
18
19use crate::client::Client;
20
21/// `Client` with the output coerced to a typed value `T`.
22///
23/// Construct via [`Client::with_structured_output`].
24pub struct StructuredClient<T> {
25    client: Client,
26    schema_json: String,
27    parser: JsonParser<T>,
28    _t: PhantomData<fn() -> T>,
29}
30
31impl<T> StructuredClient<T>
32where
33    T: JsonSchema + DeserializeOwned + Send + 'static,
34{
35    /// Build a `StructuredClient<T>` over an existing `Client`.
36    pub fn new(client: Client) -> Self {
37        let schema = schemars::schema_for!(T);
38        let schema_json =
39            serde_json::to_string_pretty(&schema).unwrap_or_else(|_| "{}".to_string());
40        Self {
41            client,
42            schema_json,
43            parser: JsonParser::new(),
44            _t: PhantomData,
45        }
46    }
47
48    /// The JSON Schema generated from `T`.
49    pub fn schema(&self) -> &str {
50        &self.schema_json
51    }
52
53    fn instructions(&self) -> String {
54        format!(
55            "Reply with a single JSON object matching this JSON Schema. \
56             Do not include any text before or after the JSON. Do not wrap \
57             the JSON in markdown code fences.\n\nSchema:\n{}",
58            self.schema_json
59        )
60    }
61}
62
63#[async_trait]
64impl<T> Runnable<Vec<Message>, T> for StructuredClient<T>
65where
66    T: JsonSchema + DeserializeOwned + Send + 'static,
67{
68    async fn invoke(&self, mut input: Vec<Message>, _: RunnableConfig) -> Result<T> {
69        // Inject schema instructions as a system message at the head.
70        input.insert(0, Message::system(self.instructions()));
71        let reply = self.client.invoke(input).await?;
72        self.parser.parse(reply.content())
73    }
74
75    fn name(&self) -> &str {
76        "StructuredClient"
77    }
78}
79
80impl Client {
81    /// Coerce this client's output to a typed value `T`.
82    ///
83    /// `T` must derive `JsonSchema` (for prompt construction) and
84    /// `Deserialize` (for parsing). The returned value is itself a
85    /// `Runnable<Vec<Message>, T>`, so it composes with `.pipe()`.
86    pub fn with_structured_output<T>(self) -> StructuredClient<T>
87    where
88        T: JsonSchema + DeserializeOwned + Send + 'static,
89    {
90        StructuredClient::new(self)
91    }
92}