Skip to main content

serdes_ai_output/
structured.rs

1//! Structured output schema implementation.
2//!
3//! This module provides `StructuredOutputSchema` for handling typed
4//! structured output using serde deserialization.
5
6use async_trait::async_trait;
7use serde::de::DeserializeOwned;
8use serde_json::Value as JsonValue;
9use serdes_ai_tools::{ObjectJsonSchema, ToolDefinition};
10use std::marker::PhantomData;
11
12use crate::error::OutputParseError;
13use crate::mode::OutputMode;
14use crate::schema::OutputSchema;
15
16/// Default tool name for structured output.
17pub const DEFAULT_OUTPUT_TOOL_NAME: &str = "final_result";
18
19/// Default tool description for structured output.
20pub const DEFAULT_OUTPUT_TOOL_DESCRIPTION: &str = "The final response which ends this conversation";
21
22/// Schema for structured output using serde.
23///
24/// This schema parses model output into a typed Rust struct using serde.
25/// It supports multiple output modes (tool, native, prompted, text).
26///
27/// # Example
28///
29/// ```rust
30/// use serdes_ai_output::StructuredOutputSchema;
31/// use serdes_ai_tools::ObjectJsonSchema;
32/// use serde::Deserialize;
33///
34/// #[derive(Deserialize)]
35/// struct Person {
36///     name: String,
37///     age: u32,
38/// }
39///
40/// let schema = ObjectJsonSchema::new()
41///     .with_property("name", serdes_ai_tools::PropertySchema::string("Name").build(), true)
42///     .with_property("age", serdes_ai_tools::PropertySchema::integer("Age").build(), true);
43///
44/// let output_schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(schema);
45/// ```
46#[derive(Debug, Clone)]
47pub struct StructuredOutputSchema<T> {
48    /// Tool name when using tool mode.
49    pub tool_name: String,
50    /// Tool description.
51    pub tool_description: String,
52    /// JSON schema for the output.
53    pub schema: ObjectJsonSchema,
54    /// Whether to use strict mode (for OpenAI).
55    pub strict: Option<bool>,
56    /// Output mode preference.
57    mode: OutputMode,
58    _phantom: PhantomData<T>,
59}
60
61impl<T: DeserializeOwned + Send + Sync> StructuredOutputSchema<T> {
62    /// Create a new structured output schema.
63    #[must_use]
64    pub fn new(schema: ObjectJsonSchema) -> Self {
65        Self {
66            tool_name: DEFAULT_OUTPUT_TOOL_NAME.to_string(),
67            tool_description: DEFAULT_OUTPUT_TOOL_DESCRIPTION.to_string(),
68            schema,
69            strict: None,
70            mode: OutputMode::Tool,
71            _phantom: PhantomData,
72        }
73    }
74
75    /// Set the tool name.
76    #[must_use]
77    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
78        self.tool_name = name.into();
79        self
80    }
81
82    /// Set the tool description.
83    #[must_use]
84    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
85        self.tool_description = desc.into();
86        self
87    }
88
89    /// Set strict mode (for OpenAI structured outputs).
90    #[must_use]
91    pub fn with_strict(mut self, strict: bool) -> Self {
92        self.strict = Some(strict);
93        self
94    }
95
96    /// Set the preferred output mode.
97    #[must_use]
98    pub fn with_mode(mut self, mode: OutputMode) -> Self {
99        self.mode = mode;
100        self
101    }
102}
103
104#[async_trait]
105impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for StructuredOutputSchema<T> {
106    fn mode(&self) -> OutputMode {
107        self.mode
108    }
109
110    fn tool_definitions(&self) -> Vec<ToolDefinition> {
111        vec![ToolDefinition::new(&self.tool_name, &self.tool_description)
112            .with_parameters(self.schema.clone())
113            .with_strict(self.strict.unwrap_or(false))]
114    }
115
116    fn json_schema(&self) -> Option<ObjectJsonSchema> {
117        Some(self.schema.clone())
118    }
119
120    fn parse_text(&self, text: &str) -> Result<T, OutputParseError> {
121        // Try to extract JSON from text (may be wrapped in markdown)
122        let json_str = extract_json(text)?;
123        serde_json::from_str(&json_str).map_err(OutputParseError::JsonParse)
124    }
125
126    fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError> {
127        if name != self.tool_name {
128            return Err(OutputParseError::unexpected_tool(&self.tool_name, name));
129        }
130        serde_json::from_value(args.clone()).map_err(OutputParseError::JsonParse)
131    }
132
133    fn parse_native(&self, value: &JsonValue) -> Result<T, OutputParseError> {
134        serde_json::from_value(value.clone()).map_err(OutputParseError::JsonParse)
135    }
136}
137
138/// Extract JSON from text that might be wrapped in markdown code blocks.
139///
140/// This function handles common patterns:
141/// - ` ```json ... ``` ` blocks
142/// - ` ``` ... ``` ` blocks (no language)
143/// - Raw JSON objects `{ ... }`
144/// - Raw JSON arrays `[ ... ]`
145pub fn extract_json(text: &str) -> Result<String, OutputParseError> {
146    let text = text.trim();
147
148    // Check for markdown code blocks with json language
149    if let Some(rest) = text.strip_prefix("```json") {
150        if let Some(end) = rest.find("```") {
151            return Ok(rest[..end].trim().to_string());
152        }
153    }
154
155    // Check for markdown code blocks without language
156    if let Some(rest) = text.strip_prefix("```") {
157        // Skip any language identifier on the first line
158        let rest = if let Some(newline) = rest.find('\n') {
159            &rest[newline + 1..]
160        } else {
161            rest
162        };
163        if let Some(end) = rest.find("```") {
164            return Ok(rest[..end].trim().to_string());
165        }
166    }
167
168    // Look for JSON object
169    if let Some(start) = text.find('{') {
170        if let Some(end) = text.rfind('}') {
171            if end > start {
172                let candidate = &text[start..=end];
173                // Validate it's actually JSON
174                if serde_json::from_str::<JsonValue>(candidate).is_ok() {
175                    return Ok(candidate.to_string());
176                }
177            }
178        }
179    }
180
181    // Look for JSON array
182    if let Some(start) = text.find('[') {
183        if let Some(end) = text.rfind(']') {
184            if end > start {
185                let candidate = &text[start..=end];
186                // Validate it's actually JSON
187                if serde_json::from_str::<JsonValue>(candidate).is_ok() {
188                    return Ok(candidate.to_string());
189                }
190            }
191        }
192    }
193
194    // Try parsing the whole thing as JSON
195    if serde_json::from_str::<JsonValue>(text).is_ok() {
196        return Ok(text.to_string());
197    }
198
199    Err(OutputParseError::NoJsonFound)
200}
201
202/// Schema that accepts any valid JSON value.
203#[derive(Debug, Clone, Default)]
204pub struct AnyJsonSchema {
205    tool_name: String,
206    tool_description: String,
207}
208
209impl AnyJsonSchema {
210    /// Create a new any JSON schema.
211    #[must_use]
212    pub fn new() -> Self {
213        Self {
214            tool_name: DEFAULT_OUTPUT_TOOL_NAME.to_string(),
215            tool_description: DEFAULT_OUTPUT_TOOL_DESCRIPTION.to_string(),
216        }
217    }
218
219    /// Set the tool name.
220    #[must_use]
221    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
222        self.tool_name = name.into();
223        self
224    }
225}
226
227#[async_trait]
228impl OutputSchema<JsonValue> for AnyJsonSchema {
229    fn mode(&self) -> OutputMode {
230        OutputMode::Tool
231    }
232
233    fn tool_definitions(&self) -> Vec<ToolDefinition> {
234        vec![ToolDefinition::new(&self.tool_name, &self.tool_description)]
235    }
236
237    fn parse_text(&self, text: &str) -> Result<JsonValue, OutputParseError> {
238        let json_str = extract_json(text)?;
239        serde_json::from_str(&json_str).map_err(OutputParseError::JsonParse)
240    }
241
242    fn parse_tool_call(
243        &self,
244        _name: &str,
245        args: &JsonValue,
246    ) -> Result<JsonValue, OutputParseError> {
247        Ok(args.clone())
248    }
249
250    fn parse_native(&self, value: &JsonValue) -> Result<JsonValue, OutputParseError> {
251        Ok(value.clone())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use serde::Deserialize;
259    use serdes_ai_tools::PropertySchema;
260
261    #[derive(Debug, Deserialize, PartialEq)]
262    struct Person {
263        name: String,
264        age: u32,
265    }
266
267    fn person_schema() -> ObjectJsonSchema {
268        ObjectJsonSchema::new()
269            .with_property("name", PropertySchema::string("Name").build(), true)
270            .with_property("age", PropertySchema::integer("Age").build(), true)
271    }
272
273    #[test]
274    fn test_structured_schema_new() {
275        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
276        assert_eq!(schema.tool_name, DEFAULT_OUTPUT_TOOL_NAME);
277        assert_eq!(schema.mode(), OutputMode::Tool);
278    }
279
280    #[test]
281    fn test_structured_schema_with_tool_name() {
282        let schema: StructuredOutputSchema<Person> =
283            StructuredOutputSchema::new(person_schema()).with_tool_name("submit_person");
284        assert_eq!(schema.tool_name, "submit_person");
285    }
286
287    #[test]
288    fn test_structured_schema_tool_definitions() {
289        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema())
290            .with_tool_name("result")
291            .with_description("Submit the person");
292
293        let defs = schema.tool_definitions();
294        assert_eq!(defs.len(), 1);
295        assert_eq!(defs[0].name, "result");
296        assert_eq!(defs[0].description, "Submit the person");
297    }
298
299    #[test]
300    fn test_structured_schema_parse_tool_call() {
301        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
302
303        let args = serde_json::json!({"name": "Alice", "age": 30});
304        let result = schema.parse_tool_call("final_result", &args).unwrap();
305        assert_eq!(result.name, "Alice");
306        assert_eq!(result.age, 30);
307    }
308
309    #[test]
310    fn test_structured_schema_parse_tool_call_wrong_name() {
311        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
312
313        let args = serde_json::json!({"name": "Alice", "age": 30});
314        let result = schema.parse_tool_call("wrong_tool", &args);
315        assert!(result.is_err());
316    }
317
318    #[test]
319    fn test_structured_schema_parse_native() {
320        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
321
322        let value = serde_json::json!({"name": "Bob", "age": 25});
323        let result = schema.parse_native(&value).unwrap();
324        assert_eq!(result.name, "Bob");
325        assert_eq!(result.age, 25);
326    }
327
328    #[test]
329    fn test_structured_schema_parse_text_raw_json() {
330        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
331
332        let text = r#"{"name": "Charlie", "age": 35}"#;
333        let result = schema.parse_text(text).unwrap();
334        assert_eq!(result.name, "Charlie");
335        assert_eq!(result.age, 35);
336    }
337
338    #[test]
339    fn test_structured_schema_parse_text_markdown() {
340        let schema: StructuredOutputSchema<Person> = StructuredOutputSchema::new(person_schema());
341
342        let text = r#"Here is the result:
343```json
344{"name": "Diana", "age": 28}
345```
346Done!"#;
347        let result = schema.parse_text(text).unwrap();
348        assert_eq!(result.name, "Diana");
349        assert_eq!(result.age, 28);
350    }
351
352    #[test]
353    fn test_extract_json_code_block() {
354        let text = r#"```json
355{"key": "value"}
356```"#;
357        let result = extract_json(text).unwrap();
358        assert_eq!(result, r#"{"key": "value"}"#);
359    }
360
361    #[test]
362    fn test_extract_json_plain_code_block() {
363        let text = r#"```
364{"key": "value"}
365```"#;
366        let result = extract_json(text).unwrap();
367        assert_eq!(result, r#"{"key": "value"}"#);
368    }
369
370    #[test]
371    fn test_extract_json_embedded() {
372        let text = r#"The result is: {"x": 1, "y": 2} and that's it."#;
373        let result = extract_json(text).unwrap();
374        assert_eq!(result, r#"{"x": 1, "y": 2}"#);
375    }
376
377    #[test]
378    fn test_extract_json_array() {
379        let text = r#"Here are the items: [1, 2, 3]"#;
380        let result = extract_json(text).unwrap();
381        assert_eq!(result, "[1, 2, 3]");
382    }
383
384    #[test]
385    fn test_extract_json_not_found() {
386        let text = "This is just plain text with no JSON.";
387        let result = extract_json(text);
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_any_json_schema() {
393        let schema = AnyJsonSchema::new();
394
395        let value = serde_json::json!({"anything": [1, 2, 3]});
396        let result = schema.parse_native(&value).unwrap();
397        assert_eq!(result, value);
398    }
399}