1use std::fmt;
22
23use serde_json::Value;
24use tracing::{info, warn};
25
26use crate::types::{ChatMessage, ChatRequest, LlmProvider, MessageRole, RunnerError};
27
28#[derive(Debug, Clone)]
30pub struct StructuredOutputRequest {
31 pub request: ChatRequest,
33 pub schema: Value,
35 pub max_retries: u32,
37}
38
39#[derive(Debug, Clone)]
41pub struct SchemaValidationError {
42 pub message: String,
44 pub path: String,
46}
47
48impl fmt::Display for SchemaValidationError {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(f, "{}: {}", self.path, self.message)
51 }
52}
53
54pub async fn request_structured_output(
68 provider: &dyn LlmProvider,
69 structured_request: &StructuredOutputRequest,
70) -> Result<Value, RunnerError> {
71 let schema_str = serde_json::to_string_pretty(&structured_request.schema)
72 .map_err(|e| RunnerError::internal(format!("failed to serialize schema: {e}")))?;
73
74 let schema_instruction = format!(
75 "\n\nYou MUST respond with ONLY valid JSON that conforms to the following JSON Schema. \
76 Do NOT include any explanatory text, markdown formatting, or anything other than the \
77 JSON object.\n\nSchema:\n```json\n{schema_str}\n```"
78 );
79
80 let mut messages = structured_request.request.messages.clone();
81
82 inject_schema_instruction(&mut messages, &schema_instruction);
84
85 let total_attempts = structured_request.max_retries + 1;
86 for attempt in 0..total_attempts {
87 let request = ChatRequest {
88 messages: messages.clone(),
89 model: structured_request.request.model.clone(),
90 temperature: structured_request.request.temperature,
91 max_tokens: structured_request.request.max_tokens,
92 stream: false,
93 tools: structured_request.request.tools.clone(),
94 tool_choice: structured_request.request.tool_choice.clone(),
95 top_p: structured_request.request.top_p,
96 stop: structured_request.request.stop.clone(),
97 response_format: structured_request.request.response_format.clone(),
98 turn_id: structured_request.request.turn_id,
99 };
100
101 let response = provider.complete(&request).await?;
102
103 let json_str = extract_json_from_response(&response.content);
105
106 let parsed: Value = match serde_json::from_str(&json_str) {
107 Ok(v) => v,
108 Err(parse_err) => {
109 warn!(
110 attempt,
111 error = %parse_err,
112 "structured output: failed to parse JSON from response"
113 );
114 if attempt < structured_request.max_retries {
115 messages.push(ChatMessage::assistant(response.content.clone()));
116 messages.push(ChatMessage::user(format!(
117 "Your response was not valid JSON: {parse_err}. \
118 Please respond with ONLY a valid JSON object matching the schema."
119 )));
120 }
121 continue;
122 }
123 };
124
125 let errors = validate_against_schema(&parsed, &structured_request.schema);
126
127 if errors.is_empty() {
128 info!(attempt, "structured output: validation passed");
129 return Ok(parsed);
130 }
131
132 warn!(
133 attempt,
134 error_count = errors.len(),
135 "structured output: schema validation failed"
136 );
137
138 if attempt < structured_request.max_retries {
139 let error_feedback: Vec<String> = errors.iter().map(ToString::to_string).collect();
140 messages.push(ChatMessage::assistant(response.content.clone()));
141 messages.push(ChatMessage::user(format!(
142 "Your JSON response had validation errors:\n- {}\n\
143 Please fix these and respond with ONLY a valid JSON object.",
144 error_feedback.join("\n- ")
145 )));
146 }
147 }
148
149 Err(RunnerError::external_service(
150 provider.name(),
151 "structured output validation exhausted after all retries",
152 ))
153}
154
155fn inject_schema_instruction(messages: &mut Vec<ChatMessage>, instruction: &str) {
157 if let Some(first) = messages.first_mut() {
158 if first.role == MessageRole::System {
159 let augmented = format!("{}{instruction}", first.content);
160 *first = ChatMessage::system(augmented);
161 return;
162 }
163 }
164 messages.insert(0, ChatMessage::system(instruction.to_owned()));
165}
166
167pub fn extract_json_from_response(content: &str) -> String {
173 let trimmed = content.trim();
174
175 if trimmed.starts_with('{') {
177 return extract_braced_json(trimmed);
178 }
179
180 if let Some(start) = trimmed.find("```") {
182 let after_fence = &trimmed[start + 3..];
183 let content_start = after_fence.find('\n').map_or(0, |pos| pos + 1);
185 let fence_content = &after_fence[content_start..];
186
187 if let Some(end) = fence_content.find("```") {
188 let inside = fence_content[..end].trim();
189 if inside.starts_with('{') {
190 return extract_braced_json(inside);
191 }
192 }
193 }
194
195 if let Some(brace_pos) = trimmed.find('{') {
197 return extract_braced_json(&trimmed[brace_pos..]);
198 }
199
200 trimmed.to_owned()
201}
202
203fn extract_braced_json(text: &str) -> String {
205 let mut depth: i32 = 0;
206 let mut in_string = false;
207 let mut escape_next = false;
208
209 for (i, ch) in text.char_indices() {
210 if escape_next {
211 escape_next = false;
212 continue;
213 }
214
215 match ch {
216 '\\' if in_string => escape_next = true,
217 '"' => in_string = !in_string,
218 '{' if !in_string => depth += 1,
219 '}' if !in_string => {
220 depth -= 1;
221 if depth == 0 {
222 return text[..=i].to_owned();
223 }
224 }
225 _ => {}
226 }
227 }
228
229 text.to_owned()
230}
231
232pub fn validate_against_schema(value: &Value, schema: &Value) -> Vec<SchemaValidationError> {
237 let mut errors = Vec::new();
238 validate_value(value, schema, "$", &mut errors);
239 errors
240}
241
242fn validate_value(
243 value: &Value,
244 schema: &Value,
245 path: &str,
246 errors: &mut Vec<SchemaValidationError>,
247) {
248 if let Some(expected_type) = schema.get("type").and_then(Value::as_str) {
250 let actual_type = json_type_name(value);
251 if actual_type != expected_type {
252 errors.push(SchemaValidationError {
253 message: format!("expected type \"{expected_type}\", got \"{actual_type}\""),
254 path: path.to_owned(),
255 });
256 return;
257 }
258 }
259
260 if let Some(enum_values) = schema.get("enum").and_then(Value::as_array) {
262 if !enum_values.contains(value) {
263 errors.push(SchemaValidationError {
264 message: format!("value not in enum: expected one of {enum_values:?}, got {value}"),
265 path: path.to_owned(),
266 });
267 return;
268 }
269 }
270
271 if let Some(num) = value.as_f64() {
273 if let Some(min) = schema.get("minimum").and_then(Value::as_f64) {
274 if num < min {
275 errors.push(SchemaValidationError {
276 message: format!("value {num} is less than minimum {min}"),
277 path: path.to_owned(),
278 });
279 }
280 }
281 if let Some(max) = schema.get("maximum").and_then(Value::as_f64) {
282 if num > max {
283 errors.push(SchemaValidationError {
284 message: format!("value {num} exceeds maximum {max}"),
285 path: path.to_owned(),
286 });
287 }
288 }
289 }
290
291 if let Some(obj) = value.as_object() {
293 if let Some(required) = schema.get("required").and_then(Value::as_array) {
294 for req in required {
295 if let Some(field_name) = req.as_str() {
296 if !obj.contains_key(field_name) {
297 errors.push(SchemaValidationError {
298 message: format!("missing required field \"{field_name}\""),
299 path: format!("{path}.{field_name}"),
300 });
301 }
302 }
303 }
304 }
305
306 if let Some(properties) = schema.get("properties").and_then(Value::as_object) {
307 for (prop_name, prop_schema) in properties {
308 if let Some(prop_value) = obj.get(prop_name) {
309 let prop_path = format!("{path}.{prop_name}");
310 validate_value(prop_value, prop_schema, &prop_path, errors);
312 }
313 }
314
315 if schema.get("additionalProperties") == Some(&Value::Bool(false)) {
317 for key in obj.keys() {
318 if !properties.contains_key(key) {
319 errors.push(SchemaValidationError {
320 message: format!("unexpected additional property \"{key}\""),
321 path: format!("{path}.{key}"),
322 });
323 }
324 }
325 }
326 }
327 }
328
329 if let Some(arr) = value.as_array() {
331 if let Some(items_schema) = schema.get("items") {
332 for (i, item) in arr.iter().enumerate() {
333 let item_path = format!("{path}[{i}]");
334 validate_value(item, items_schema, &item_path, errors);
335 }
336 }
337 }
338}
339
340fn json_type_name(value: &Value) -> &'static str {
342 match value {
343 Value::Null => "null",
344 Value::Bool(_) => "boolean",
345 Value::Number(n) => {
346 if n.is_i64() || n.is_u64() {
347 "integer"
348 } else {
349 "number"
350 }
351 }
352 Value::String(_) => "string",
353 Value::Array(_) => "array",
354 Value::Object(_) => "object",
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::types::{
362 ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
363 RunnerError,
364 };
365 use async_trait::async_trait;
366 use serde_json::json;
367 use std::sync::atomic::{AtomicU32, Ordering};
368 use std::sync::Mutex;
369
370 struct TestProvider {
371 responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
372 call_count: AtomicU32,
373 }
374
375 impl TestProvider {
376 fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
377 Self {
378 responses: Mutex::new(responses),
379 call_count: AtomicU32::new(0),
380 }
381 }
382 }
383
384 #[async_trait]
385 impl LlmProvider for TestProvider {
386 fn name(&self) -> &'static str {
387 "test"
388 }
389 fn display_name(&self) -> &str {
390 "Test Provider"
391 }
392 fn capabilities(&self) -> LlmCapabilities {
393 LlmCapabilities::text_only()
394 }
395 fn default_model(&self) -> &'static str {
396 "test-model"
397 }
398 fn available_models(&self) -> &[String] {
399 &[]
400 }
401 async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
402 self.call_count.fetch_add(1, Ordering::SeqCst);
403 let mut responses = self.responses.lock().expect("test lock"); if responses.is_empty() {
405 Err(RunnerError::internal("no more test responses"))
406 } else {
407 responses.remove(0)
408 }
409 }
410 async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
411 Err(RunnerError::internal("not supported"))
412 }
413 async fn health_check(&self) -> Result<bool, RunnerError> {
414 Ok(true)
415 }
416 }
417
418 fn make_response(content: &str) -> ChatResponse {
419 ChatResponse {
420 content: content.to_owned(),
421 model: "test-model".to_owned(),
422 usage: None,
423 finish_reason: Some("stop".to_owned()),
424 warnings: None,
425 tool_calls: None,
426 }
427 }
428
429 #[test]
432 fn validate_valid_object() {
433 let schema = json!({
434 "type": "object",
435 "properties": {
436 "name": {"type": "string"},
437 "age": {"type": "integer"}
438 },
439 "required": ["name", "age"]
440 });
441
442 let value = json!({"name": "Alice", "age": 30});
443 let errors = validate_against_schema(&value, &schema);
444 assert!(errors.is_empty());
445 }
446
447 #[test]
448 fn validate_missing_required_fields() {
449 let schema = json!({
450 "type": "object",
451 "properties": {
452 "name": {"type": "string"},
453 "age": {"type": "integer"}
454 },
455 "required": ["name", "age"]
456 });
457
458 let value = json!({"name": "Alice"});
459 let errors = validate_against_schema(&value, &schema);
460 assert_eq!(errors.len(), 1);
461 assert!(errors[0].message.contains("age"));
462 }
463
464 #[test]
465 fn validate_wrong_types() {
466 let schema = json!({
467 "type": "object",
468 "properties": {
469 "name": {"type": "string"},
470 "age": {"type": "integer"}
471 },
472 "required": ["name"]
473 });
474
475 let value = json!({"name": 42, "age": "not a number"});
476 let errors = validate_against_schema(&value, &schema);
477 assert_eq!(errors.len(), 2);
478 }
479
480 #[test]
481 fn validate_wrong_root_type() {
482 let schema = json!({"type": "object"});
483 let value = json!("just a string");
484 let errors = validate_against_schema(&value, &schema);
485 assert_eq!(errors.len(), 1);
486 assert!(errors[0].message.contains("expected type \"object\""));
487 }
488
489 #[test]
492 fn extract_raw_json() {
493 let content = r#"{"name": "Alice", "age": 30}"#;
494 let extracted = extract_json_from_response(content);
495 let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); assert_eq!(parsed["name"], "Alice");
497 }
498
499 #[test]
500 fn extract_json_from_markdown_fences() {
501 let content = "Here is the result:\n```json\n{\"name\": \"Bob\", \"age\": 25}\n```\nDone.";
502 let extracted = extract_json_from_response(content);
503 let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); assert_eq!(parsed["name"], "Bob");
505 }
506
507 #[test]
508 fn extract_json_with_nested_braces() {
509 let content = r#"{"outer": {"inner": "value"}, "list": [1, 2]}"#;
510 let extracted = extract_json_from_response(content);
511 let parsed: Value = serde_json::from_str(&extracted).expect("valid JSON"); assert_eq!(parsed["outer"]["inner"], "value");
513 }
514
515 #[tokio::test]
518 async fn full_retry_loop_eventual_success() {
519 let provider = TestProvider::new(vec![
520 Ok(make_response("not json at all")),
521 Ok(make_response(r#"{"name": "Alice", "age": 30}"#)),
522 ]);
523
524 let schema = json!({
525 "type": "object",
526 "properties": {
527 "name": {"type": "string"},
528 "age": {"type": "integer"}
529 },
530 "required": ["name", "age"]
531 });
532
533 let structured = StructuredOutputRequest {
534 request: ChatRequest::new(vec![ChatMessage::user("give me data")]),
535 schema,
536 max_retries: 2,
537 };
538
539 let result = request_structured_output(&provider, &structured)
540 .await
541 .expect("should succeed on retry"); assert_eq!(result["name"], "Alice");
543 assert_eq!(result["age"], 30);
544 }
545
546 #[tokio::test]
547 async fn exhaustion_returns_error() {
548 let provider = TestProvider::new(vec![
549 Ok(make_response("garbage")),
550 Ok(make_response("still garbage")),
551 Ok(make_response("nope")),
552 ]);
553
554 let schema = json!({
555 "type": "object",
556 "required": ["name"]
557 });
558
559 let structured = StructuredOutputRequest {
560 request: ChatRequest::new(vec![ChatMessage::user("give me data")]),
561 schema,
562 max_retries: 2,
563 };
564
565 let result = request_structured_output(&provider, &structured).await;
566 assert!(result.is_err());
567 let err = result.unwrap_err();
568 assert!(err.message.contains("exhausted"));
569 }
570
571 #[test]
574 fn validate_nested_object() {
575 let schema = json!({
576 "type": "object",
577 "properties": {
578 "address": {
579 "type": "object",
580 "properties": {
581 "city": {"type": "string"},
582 "zip": {"type": "string"}
583 },
584 "required": ["city"]
585 }
586 },
587 "required": ["address"]
588 });
589
590 let valid = json!({"address": {"city": "Paris", "zip": "75001"}});
591 assert!(validate_against_schema(&valid, &schema).is_empty());
592
593 let missing_city = json!({"address": {"zip": "75001"}});
594 let errors = validate_against_schema(&missing_city, &schema);
595 assert_eq!(errors.len(), 1);
596 assert!(errors[0].path.contains("city"));
597
598 let wrong_type = json!({"address": {"city": 42}});
599 let errors = validate_against_schema(&wrong_type, &schema);
600 assert_eq!(errors.len(), 1);
601 assert!(errors[0].message.contains("expected type \"string\""));
602 }
603
604 #[test]
605 fn validate_array_items() {
606 let schema = json!({
607 "type": "array",
608 "items": {"type": "string"}
609 });
610
611 let valid = json!(["a", "b", "c"]);
612 assert!(validate_against_schema(&valid, &schema).is_empty());
613
614 let invalid = json!(["a", 42, "c"]);
615 let errors = validate_against_schema(&invalid, &schema);
616 assert_eq!(errors.len(), 1);
617 assert!(errors[0].path.contains("[1]"));
618 }
619
620 #[test]
621 fn validate_enum_values() {
622 let schema = json!({
623 "type": "string",
624 "enum": ["red", "green", "blue"]
625 });
626
627 let valid = json!("green");
628 assert!(validate_against_schema(&valid, &schema).is_empty());
629
630 let invalid = json!("yellow");
631 let errors = validate_against_schema(&invalid, &schema);
632 assert_eq!(errors.len(), 1);
633 assert!(errors[0].message.contains("not in enum"));
634 }
635
636 #[test]
637 fn validate_numeric_bounds() {
638 let schema = json!({
639 "type": "integer",
640 "minimum": 0,
641 "maximum": 100
642 });
643
644 let valid = json!(50);
645 assert!(validate_against_schema(&valid, &schema).is_empty());
646
647 let too_low = json!(-1);
648 let errors = validate_against_schema(&too_low, &schema);
649 assert_eq!(errors.len(), 1);
650 assert!(errors[0].message.contains("less than minimum"));
651
652 let too_high = json!(101);
653 let errors = validate_against_schema(&too_high, &schema);
654 assert_eq!(errors.len(), 1);
655 assert!(errors[0].message.contains("exceeds maximum"));
656 }
657
658 #[test]
659 fn validate_additional_properties_false() {
660 let schema = json!({
661 "type": "object",
662 "properties": {
663 "name": {"type": "string"}
664 },
665 "additionalProperties": false
666 });
667
668 let valid = json!({"name": "Alice"});
669 assert!(validate_against_schema(&valid, &schema).is_empty());
670
671 let with_extra = json!({"name": "Alice", "age": 30});
672 let errors = validate_against_schema(&with_extra, &schema);
673 assert_eq!(errors.len(), 1);
674 assert!(errors[0].message.contains("unexpected additional property"));
675 }
676}