cognis_core/wrappers/
schema.rs1use std::marker::PhantomData;
16
17use async_trait::async_trait;
18
19use crate::runnable::{Runnable, RunnableConfig};
20use crate::Result;
21
22pub struct WithSchema<R, I, O> {
25 inner: R,
26 input_schema: serde_json::Value,
27 output_schema: serde_json::Value,
28 _phantom: PhantomData<fn(I) -> O>,
29}
30
31impl<R, I, O> WithSchema<R, I, O>
32where
33 R: Runnable<I, O>,
34 I: schemars::JsonSchema + Send + 'static,
35 O: schemars::JsonSchema + Send + 'static,
36{
37 pub fn new(inner: R) -> Self {
39 let input_schema =
40 serde_json::to_value(schemars::schema_for!(I)).unwrap_or(serde_json::Value::Null);
41 let output_schema =
42 serde_json::to_value(schemars::schema_for!(O)).unwrap_or(serde_json::Value::Null);
43 Self {
44 inner,
45 input_schema,
46 output_schema,
47 _phantom: PhantomData,
48 }
49 }
50}
51
52impl<R, I, O> WithSchema<R, I, O>
53where
54 R: Runnable<I, O>,
55 I: Send + 'static,
56 O: Send + 'static,
57{
58 pub fn with_schemas(
62 inner: R,
63 input_schema: serde_json::Value,
64 output_schema: serde_json::Value,
65 ) -> Self {
66 Self {
67 inner,
68 input_schema,
69 output_schema,
70 _phantom: PhantomData,
71 }
72 }
73
74 pub fn override_input_schema(mut self, schema: serde_json::Value) -> Self {
76 self.input_schema = schema;
77 self
78 }
79
80 pub fn override_output_schema(mut self, schema: serde_json::Value) -> Self {
82 self.output_schema = schema;
83 self
84 }
85
86 pub fn schemas(&self) -> (&serde_json::Value, &serde_json::Value) {
88 (&self.input_schema, &self.output_schema)
89 }
90}
91
92#[async_trait]
93impl<R, I, O> Runnable<I, O> for WithSchema<R, I, O>
94where
95 R: Runnable<I, O>,
96 I: Send + 'static,
97 O: Send + 'static,
98{
99 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
100 self.inner.invoke(input, config).await
101 }
102
103 fn name(&self) -> &str {
104 self.inner.name()
105 }
106
107 fn input_schema(&self) -> Option<serde_json::Value> {
108 Some(self.input_schema.clone())
109 }
110
111 fn output_schema(&self) -> Option<serde_json::Value> {
112 Some(self.output_schema.clone())
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use schemars::JsonSchema;
120 use serde::{Deserialize, Serialize};
121
122 #[derive(Serialize, Deserialize, JsonSchema)]
123 struct In {
124 topic: String,
125 }
126 #[derive(Serialize, Deserialize, JsonSchema)]
127 struct Out {
128 summary: String,
129 }
130
131 struct R;
132
133 #[async_trait]
134 impl Runnable<In, Out> for R {
135 async fn invoke(&self, input: In, _: RunnableConfig) -> Result<Out> {
136 Ok(Out {
137 summary: format!("about {}", input.topic),
138 })
139 }
140 }
141
142 #[tokio::test]
143 async fn auto_derives_schemas_from_jsonschema() {
144 let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
145 let inp = wrapped.input_schema().unwrap();
146 let out = wrapped.output_schema().unwrap();
147 assert!(inp.to_string().contains("topic"));
149 assert!(out.to_string().contains("summary"));
150 }
151
152 #[tokio::test]
153 async fn override_schemas_replaces_derived() {
154 let wrapped: WithSchema<R, In, Out> =
155 WithSchema::new(R).override_input_schema(serde_json::json!({"custom": true}));
156 let inp = wrapped.input_schema().unwrap();
157 assert_eq!(inp["custom"], true);
158 }
159
160 #[tokio::test]
161 async fn invoke_pass_through() {
162 let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
163 let out = wrapped
164 .invoke(
165 In {
166 topic: "rust".into(),
167 },
168 RunnableConfig::default(),
169 )
170 .await
171 .unwrap();
172 assert_eq!(out.summary, "about rust");
173 }
174
175 #[tokio::test]
176 async fn with_schemas_skips_jsonschema_bound() {
177 struct PlainIn(#[allow(dead_code)] String);
179 struct PlainOut(#[allow(dead_code)] String);
180 struct P;
181 #[async_trait]
182 impl Runnable<PlainIn, PlainOut> for P {
183 async fn invoke(&self, input: PlainIn, _: RunnableConfig) -> Result<PlainOut> {
184 Ok(PlainOut(input.0))
185 }
186 }
187 let wrapped: WithSchema<P, PlainIn, PlainOut> = WithSchema::with_schemas(
188 P,
189 serde_json::json!({"type": "string"}),
190 serde_json::json!({"type": "string"}),
191 );
192 assert_eq!(wrapped.input_schema().unwrap()["type"], "string");
193 }
194}