Skip to main content

serdes_ai_agent/
output.rs

1//! Output validation and parsing.
2//!
3//! This module provides traits and implementations for validating
4//! and transforming agent outputs.
5
6use crate::context::RunContext;
7use crate::errors::{OutputParseError, OutputValidationError};
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde_json::Value as JsonValue;
11use std::any::TypeId;
12use std::marker::PhantomData;
13
14/// Trait for validating agent outputs.
15#[async_trait]
16pub trait OutputValidator<Output, Deps>: Send + Sync {
17    /// Validate and optionally transform the output.
18    ///
19    /// Returns the validated output or an error.
20    async fn validate(
21        &self,
22        output: Output,
23        ctx: &RunContext<Deps>,
24    ) -> Result<Output, OutputValidationError>;
25}
26
27// ============================================================================
28// Function-based Validators
29// ============================================================================
30
31/// Validator that uses an async function.
32pub struct AsyncValidator<F, Deps, Output, Fut>
33where
34    F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
35    Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send,
36{
37    func: F,
38    _phantom: PhantomData<(Deps, Output, Fut)>,
39}
40
41impl<F, Deps, Output, Fut> AsyncValidator<F, Deps, Output, Fut>
42where
43    F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
44    Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send,
45{
46    /// Create a new async validator.
47    pub fn new(func: F) -> Self {
48        Self {
49            func,
50            _phantom: PhantomData,
51        }
52    }
53}
54
55#[async_trait]
56impl<F, Deps, Output, Fut> OutputValidator<Output, Deps> for AsyncValidator<F, Deps, Output, Fut>
57where
58    F: Fn(Output, &RunContext<Deps>) -> Fut + Send + Sync,
59    Fut: std::future::Future<Output = Result<Output, OutputValidationError>> + Send + Sync,
60    Deps: Send + Sync,
61    Output: Send + Sync,
62{
63    async fn validate(
64        &self,
65        output: Output,
66        ctx: &RunContext<Deps>,
67    ) -> Result<Output, OutputValidationError> {
68        (self.func)(output, ctx).await
69    }
70}
71
72/// Validator that uses a sync function.
73pub struct SyncValidator<F, Deps, Output>
74where
75    F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
76{
77    func: F,
78    _phantom: PhantomData<(Deps, Output)>,
79}
80
81impl<F, Deps, Output> SyncValidator<F, Deps, Output>
82where
83    F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
84{
85    /// Create a new sync validator.
86    pub fn new(func: F) -> Self {
87        Self {
88            func,
89            _phantom: PhantomData,
90        }
91    }
92}
93
94#[async_trait]
95impl<F, Deps, Output> OutputValidator<Output, Deps> for SyncValidator<F, Deps, Output>
96where
97    F: Fn(Output, &RunContext<Deps>) -> Result<Output, OutputValidationError> + Send + Sync,
98    Deps: Send + Sync,
99    Output: Send + Sync,
100{
101    async fn validate(
102        &self,
103        output: Output,
104        ctx: &RunContext<Deps>,
105    ) -> Result<Output, OutputValidationError> {
106        (self.func)(output, ctx)
107    }
108}
109
110// ============================================================================
111// Common Validators
112// ============================================================================
113
114/// Validator that checks string outputs are not empty.
115pub struct NonEmptyValidator;
116
117#[async_trait]
118impl<Deps: Send + Sync> OutputValidator<String, Deps> for NonEmptyValidator {
119    async fn validate(
120        &self,
121        output: String,
122        _ctx: &RunContext<Deps>,
123    ) -> Result<String, OutputValidationError> {
124        if output.trim().is_empty() {
125            Err(OutputValidationError::failed("Output cannot be empty"))
126        } else {
127            Ok(output)
128        }
129    }
130}
131
132/// Validator that checks string length.
133pub struct LengthValidator {
134    min: Option<usize>,
135    max: Option<usize>,
136}
137
138impl LengthValidator {
139    /// Create a new length validator.
140    pub fn new() -> Self {
141        Self {
142            min: None,
143            max: None,
144        }
145    }
146
147    /// Set minimum length.
148    pub fn min(mut self, min: usize) -> Self {
149        self.min = Some(min);
150        self
151    }
152
153    /// Set maximum length.
154    pub fn max(mut self, max: usize) -> Self {
155        self.max = Some(max);
156        self
157    }
158}
159
160impl Default for LengthValidator {
161    fn default() -> Self {
162        Self::new()
163    }
164}
165
166#[async_trait]
167impl<Deps: Send + Sync> OutputValidator<String, Deps> for LengthValidator {
168    async fn validate(
169        &self,
170        output: String,
171        _ctx: &RunContext<Deps>,
172    ) -> Result<String, OutputValidationError> {
173        let len = output.len();
174
175        if let Some(min) = self.min {
176            if len < min {
177                return Err(OutputValidationError::failed(format!(
178                    "Output too short: {} < {}",
179                    len, min
180                )));
181            }
182        }
183
184        if let Some(max) = self.max {
185            if len > max {
186                return Err(OutputValidationError::failed(format!(
187                    "Output too long: {} > {}",
188                    len, max
189                )));
190            }
191        }
192
193        Ok(output)
194    }
195}
196
197/// Validator that applies a regex pattern.
198#[cfg(feature = "regex")]
199pub struct RegexValidator {
200    pattern: regex::Regex,
201    message: String,
202}
203
204#[cfg(feature = "regex")]
205impl RegexValidator {
206    /// Create a new regex validator.
207    pub fn new(pattern: &str, message: impl Into<String>) -> Result<Self, regex::Error> {
208        Ok(Self {
209            pattern: regex::Regex::new(pattern)?,
210            message: message.into(),
211        })
212    }
213}
214
215#[cfg(feature = "regex")]
216#[async_trait]
217impl<Deps: Send + Sync> OutputValidator<String, Deps> for RegexValidator {
218    async fn validate(
219        &self,
220        output: String,
221        _ctx: &RunContext<Deps>,
222    ) -> Result<String, OutputValidationError> {
223        if self.pattern.is_match(&output) {
224            Ok(output)
225        } else {
226            Err(OutputValidationError::failed(&self.message))
227        }
228    }
229}
230
231// ============================================================================
232// Chained Validators
233// ============================================================================
234
235/// Chain multiple validators together.
236pub struct ChainedValidator<Output, Deps> {
237    validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
238}
239
240impl<Output: Send + Sync + 'static, Deps: Send + Sync + 'static> ChainedValidator<Output, Deps> {
241    /// Create a new chained validator.
242    pub fn new() -> Self {
243        Self {
244            validators: Vec::new(),
245        }
246    }
247
248    /// Add a validator.
249    #[allow(clippy::should_implement_trait)]
250    pub fn add<V: OutputValidator<Output, Deps> + 'static>(mut self, validator: V) -> Self {
251        self.validators.push(Box::new(validator));
252        self
253    }
254}
255
256impl<Output: Send + Sync + 'static, Deps: Send + Sync + 'static> Default
257    for ChainedValidator<Output, Deps>
258{
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264#[async_trait]
265impl<Output: Send + Sync, Deps: Send + Sync> OutputValidator<Output, Deps>
266    for ChainedValidator<Output, Deps>
267{
268    async fn validate(
269        &self,
270        mut output: Output,
271        ctx: &RunContext<Deps>,
272    ) -> Result<Output, OutputValidationError> {
273        for validator in &self.validators {
274            output = validator.validate(output, ctx).await?;
275        }
276        Ok(output)
277    }
278}
279
280// ============================================================================
281// Output Schema
282// ============================================================================
283
284/// Output mode for the model.
285#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
286pub enum OutputMode {
287    /// Plain text output.
288    #[default]
289    Text,
290    /// JSON output.
291    Json,
292    /// Tool call output.
293    ToolCall,
294}
295
296/// Schema for parsing and validating output.
297pub trait OutputSchema<Output>: Send + Sync {
298    /// Get the JSON schema for structured output.
299    fn json_schema(&self) -> Option<JsonValue> {
300        None
301    }
302
303    /// Get the output mode.
304    fn mode(&self) -> OutputMode {
305        OutputMode::Text
306    }
307
308    /// Get the name of the output tool (if using tool mode).
309    fn tool_name(&self) -> Option<&str> {
310        None
311    }
312
313    /// Parse text output.
314    fn parse_text(&self, text: &str) -> Result<Output, OutputParseError>;
315
316    /// Parse tool call output.
317    fn parse_tool_call(&self, _name: &str, _args: &JsonValue) -> Result<Output, OutputParseError> {
318        Err(OutputParseError::ToolNotCalled)
319    }
320}
321
322/// Text output schema (returns String).
323#[derive(Debug, Clone, Default)]
324pub struct TextOutputSchema;
325
326impl OutputSchema<String> for TextOutputSchema {
327    fn parse_text(&self, text: &str) -> Result<String, OutputParseError> {
328        Ok(text.to_string())
329    }
330}
331
332/// Default output schema (text for String, JSON for others).
333#[derive(Debug, Clone, Default)]
334pub struct DefaultOutputSchema<Output> {
335    _phantom: PhantomData<Output>,
336}
337
338impl<Output> DefaultOutputSchema<Output> {
339    /// Create a new default output schema.
340    pub fn new() -> Self {
341        Self {
342            _phantom: PhantomData,
343        }
344    }
345}
346
347impl<Output: DeserializeOwned + Send + Sync + 'static> OutputSchema<Output>
348    for DefaultOutputSchema<Output>
349{
350    fn mode(&self) -> OutputMode {
351        if TypeId::of::<Output>() == TypeId::of::<String>() {
352            OutputMode::Text
353        } else {
354            OutputMode::Json
355        }
356    }
357
358    fn parse_text(&self, text: &str) -> Result<Output, OutputParseError> {
359        if TypeId::of::<Output>() == TypeId::of::<String>() {
360            // For String output, use serde_json with Value::String to safely convert
361            serde_json::from_value(serde_json::Value::String(text.to_string()))
362                .map_err(OutputParseError::Json)
363        } else {
364            let json_str = extract_json(text).unwrap_or(text);
365            serde_json::from_str(json_str).map_err(OutputParseError::Json)
366        }
367    }
368}
369
370/// JSON output schema (parses JSON to type).
371pub struct JsonOutputSchema<T> {
372    schema: Option<JsonValue>,
373    _phantom: PhantomData<T>,
374}
375
376impl<T: DeserializeOwned> JsonOutputSchema<T> {
377    /// Create a new JSON output schema.
378    pub fn new() -> Self {
379        Self {
380            schema: None,
381            _phantom: PhantomData,
382        }
383    }
384
385    /// Set the JSON schema.
386    pub fn with_schema(mut self, schema: JsonValue) -> Self {
387        self.schema = Some(schema);
388        self
389    }
390}
391
392impl<T: DeserializeOwned> Default for JsonOutputSchema<T> {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for JsonOutputSchema<T> {
399    fn json_schema(&self) -> Option<JsonValue> {
400        self.schema.clone()
401    }
402
403    fn mode(&self) -> OutputMode {
404        OutputMode::Json
405    }
406
407    fn parse_text(&self, text: &str) -> Result<T, OutputParseError> {
408        // Try to extract JSON from the text
409        let json_str = extract_json(text).unwrap_or(text);
410        serde_json::from_str(json_str).map_err(OutputParseError::Json)
411    }
412}
413
414/// Tool-based output schema.
415pub struct ToolOutputSchema<T> {
416    tool_name: String,
417    schema: Option<JsonValue>,
418    _phantom: PhantomData<T>,
419}
420
421impl<T: DeserializeOwned> ToolOutputSchema<T> {
422    /// Create a new tool output schema.
423    pub fn new(tool_name: impl Into<String>) -> Self {
424        Self {
425            tool_name: tool_name.into(),
426            schema: None,
427            _phantom: PhantomData,
428        }
429    }
430
431    /// Set the JSON schema.
432    pub fn with_schema(mut self, schema: JsonValue) -> Self {
433        self.schema = Some(schema);
434        self
435    }
436}
437
438impl<T: DeserializeOwned + Send + Sync> OutputSchema<T> for ToolOutputSchema<T> {
439    fn json_schema(&self) -> Option<JsonValue> {
440        self.schema.clone()
441    }
442
443    fn mode(&self) -> OutputMode {
444        OutputMode::ToolCall
445    }
446
447    fn tool_name(&self) -> Option<&str> {
448        Some(&self.tool_name)
449    }
450
451    fn parse_text(&self, _text: &str) -> Result<T, OutputParseError> {
452        Err(OutputParseError::ToolNotCalled)
453    }
454
455    fn parse_tool_call(&self, name: &str, args: &JsonValue) -> Result<T, OutputParseError> {
456        if name != self.tool_name {
457            return Err(OutputParseError::ToolNotCalled);
458        }
459        serde_json::from_value(args.clone()).map_err(OutputParseError::Json)
460    }
461}
462
463/// Extract JSON from text (handles markdown code blocks).
464fn extract_json(text: &str) -> Option<&str> {
465    // Try to find JSON in code blocks
466    if let Some(start) = text.find("```json") {
467        let content_start = start + 7;
468        if let Some(end) = text[content_start..].find("```") {
469            return Some(text[content_start..content_start + end].trim());
470        }
471    }
472
473    // Try to find JSON in plain code blocks
474    if let Some(start) = text.find("```") {
475        let content_start = start + 3;
476        // Skip any language identifier
477        let line_end = text[content_start..].find('\n').unwrap_or(0);
478        let content_start = content_start + line_end + 1;
479        if let Some(end) = text[content_start..].find("```") {
480            let potential = &text[content_start..content_start + end].trim();
481            if potential.starts_with('{') || potential.starts_with('[') {
482                return Some(potential);
483            }
484        }
485    }
486
487    // Try to find raw JSON
488    if let Some(start) = text.find('{') {
489        if let Some(end) = text.rfind('}') {
490            if end > start {
491                return Some(&text[start..=end]);
492            }
493        }
494    }
495
496    None
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use chrono::Utc;
503    use std::sync::Arc;
504
505    fn make_context() -> RunContext<()> {
506        RunContext {
507            deps: Arc::new(()),
508            run_id: "test".to_string(),
509            start_time: Utc::now(),
510            model_name: "test".to_string(),
511            model_settings: Default::default(),
512            tool_name: None,
513            tool_call_id: None,
514            retry_count: 0,
515            metadata: None,
516        }
517    }
518
519    #[tokio::test]
520    async fn test_non_empty_validator() {
521        let validator = NonEmptyValidator;
522        let ctx = make_context();
523
524        let result = validator.validate("hello".to_string(), &ctx).await;
525        assert!(result.is_ok());
526
527        let result = validator.validate("".to_string(), &ctx).await;
528        assert!(result.is_err());
529
530        let result = validator.validate("   ".to_string(), &ctx).await;
531        assert!(result.is_err());
532    }
533
534    #[tokio::test]
535    async fn test_length_validator() {
536        let validator = LengthValidator::new().min(5).max(10);
537        let ctx = make_context();
538
539        let result = validator.validate("hello".to_string(), &ctx).await;
540        assert!(result.is_ok());
541
542        let result = validator.validate("hi".to_string(), &ctx).await;
543        assert!(result.is_err());
544
545        let result = validator.validate("hello world!".to_string(), &ctx).await;
546        assert!(result.is_err());
547    }
548
549    #[tokio::test]
550    async fn test_chained_validator() {
551        let validator = ChainedValidator::<String, ()>::new()
552            .add(NonEmptyValidator)
553            .add(LengthValidator::new().min(3));
554
555        let ctx = make_context();
556
557        let result = validator.validate("hello".to_string(), &ctx).await;
558        assert!(result.is_ok());
559
560        let result = validator.validate("hi".to_string(), &ctx).await;
561        assert!(result.is_err());
562    }
563
564    #[test]
565    fn test_text_output_schema() {
566        let schema = TextOutputSchema;
567        let result = schema.parse_text("hello world");
568        assert_eq!(result.unwrap(), "hello world");
569    }
570
571    #[test]
572    fn test_json_output_schema() {
573        use serde::Deserialize;
574
575        #[derive(Debug, Deserialize, PartialEq)]
576        struct Person {
577            name: String,
578            age: u32,
579        }
580
581        let schema = JsonOutputSchema::<Person>::new();
582
583        // Plain JSON
584        let result = schema.parse_text(r#"{"name": "Alice", "age": 30}"#);
585        assert_eq!(
586            result.unwrap(),
587            Person {
588                name: "Alice".to_string(),
589                age: 30
590            }
591        );
592
593        // JSON in code block
594        let text = r#"Here's the person:
595```json
596{"name": "Bob", "age": 25}
597```"#;
598        let result = schema.parse_text(text);
599        assert_eq!(
600            result.unwrap(),
601            Person {
602                name: "Bob".to_string(),
603                age: 25
604            }
605        );
606    }
607
608    #[test]
609    fn test_extract_json() {
610        let text = "Here's some JSON: {\"a\": 1}";
611        assert_eq!(extract_json(text), Some("{\"a\": 1}"));
612
613        let text = "```json\n{\"a\": 1}\n```";
614        assert!(extract_json(text).is_some());
615    }
616}