use std::collections::BTreeMap;
use serde_json::{Map, Value};
use crate::error::Error;
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct RenderedForLlm<T>(T);
impl<T> RenderedForLlm<T> {
pub(crate) const fn new(inner: T) -> Self {
Self(inner)
}
#[must_use]
pub const fn as_inner(&self) -> &T {
&self.0
}
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: AsRef<str>> AsRef<str> for RenderedForLlm<T> {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl<T: std::fmt::Display> std::fmt::Display for RenderedForLlm<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl<T> serde::Serialize for RenderedForLlm<T>
where
T: serde::Serialize,
{
fn serialize<S: serde::Serializer>(&self, ser: S) -> std::result::Result<S::Ok, S::Error> {
self.0.serialize(ser)
}
}
impl<'de, T> serde::Deserialize<'de> for RenderedForLlm<T>
where
T: serde::Deserialize<'de>,
{
fn deserialize<D: serde::Deserializer<'de>>(de: D) -> std::result::Result<Self, D::Error> {
T::deserialize(de).map(Self::new)
}
}
pub trait LlmRenderable<T> {
fn render_for_llm(&self) -> T;
fn for_llm(&self) -> RenderedForLlm<T> {
RenderedForLlm::new(self.render_for_llm())
}
}
#[allow(clippy::use_self)]
impl LlmRenderable<String> for String {
fn render_for_llm(&self) -> String {
self.clone()
}
}
impl LlmRenderable<String> for &str {
fn render_for_llm(&self) -> String {
(*self).to_owned()
}
}
impl LlmRenderable<String> for Error {
fn render_for_llm(&self) -> String {
match self {
Self::InvalidRequest(msg) => format!("invalid input: {msg}"),
Self::Provider { .. } => "upstream model error".to_owned(),
Self::Auth(_) => "authentication failed".to_owned(),
Self::Config(_) => "tool misconfigured".to_owned(),
Self::Cancelled => "cancelled".to_owned(),
Self::DeadlineExceeded => "timed out".to_owned(),
Self::Interrupted { .. } => "awaiting human review".to_owned(),
Self::Serde(_) => "output could not be serialised".to_owned(),
Self::UsageLimitExceeded(_) => "request quota reached".to_owned(),
Self::ModelRetry { hint, .. } => hint.as_inner().clone(),
}
}
}
pub struct LlmFacingSchema;
enum AllowedKey {
Literal,
Schema,
SchemaArray,
SchemaMap,
UserData,
}
fn classify(key: &str) -> Option<AllowedKey> {
Some(match key {
"type" | "description" | "minimum" | "maximum" | "exclusiveMinimum"
| "exclusiveMaximum" | "minLength" | "maxLength" | "minItems" | "maxItems"
| "uniqueItems" | "minProperties" | "maxProperties" | "pattern" | "format" => {
AllowedKey::Literal
}
"items" | "additionalProperties" | "not" => AllowedKey::Schema,
"anyOf" | "oneOf" | "allOf" => AllowedKey::SchemaArray,
"properties" => AllowedKey::SchemaMap,
"enum" | "default" | "const" | "required" => AllowedKey::UserData,
_ => return None,
})
}
const NOISY_FORMATS: &[&str] = &[
"int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float", "double",
];
impl LlmFacingSchema {
#[must_use]
pub fn strip(schema: &Value) -> Value {
let defs = collect_defs(schema);
strip_schema(schema, &defs)
}
}
fn collect_defs(schema: &Value) -> BTreeMap<String, Value> {
let mut out = BTreeMap::new();
if let Some(obj) = schema.as_object() {
for key in ["$defs", "definitions"] {
if let Some(Value::Object(defs)) = obj.get(key) {
for (name, body) in defs {
out.insert(name.clone(), body.clone());
}
}
}
}
out
}
fn strip_schema(node: &Value, defs: &BTreeMap<String, Value>) -> Value {
let Some(obj) = node.as_object() else {
return node.clone();
};
if let Some(Value::String(reference)) = obj.get("$ref")
&& let Some(name) = reference
.strip_prefix("#/$defs/")
.or_else(|| reference.strip_prefix("#/definitions/"))
&& let Some(target) = defs.get(name)
{
return strip_schema(target, defs);
}
let mut out = Map::new();
for (key, value) in obj {
let Some(kind) = classify(key) else {
continue;
};
match kind {
AllowedKey::Literal => {
if key == "format"
&& let Some(format) = value.as_str()
&& NOISY_FORMATS.contains(&format)
{
continue;
}
out.insert(key.clone(), value.clone());
}
AllowedKey::Schema => {
let stripped = match value {
Value::Array(arr) => {
Value::Array(arr.iter().map(|v| strip_schema(v, defs)).collect())
}
other => strip_schema(other, defs),
};
out.insert(key.clone(), stripped);
}
AllowedKey::SchemaArray => {
if let Value::Array(arr) = value {
let stripped: Vec<Value> = arr.iter().map(|v| strip_schema(v, defs)).collect();
out.insert(key.clone(), Value::Array(stripped));
} else {
out.insert(key.clone(), value.clone());
}
}
AllowedKey::SchemaMap => {
if let Value::Object(map) = value {
let stripped: Map<String, Value> = map
.iter()
.map(|(k, v)| (k.clone(), strip_schema(v, defs)))
.collect();
out.insert(key.clone(), Value::Object(stripped));
} else {
out.insert(key.clone(), value.clone());
}
}
AllowedKey::UserData => {
out.insert(key.clone(), value.clone());
}
}
}
Value::Object(out)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn render_for_llm_omits_provider_status() {
let err = Error::provider_http(503, "vendor down".to_owned());
let rendered = err.render_for_llm();
assert!(!rendered.contains("503"), "{rendered}");
assert!(!rendered.contains("vendor down"), "{rendered}");
assert!(!rendered.contains("provider returned"), "{rendered}");
}
#[test]
fn render_for_llm_invalid_request_carries_caller_message() {
let err = Error::invalid_request("missing 'task' field");
assert_eq!(err.render_for_llm(), "invalid input: missing 'task' field");
}
#[test]
fn strip_removes_schema_envelope() {
let raw = json!({
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "DoubleInput",
"type": "object",
"properties": {"n": {"type": "integer", "format": "int64"}},
"required": ["n"]
});
let stripped = LlmFacingSchema::strip(&raw);
assert!(stripped.get("$schema").is_none());
assert!(stripped.get("title").is_none());
assert_eq!(stripped["type"], "object");
assert_eq!(stripped["properties"]["n"]["type"], "integer");
assert!(stripped["properties"]["n"].get("format").is_none());
assert_eq!(stripped["required"], json!(["n"]));
}
#[test]
fn strip_inlines_refs_and_drops_defs_envelope() {
let raw = json!({
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Outer",
"type": "object",
"properties": {"inner": {"$ref": "#/$defs/Inner"}},
"$defs": {
"Inner": {
"title": "Inner",
"type": "object",
"properties": {"x": {"type": "string"}},
"required": ["x"]
}
}
});
let stripped = LlmFacingSchema::strip(&raw);
assert!(stripped.get("$defs").is_none());
let inner = &stripped["properties"]["inner"];
assert_eq!(inner["type"], "object");
assert_eq!(inner["properties"]["x"]["type"], "string");
assert!(inner.get("title").is_none());
}
#[test]
fn strip_keeps_meaningful_format_specifiers() {
let raw = json!({
"type": "string",
"format": "date-time"
});
let stripped = LlmFacingSchema::strip(&raw);
assert_eq!(stripped["format"], "date-time");
}
}