Skip to main content

cognis_core/wrappers/
schema.rs

1//! Schema introspection wrapper — fills [`Runnable::input_schema`] and
2//! [`Runnable::output_schema`] for any `(I, O)` whose types implement
3//! [`schemars::JsonSchema`].
4//!
5//! Why a wrapper rather than a blanket impl? Many runnables are generic
6//! over `(I, O)` without bounding `I/O: JsonSchema`, and the trait
7//! itself is object-safe. The wrapper opts the runnable in at the call
8//! site without polluting the trait.
9//!
10//! Customization: pass user-supplied JSON Schema values via
11//! [`WithSchema::with_schemas`] to override the schemars-derived ones
12//! (useful when the runnable accepts a wider type than its schema
13//! advertises, or when schema generation is too expensive).
14
15use std::marker::PhantomData;
16
17use async_trait::async_trait;
18
19use crate::runnable::{Runnable, RunnableConfig};
20use crate::Result;
21
22/// Wraps a runnable so that `input_schema()` / `output_schema()` return
23/// real [`schemars::JsonSchema`]-derived JSON Schema values.
24pub struct WithSchema<R, I, O> {
25    inner: R,
26    input_schema: serde_json::Value,
27    output_schema: serde_json::Value,
28    _phantom: PhantomData<fn(I) -> O>,
29}
30
31impl<R, I, O> WithSchema<R, I, O>
32where
33    R: Runnable<I, O>,
34    I: schemars::JsonSchema + Send + 'static,
35    O: schemars::JsonSchema + Send + 'static,
36{
37    /// Wrap with schemas auto-derived from `I` and `O` via `schemars`.
38    pub fn new(inner: R) -> Self {
39        let input_schema =
40            serde_json::to_value(schemars::schema_for!(I)).unwrap_or(serde_json::Value::Null);
41        let output_schema =
42            serde_json::to_value(schemars::schema_for!(O)).unwrap_or(serde_json::Value::Null);
43        Self {
44            inner,
45            input_schema,
46            output_schema,
47            _phantom: PhantomData,
48        }
49    }
50}
51
52impl<R, I, O> WithSchema<R, I, O>
53where
54    R: Runnable<I, O>,
55    I: Send + 'static,
56    O: Send + 'static,
57{
58    /// Wrap with caller-supplied schemas (no `JsonSchema` bound on
59    /// `I`/`O`). Use when the runnable accepts a more permissive type
60    /// than its public schema advertises.
61    pub fn with_schemas(
62        inner: R,
63        input_schema: serde_json::Value,
64        output_schema: serde_json::Value,
65    ) -> Self {
66        Self {
67            inner,
68            input_schema,
69            output_schema,
70            _phantom: PhantomData,
71        }
72    }
73
74    /// Replace just the input schema (e.g. after a [`Self::new`]).
75    pub fn override_input_schema(mut self, schema: serde_json::Value) -> Self {
76        self.input_schema = schema;
77        self
78    }
79
80    /// Replace just the output schema.
81    pub fn override_output_schema(mut self, schema: serde_json::Value) -> Self {
82        self.output_schema = schema;
83        self
84    }
85
86    /// Borrow the configured schemas (read-only).
87    pub fn schemas(&self) -> (&serde_json::Value, &serde_json::Value) {
88        (&self.input_schema, &self.output_schema)
89    }
90}
91
92#[async_trait]
93impl<R, I, O> Runnable<I, O> for WithSchema<R, I, O>
94where
95    R: Runnable<I, O>,
96    I: Send + 'static,
97    O: Send + 'static,
98{
99    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
100        self.inner.invoke(input, config).await
101    }
102
103    fn name(&self) -> &str {
104        self.inner.name()
105    }
106
107    fn input_schema(&self) -> Option<serde_json::Value> {
108        Some(self.input_schema.clone())
109    }
110
111    fn output_schema(&self) -> Option<serde_json::Value> {
112        Some(self.output_schema.clone())
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use schemars::JsonSchema;
120    use serde::{Deserialize, Serialize};
121
122    #[derive(Serialize, Deserialize, JsonSchema)]
123    struct In {
124        topic: String,
125    }
126    #[derive(Serialize, Deserialize, JsonSchema)]
127    struct Out {
128        summary: String,
129    }
130
131    struct R;
132
133    #[async_trait]
134    impl Runnable<In, Out> for R {
135        async fn invoke(&self, input: In, _: RunnableConfig) -> Result<Out> {
136            Ok(Out {
137                summary: format!("about {}", input.topic),
138            })
139        }
140    }
141
142    #[tokio::test]
143    async fn auto_derives_schemas_from_jsonschema() {
144        let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
145        let inp = wrapped.input_schema().unwrap();
146        let out = wrapped.output_schema().unwrap();
147        // schemars emits a top-level "title" referencing the type name.
148        assert!(inp.to_string().contains("topic"));
149        assert!(out.to_string().contains("summary"));
150    }
151
152    #[tokio::test]
153    async fn override_schemas_replaces_derived() {
154        let wrapped: WithSchema<R, In, Out> =
155            WithSchema::new(R).override_input_schema(serde_json::json!({"custom": true}));
156        let inp = wrapped.input_schema().unwrap();
157        assert_eq!(inp["custom"], true);
158    }
159
160    #[tokio::test]
161    async fn invoke_pass_through() {
162        let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
163        let out = wrapped
164            .invoke(
165                In {
166                    topic: "rust".into(),
167                },
168                RunnableConfig::default(),
169            )
170            .await
171            .unwrap();
172        assert_eq!(out.summary, "about rust");
173    }
174
175    #[tokio::test]
176    async fn with_schemas_skips_jsonschema_bound() {
177        // When I/O don't impl JsonSchema, with_schemas still works.
178        struct PlainIn(#[allow(dead_code)] String);
179        struct PlainOut(#[allow(dead_code)] String);
180        struct P;
181        #[async_trait]
182        impl Runnable<PlainIn, PlainOut> for P {
183            async fn invoke(&self, input: PlainIn, _: RunnableConfig) -> Result<PlainOut> {
184                Ok(PlainOut(input.0))
185            }
186        }
187        let wrapped: WithSchema<P, PlainIn, PlainOut> = WithSchema::with_schemas(
188            P,
189            serde_json::json!({"type": "string"}),
190            serde_json::json!({"type": "string"}),
191        );
192        assert_eq!(wrapped.input_schema().unwrap()["type"], "string");
193    }
194}