aidale_core/strategy/
json_output.rs1use crate::error::AiError;
9use crate::types::{ChatCompletionRequest, ContentPart, Message, ResponseFormat, Role};
10
11pub trait JsonOutputStrategy: Send + Sync {
17 fn name(&self) -> &str;
19
20 fn apply(
27 &self,
28 req: &mut ChatCompletionRequest,
29 schema: &serde_json::Value,
30 ) -> Result<(), AiError>;
31}
32
33#[derive(Debug, Clone)]
38pub struct JsonSchemaStrategy {
39 pub strict: bool,
41}
42
43impl JsonSchemaStrategy {
44 pub fn new() -> Self {
46 Self { strict: true }
47 }
48
49 pub fn with_strict(strict: bool) -> Self {
51 Self { strict }
52 }
53}
54
55impl Default for JsonSchemaStrategy {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl JsonOutputStrategy for JsonSchemaStrategy {
62 fn name(&self) -> &str {
63 "JsonSchemaStrategy"
64 }
65
66 fn apply(
67 &self,
68 req: &mut ChatCompletionRequest,
69 schema: &serde_json::Value,
70 ) -> Result<(), AiError> {
71 req.response_format = Some(ResponseFormat::JsonSchema {
73 name: "response".to_string(),
74 schema: schema.clone(),
75 strict: self.strict,
76 });
77
78 Ok(())
79 }
80}
81
82#[derive(Debug, Clone)]
91pub struct JsonModeStrategy {
92 pub use_system_message: bool,
94}
95
96impl JsonModeStrategy {
97 pub fn new() -> Self {
99 Self {
100 use_system_message: true,
101 }
102 }
103
104 pub fn with_system_message(use_system_message: bool) -> Self {
106 Self { use_system_message }
107 }
108
109 fn build_json_instruction(schema: &serde_json::Value) -> Result<String, AiError> {
111 let schema_str = serde_json::to_string_pretty(schema)?;
112 Ok(format!(
113 "You must respond with valid JSON that matches this schema:\n```json\n{}\n```\n\nIMPORTANT:\n\
114 1. Only return the JSON object, nothing else\n\
115 2. Ensure all required fields are present\n\
116 3. Follow the schema structure exactly\n\
117 4. Use the correct data types for each field",
118 schema_str
119 ))
120 }
121}
122
123impl Default for JsonModeStrategy {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl JsonOutputStrategy for JsonModeStrategy {
130 fn name(&self) -> &str {
131 "JsonModeStrategy"
132 }
133
134 fn apply(
135 &self,
136 req: &mut ChatCompletionRequest,
137 schema: &serde_json::Value,
138 ) -> Result<(), AiError> {
139 req.response_format = Some(ResponseFormat::JsonObject);
141
142 let instruction = Self::build_json_instruction(schema)?;
144
145 if self.use_system_message {
147 let system_msg = Message {
149 role: Role::System,
150 content: vec![ContentPart::Text { text: instruction }],
151 name: None,
152 };
153 req.messages.insert(0, system_msg);
154 } else {
155 if let Some(last_msg) = req.messages.iter_mut().rev().find(|m| m.role == Role::User) {
157 last_msg.content.push(ContentPart::Text {
158 text: format!("\n\n{}", instruction),
159 });
160 } else {
161 let user_msg = Message {
163 role: Role::User,
164 content: vec![ContentPart::Text { text: instruction }],
165 name: None,
166 };
167 req.messages.push(user_msg);
168 }
169 }
170
171 Ok(())
172 }
173}
174
175pub fn detect_json_strategy(provider_id: &str) -> Box<dyn JsonOutputStrategy> {
180 match provider_id {
181 "openai" | "anthropic" | "azure" => Box::new(JsonSchemaStrategy::new()),
183
184 "deepseek" => Box::new(JsonModeStrategy::new()),
186
187 _ => Box::new(JsonModeStrategy::new()),
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_json_schema_strategy() {
198 let strategy = JsonSchemaStrategy::new();
199 let mut req = ChatCompletionRequest::new("test-model", vec![]);
200 let schema = serde_json::json!({
201 "type": "object",
202 "properties": {
203 "name": {"type": "string"}
204 }
205 });
206
207 strategy.apply(&mut req, &schema).unwrap();
208
209 match req.response_format {
210 Some(ResponseFormat::JsonSchema {
211 name: _,
212 schema: s,
213 strict,
214 }) => {
215 assert_eq!(s, schema);
216 assert!(strict);
217 }
218 _ => panic!("Expected JsonSchema response format"),
219 }
220 }
221
222 #[test]
223 fn test_json_mode_strategy() {
224 let strategy = JsonModeStrategy::new();
225 let mut req = ChatCompletionRequest::new("test-model", vec![Message::user("Hello")]);
226 let schema = serde_json::json!({
227 "type": "object",
228 "properties": {
229 "name": {"type": "string"}
230 }
231 });
232
233 strategy.apply(&mut req, &schema).unwrap();
234
235 assert!(matches!(
237 req.response_format,
238 Some(ResponseFormat::JsonObject)
239 ));
240
241 assert_eq!(req.messages.len(), 2);
243 assert_eq!(req.messages[0].role, Role::System);
244 }
245
246 #[test]
247 fn test_detect_json_strategy() {
248 let openai_strategy = detect_json_strategy("openai");
250 assert_eq!(openai_strategy.name(), "JsonSchemaStrategy");
251
252 let deepseek_strategy = detect_json_strategy("deepseek");
254 assert_eq!(deepseek_strategy.name(), "JsonModeStrategy");
255
256 let unknown_strategy = detect_json_strategy("unknown");
258 assert_eq!(unknown_strategy.name(), "JsonModeStrategy");
259 }
260}