use std::marker::PhantomData;
use async_trait::async_trait;
use crate::runnable::{Runnable, RunnableConfig};
use crate::Result;
pub struct WithSchema<R, I, O> {
inner: R,
input_schema: serde_json::Value,
output_schema: serde_json::Value,
_phantom: PhantomData<fn(I) -> O>,
}
impl<R, I, O> WithSchema<R, I, O>
where
R: Runnable<I, O>,
I: schemars::JsonSchema + Send + 'static,
O: schemars::JsonSchema + Send + 'static,
{
pub fn new(inner: R) -> Self {
let input_schema =
serde_json::to_value(schemars::schema_for!(I)).unwrap_or(serde_json::Value::Null);
let output_schema =
serde_json::to_value(schemars::schema_for!(O)).unwrap_or(serde_json::Value::Null);
Self {
inner,
input_schema,
output_schema,
_phantom: PhantomData,
}
}
}
impl<R, I, O> WithSchema<R, I, O>
where
R: Runnable<I, O>,
I: Send + 'static,
O: Send + 'static,
{
pub fn with_schemas(
inner: R,
input_schema: serde_json::Value,
output_schema: serde_json::Value,
) -> Self {
Self {
inner,
input_schema,
output_schema,
_phantom: PhantomData,
}
}
pub fn override_input_schema(mut self, schema: serde_json::Value) -> Self {
self.input_schema = schema;
self
}
pub fn override_output_schema(mut self, schema: serde_json::Value) -> Self {
self.output_schema = schema;
self
}
pub fn schemas(&self) -> (&serde_json::Value, &serde_json::Value) {
(&self.input_schema, &self.output_schema)
}
}
#[async_trait]
impl<R, I, O> Runnable<I, O> for WithSchema<R, I, O>
where
R: Runnable<I, O>,
I: Send + 'static,
O: Send + 'static,
{
async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
self.inner.invoke(input, config).await
}
fn name(&self) -> &str {
self.inner.name()
}
fn input_schema(&self) -> Option<serde_json::Value> {
Some(self.input_schema.clone())
}
fn output_schema(&self) -> Option<serde_json::Value> {
Some(self.output_schema.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, JsonSchema)]
struct In {
topic: String,
}
#[derive(Serialize, Deserialize, JsonSchema)]
struct Out {
summary: String,
}
struct R;
#[async_trait]
impl Runnable<In, Out> for R {
async fn invoke(&self, input: In, _: RunnableConfig) -> Result<Out> {
Ok(Out {
summary: format!("about {}", input.topic),
})
}
}
#[tokio::test]
async fn auto_derives_schemas_from_jsonschema() {
let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
let inp = wrapped.input_schema().unwrap();
let out = wrapped.output_schema().unwrap();
assert!(inp.to_string().contains("topic"));
assert!(out.to_string().contains("summary"));
}
#[tokio::test]
async fn override_schemas_replaces_derived() {
let wrapped: WithSchema<R, In, Out> =
WithSchema::new(R).override_input_schema(serde_json::json!({"custom": true}));
let inp = wrapped.input_schema().unwrap();
assert_eq!(inp["custom"], true);
}
#[tokio::test]
async fn invoke_pass_through() {
let wrapped: WithSchema<R, In, Out> = WithSchema::new(R);
let out = wrapped
.invoke(
In {
topic: "rust".into(),
},
RunnableConfig::default(),
)
.await
.unwrap();
assert_eq!(out.summary, "about rust");
}
#[tokio::test]
async fn with_schemas_skips_jsonschema_bound() {
struct PlainIn(#[allow(dead_code)] String);
struct PlainOut(#[allow(dead_code)] String);
struct P;
#[async_trait]
impl Runnable<PlainIn, PlainOut> for P {
async fn invoke(&self, input: PlainIn, _: RunnableConfig) -> Result<PlainOut> {
Ok(PlainOut(input.0))
}
}
let wrapped: WithSchema<P, PlainIn, PlainOut> = WithSchema::with_schemas(
P,
serde_json::json!({"type": "string"}),
serde_json::json!({"type": "string"}),
);
assert_eq!(wrapped.input_schema().unwrap()["type"], "string");
}
}