1use super::{LlmClient, Message, StreamEvent, TokenUsage, ToolDefinition};
9use anyhow::{bail, Context, Result};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum StructuredMode {
22 Auto,
24 Strict,
26 Json,
28 Tool,
31 Prompt,
33}
34
35#[derive(Debug, Clone)]
37pub struct StructuredRequest {
38 pub prompt: String,
39 pub system: Option<String>,
40 pub schema: Value,
41 pub schema_name: String,
42 pub schema_description: Option<String>,
43 pub mode: StructuredMode,
44 pub max_repair_attempts: u8,
45}
46
47#[derive(Debug, Clone, Serialize)]
49pub struct StructuredResult {
50 pub object: Value,
51 pub raw_text: Option<String>,
52 pub usage: TokenUsage,
53 pub repair_rounds: u8,
54 pub mode_used: StructuredMode,
55}
56
57pub type PartialObjectCallback = Box<dyn Fn(&Value) + Send>;
59
60pub async fn generate_blocking(
69 client: &dyn LlmClient,
70 req: &StructuredRequest,
71) -> Result<StructuredResult> {
72 let mode = req.mode;
73 let mut messages = build_initial_messages(req, mode);
74 let system = build_system_prompt(req, mode);
75 let tools = build_tools(req, mode);
76
77 let mut total_usage = TokenUsage::default();
78 let mut repair_rounds: u8 = 0;
79
80 loop {
81 let resp = client
82 .complete(&messages, Some(&system), &tools)
83 .await
84 .context("LLM call failed during structured generation")?;
85
86 accumulate_usage(&mut total_usage, &resp.usage);
87
88 let raw_text = extract_raw_output(&resp.message, mode);
89 let parsed = extract_json_value(&raw_text);
90
91 match parsed {
92 Ok(value) => match validate_against_schema(&value, &req.schema) {
93 Ok(()) => {
94 return Ok(StructuredResult {
95 object: value,
96 raw_text: Some(raw_text),
97 usage: total_usage,
98 repair_rounds,
99 mode_used: mode,
100 });
101 }
102 Err(errors) if repair_rounds < req.max_repair_attempts => {
103 repair_rounds += 1;
104 let repair_msg = build_repair_message(&raw_text, &errors);
105 append_repair_context(
106 &mut messages,
107 &resp.message,
108 &repair_msg,
109 mode,
110 &raw_text,
111 );
112 }
113 Err(errors) => {
114 bail!(
115 "Structured output failed schema validation after {} repair attempts. Errors: {}",
116 repair_rounds,
117 errors.join("; ")
118 );
119 }
120 },
121 Err(parse_err) if repair_rounds < req.max_repair_attempts => {
122 repair_rounds += 1;
123 let repair_msg = format!(
124 "Your previous output could not be parsed as JSON:\n\n{}\n\nError: {}\n\nPlease return ONLY a valid JSON object matching the schema.",
125 raw_text, parse_err
126 );
127 append_repair_context(&mut messages, &resp.message, &repair_msg, mode, &raw_text);
128 }
129 Err(parse_err) => {
130 bail!(
131 "Structured output failed JSON parsing after {} repair attempts: {}",
132 repair_rounds,
133 parse_err
134 );
135 }
136 }
137 }
138}
139
140pub async fn generate_streaming(
152 client: &dyn LlmClient,
153 req: &StructuredRequest,
154 on_partial: PartialObjectCallback,
155) -> Result<StructuredResult> {
156 let mode = req.mode;
157 let messages = build_initial_messages(req, mode);
158 let system = build_system_prompt(req, mode);
159 let tools = build_tools(req, mode);
160
161 let cancel_token = CancellationToken::new();
162 let mut rx = client
163 .complete_streaming(&messages, Some(&system), &tools, cancel_token)
164 .await
165 .context("LLM streaming call failed during structured generation")?;
166
167 let mut json_buffer = String::new();
168 let mut last_valid_partial: Option<Value> = None;
169 let mut final_response: Option<super::LlmResponse> = None;
170 let mut last_parse_len: usize = 0;
171 const PARSE_THRESHOLD: usize = 8;
173
174 while let Some(event) = rx.recv().await {
175 match event {
176 StreamEvent::ToolUseInputDelta(delta) if mode == StructuredMode::Tool => {
177 if final_response.is_some() {
178 continue;
179 }
180 json_buffer.push_str(&delta);
181 if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
182 if let Some(partial) = try_parse_partial_json(&json_buffer) {
183 if last_valid_partial.as_ref() != Some(&partial) {
184 on_partial(&partial);
185 last_valid_partial = Some(partial);
186 }
187 }
188 last_parse_len = json_buffer.len();
189 }
190 }
191 StreamEvent::TextDelta(delta) if mode != StructuredMode::Tool => {
192 if final_response.is_some() {
193 continue;
194 }
195 json_buffer.push_str(&delta);
196 if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
197 if let Some(json_start) = find_json_start(&json_buffer) {
198 let candidate = &json_buffer[json_start..];
199 if let Some(partial) = try_parse_partial_json(candidate) {
200 if last_valid_partial.as_ref() != Some(&partial) {
201 on_partial(&partial);
202 last_valid_partial = Some(partial);
203 }
204 }
205 }
206 last_parse_len = json_buffer.len();
207 }
208 }
209 StreamEvent::Done(resp) => {
210 final_response = Some(resp);
211 }
212 _ => {}
213 }
214 }
215
216 let resp = final_response.context("Stream ended without Done event")?;
217 let raw_text = extract_raw_output(&resp.message, mode);
218 let value =
219 extract_json_value(&raw_text).context("Failed to parse final streamed output as JSON")?;
220
221 validate_against_schema(&value, &req.schema).map_err(|errors| {
222 anyhow::anyhow!(
223 "Streamed structured output failed schema validation: {}",
224 errors.join("; ")
225 )
226 })?;
227
228 on_partial(&value);
230
231 Ok(StructuredResult {
232 object: value,
233 raw_text: Some(raw_text),
234 usage: resp.usage,
235 repair_rounds: 0,
236 mode_used: mode,
237 })
238}
239
240pub fn extract_json_value(text: &str) -> Result<Value> {
248 let trimmed = text.trim();
249
250 if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
252 if v.is_object() || v.is_array() {
253 return Ok(v);
254 }
255 }
256
257 if let Some(inner) = strip_code_fence(trimmed) {
259 if let Ok(v) = serde_json::from_str::<Value>(inner.trim()) {
260 if v.is_object() || v.is_array() {
261 return Ok(v);
262 }
263 }
264 }
265
266 if let Some(candidate) = find_balanced_json_object(trimmed) {
268 if let Ok(v) = serde_json::from_str::<Value>(candidate) {
269 return Ok(v);
270 }
271 }
272
273 if let Some(candidate) = find_balanced_json_array(trimmed) {
275 if let Ok(v) = serde_json::from_str::<Value>(candidate) {
276 return Ok(v);
277 }
278 }
279
280 bail!("No valid JSON object found in LLM output")
281}
282
283fn strip_code_fence(text: &str) -> Option<&str> {
285 let start_patterns = ["```json\n", "```json\r\n", "```\n", "```\r\n"];
286 for pat in &start_patterns {
287 if let Some(rest) = text.strip_prefix(pat) {
288 if let Some(end) = rest.rfind("```") {
290 return Some(&rest[..end]);
291 }
292 }
293 }
294 if let Some(inner) = text.strip_prefix("```json") {
296 if let Some(end) = inner.rfind("```") {
297 return Some(inner[..end].trim());
298 }
299 }
300 if let Some(inner) = text.strip_prefix("```") {
301 if let Some(end) = inner.rfind("```") {
302 return Some(inner[..end].trim());
303 }
304 }
305 None
306}
307
308fn find_balanced_json_object(text: &str) -> Option<&str> {
310 find_balanced(text, '{', '}')
311}
312
313fn find_balanced_json_array(text: &str) -> Option<&str> {
315 find_balanced(text, '[', ']')
316}
317
318fn find_balanced(text: &str, open: char, close: char) -> Option<&str> {
319 let bytes = text.as_bytes();
320 let open_byte = open as u8;
321 let close_byte = close as u8;
322
323 let mut in_string = false;
325 let mut escape_next = false;
326 let mut start = None;
327
328 for (i, &b) in bytes.iter().enumerate() {
329 if escape_next {
330 escape_next = false;
331 continue;
332 }
333 match b {
334 b'\\' if in_string => escape_next = true,
335 b'"' => in_string = !in_string,
336 _ if in_string => {}
337 _ if b == open_byte => {
338 start = Some(i);
339 break;
340 }
341 _ => {}
342 }
343 }
344
345 let start = start?;
346 let mut depth = 0i32;
347 in_string = false;
348 escape_next = false;
349
350 for (i, &b) in bytes[start..].iter().enumerate() {
351 if escape_next {
352 escape_next = false;
353 continue;
354 }
355 match b {
356 b'\\' if in_string => escape_next = true,
357 b'"' => in_string = !in_string,
358 _ if in_string => {}
359 _ if b == open_byte => depth += 1,
360 _ if b == close_byte => {
361 depth -= 1;
362 if depth == 0 {
363 return Some(&text[start..start + i + 1]);
364 }
365 }
366 _ => {}
367 }
368 }
369 None
370}
371
372fn find_json_start(text: &str) -> Option<usize> {
375 let (search_text, offset) = if let Some(rest) = text.strip_prefix("```json") {
377 (rest, 7)
378 } else if let Some(rest) = text.strip_prefix("```") {
379 (rest, 3)
380 } else {
381 (text, 0)
382 };
383
384 let mut in_string = false;
385 let mut escape_next = false;
386 for (i, &b) in search_text.as_bytes().iter().enumerate() {
387 if escape_next {
388 escape_next = false;
389 continue;
390 }
391 match b {
392 b'\\' if in_string => {
393 escape_next = true;
394 }
395 b'"' => {
396 in_string = !in_string;
397 }
398 b'{' | b'[' if !in_string => {
399 return Some(offset + i);
400 }
401 _ => {}
402 }
403 }
404 None
405}
406
407fn try_parse_partial_json(text: &str) -> Option<Value> {
418 let trimmed = text.trim();
419 if trimmed.is_empty() {
420 return None;
421 }
422
423 if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
425 if v.is_object() || v.is_array() {
426 return Some(v);
427 }
428 }
429
430 let mut closers = Vec::new();
432 let mut in_string = false;
433 let mut escape_next = false;
434 let mut last_significant: Option<u8> = None;
436
437 for &b in trimmed.as_bytes() {
438 if escape_next {
439 escape_next = false;
440 continue;
441 }
442 match b {
443 b'\\' if in_string => {
444 escape_next = true;
445 }
446 b'"' => {
447 in_string = !in_string;
448 if !in_string {
449 last_significant = Some(b'"');
450 }
451 }
452 _ if in_string => {}
453 b'{' => {
454 closers.push(b'}');
455 last_significant = Some(b'{');
456 }
457 b'[' => {
458 closers.push(b']');
459 last_significant = Some(b'[');
460 }
461 b'}' | b']' => {
462 closers.pop();
463 last_significant = Some(b);
464 }
465 b':' | b',' => {
466 last_significant = Some(b);
467 }
468 b if !b.is_ascii_whitespace() => {
469 last_significant = Some(b);
470 }
471 _ => {}
472 }
473 }
474
475 if closers.is_empty() {
476 return None; }
478
479 let mut repaired = String::with_capacity(trimmed.len() + closers.len() + 6);
481 repaired.push_str(trimmed);
482
483 if in_string {
484 repaired.push('"');
485 last_significant = Some(b'"');
486 }
487
488 if let Some(last) = last_significant {
490 if last == b':' {
491 repaired.push_str("null");
493 } else if last == b',' {
494 if let Some(pos) = repaired.rfind(',') {
496 repaired.truncate(pos);
497 }
498 }
499 }
500
501 for &closer in closers.iter().rev() {
503 repaired.push(closer as char);
504 }
505
506 serde_json::from_str::<Value>(&repaired)
507 .ok()
508 .filter(|v| v.is_object() || v.is_array())
509}
510
511fn validate_against_schema(value: &Value, schema: &Value) -> Result<(), Vec<String>> {
518 let errors = basic_schema_validate(value, schema, "");
522 if errors.is_empty() {
523 Ok(())
524 } else {
525 Err(errors)
526 }
527}
528
529fn basic_schema_validate(value: &Value, schema: &Value, path: &str) -> Vec<String> {
531 let mut errors = Vec::new();
532
533 if schema.get("$ref").is_some() {
535 return errors;
536 }
537
538 if let Some(any_of) = schema
540 .get("anyOf")
541 .or_else(|| schema.get("oneOf"))
542 .and_then(|v| v.as_array())
543 {
544 let matched = any_of
545 .iter()
546 .any(|sub| basic_schema_validate(value, sub, path).is_empty());
547 if !matched {
548 errors.push(format!(
549 "{}: value does not match any variant in anyOf/oneOf",
550 path_or_root(path),
551 ));
552 }
553 return errors;
554 }
555
556 if let Some(enum_values) = schema.get("enum").and_then(|v| v.as_array()) {
558 if !enum_values.contains(value) {
559 errors.push(format!(
560 "{}: value {:?} not in enum {:?}",
561 path_or_root(path),
562 value,
563 enum_values
564 ));
565 }
566 return errors;
567 }
568
569 if let Some(const_val) = schema.get("const") {
571 if value != const_val {
572 errors.push(format!(
573 "{}: expected const {:?}, got {:?}",
574 path_or_root(path),
575 const_val,
576 value
577 ));
578 }
579 return errors;
580 }
581
582 if let Some(type_val) = schema.get("type") {
584 let type_ok = if let Some(type_str) = type_val.as_str() {
585 check_type(value, type_str)
586 } else if let Some(type_arr) = type_val.as_array() {
587 type_arr
588 .iter()
589 .filter_map(|t| t.as_str())
590 .any(|t| check_type(value, t))
591 } else {
592 true
593 };
594 if !type_ok {
595 errors.push(format!(
596 "{}: expected type {:?}, got {:?}",
597 path_or_root(path),
598 type_val,
599 value_type_name(value)
600 ));
601 return errors;
602 }
603 }
604
605 if let Some(obj) = value.as_object() {
607 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
608 for (key, prop_schema) in properties {
609 if let Some(child_value) = obj.get(key) {
610 let child_path = if path.is_empty() {
611 format!(".{}", key)
612 } else {
613 format!("{}.{}", path, key)
614 };
615 errors.extend(basic_schema_validate(child_value, prop_schema, &child_path));
616 }
617 }
618 }
619
620 if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
621 for req_field in required {
622 if let Some(field_name) = req_field.as_str() {
623 if !obj.contains_key(field_name) {
624 errors.push(format!(
625 "{}: missing required field '{}'",
626 path_or_root(path),
627 field_name
628 ));
629 }
630 }
631 }
632 }
633
634 if schema.get("additionalProperties") == Some(&Value::Bool(false)) {
636 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
637 for key in obj.keys() {
638 if !properties.contains_key(key) {
639 errors.push(format!(
640 "{}: unexpected additional property '{}'",
641 path_or_root(path),
642 key
643 ));
644 }
645 }
646 }
647 }
648 }
649
650 if let Some(arr) = value.as_array() {
652 if let Some(items_schema) = schema.get("items") {
653 for (i, item) in arr.iter().enumerate() {
654 let child_path = format!("{}[{}]", path, i);
655 errors.extend(basic_schema_validate(item, items_schema, &child_path));
656 }
657 }
658 if let Some(min) = schema.get("minItems").and_then(|v| v.as_u64()) {
659 if (arr.len() as u64) < min {
660 errors.push(format!(
661 "{}: array has {} items, minimum is {}",
662 path_or_root(path),
663 arr.len(),
664 min
665 ));
666 }
667 }
668 if let Some(max) = schema.get("maxItems").and_then(|v| v.as_u64()) {
669 if (arr.len() as u64) > max {
670 errors.push(format!(
671 "{}: array has {} items, maximum is {}",
672 path_or_root(path),
673 arr.len(),
674 max
675 ));
676 }
677 }
678 }
679
680 if let Some(s) = value.as_str() {
682 if let Some(min_len) = schema.get("minLength").and_then(|v| v.as_u64()) {
683 if (s.chars().count() as u64) < min_len {
684 errors.push(format!(
685 "{}: string length {} < minLength {}",
686 path_or_root(path),
687 s.chars().count(),
688 min_len
689 ));
690 }
691 }
692 if let Some(max_len) = schema.get("maxLength").and_then(|v| v.as_u64()) {
693 if (s.chars().count() as u64) > max_len {
694 errors.push(format!(
695 "{}: string length {} > maxLength {}",
696 path_or_root(path),
697 s.chars().count(),
698 max_len
699 ));
700 }
701 }
702 if let Some(pattern) = schema.get("pattern").and_then(|v| v.as_str()) {
703 if let Ok(re) = regex::Regex::new(pattern) {
704 if !re.is_match(s) {
705 errors.push(format!(
706 "{}: string does not match pattern '{}'",
707 path_or_root(path),
708 pattern
709 ));
710 }
711 }
712 }
713 }
714
715 if let Some(n) = value.as_f64() {
717 if let Some(min) = schema.get("minimum").and_then(|v| v.as_f64()) {
718 if n < min {
719 errors.push(format!(
720 "{}: value {} < minimum {}",
721 path_or_root(path),
722 n,
723 min
724 ));
725 }
726 }
727 if let Some(max) = schema.get("maximum").and_then(|v| v.as_f64()) {
728 if n > max {
729 errors.push(format!(
730 "{}: value {} > maximum {}",
731 path_or_root(path),
732 n,
733 max
734 ));
735 }
736 }
737 if let Some(exc_min) = schema.get("exclusiveMinimum").and_then(|v| v.as_f64()) {
738 if n <= exc_min {
739 errors.push(format!(
740 "{}: value {} <= exclusiveMinimum {}",
741 path_or_root(path),
742 n,
743 exc_min
744 ));
745 }
746 }
747 if let Some(exc_max) = schema.get("exclusiveMaximum").and_then(|v| v.as_f64()) {
748 if n >= exc_max {
749 errors.push(format!(
750 "{}: value {} >= exclusiveMaximum {}",
751 path_or_root(path),
752 n,
753 exc_max
754 ));
755 }
756 }
757 }
758
759 errors
760}
761
762fn check_type(value: &Value, type_str: &str) -> bool {
763 match type_str {
764 "object" => value.is_object(),
765 "array" => value.is_array(),
766 "string" => value.is_string(),
767 "number" => value.is_number(),
768 "integer" => {
769 value.is_i64()
770 || value.is_u64()
771 || value
772 .as_f64()
773 .map(|f| f.fract() == 0.0 && f.is_finite())
774 .unwrap_or(false)
775 }
776 "boolean" => value.is_boolean(),
777 "null" => value.is_null(),
778 _ => true,
779 }
780}
781
782fn path_or_root(path: &str) -> &str {
783 if path.is_empty() {
784 "$"
785 } else {
786 path
787 }
788}
789
790fn value_type_name(value: &Value) -> &'static str {
791 match value {
792 Value::Null => "null",
793 Value::Bool(_) => "boolean",
794 Value::Number(_) => "number",
795 Value::String(_) => "string",
796 Value::Array(_) => "array",
797 Value::Object(_) => "object",
798 }
799}
800
801fn build_initial_messages(req: &StructuredRequest, mode: StructuredMode) -> Vec<Message> {
806 match mode {
807 StructuredMode::Tool => {
808 vec![Message::user(&req.prompt)]
811 }
812 StructuredMode::Prompt => {
813 let augmented = format!(
815 "{}\n\nYou MUST respond with ONLY a valid JSON object (no markdown, no explanation) that conforms to this JSON Schema:\n\n```json\n{}\n```",
816 req.prompt,
817 serde_json::to_string_pretty(&req.schema).unwrap_or_default()
818 );
819 vec![Message::user(&augmented)]
820 }
821 _ => {
822 vec![Message::user(&req.prompt)]
825 }
826 }
827}
828
829fn build_system_prompt(req: &StructuredRequest, mode: StructuredMode) -> String {
830 let base = req.system.as_deref().unwrap_or("");
831
832 match mode {
833 StructuredMode::Tool => {
834 format!(
835 "{}{}You MUST respond by calling the `emit_{}` tool exactly once with a valid argument matching the schema. Do not output any text outside the tool call.",
836 base,
837 if base.is_empty() { "" } else { "\n\n" },
838 req.schema_name
839 )
840 }
841 StructuredMode::Prompt => {
842 format!(
843 "{}{}You are a structured data extraction assistant. Always respond with valid JSON only, no markdown fences, no explanation text.",
844 base,
845 if base.is_empty() { "" } else { "\n\n" },
846 )
847 }
848 _ => base.to_string(),
849 }
850}
851
852fn build_tools(req: &StructuredRequest, mode: StructuredMode) -> Vec<ToolDefinition> {
853 match mode {
854 StructuredMode::Tool => {
855 vec![ToolDefinition {
856 name: format!("emit_{}", req.schema_name),
857 description: req
858 .schema_description
859 .clone()
860 .unwrap_or_else(|| format!("Emit a structured {} object", req.schema_name)),
861 parameters: req.schema.clone(),
862 }]
863 }
864 _ => vec![],
865 }
866}
867
868fn extract_raw_output(message: &super::Message, mode: StructuredMode) -> String {
870 match mode {
871 StructuredMode::Tool => {
872 let calls = message.tool_calls();
874 if let Some(call) = calls.first() {
875 serde_json::to_string(&call.args).unwrap_or_default()
876 } else {
877 message.text()
879 }
880 }
881 _ => message.text(),
882 }
883}
884
885fn build_repair_message(raw_text: &str, errors: &[String]) -> String {
886 let truncated_raw = if raw_text.len() > 2000 {
888 format!(
889 "{}...[truncated, {} bytes total]",
890 &raw_text[..2000],
891 raw_text.len()
892 )
893 } else {
894 raw_text.to_string()
895 };
896 format!(
897 "Your previous output failed schema validation:\n\n{}\n\nValidation errors:\n{}\n\nPlease return ONLY a corrected JSON object that fixes these errors. No explanation, no markdown.",
898 truncated_raw,
899 errors.iter().map(|e| format!("- {}", e)).collect::<Vec<_>>().join("\n")
900 )
901}
902
903fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
904 total.prompt_tokens += delta.prompt_tokens;
905 total.completion_tokens += delta.completion_tokens;
906 total.total_tokens += delta.total_tokens;
907}
908
909fn append_repair_context(
916 messages: &mut Vec<Message>,
917 assistant_msg: &Message,
918 repair_text: &str,
919 mode: StructuredMode,
920 _raw_text: &str,
921) {
922 if mode == StructuredMode::Tool {
923 messages.push(assistant_msg.clone());
925 let tool_use_id = assistant_msg
927 .tool_calls()
928 .first()
929 .map(|tc| tc.id.clone())
930 .unwrap_or_else(|| "unknown".to_string());
931 messages.push(Message::tool_result(&tool_use_id, repair_text, true));
933 } else {
934 messages.push(assistant_msg.clone());
936 messages.push(Message::user(repair_text));
937 }
938}
939
940#[cfg(test)]
945#[path = "structured_tests.rs"]
946mod structured_tests;