use super::{LlmClient, Message, StreamEvent, TokenUsage, ToolDefinition};
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StructuredMode {
Auto,
Strict,
Json,
Tool,
Prompt,
}
#[derive(Debug, Clone)]
pub struct StructuredRequest {
pub prompt: String,
pub system: Option<String>,
pub schema: Value,
pub schema_name: String,
pub schema_description: Option<String>,
pub mode: StructuredMode,
pub max_repair_attempts: u8,
}
#[derive(Debug, Clone, Serialize)]
pub struct StructuredResult {
pub object: Value,
pub raw_text: Option<String>,
pub usage: TokenUsage,
pub repair_rounds: u8,
pub mode_used: StructuredMode,
}
pub type PartialObjectCallback = Box<dyn Fn(&Value) + Send>;
pub async fn generate_blocking(
client: &dyn LlmClient,
req: &StructuredRequest,
) -> Result<StructuredResult> {
let mode = req.mode;
let mut messages = build_initial_messages(req, mode);
let system = build_system_prompt(req, mode);
let tools = build_tools(req, mode);
let mut total_usage = TokenUsage::default();
let mut repair_rounds: u8 = 0;
loop {
let resp = client
.complete(&messages, Some(&system), &tools)
.await
.context("LLM call failed during structured generation")?;
accumulate_usage(&mut total_usage, &resp.usage);
let raw_text = extract_raw_output(&resp.message, mode);
let parsed = extract_json_value(&raw_text);
match parsed {
Ok(value) => match validate_against_schema(&value, &req.schema) {
Ok(()) => {
return Ok(StructuredResult {
object: value,
raw_text: Some(raw_text),
usage: total_usage,
repair_rounds,
mode_used: mode,
});
}
Err(errors) if repair_rounds < req.max_repair_attempts => {
repair_rounds += 1;
let repair_msg = build_repair_message(&raw_text, &errors);
append_repair_context(
&mut messages,
&resp.message,
&repair_msg,
mode,
&raw_text,
);
}
Err(errors) => {
bail!(
"Structured output failed schema validation after {} repair attempts. Errors: {}",
repair_rounds,
errors.join("; ")
);
}
},
Err(parse_err) if repair_rounds < req.max_repair_attempts => {
repair_rounds += 1;
let repair_msg = format!(
"Your previous output could not be parsed as JSON:\n\n{}\n\nError: {}\n\nPlease return ONLY a valid JSON object matching the schema.",
raw_text, parse_err
);
append_repair_context(&mut messages, &resp.message, &repair_msg, mode, &raw_text);
}
Err(parse_err) => {
bail!(
"Structured output failed JSON parsing after {} repair attempts: {}",
repair_rounds,
parse_err
);
}
}
}
}
pub async fn generate_streaming(
client: &dyn LlmClient,
req: &StructuredRequest,
on_partial: PartialObjectCallback,
) -> Result<StructuredResult> {
let mode = req.mode;
let messages = build_initial_messages(req, mode);
let system = build_system_prompt(req, mode);
let tools = build_tools(req, mode);
let cancel_token = CancellationToken::new();
let mut rx = client
.complete_streaming(&messages, Some(&system), &tools, cancel_token)
.await
.context("LLM streaming call failed during structured generation")?;
let mut json_buffer = String::new();
let mut last_valid_partial: Option<Value> = None;
let mut final_response: Option<super::LlmResponse> = None;
let mut last_parse_len: usize = 0;
const PARSE_THRESHOLD: usize = 8;
while let Some(event) = rx.recv().await {
match event {
StreamEvent::ToolUseInputDelta(delta) if mode == StructuredMode::Tool => {
if final_response.is_some() {
continue;
}
json_buffer.push_str(&delta);
if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
if let Some(partial) = try_parse_partial_json(&json_buffer) {
if last_valid_partial.as_ref() != Some(&partial) {
on_partial(&partial);
last_valid_partial = Some(partial);
}
}
last_parse_len = json_buffer.len();
}
}
StreamEvent::TextDelta(delta) if mode != StructuredMode::Tool => {
if final_response.is_some() {
continue;
}
json_buffer.push_str(&delta);
if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
if let Some(json_start) = find_json_start(&json_buffer) {
let candidate = &json_buffer[json_start..];
if let Some(partial) = try_parse_partial_json(candidate) {
if last_valid_partial.as_ref() != Some(&partial) {
on_partial(&partial);
last_valid_partial = Some(partial);
}
}
}
last_parse_len = json_buffer.len();
}
}
StreamEvent::Done(resp) => {
final_response = Some(resp);
}
_ => {}
}
}
let resp = final_response.context("Stream ended without Done event")?;
let raw_text = extract_raw_output(&resp.message, mode);
let value =
extract_json_value(&raw_text).context("Failed to parse final streamed output as JSON")?;
validate_against_schema(&value, &req.schema).map_err(|errors| {
anyhow::anyhow!(
"Streamed structured output failed schema validation: {}",
errors.join("; ")
)
})?;
on_partial(&value);
Ok(StructuredResult {
object: value,
raw_text: Some(raw_text),
usage: resp.usage,
repair_rounds: 0,
mode_used: mode,
})
}
pub fn extract_json_value(text: &str) -> Result<Value> {
let trimmed = text.trim();
if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
if v.is_object() || v.is_array() {
return Ok(v);
}
}
if let Some(inner) = strip_code_fence(trimmed) {
if let Ok(v) = serde_json::from_str::<Value>(inner.trim()) {
if v.is_object() || v.is_array() {
return Ok(v);
}
}
}
if let Some(candidate) = find_balanced_json_object(trimmed) {
if let Ok(v) = serde_json::from_str::<Value>(candidate) {
return Ok(v);
}
}
if let Some(candidate) = find_balanced_json_array(trimmed) {
if let Ok(v) = serde_json::from_str::<Value>(candidate) {
return Ok(v);
}
}
bail!("No valid JSON object found in LLM output")
}
fn strip_code_fence(text: &str) -> Option<&str> {
let start_patterns = ["```json\n", "```json\r\n", "```\n", "```\r\n"];
for pat in &start_patterns {
if let Some(rest) = text.strip_prefix(pat) {
if let Some(end) = rest.rfind("```") {
return Some(&rest[..end]);
}
}
}
if let Some(inner) = text.strip_prefix("```json") {
if let Some(end) = inner.rfind("```") {
return Some(inner[..end].trim());
}
}
if let Some(inner) = text.strip_prefix("```") {
if let Some(end) = inner.rfind("```") {
return Some(inner[..end].trim());
}
}
None
}
fn find_balanced_json_object(text: &str) -> Option<&str> {
find_balanced(text, '{', '}')
}
fn find_balanced_json_array(text: &str) -> Option<&str> {
find_balanced(text, '[', ']')
}
fn find_balanced(text: &str, open: char, close: char) -> Option<&str> {
let bytes = text.as_bytes();
let open_byte = open as u8;
let close_byte = close as u8;
let mut in_string = false;
let mut escape_next = false;
let mut start = None;
for (i, &b) in bytes.iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match b {
b'\\' if in_string => escape_next = true,
b'"' => in_string = !in_string,
_ if in_string => {}
_ if b == open_byte => {
start = Some(i);
break;
}
_ => {}
}
}
let start = start?;
let mut depth = 0i32;
in_string = false;
escape_next = false;
for (i, &b) in bytes[start..].iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match b {
b'\\' if in_string => escape_next = true,
b'"' => in_string = !in_string,
_ if in_string => {}
_ if b == open_byte => depth += 1,
_ if b == close_byte => {
depth -= 1;
if depth == 0 {
return Some(&text[start..start + i + 1]);
}
}
_ => {}
}
}
None
}
fn find_json_start(text: &str) -> Option<usize> {
let (search_text, offset) = if let Some(rest) = text.strip_prefix("```json") {
(rest, 7)
} else if let Some(rest) = text.strip_prefix("```") {
(rest, 3)
} else {
(text, 0)
};
let mut in_string = false;
let mut escape_next = false;
for (i, &b) in search_text.as_bytes().iter().enumerate() {
if escape_next {
escape_next = false;
continue;
}
match b {
b'\\' if in_string => {
escape_next = true;
}
b'"' => {
in_string = !in_string;
}
b'{' | b'[' if !in_string => {
return Some(offset + i);
}
_ => {}
}
}
None
}
fn try_parse_partial_json(text: &str) -> Option<Value> {
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
if v.is_object() || v.is_array() {
return Some(v);
}
}
let mut closers = Vec::new();
let mut in_string = false;
let mut escape_next = false;
let mut last_significant: Option<u8> = None;
for &b in trimmed.as_bytes() {
if escape_next {
escape_next = false;
continue;
}
match b {
b'\\' if in_string => {
escape_next = true;
}
b'"' => {
in_string = !in_string;
if !in_string {
last_significant = Some(b'"');
}
}
_ if in_string => {}
b'{' => {
closers.push(b'}');
last_significant = Some(b'{');
}
b'[' => {
closers.push(b']');
last_significant = Some(b'[');
}
b'}' | b']' => {
closers.pop();
last_significant = Some(b);
}
b':' | b',' => {
last_significant = Some(b);
}
b if !b.is_ascii_whitespace() => {
last_significant = Some(b);
}
_ => {}
}
}
if closers.is_empty() {
return None; }
let mut repaired = String::with_capacity(trimmed.len() + closers.len() + 6);
repaired.push_str(trimmed);
if in_string {
repaired.push('"');
last_significant = Some(b'"');
}
if let Some(last) = last_significant {
if last == b':' {
repaired.push_str("null");
} else if last == b',' {
if let Some(pos) = repaired.rfind(',') {
repaired.truncate(pos);
}
}
}
for &closer in closers.iter().rev() {
repaired.push(closer as char);
}
serde_json::from_str::<Value>(&repaired)
.ok()
.filter(|v| v.is_object() || v.is_array())
}
fn validate_against_schema(value: &Value, schema: &Value) -> Result<(), Vec<String>> {
let errors = basic_schema_validate(value, schema, "");
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
fn basic_schema_validate(value: &Value, schema: &Value, path: &str) -> Vec<String> {
let mut errors = Vec::new();
if schema.get("$ref").is_some() {
return errors;
}
if let Some(any_of) = schema
.get("anyOf")
.or_else(|| schema.get("oneOf"))
.and_then(|v| v.as_array())
{
let matched = any_of
.iter()
.any(|sub| basic_schema_validate(value, sub, path).is_empty());
if !matched {
errors.push(format!(
"{}: value does not match any variant in anyOf/oneOf",
path_or_root(path),
));
}
return errors;
}
if let Some(enum_values) = schema.get("enum").and_then(|v| v.as_array()) {
if !enum_values.contains(value) {
errors.push(format!(
"{}: value {:?} not in enum {:?}",
path_or_root(path),
value,
enum_values
));
}
return errors;
}
if let Some(const_val) = schema.get("const") {
if value != const_val {
errors.push(format!(
"{}: expected const {:?}, got {:?}",
path_or_root(path),
const_val,
value
));
}
return errors;
}
if let Some(type_val) = schema.get("type") {
let type_ok = if let Some(type_str) = type_val.as_str() {
check_type(value, type_str)
} else if let Some(type_arr) = type_val.as_array() {
type_arr
.iter()
.filter_map(|t| t.as_str())
.any(|t| check_type(value, t))
} else {
true
};
if !type_ok {
errors.push(format!(
"{}: expected type {:?}, got {:?}",
path_or_root(path),
type_val,
value_type_name(value)
));
return errors;
}
}
if let Some(obj) = value.as_object() {
if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
for (key, prop_schema) in properties {
if let Some(child_value) = obj.get(key) {
let child_path = if path.is_empty() {
format!(".{}", key)
} else {
format!("{}.{}", path, key)
};
errors.extend(basic_schema_validate(child_value, prop_schema, &child_path));
}
}
}
if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
for req_field in required {
if let Some(field_name) = req_field.as_str() {
if !obj.contains_key(field_name) {
errors.push(format!(
"{}: missing required field '{}'",
path_or_root(path),
field_name
));
}
}
}
}
if schema.get("additionalProperties") == Some(&Value::Bool(false)) {
if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
for key in obj.keys() {
if !properties.contains_key(key) {
errors.push(format!(
"{}: unexpected additional property '{}'",
path_or_root(path),
key
));
}
}
}
}
}
if let Some(arr) = value.as_array() {
if let Some(items_schema) = schema.get("items") {
for (i, item) in arr.iter().enumerate() {
let child_path = format!("{}[{}]", path, i);
errors.extend(basic_schema_validate(item, items_schema, &child_path));
}
}
if let Some(min) = schema.get("minItems").and_then(|v| v.as_u64()) {
if (arr.len() as u64) < min {
errors.push(format!(
"{}: array has {} items, minimum is {}",
path_or_root(path),
arr.len(),
min
));
}
}
if let Some(max) = schema.get("maxItems").and_then(|v| v.as_u64()) {
if (arr.len() as u64) > max {
errors.push(format!(
"{}: array has {} items, maximum is {}",
path_or_root(path),
arr.len(),
max
));
}
}
}
if let Some(s) = value.as_str() {
if let Some(min_len) = schema.get("minLength").and_then(|v| v.as_u64()) {
if (s.chars().count() as u64) < min_len {
errors.push(format!(
"{}: string length {} < minLength {}",
path_or_root(path),
s.chars().count(),
min_len
));
}
}
if let Some(max_len) = schema.get("maxLength").and_then(|v| v.as_u64()) {
if (s.chars().count() as u64) > max_len {
errors.push(format!(
"{}: string length {} > maxLength {}",
path_or_root(path),
s.chars().count(),
max_len
));
}
}
if let Some(pattern) = schema.get("pattern").and_then(|v| v.as_str()) {
if let Ok(re) = regex::Regex::new(pattern) {
if !re.is_match(s) {
errors.push(format!(
"{}: string does not match pattern '{}'",
path_or_root(path),
pattern
));
}
}
}
}
if let Some(n) = value.as_f64() {
if let Some(min) = schema.get("minimum").and_then(|v| v.as_f64()) {
if n < min {
errors.push(format!(
"{}: value {} < minimum {}",
path_or_root(path),
n,
min
));
}
}
if let Some(max) = schema.get("maximum").and_then(|v| v.as_f64()) {
if n > max {
errors.push(format!(
"{}: value {} > maximum {}",
path_or_root(path),
n,
max
));
}
}
if let Some(exc_min) = schema.get("exclusiveMinimum").and_then(|v| v.as_f64()) {
if n <= exc_min {
errors.push(format!(
"{}: value {} <= exclusiveMinimum {}",
path_or_root(path),
n,
exc_min
));
}
}
if let Some(exc_max) = schema.get("exclusiveMaximum").and_then(|v| v.as_f64()) {
if n >= exc_max {
errors.push(format!(
"{}: value {} >= exclusiveMaximum {}",
path_or_root(path),
n,
exc_max
));
}
}
}
errors
}
fn check_type(value: &Value, type_str: &str) -> bool {
match type_str {
"object" => value.is_object(),
"array" => value.is_array(),
"string" => value.is_string(),
"number" => value.is_number(),
"integer" => {
value.is_i64()
|| value.is_u64()
|| value
.as_f64()
.map(|f| f.fract() == 0.0 && f.is_finite())
.unwrap_or(false)
}
"boolean" => value.is_boolean(),
"null" => value.is_null(),
_ => true,
}
}
fn path_or_root(path: &str) -> &str {
if path.is_empty() {
"$"
} else {
path
}
}
fn value_type_name(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "boolean",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
fn build_initial_messages(req: &StructuredRequest, mode: StructuredMode) -> Vec<Message> {
match mode {
StructuredMode::Tool => {
vec![Message::user(&req.prompt)]
}
StructuredMode::Prompt => {
let augmented = format!(
"{}\n\nYou MUST respond with ONLY a valid JSON object (no markdown, no explanation) that conforms to this JSON Schema:\n\n```json\n{}\n```",
req.prompt,
serde_json::to_string_pretty(&req.schema).unwrap_or_default()
);
vec![Message::user(&augmented)]
}
_ => {
vec![Message::user(&req.prompt)]
}
}
}
fn build_system_prompt(req: &StructuredRequest, mode: StructuredMode) -> String {
let base = req.system.as_deref().unwrap_or("");
match mode {
StructuredMode::Tool => {
format!(
"{}{}You MUST respond by calling the `emit_{}` tool exactly once with a valid argument matching the schema. Do not output any text outside the tool call.",
base,
if base.is_empty() { "" } else { "\n\n" },
req.schema_name
)
}
StructuredMode::Prompt => {
format!(
"{}{}You are a structured data extraction assistant. Always respond with valid JSON only, no markdown fences, no explanation text.",
base,
if base.is_empty() { "" } else { "\n\n" },
)
}
_ => base.to_string(),
}
}
fn build_tools(req: &StructuredRequest, mode: StructuredMode) -> Vec<ToolDefinition> {
match mode {
StructuredMode::Tool => {
vec![ToolDefinition {
name: format!("emit_{}", req.schema_name),
description: req
.schema_description
.clone()
.unwrap_or_else(|| format!("Emit a structured {} object", req.schema_name)),
parameters: req.schema.clone(),
}]
}
_ => vec![],
}
}
fn extract_raw_output(message: &super::Message, mode: StructuredMode) -> String {
match mode {
StructuredMode::Tool => {
let calls = message.tool_calls();
if let Some(call) = calls.first() {
serde_json::to_string(&call.args).unwrap_or_default()
} else {
message.text()
}
}
_ => message.text(),
}
}
fn build_repair_message(raw_text: &str, errors: &[String]) -> String {
let truncated_raw = if raw_text.len() > 2000 {
format!(
"{}...[truncated, {} bytes total]",
&raw_text[..2000],
raw_text.len()
)
} else {
raw_text.to_string()
};
format!(
"Your previous output failed schema validation:\n\n{}\n\nValidation errors:\n{}\n\nPlease return ONLY a corrected JSON object that fixes these errors. No explanation, no markdown.",
truncated_raw,
errors.iter().map(|e| format!("- {}", e)).collect::<Vec<_>>().join("\n")
)
}
fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
total.prompt_tokens += delta.prompt_tokens;
total.completion_tokens += delta.completion_tokens;
total.total_tokens += delta.total_tokens;
}
fn append_repair_context(
messages: &mut Vec<Message>,
assistant_msg: &Message,
repair_text: &str,
mode: StructuredMode,
_raw_text: &str,
) {
if mode == StructuredMode::Tool {
messages.push(assistant_msg.clone());
let tool_use_id = assistant_msg
.tool_calls()
.first()
.map(|tc| tc.id.clone())
.unwrap_or_else(|| "unknown".to_string());
messages.push(Message::tool_result(&tool_use_id, repair_text, true));
} else {
messages.push(assistant_msg.clone());
messages.push(Message::user(repair_text));
}
}
#[cfg(test)]
#[path = "structured_tests.rs"]
mod structured_tests;