Skip to main content

a3s_code_core/llm/
structured.rs

1//! Structured object generation from LLM output.
2//!
3//! Provides reliable JSON object generation with schema validation, automatic
4//! repair, and streaming partial object support. Works across all providers by
5//! selecting the best available mode (strict JSON schema, json_mode, tool-call,
6//! or prompt-only).
7
8use super::{LlmClient, Message, StreamEvent, TokenUsage, ToolDefinition};
9use anyhow::{bail, Context, Result};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use tokio_util::sync::CancellationToken;
13
14// ---------------------------------------------------------------------------
15// Public types
16// ---------------------------------------------------------------------------
17
18/// Mode selection for structured output generation.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum StructuredMode {
22    /// Auto-select best mode based on provider capabilities.
23    Auto,
24    /// OpenAI native strict JSON schema (response_format.type = json_schema).
25    Strict,
26    /// OpenAI json_object mode (guarantees valid JSON, not schema-conformant).
27    Json,
28    /// Use tool-calling: inject a synthetic tool whose parameters IS the schema.
29    /// Works on all providers that support tool use (Anthropic, OpenAI, etc).
30    Tool,
31    /// Prompt-only: append schema instructions to the prompt. Least reliable.
32    Prompt,
33}
34
35/// Request specification for structured object generation.
36#[derive(Debug, Clone)]
37pub struct StructuredRequest {
38    pub prompt: String,
39    pub system: Option<String>,
40    pub schema: Value,
41    pub schema_name: String,
42    pub schema_description: Option<String>,
43    pub mode: StructuredMode,
44    pub max_repair_attempts: u8,
45}
46
47/// Result of a successful structured generation.
48#[derive(Debug, Clone, Serialize)]
49pub struct StructuredResult {
50    pub object: Value,
51    pub raw_text: Option<String>,
52    pub usage: TokenUsage,
53    pub repair_rounds: u8,
54    pub mode_used: StructuredMode,
55}
56
57/// Callback for streaming partial object snapshots.
58pub type PartialObjectCallback = Box<dyn Fn(&Value) + Send>;
59
60// ---------------------------------------------------------------------------
61// Core generation: blocking (non-streaming)
62// ---------------------------------------------------------------------------
63
64/// Generate a structured JSON object using the given LLM client.
65///
66/// Selects the best mode based on `req.mode`, calls the LLM, validates against
67/// the schema, and retries with repair prompts if validation fails.
68pub async fn generate_blocking(
69    client: &dyn LlmClient,
70    req: &StructuredRequest,
71) -> Result<StructuredResult> {
72    let mode = req.mode;
73    let mut messages = build_initial_messages(req, mode);
74    let system = build_system_prompt(req, mode);
75    let tools = build_tools(req, mode);
76
77    let mut total_usage = TokenUsage::default();
78    let mut repair_rounds: u8 = 0;
79
80    loop {
81        let resp = client
82            .complete(&messages, Some(&system), &tools)
83            .await
84            .context("LLM call failed during structured generation")?;
85
86        accumulate_usage(&mut total_usage, &resp.usage);
87
88        let raw_text = extract_raw_output(&resp.message, mode);
89        let parsed = extract_json_value(&raw_text);
90
91        match parsed {
92            Ok(value) => match validate_against_schema(&value, &req.schema) {
93                Ok(()) => {
94                    return Ok(StructuredResult {
95                        object: value,
96                        raw_text: Some(raw_text),
97                        usage: total_usage,
98                        repair_rounds,
99                        mode_used: mode,
100                    });
101                }
102                Err(errors) if repair_rounds < req.max_repair_attempts => {
103                    repair_rounds += 1;
104                    let repair_msg = build_repair_message(&raw_text, &errors);
105                    append_repair_context(
106                        &mut messages,
107                        &resp.message,
108                        &repair_msg,
109                        mode,
110                        &raw_text,
111                    );
112                }
113                Err(errors) => {
114                    bail!(
115                            "Structured output failed schema validation after {} repair attempts. Errors: {}",
116                            repair_rounds,
117                            errors.join("; ")
118                        );
119                }
120            },
121            Err(parse_err) if repair_rounds < req.max_repair_attempts => {
122                repair_rounds += 1;
123                let repair_msg = format!(
124                    "Your previous output could not be parsed as JSON:\n\n{}\n\nError: {}\n\nPlease return ONLY a valid JSON object matching the schema.",
125                    raw_text, parse_err
126                );
127                append_repair_context(&mut messages, &resp.message, &repair_msg, mode, &raw_text);
128            }
129            Err(parse_err) => {
130                bail!(
131                    "Structured output failed JSON parsing after {} repair attempts: {}",
132                    repair_rounds,
133                    parse_err
134                );
135            }
136        }
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Core generation: streaming
142// ---------------------------------------------------------------------------
143
144/// Generate a structured JSON object with streaming partial updates.
145///
146/// Calls `on_partial` with progressively more complete partial objects as tokens
147/// arrive. Returns the final validated object.
148///
149/// In streaming mode, `max_repair_attempts` defaults to 0 because a repair
150/// would reset the partial object stream (confusing for consumers).
151pub async fn generate_streaming(
152    client: &dyn LlmClient,
153    req: &StructuredRequest,
154    on_partial: PartialObjectCallback,
155) -> Result<StructuredResult> {
156    let mode = req.mode;
157    let messages = build_initial_messages(req, mode);
158    let system = build_system_prompt(req, mode);
159    let tools = build_tools(req, mode);
160
161    let cancel_token = CancellationToken::new();
162    let mut rx = client
163        .complete_streaming(&messages, Some(&system), &tools, cancel_token)
164        .await
165        .context("LLM streaming call failed during structured generation")?;
166
167    let mut json_buffer = String::new();
168    let mut last_valid_partial: Option<Value> = None;
169    let mut final_response: Option<super::LlmResponse> = None;
170    let mut last_parse_len: usize = 0;
171    // Minimum bytes of new data before attempting a partial parse (reduces CPU)
172    const PARSE_THRESHOLD: usize = 8;
173
174    while let Some(event) = rx.recv().await {
175        match event {
176            StreamEvent::ToolUseInputDelta(delta) if mode == StructuredMode::Tool => {
177                if final_response.is_some() {
178                    continue;
179                }
180                json_buffer.push_str(&delta);
181                if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
182                    if let Some(partial) = try_parse_partial_json(&json_buffer) {
183                        if last_valid_partial.as_ref() != Some(&partial) {
184                            on_partial(&partial);
185                            last_valid_partial = Some(partial);
186                        }
187                    }
188                    last_parse_len = json_buffer.len();
189                }
190            }
191            StreamEvent::TextDelta(delta) if mode != StructuredMode::Tool => {
192                if final_response.is_some() {
193                    continue;
194                }
195                json_buffer.push_str(&delta);
196                if json_buffer.len() - last_parse_len >= PARSE_THRESHOLD {
197                    if let Some(json_start) = find_json_start(&json_buffer) {
198                        let candidate = &json_buffer[json_start..];
199                        if let Some(partial) = try_parse_partial_json(candidate) {
200                            if last_valid_partial.as_ref() != Some(&partial) {
201                                on_partial(&partial);
202                                last_valid_partial = Some(partial);
203                            }
204                        }
205                    }
206                    last_parse_len = json_buffer.len();
207                }
208            }
209            StreamEvent::Done(resp) => {
210                final_response = Some(resp);
211            }
212            _ => {}
213        }
214    }
215
216    let resp = final_response.context("Stream ended without Done event")?;
217    let raw_text = extract_raw_output(&resp.message, mode);
218    let value =
219        extract_json_value(&raw_text).context("Failed to parse final streamed output as JSON")?;
220
221    validate_against_schema(&value, &req.schema).map_err(|errors| {
222        anyhow::anyhow!(
223            "Streamed structured output failed schema validation: {}",
224            errors.join("; ")
225        )
226    })?;
227
228    // Emit final complete object
229    on_partial(&value);
230
231    Ok(StructuredResult {
232        object: value,
233        raw_text: Some(raw_text),
234        usage: resp.usage,
235        repair_rounds: 0,
236        mode_used: mode,
237    })
238}
239
240// ---------------------------------------------------------------------------
241// JSON extraction and parsing
242// ---------------------------------------------------------------------------
243
244/// Extract a JSON value from potentially dirty LLM output.
245///
246/// Handles: raw JSON, markdown code fences, leading/trailing prose.
247pub fn extract_json_value(text: &str) -> Result<Value> {
248    let trimmed = text.trim();
249
250    // 1. Direct parse
251    if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
252        if v.is_object() || v.is_array() {
253            return Ok(v);
254        }
255    }
256
257    // 2. Strip markdown code fence
258    if let Some(inner) = strip_code_fence(trimmed) {
259        if let Ok(v) = serde_json::from_str::<Value>(inner.trim()) {
260            if v.is_object() || v.is_array() {
261                return Ok(v);
262            }
263        }
264    }
265
266    // 3. Find balanced JSON substring (first { to matching })
267    if let Some(candidate) = find_balanced_json_object(trimmed) {
268        if let Ok(v) = serde_json::from_str::<Value>(candidate) {
269            return Ok(v);
270        }
271    }
272
273    // 4. Try array
274    if let Some(candidate) = find_balanced_json_array(trimmed) {
275        if let Ok(v) = serde_json::from_str::<Value>(candidate) {
276            return Ok(v);
277        }
278    }
279
280    bail!("No valid JSON object found in LLM output")
281}
282
283/// Strip ```json ... ``` or ``` ... ``` fences.
284fn strip_code_fence(text: &str) -> Option<&str> {
285    let start_patterns = ["```json\n", "```json\r\n", "```\n", "```\r\n"];
286    for pat in &start_patterns {
287        if let Some(rest) = text.strip_prefix(pat) {
288            // Find closing fence
289            if let Some(end) = rest.rfind("```") {
290                return Some(&rest[..end]);
291            }
292        }
293    }
294    // Also handle inline: ```json{...}```
295    if let Some(inner) = text.strip_prefix("```json") {
296        if let Some(end) = inner.rfind("```") {
297            return Some(inner[..end].trim());
298        }
299    }
300    if let Some(inner) = text.strip_prefix("```") {
301        if let Some(end) = inner.rfind("```") {
302            return Some(inner[..end].trim());
303        }
304    }
305    None
306}
307
308/// Find the first balanced `{...}` substring using bracket counting.
309fn find_balanced_json_object(text: &str) -> Option<&str> {
310    find_balanced(text, '{', '}')
311}
312
313/// Find the first balanced `[...]` substring.
314fn find_balanced_json_array(text: &str) -> Option<&str> {
315    find_balanced(text, '[', ']')
316}
317
318fn find_balanced(text: &str, open: char, close: char) -> Option<&str> {
319    let bytes = text.as_bytes();
320    let open_byte = open as u8;
321    let close_byte = close as u8;
322
323    // Find the first unquoted occurrence of `open`
324    let mut in_string = false;
325    let mut escape_next = false;
326    let mut start = None;
327
328    for (i, &b) in bytes.iter().enumerate() {
329        if escape_next {
330            escape_next = false;
331            continue;
332        }
333        match b {
334            b'\\' if in_string => escape_next = true,
335            b'"' => in_string = !in_string,
336            _ if in_string => {}
337            _ if b == open_byte => {
338                start = Some(i);
339                break;
340            }
341            _ => {}
342        }
343    }
344
345    let start = start?;
346    let mut depth = 0i32;
347    in_string = false;
348    escape_next = false;
349
350    for (i, &b) in bytes[start..].iter().enumerate() {
351        if escape_next {
352            escape_next = false;
353            continue;
354        }
355        match b {
356            b'\\' if in_string => escape_next = true,
357            b'"' => in_string = !in_string,
358            _ if in_string => {}
359            _ if b == open_byte => depth += 1,
360            _ if b == close_byte => {
361                depth -= 1;
362                if depth == 0 {
363                    return Some(&text[start..start + i + 1]);
364                }
365            }
366            _ => {}
367        }
368    }
369    None
370}
371
372/// Find the byte offset where JSON content starts in a text stream.
373/// Skips leading prose/whitespace to find `{` or `[` that isn't inside a string.
374fn find_json_start(text: &str) -> Option<usize> {
375    // Skip past code fence markers if present
376    let (search_text, offset) = if let Some(rest) = text.strip_prefix("```json") {
377        (rest, 7)
378    } else if let Some(rest) = text.strip_prefix("```") {
379        (rest, 3)
380    } else {
381        (text, 0)
382    };
383
384    let mut in_string = false;
385    let mut escape_next = false;
386    for (i, &b) in search_text.as_bytes().iter().enumerate() {
387        if escape_next {
388            escape_next = false;
389            continue;
390        }
391        match b {
392            b'\\' if in_string => {
393                escape_next = true;
394            }
395            b'"' => {
396                in_string = !in_string;
397            }
398            b'{' | b'[' if !in_string => {
399                return Some(offset + i);
400            }
401            _ => {}
402        }
403    }
404    None
405}
406
407// ---------------------------------------------------------------------------
408// Partial JSON parsing (for streaming)
409// ---------------------------------------------------------------------------
410
411/// Attempt to parse a potentially incomplete JSON string into the most complete
412/// valid partial object possible.
413///
414/// Strategy: try parsing as-is first. If that fails, progressively close open
415/// braces/brackets and try again. This handles the common case where the LLM
416/// has output `{"name": "foo", "items": [1, 2` — we close it to get a partial.
417fn try_parse_partial_json(text: &str) -> Option<Value> {
418    let trimmed = text.trim();
419    if trimmed.is_empty() {
420        return None;
421    }
422
423    // Fast path: already valid
424    if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
425        if v.is_object() || v.is_array() {
426            return Some(v);
427        }
428    }
429
430    // Count unclosed brackets/braces (respecting strings)
431    let mut closers = Vec::new();
432    let mut in_string = false;
433    let mut escape_next = false;
434    // Track if we're mid-value (after a colon or comma, before the value is complete)
435    let mut last_significant: Option<u8> = None;
436
437    for &b in trimmed.as_bytes() {
438        if escape_next {
439            escape_next = false;
440            continue;
441        }
442        match b {
443            b'\\' if in_string => {
444                escape_next = true;
445            }
446            b'"' => {
447                in_string = !in_string;
448                if !in_string {
449                    last_significant = Some(b'"');
450                }
451            }
452            _ if in_string => {}
453            b'{' => {
454                closers.push(b'}');
455                last_significant = Some(b'{');
456            }
457            b'[' => {
458                closers.push(b']');
459                last_significant = Some(b'[');
460            }
461            b'}' | b']' => {
462                closers.pop();
463                last_significant = Some(b);
464            }
465            b':' | b',' => {
466                last_significant = Some(b);
467            }
468            b if !b.is_ascii_whitespace() => {
469                last_significant = Some(b);
470            }
471            _ => {}
472        }
473    }
474
475    if closers.is_empty() {
476        return None; // Already balanced but didn't parse — genuinely invalid
477    }
478
479    // Pre-allocate repair buffer: original + at most 6 extra chars (null + closers)
480    let mut repaired = String::with_capacity(trimmed.len() + closers.len() + 6);
481    repaired.push_str(trimmed);
482
483    if in_string {
484        repaired.push('"');
485        last_significant = Some(b'"');
486    }
487
488    // If last significant char suggests an incomplete key or value, handle it
489    if let Some(last) = last_significant {
490        if last == b':' {
491            // Key with no value yet — add null
492            repaired.push_str("null");
493        } else if last == b',' {
494            // Trailing comma — some parsers choke on this, trim it
495            if let Some(pos) = repaired.rfind(',') {
496                repaired.truncate(pos);
497            }
498        }
499    }
500
501    // Close all open brackets/braces
502    for &closer in closers.iter().rev() {
503        repaired.push(closer as char);
504    }
505
506    serde_json::from_str::<Value>(&repaired)
507        .ok()
508        .filter(|v| v.is_object() || v.is_array())
509}
510
511// ---------------------------------------------------------------------------
512// Schema validation
513// ---------------------------------------------------------------------------
514
515/// Validate a JSON value against a JSON Schema.
516/// Returns Ok(()) on success, or a list of human-readable error strings.
517fn validate_against_schema(value: &Value, schema: &Value) -> Result<(), Vec<String>> {
518    // We do a basic recursive validation here. For production, consider using
519    // the `jsonschema` crate, but to avoid adding a heavy dependency we implement
520    // the subset of JSON Schema that matters for structured output.
521    let errors = basic_schema_validate(value, schema, "");
522    if errors.is_empty() {
523        Ok(())
524    } else {
525        Err(errors)
526    }
527}
528
529/// Basic JSON Schema validator covering the most common constraints.
530fn basic_schema_validate(value: &Value, schema: &Value, path: &str) -> Vec<String> {
531    let mut errors = Vec::new();
532
533    // Handle $ref — not supported in basic validator, skip
534    if schema.get("$ref").is_some() {
535        return errors;
536    }
537
538    // Handle anyOf / oneOf: value must match at least one sub-schema
539    if let Some(any_of) = schema
540        .get("anyOf")
541        .or_else(|| schema.get("oneOf"))
542        .and_then(|v| v.as_array())
543    {
544        let matched = any_of
545            .iter()
546            .any(|sub| basic_schema_validate(value, sub, path).is_empty());
547        if !matched {
548            errors.push(format!(
549                "{}: value does not match any variant in anyOf/oneOf",
550                path_or_root(path),
551            ));
552        }
553        return errors;
554    }
555
556    // Handle enum
557    if let Some(enum_values) = schema.get("enum").and_then(|v| v.as_array()) {
558        if !enum_values.contains(value) {
559            errors.push(format!(
560                "{}: value {:?} not in enum {:?}",
561                path_or_root(path),
562                value,
563                enum_values
564            ));
565        }
566        return errors;
567    }
568
569    // Handle const
570    if let Some(const_val) = schema.get("const") {
571        if value != const_val {
572            errors.push(format!(
573                "{}: expected const {:?}, got {:?}",
574                path_or_root(path),
575                const_val,
576                value
577            ));
578        }
579        return errors;
580    }
581
582    // Type checking (supports nullable via type array: ["string", "null"])
583    if let Some(type_val) = schema.get("type") {
584        let type_ok = if let Some(type_str) = type_val.as_str() {
585            check_type(value, type_str)
586        } else if let Some(type_arr) = type_val.as_array() {
587            type_arr
588                .iter()
589                .filter_map(|t| t.as_str())
590                .any(|t| check_type(value, t))
591        } else {
592            true
593        };
594        if !type_ok {
595            errors.push(format!(
596                "{}: expected type {:?}, got {:?}",
597                path_or_root(path),
598                type_val,
599                value_type_name(value)
600            ));
601            return errors;
602        }
603    }
604
605    // Object validation
606    if let Some(obj) = value.as_object() {
607        if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
608            for (key, prop_schema) in properties {
609                if let Some(child_value) = obj.get(key) {
610                    let child_path = if path.is_empty() {
611                        format!(".{}", key)
612                    } else {
613                        format!("{}.{}", path, key)
614                    };
615                    errors.extend(basic_schema_validate(child_value, prop_schema, &child_path));
616                }
617            }
618        }
619
620        if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
621            for req_field in required {
622                if let Some(field_name) = req_field.as_str() {
623                    if !obj.contains_key(field_name) {
624                        errors.push(format!(
625                            "{}: missing required field '{}'",
626                            path_or_root(path),
627                            field_name
628                        ));
629                    }
630                }
631            }
632        }
633
634        // additionalProperties: false
635        if schema.get("additionalProperties") == Some(&Value::Bool(false)) {
636            if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
637                for key in obj.keys() {
638                    if !properties.contains_key(key) {
639                        errors.push(format!(
640                            "{}: unexpected additional property '{}'",
641                            path_or_root(path),
642                            key
643                        ));
644                    }
645                }
646            }
647        }
648    }
649
650    // Array validation
651    if let Some(arr) = value.as_array() {
652        if let Some(items_schema) = schema.get("items") {
653            for (i, item) in arr.iter().enumerate() {
654                let child_path = format!("{}[{}]", path, i);
655                errors.extend(basic_schema_validate(item, items_schema, &child_path));
656            }
657        }
658        if let Some(min) = schema.get("minItems").and_then(|v| v.as_u64()) {
659            if (arr.len() as u64) < min {
660                errors.push(format!(
661                    "{}: array has {} items, minimum is {}",
662                    path_or_root(path),
663                    arr.len(),
664                    min
665                ));
666            }
667        }
668        if let Some(max) = schema.get("maxItems").and_then(|v| v.as_u64()) {
669            if (arr.len() as u64) > max {
670                errors.push(format!(
671                    "{}: array has {} items, maximum is {}",
672                    path_or_root(path),
673                    arr.len(),
674                    max
675                ));
676            }
677        }
678    }
679
680    // String validation
681    if let Some(s) = value.as_str() {
682        if let Some(min_len) = schema.get("minLength").and_then(|v| v.as_u64()) {
683            if (s.chars().count() as u64) < min_len {
684                errors.push(format!(
685                    "{}: string length {} < minLength {}",
686                    path_or_root(path),
687                    s.chars().count(),
688                    min_len
689                ));
690            }
691        }
692        if let Some(max_len) = schema.get("maxLength").and_then(|v| v.as_u64()) {
693            if (s.chars().count() as u64) > max_len {
694                errors.push(format!(
695                    "{}: string length {} > maxLength {}",
696                    path_or_root(path),
697                    s.chars().count(),
698                    max_len
699                ));
700            }
701        }
702        if let Some(pattern) = schema.get("pattern").and_then(|v| v.as_str()) {
703            if let Ok(re) = regex::Regex::new(pattern) {
704                if !re.is_match(s) {
705                    errors.push(format!(
706                        "{}: string does not match pattern '{}'",
707                        path_or_root(path),
708                        pattern
709                    ));
710                }
711            }
712        }
713    }
714
715    // Number validation
716    if let Some(n) = value.as_f64() {
717        if let Some(min) = schema.get("minimum").and_then(|v| v.as_f64()) {
718            if n < min {
719                errors.push(format!(
720                    "{}: value {} < minimum {}",
721                    path_or_root(path),
722                    n,
723                    min
724                ));
725            }
726        }
727        if let Some(max) = schema.get("maximum").and_then(|v| v.as_f64()) {
728            if n > max {
729                errors.push(format!(
730                    "{}: value {} > maximum {}",
731                    path_or_root(path),
732                    n,
733                    max
734                ));
735            }
736        }
737        if let Some(exc_min) = schema.get("exclusiveMinimum").and_then(|v| v.as_f64()) {
738            if n <= exc_min {
739                errors.push(format!(
740                    "{}: value {} <= exclusiveMinimum {}",
741                    path_or_root(path),
742                    n,
743                    exc_min
744                ));
745            }
746        }
747        if let Some(exc_max) = schema.get("exclusiveMaximum").and_then(|v| v.as_f64()) {
748            if n >= exc_max {
749                errors.push(format!(
750                    "{}: value {} >= exclusiveMaximum {}",
751                    path_or_root(path),
752                    n,
753                    exc_max
754                ));
755            }
756        }
757    }
758
759    errors
760}
761
762fn check_type(value: &Value, type_str: &str) -> bool {
763    match type_str {
764        "object" => value.is_object(),
765        "array" => value.is_array(),
766        "string" => value.is_string(),
767        "number" => value.is_number(),
768        "integer" => {
769            value.is_i64()
770                || value.is_u64()
771                || value
772                    .as_f64()
773                    .map(|f| f.fract() == 0.0 && f.is_finite())
774                    .unwrap_or(false)
775        }
776        "boolean" => value.is_boolean(),
777        "null" => value.is_null(),
778        _ => true,
779    }
780}
781
782fn path_or_root(path: &str) -> &str {
783    if path.is_empty() {
784        "$"
785    } else {
786        path
787    }
788}
789
790fn value_type_name(value: &Value) -> &'static str {
791    match value {
792        Value::Null => "null",
793        Value::Bool(_) => "boolean",
794        Value::Number(_) => "number",
795        Value::String(_) => "string",
796        Value::Array(_) => "array",
797        Value::Object(_) => "object",
798    }
799}
800
801// ---------------------------------------------------------------------------
802// Message/prompt construction helpers
803// ---------------------------------------------------------------------------
804
805fn build_initial_messages(req: &StructuredRequest, mode: StructuredMode) -> Vec<Message> {
806    match mode {
807        StructuredMode::Tool => {
808            // For tool mode, the prompt is the user message; the LLM will respond
809            // with a tool call whose input is the structured object.
810            vec![Message::user(&req.prompt)]
811        }
812        StructuredMode::Prompt => {
813            // Append schema instructions to the user prompt
814            let augmented = format!(
815                "{}\n\nYou MUST respond with ONLY a valid JSON object (no markdown, no explanation) that conforms to this JSON Schema:\n\n```json\n{}\n```",
816                req.prompt,
817                serde_json::to_string_pretty(&req.schema).unwrap_or_default()
818            );
819            vec![Message::user(&augmented)]
820        }
821        _ => {
822            // Strict/Json modes: the schema constraint is enforced by the provider,
823            // so the user message is just the prompt.
824            vec![Message::user(&req.prompt)]
825        }
826    }
827}
828
829fn build_system_prompt(req: &StructuredRequest, mode: StructuredMode) -> String {
830    let base = req.system.as_deref().unwrap_or("");
831
832    match mode {
833        StructuredMode::Tool => {
834            format!(
835                "{}{}You MUST respond by calling the `emit_{}` tool exactly once with a valid argument matching the schema. Do not output any text outside the tool call.",
836                base,
837                if base.is_empty() { "" } else { "\n\n" },
838                req.schema_name
839            )
840        }
841        StructuredMode::Prompt => {
842            format!(
843                "{}{}You are a structured data extraction assistant. Always respond with valid JSON only, no markdown fences, no explanation text.",
844                base,
845                if base.is_empty() { "" } else { "\n\n" },
846            )
847        }
848        _ => base.to_string(),
849    }
850}
851
852fn build_tools(req: &StructuredRequest, mode: StructuredMode) -> Vec<ToolDefinition> {
853    match mode {
854        StructuredMode::Tool => {
855            vec![ToolDefinition {
856                name: format!("emit_{}", req.schema_name),
857                description: req
858                    .schema_description
859                    .clone()
860                    .unwrap_or_else(|| format!("Emit a structured {} object", req.schema_name)),
861                parameters: req.schema.clone(),
862            }]
863        }
864        _ => vec![],
865    }
866}
867
868/// Extract the raw JSON string from the LLM response based on mode.
869fn extract_raw_output(message: &super::Message, mode: StructuredMode) -> String {
870    match mode {
871        StructuredMode::Tool => {
872            // Look for tool call input
873            let calls = message.tool_calls();
874            if let Some(call) = calls.first() {
875                serde_json::to_string(&call.args).unwrap_or_default()
876            } else {
877                // Fallback: maybe the model responded with text anyway
878                message.text()
879            }
880        }
881        _ => message.text(),
882    }
883}
884
885fn build_repair_message(raw_text: &str, errors: &[String]) -> String {
886    // Truncate raw output in repair message to avoid blowing context
887    let truncated_raw = if raw_text.len() > 2000 {
888        format!(
889            "{}...[truncated, {} bytes total]",
890            &raw_text[..2000],
891            raw_text.len()
892        )
893    } else {
894        raw_text.to_string()
895    };
896    format!(
897        "Your previous output failed schema validation:\n\n{}\n\nValidation errors:\n{}\n\nPlease return ONLY a corrected JSON object that fixes these errors. No explanation, no markdown.",
898        truncated_raw,
899        errors.iter().map(|e| format!("- {}", e)).collect::<Vec<_>>().join("\n")
900    )
901}
902
903fn accumulate_usage(total: &mut TokenUsage, delta: &TokenUsage) {
904    total.prompt_tokens += delta.prompt_tokens;
905    total.completion_tokens += delta.completion_tokens;
906    total.total_tokens += delta.total_tokens;
907}
908
909/// Append repair context to the message history, respecting conversation structure.
910///
911/// In tool mode, the LLM returned a tool_use block. The correct follow-up is:
912///   assistant (tool_use) → user (tool_result with error) → assistant (retry)
913/// In text modes, it's simply:
914///   assistant (text) → user (repair request) → assistant (retry)
915fn append_repair_context(
916    messages: &mut Vec<Message>,
917    assistant_msg: &Message,
918    repair_text: &str,
919    mode: StructuredMode,
920    _raw_text: &str,
921) {
922    if mode == StructuredMode::Tool {
923        // Push the original assistant message (with tool_use block intact)
924        messages.push(assistant_msg.clone());
925        // Find the tool_use ID to construct a proper tool_result
926        let tool_use_id = assistant_msg
927            .tool_calls()
928            .first()
929            .map(|tc| tc.id.clone())
930            .unwrap_or_else(|| "unknown".to_string());
931        // Return the error as a tool_result so the conversation stays valid
932        messages.push(Message::tool_result(&tool_use_id, repair_text, true));
933    } else {
934        // Text modes: push assistant text then user repair request
935        messages.push(assistant_msg.clone());
936        messages.push(Message::user(repair_text));
937    }
938}
939
940// ---------------------------------------------------------------------------
941// Tests
942// ---------------------------------------------------------------------------
943
944#[cfg(test)]
945#[path = "structured_tests.rs"]
946mod structured_tests;