Skip to main content

serdes_ai_output/
schema.rs

1//! Output schema trait and core types.
2//!
3//! This module provides the `OutputSchema` trait which defines how to
4//! parse and validate model responses into typed output.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{ObjectJsonSchema, ToolDefinition};
9
10use crate::error::OutputParseError;
11use crate::mode::OutputMode;
12
13/// Trait for output schemas that can validate model responses.
14///
15/// This trait defines how to parse model output from various formats
16/// (text, tool calls, native JSON) and optionally validate it.
17///
18/// # Type Parameters
19///
20/// - `T`: The output type to parse into.
21#[async_trait]
22pub trait OutputSchema<T: Send>: Send + Sync {
23    /// The preferred output mode for this schema.
24    fn mode(&self) -> OutputMode;
25
26    /// Get tool definitions if using tool mode.
27    ///
28    /// Returns an empty vector if not using tool mode.
29    fn tool_definitions(&self) -> Vec<ToolDefinition> {
30        vec![]
31    }
32
33    /// Get JSON schema for native/prompted mode.
34    ///
35    /// Returns `None` if no schema is available.
36    fn json_schema(&self) -> Option<ObjectJsonSchema> {
37        None
38    }
39
40    /// Whether this schema supports a given output mode.
41    fn supports_mode(&self, mode: OutputMode) -> bool {
42        match mode {
43            OutputMode::Text => true, // Text is always supported
44            OutputMode::Tool => !self.tool_definitions().is_empty(),
45            OutputMode::Native | OutputMode::Prompted => self.json_schema().is_some(),
46        }
47    }
48
49    /// Parse output from text.
50    fn parse_text(&self, text: &str) -> Result<T, OutputParseError>;
51
52    /// Parse output from a tool call.
53    fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError>;
54
55    /// Parse output from native structured response.
56    fn parse_native(&self, value: &JsonValue) -> Result<T, OutputParseError>;
57
58    /// Parse output based on the mode.
59    fn parse(
60        &self,
61        mode: OutputMode,
62        text: Option<&str>,
63        tool_name: Option<&str>,
64        args: Option<&JsonValue>,
65    ) -> Result<T, OutputParseError> {
66        match mode {
67            OutputMode::Text => {
68                let text = text.ok_or_else(|| OutputParseError::custom("No text output"))?;
69                self.parse_text(text)
70            }
71            OutputMode::Tool => {
72                let name = tool_name.ok_or_else(|| OutputParseError::custom("No tool call"))?;
73                let args = args.ok_or_else(|| OutputParseError::custom("No tool arguments"))?;
74                self.parse_tool_call(name, args)
75            }
76            OutputMode::Native | OutputMode::Prompted => {
77                // Try tool call first, then native JSON
78                if let (Some(name), Some(args)) = (tool_name, args) {
79                    return self.parse_tool_call(name, args);
80                }
81                if let Some(args) = args {
82                    return self.parse_native(args);
83                }
84                if let Some(text) = text {
85                    return self.parse_text(text);
86                }
87                Err(OutputParseError::custom("No output to parse"))
88            }
89        }
90    }
91}
92
93/// Boxed output schema for dynamic dispatch.
94pub type BoxedOutputSchema<T> = Box<dyn OutputSchema<T>>;
95
96/// A simple wrapper for output schemas that implements Send + Sync.
97#[derive(Debug)]
98pub struct OutputSchemaWrapper<S, T> {
99    inner: S,
100    _phantom: std::marker::PhantomData<T>,
101}
102
103impl<S, T> OutputSchemaWrapper<S, T> {
104    /// Create a new wrapper.
105    pub fn new(inner: S) -> Self {
106        Self {
107            inner,
108            _phantom: std::marker::PhantomData,
109        }
110    }
111
112    /// Get the inner schema.
113    pub fn inner(&self) -> &S {
114        &self.inner
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    struct MockSchema;
123
124    #[async_trait]
125    impl OutputSchema<String> for MockSchema {
126        fn mode(&self) -> OutputMode {
127            OutputMode::Text
128        }
129
130        fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
131            Ok(text.to_string())
132        }
133
134        fn parse_tool_call(
135            &self,
136            _name: &str,
137            args: &JsonValue,
138        ) -> Result<String, OutputParseError> {
139            args.as_str()
140                .map(String::from)
141                .ok_or(OutputParseError::NotJson)
142        }
143
144        fn parse_native(&self, value: &JsonValue) -> Result<String, OutputParseError> {
145            value
146                .as_str()
147                .map(String::from)
148                .ok_or(OutputParseError::NotJson)
149        }
150    }
151
152    #[test]
153    fn test_mock_schema_parse_text() {
154        let schema = MockSchema;
155        let result = schema.parse_text("hello").unwrap();
156        assert_eq!(result, "hello");
157    }
158
159    #[test]
160    fn test_mock_schema_supports_mode() {
161        let schema = MockSchema;
162        assert!(schema.supports_mode(OutputMode::Text));
163        assert!(!schema.supports_mode(OutputMode::Tool));
164        assert!(!schema.supports_mode(OutputMode::Native));
165    }
166
167    #[test]
168    fn test_parse_dispatch() {
169        let schema = MockSchema;
170
171        // Text mode
172        let result = schema
173            .parse(OutputMode::Text, Some("hello"), None, None)
174            .unwrap();
175        assert_eq!(result, "hello");
176
177        // Missing text
178        let result = schema.parse(OutputMode::Text, None, None, None);
179        assert!(result.is_err());
180    }
181}