aidale_core/strategy/
json_output.rs

1//! JSON output strategies for different providers.
2//!
3//! This module defines strategies for handling JSON output differences
4//! between providers:
5//! - JsonSchemaStrategy: Providers that support strict JSON Schema (OpenAI, Anthropic)
6//! - JsonModeStrategy: Providers that only support basic JSON object mode (DeepSeek)
7
8use crate::error::AiError;
9use crate::types::{ChatCompletionRequest, ContentPart, Message, ResponseFormat, Role};
10
11/// Strategy for handling JSON output in chat completion requests.
12///
13/// Different providers have different capabilities for JSON output:
14/// - Some support strict JSON Schema validation (OpenAI)
15/// - Some only support basic JSON object mode with prompt engineering (DeepSeek)
16pub trait JsonOutputStrategy: Send + Sync {
17    /// Get the strategy name for debugging
18    fn name(&self) -> &str;
19
20    /// Apply this strategy to a chat completion request to enable JSON output.
21    ///
22    /// This method modifies the request to use the appropriate JSON output mode
23    /// for the provider. It may:
24    /// - Set response_format to JsonSchema (for providers that support it)
25    /// - Set response_format to JsonObject and inject schema into prompt (for providers that don't)
26    fn apply(
27        &self,
28        req: &mut ChatCompletionRequest,
29        schema: &serde_json::Value,
30    ) -> Result<(), AiError>;
31}
32
33/// JSON Schema strategy for providers that support strict JSON Schema.
34///
35/// This is used for providers like OpenAI that support the response_format.json_schema
36/// parameter with strict schema validation.
37#[derive(Debug, Clone)]
38pub struct JsonSchemaStrategy {
39    /// Whether to enable strict mode
40    pub strict: bool,
41}
42
43impl JsonSchemaStrategy {
44    /// Create a new JSON Schema strategy with strict mode enabled
45    pub fn new() -> Self {
46        Self { strict: true }
47    }
48
49    /// Create a new JSON Schema strategy with configurable strict mode
50    pub fn with_strict(strict: bool) -> Self {
51        Self { strict }
52    }
53}
54
55impl Default for JsonSchemaStrategy {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl JsonOutputStrategy for JsonSchemaStrategy {
62    fn name(&self) -> &str {
63        "JsonSchemaStrategy"
64    }
65
66    fn apply(
67        &self,
68        req: &mut ChatCompletionRequest,
69        schema: &serde_json::Value,
70    ) -> Result<(), AiError> {
71        // Set response_format to JsonSchema with the provided schema
72        req.response_format = Some(ResponseFormat::JsonSchema {
73            name: "response".to_string(),
74            schema: schema.clone(),
75            strict: self.strict,
76        });
77
78        Ok(())
79    }
80}
81
82/// JSON Mode strategy for providers that only support basic JSON object mode.
83///
84/// This is used for providers like DeepSeek that don't support JSON Schema
85/// but can output JSON objects when instructed via prompt.
86///
87/// This strategy:
88/// 1. Sets response_format to JsonObject
89/// 2. Injects the schema into the system prompt to guide the model
90#[derive(Debug, Clone)]
91pub struct JsonModeStrategy {
92    /// Whether to inject schema as a system message (true) or append to last user message (false)
93    pub use_system_message: bool,
94}
95
96impl JsonModeStrategy {
97    /// Create a new JSON Mode strategy that uses system messages
98    pub fn new() -> Self {
99        Self {
100            use_system_message: true,
101        }
102    }
103
104    /// Create a new JSON Mode strategy with configurable message injection
105    pub fn with_system_message(use_system_message: bool) -> Self {
106        Self { use_system_message }
107    }
108
109    /// Build a JSON instruction from a schema
110    fn build_json_instruction(schema: &serde_json::Value) -> Result<String, AiError> {
111        let schema_str = serde_json::to_string_pretty(schema)?;
112        Ok(format!(
113            "You must respond with valid JSON that matches this schema:\n```json\n{}\n```\n\nIMPORTANT:\n\
114            1. Only return the JSON object, nothing else\n\
115            2. Ensure all required fields are present\n\
116            3. Follow the schema structure exactly\n\
117            4. Use the correct data types for each field",
118            schema_str
119        ))
120    }
121}
122
123impl Default for JsonModeStrategy {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl JsonOutputStrategy for JsonModeStrategy {
130    fn name(&self) -> &str {
131        "JsonModeStrategy"
132    }
133
134    fn apply(
135        &self,
136        req: &mut ChatCompletionRequest,
137        schema: &serde_json::Value,
138    ) -> Result<(), AiError> {
139        // Set response_format to JsonObject (basic JSON mode)
140        req.response_format = Some(ResponseFormat::JsonObject);
141
142        // Build the schema instruction
143        let instruction = Self::build_json_instruction(schema)?;
144
145        // Inject the schema instruction into messages
146        if self.use_system_message {
147            // Add as system message at the beginning
148            let system_msg = Message {
149                role: Role::System,
150                content: vec![ContentPart::Text { text: instruction }],
151                name: None,
152            };
153            req.messages.insert(0, system_msg);
154        } else {
155            // Append to the last user message
156            if let Some(last_msg) = req.messages.iter_mut().rev().find(|m| m.role == Role::User) {
157                last_msg.content.push(ContentPart::Text {
158                    text: format!("\n\n{}", instruction),
159                });
160            } else {
161                // If no user message found, create one with just the instruction
162                let user_msg = Message {
163                    role: Role::User,
164                    content: vec![ContentPart::Text { text: instruction }],
165                    name: None,
166                };
167                req.messages.push(user_msg);
168            }
169        }
170
171        Ok(())
172    }
173}
174
175/// Auto-detect the appropriate JSON output strategy for a provider.
176///
177/// This function returns the recommended strategy based on the provider ID.
178/// In the future, this could be enhanced to query provider capabilities.
179pub fn detect_json_strategy(provider_id: &str) -> Box<dyn JsonOutputStrategy> {
180    match provider_id {
181        // Providers that support JSON Schema
182        "openai" | "anthropic" | "azure" => Box::new(JsonSchemaStrategy::new()),
183
184        // Providers that only support basic JSON mode
185        "deepseek" => Box::new(JsonModeStrategy::new()),
186
187        // Default to JSON Mode for unknown providers (safer fallback)
188        _ => Box::new(JsonModeStrategy::new()),
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_json_schema_strategy() {
198        let strategy = JsonSchemaStrategy::new();
199        let mut req = ChatCompletionRequest::new("test-model", vec![]);
200        let schema = serde_json::json!({
201            "type": "object",
202            "properties": {
203                "name": {"type": "string"}
204            }
205        });
206
207        strategy.apply(&mut req, &schema).unwrap();
208
209        match req.response_format {
210            Some(ResponseFormat::JsonSchema {
211                name: _,
212                schema: s,
213                strict,
214            }) => {
215                assert_eq!(s, schema);
216                assert!(strict);
217            }
218            _ => panic!("Expected JsonSchema response format"),
219        }
220    }
221
222    #[test]
223    fn test_json_mode_strategy() {
224        let strategy = JsonModeStrategy::new();
225        let mut req = ChatCompletionRequest::new("test-model", vec![Message::user("Hello")]);
226        let schema = serde_json::json!({
227            "type": "object",
228            "properties": {
229                "name": {"type": "string"}
230            }
231        });
232
233        strategy.apply(&mut req, &schema).unwrap();
234
235        // Should have JsonObject response format
236        assert!(matches!(
237            req.response_format,
238            Some(ResponseFormat::JsonObject)
239        ));
240
241        // Should have injected system message
242        assert_eq!(req.messages.len(), 2);
243        assert_eq!(req.messages[0].role, Role::System);
244    }
245
246    #[test]
247    fn test_detect_json_strategy() {
248        // OpenAI should get JsonSchemaStrategy
249        let openai_strategy = detect_json_strategy("openai");
250        assert_eq!(openai_strategy.name(), "JsonSchemaStrategy");
251
252        // DeepSeek should get JsonModeStrategy
253        let deepseek_strategy = detect_json_strategy("deepseek");
254        assert_eq!(deepseek_strategy.name(), "JsonModeStrategy");
255
256        // Unknown providers should get JsonModeStrategy (safer fallback)
257        let unknown_strategy = detect_json_strategy("unknown");
258        assert_eq!(unknown_strategy.name(), "JsonModeStrategy");
259    }
260}