Skip to main content

llm_stack/
structured.rs

1//! Structured output — typed LLM responses with schema validation.
2//!
3//! This module provides high-level functions that combine schema derivation,
4//! LLM generation, validation, and deserialization into a single call.
5//!
6//! # Non-streaming
7//!
8//! [`generate_object`] sends a request with a JSON Schema constraint and
9//! returns a fully validated, deserialized `T`:
10//!
11//! ```rust,no_run
12//! use llm_stack::structured::{generate_object, GenerateObjectConfig};
13//! use llm_stack::{ChatMessage, ChatParams};
14//! use serde::Deserialize;
15//!
16//! #[derive(Deserialize, schemars::JsonSchema)]
17//! struct Person {
18//!     name: String,
19//!     age: u32,
20//! }
21//!
22//! # async fn example(provider: &dyn llm_stack::DynProvider) -> Result<(), llm_stack::LlmError> {
23//! let params = ChatParams {
24//!     messages: vec![ChatMessage::user("Generate a person named Alice aged 30")],
25//!     ..Default::default()
26//! };
27//!
28//! let result = generate_object::<Person>(provider, params, GenerateObjectConfig::default()).await?;
29//! assert_eq!(result.value.name, "Alice");
30//! # Ok(())
31//! # }
32//! ```
33//!
34//! # Streaming
35//!
36//! [`stream_object_async`] yields [`PartialObject<T>`] events as JSON tokens
37//! arrive. Each event carries the accumulated JSON so far, and the final
38//! event includes the fully deserialized object.
39
40use serde::de::DeserializeOwned;
41use serde_json::Value;
42
43use crate::chat::ChatResponse;
44use crate::error::LlmError;
45use crate::provider::{ChatParams, DynProvider, JsonSchema};
46use crate::stream::{ChatStream, StreamEvent};
47use crate::usage::Usage;
48
49// ── GenerateObjectConfig ────────────────────────────────────────────
50
51/// Configuration for [`generate_object`] and [`stream_object_async`].
52#[derive(Debug, Clone)]
53pub struct GenerateObjectConfig {
54    /// Maximum number of attempts (initial + retries). Default: 1
55    /// (single attempt, no retries). Set to 2 for one retry, etc.
56    pub max_attempts: u32,
57    /// Whether to include the schema in the system prompt as a fallback
58    /// for providers that don't support native structured output.
59    /// Default: false.
60    pub system_prompt_fallback: bool,
61}
62
63impl Default for GenerateObjectConfig {
64    fn default() -> Self {
65        Self {
66            max_attempts: 1,
67            system_prompt_fallback: false,
68        }
69    }
70}
71
72// ── PartialObject ───────────────────────────────────────────────────
73
74/// A partially-received structured object from a streaming response.
75#[derive(Debug, Clone)]
76pub struct PartialObject<T> {
77    /// The raw JSON accumulated so far (may be incomplete).
78    pub partial_json: String,
79    /// `Some(T)` once the stream completes and the object passes
80    /// validation + deserialization.
81    pub complete: Option<T>,
82    /// Accumulated usage across retries and the final successful call.
83    pub usage: Usage,
84}
85
86// ── GenerateObjectResult ────────────────────────────────────────────
87
88/// The result of a successful [`generate_object`] call.
89#[derive(Debug)]
90pub struct GenerateObjectResult<T> {
91    /// The deserialized, validated object.
92    pub value: T,
93    /// The raw JSON string returned by the model.
94    pub raw_json: String,
95    /// Token usage for the successful attempt (and all retries).
96    pub usage: Usage,
97    /// How many attempts were made (1 = succeeded on first try).
98    pub attempts: u32,
99}
100
101// ── generate_object ─────────────────────────────────────────────────
102
103/// Generates a typed object from the LLM with schema validation.
104///
105/// 1. Derives a JSON Schema from `T` (via [`schemars`])
106/// 2. Sets `structured_output` on `ChatParams`
107/// 3. Calls the provider
108/// 4. Parses the response text as JSON
109/// 5. Validates against the schema
110/// 6. Deserializes to `T`
111///
112/// Retries up to `config.max_attempts` times on parse/validation failures.
113/// On retry, the model's invalid response and the validation error are
114/// appended to the message history so the model can self-correct.
115///
116/// # Errors
117///
118/// Returns [`LlmError`] if:
119/// - `max_attempts` is 0
120/// - The schema cannot be derived from `T`
121/// - The provider returns an error (propagated immediately, not retried)
122/// - All attempts fail validation (returns the last validation error)
123#[cfg(feature = "schema")]
124pub async fn generate_object<T>(
125    provider: &dyn DynProvider,
126    mut params: ChatParams,
127    config: GenerateObjectConfig,
128) -> Result<GenerateObjectResult<T>, LlmError>
129where
130    T: DeserializeOwned + schemars::JsonSchema,
131{
132    if config.max_attempts == 0 {
133        return Err(LlmError::InvalidRequest(
134            "max_attempts must be at least 1".into(),
135        ));
136    }
137
138    let schema = JsonSchema::from_type::<T>()
139        .map_err(|e| LlmError::InvalidRequest(format!("failed to derive JSON schema: {e}")))?;
140
141    params.structured_output = Some(schema.clone());
142
143    if config.system_prompt_fallback {
144        inject_schema_prompt(&mut params, &schema);
145    }
146
147    let mut total_usage = Usage::default();
148    let mut last_error = None;
149
150    for attempt in 1..=config.max_attempts {
151        let response = provider.generate_boxed(&params).await?;
152        total_usage += response.usage.clone();
153
154        match extract_and_validate::<T>(&response, &schema) {
155            Ok((value, raw_json)) => {
156                return Ok(GenerateObjectResult {
157                    value,
158                    raw_json,
159                    usage: total_usage,
160                    attempts: attempt,
161                });
162            }
163            Err(e) => {
164                last_error = Some(e);
165                if attempt < config.max_attempts {
166                    append_retry_feedback(
167                        &mut params,
168                        &response,
169                        last_error.as_ref().expect("set on previous line"),
170                    );
171                }
172            }
173        }
174    }
175
176    Err(last_error.expect("max_attempts >= 1 guarantees at least one iteration"))
177}
178
179/// Async streaming variant of [`generate_object`].
180///
181/// Awaits the provider's stream creation and returns the [`ChatStream`]
182/// directly. Yields [`StreamEvent`]s which can be collected via
183/// [`collect_stream_object`].
184///
185/// The `max_attempts` field in `config` is ignored — retry logic must
186/// be implemented by the caller.
187#[cfg(feature = "schema")]
188pub async fn stream_object_async<T>(
189    provider: &dyn DynProvider,
190    mut params: ChatParams,
191    config: GenerateObjectConfig,
192) -> Result<ChatStream, LlmError>
193where
194    T: DeserializeOwned + schemars::JsonSchema,
195{
196    let schema = JsonSchema::from_type::<T>()
197        .map_err(|e| LlmError::InvalidRequest(format!("failed to derive JSON schema: {e}")))?;
198
199    params.structured_output = Some(schema.clone());
200
201    if config.system_prompt_fallback {
202        inject_schema_prompt(&mut params, &schema);
203    }
204
205    provider.stream_boxed(&params).await
206}
207
208/// Collects a [`ChatStream`] into a [`PartialObject<T>`].
209///
210/// Accumulates text deltas, then validates and deserializes the result
211/// when the stream completes. Use this with [`stream_object_async`].
212///
213/// Unlike [`generate_object`], this function does not retry on
214/// validation failures — errors are returned immediately.
215#[cfg(feature = "schema")]
216pub async fn collect_stream_object<T>(
217    mut stream: ChatStream,
218    schema: &JsonSchema,
219) -> Result<PartialObject<T>, LlmError>
220where
221    T: DeserializeOwned,
222{
223    use futures::StreamExt;
224
225    let mut json_buf = String::new();
226    let mut usage = Usage::default();
227
228    while let Some(event) = stream.next().await {
229        match event? {
230            StreamEvent::TextDelta(text) => json_buf.push_str(&text),
231            StreamEvent::Usage(u) => usage += u,
232            StreamEvent::Done { .. } => break,
233            _ => {}
234        }
235    }
236
237    if json_buf.is_empty() {
238        return Err(LlmError::ResponseFormat {
239            message: "model returned no text content for structured output".into(),
240            raw: String::new(),
241        });
242    }
243
244    // Parse, validate, deserialize
245    let value: Value = serde_json::from_str(&json_buf).map_err(|e| LlmError::ResponseFormat {
246        message: format!("invalid JSON in structured output: {e}"),
247        raw: json_buf.clone(),
248    })?;
249
250    schema.validate(&value)?;
251
252    let typed: T = serde_json::from_value(value).map_err(|e| LlmError::ResponseFormat {
253        message: format!("failed to deserialize structured output: {e}"),
254        raw: json_buf.clone(),
255    })?;
256
257    Ok(PartialObject {
258        partial_json: json_buf,
259        complete: Some(typed),
260        usage,
261    })
262}
263
264// ── Helpers ─────────────────────────────────────────────────────────
265
266/// Extracts the text from a `ChatResponse`, parses it as JSON,
267/// validates against the schema, and deserializes to `T`.
268#[cfg(feature = "schema")]
269fn extract_and_validate<T: DeserializeOwned>(
270    response: &ChatResponse,
271    schema: &JsonSchema,
272) -> Result<(T, String), LlmError> {
273    let raw_json = response.text().ok_or_else(|| LlmError::ResponseFormat {
274        message: "model returned no text content for structured output".into(),
275        raw: String::new(),
276    })?;
277
278    let value: Value = serde_json::from_str(raw_json).map_err(|e| LlmError::ResponseFormat {
279        message: format!("invalid JSON in structured output: {e}"),
280        raw: raw_json.to_string(),
281    })?;
282
283    schema.validate(&value)?;
284
285    let typed: T = serde_json::from_value(value).map_err(|e| LlmError::ResponseFormat {
286        message: format!("failed to deserialize structured output: {e}"),
287        raw: raw_json.to_string(),
288    })?;
289
290    Ok((typed, raw_json.to_string()))
291}
292
293/// Injects a system prompt instructing the model to respond with JSON
294/// matching the given schema.
295fn inject_schema_prompt(params: &mut ChatParams, schema: &JsonSchema) {
296    let schema_json = serde_json::to_string_pretty(schema.as_value())
297        .expect("serializing Value to JSON cannot fail");
298
299    let instruction = format!(
300        "You must respond with valid JSON that conforms to this JSON Schema:\n\
301         ```json\n{schema_json}\n```\n\
302         Respond ONLY with the JSON object. No markdown, no explanation."
303    );
304
305    match &mut params.system {
306        Some(existing) => {
307            existing.push_str("\n\n");
308            existing.push_str(&instruction);
309        }
310        None => params.system = Some(instruction),
311    }
312}
313
314/// Appends the model's failed response and the validation error as
315/// feedback messages, so the model can retry.
316fn append_retry_feedback(params: &mut ChatParams, response: &ChatResponse, error: &LlmError) {
317    use crate::chat::ChatMessage;
318
319    // Add the model's response as an assistant message
320    params
321        .messages
322        .push(ChatMessage::assistant(response.text().unwrap_or("")));
323
324    // Add the error as a user message asking for correction
325    params.messages.push(ChatMessage::user(format!(
326        "Your response did not pass validation: {error}\n\
327         Please try again with valid JSON that conforms to the schema."
328    )));
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::chat::{ChatMessage, ContentBlock, StopReason};
335    use crate::test_helpers::{mock_for, sample_usage};
336    use serde::Deserialize;
337    use serde_json::json;
338    use std::collections::HashMap;
339
340    #[derive(Debug, Deserialize, PartialEq, schemars::JsonSchema)]
341    struct Person {
342        name: String,
343        age: u32,
344    }
345
346    #[derive(Debug, Deserialize, PartialEq, schemars::JsonSchema)]
347    struct Coord {
348        x: f64,
349        y: f64,
350    }
351
352    fn json_response(json_str: &str) -> ChatResponse {
353        ChatResponse {
354            content: vec![ContentBlock::Text(json_str.into())],
355            usage: sample_usage(),
356            stop_reason: StopReason::EndTurn,
357            model: "test-model".into(),
358            metadata: HashMap::new(),
359        }
360    }
361
362    // ── generate_object tests ──────────────────────────────────────
363
364    #[tokio::test]
365    async fn test_generate_object_happy_path() {
366        let mock = mock_for("test", "test-model");
367        mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
368
369        let params = ChatParams {
370            messages: vec![ChatMessage::user("Generate a person")],
371            ..Default::default()
372        };
373
374        let result: GenerateObjectResult<Person> =
375            generate_object(&mock, params, GenerateObjectConfig::default())
376                .await
377                .unwrap();
378
379        assert_eq!(
380            result.value,
381            Person {
382                name: "Alice".into(),
383                age: 30
384            }
385        );
386        assert_eq!(result.attempts, 1);
387        assert_eq!(result.raw_json, r#"{"name": "Alice", "age": 30}"#);
388    }
389
390    #[tokio::test]
391    async fn test_generate_object_invalid_json() {
392        let mock = mock_for("test", "test-model");
393        mock.queue_response(json_response("not valid json"));
394
395        let params = ChatParams {
396            messages: vec![ChatMessage::user("Generate a person")],
397            ..Default::default()
398        };
399
400        let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
401            .await
402            .unwrap_err();
403
404        assert!(matches!(err, LlmError::ResponseFormat { .. }));
405    }
406
407    #[tokio::test]
408    async fn test_generate_object_schema_violation() {
409        let mock = mock_for("test", "test-model");
410        // Missing required field "age"
411        mock.queue_response(json_response(r#"{"name": "Alice"}"#));
412
413        let params = ChatParams {
414            messages: vec![ChatMessage::user("Generate a person")],
415            ..Default::default()
416        };
417
418        let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
419            .await
420            .unwrap_err();
421
422        assert!(matches!(err, LlmError::SchemaValidation { .. }));
423    }
424
425    #[tokio::test]
426    async fn test_generate_object_wrong_type() {
427        let mock = mock_for("test", "test-model");
428        // age is string instead of number
429        mock.queue_response(json_response(r#"{"name": "Alice", "age": "thirty"}"#));
430
431        let params = ChatParams {
432            messages: vec![ChatMessage::user("Generate a person")],
433            ..Default::default()
434        };
435
436        let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
437            .await
438            .unwrap_err();
439
440        assert!(matches!(err, LlmError::SchemaValidation { .. }));
441    }
442
443    #[tokio::test]
444    async fn test_generate_object_retry_succeeds_on_second_attempt() {
445        let mock = mock_for("test", "test-model");
446        // First attempt: invalid
447        mock.queue_response(json_response(r#"{"name": "Alice"}"#));
448        // Second attempt: valid
449        mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
450
451        let params = ChatParams {
452            messages: vec![ChatMessage::user("Generate a person")],
453            ..Default::default()
454        };
455        let config = GenerateObjectConfig {
456            max_attempts: 2,
457            ..Default::default()
458        };
459
460        let result: GenerateObjectResult<Person> =
461            generate_object(&mock, params, config).await.unwrap();
462
463        assert_eq!(
464            result.value,
465            Person {
466                name: "Alice".into(),
467                age: 30
468            }
469        );
470        assert_eq!(result.attempts, 2);
471        // Usage should include both attempts
472        assert_eq!(result.usage.input_tokens, 200);
473    }
474
475    #[tokio::test]
476    async fn test_generate_object_retry_exhausted() {
477        let mock = mock_for("test", "test-model");
478        // All attempts fail
479        mock.queue_response(json_response(r#"{"name": "Alice"}"#));
480        mock.queue_response(json_response(r#"{"name": "Bob"}"#));
481        mock.queue_response(json_response(r#"{"name": "Charlie"}"#));
482
483        let params = ChatParams {
484            messages: vec![ChatMessage::user("Generate a person")],
485            ..Default::default()
486        };
487        let config = GenerateObjectConfig {
488            max_attempts: 3,
489            ..Default::default()
490        };
491
492        let err = generate_object::<Person>(&mock, params, config)
493            .await
494            .unwrap_err();
495
496        assert!(matches!(err, LlmError::SchemaValidation { .. }));
497    }
498
499    #[tokio::test]
500    async fn test_generate_object_no_text_content() {
501        let mock = mock_for("test", "test-model");
502        // Response with no text (empty content)
503        mock.queue_response(ChatResponse {
504            content: vec![],
505            usage: sample_usage(),
506            stop_reason: StopReason::EndTurn,
507            model: "test-model".into(),
508            metadata: HashMap::new(),
509        });
510
511        let params = ChatParams {
512            messages: vec![ChatMessage::user("Generate a person")],
513            ..Default::default()
514        };
515
516        let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
517            .await
518            .unwrap_err();
519
520        assert!(matches!(err, LlmError::ResponseFormat { .. }));
521    }
522
523    #[tokio::test]
524    async fn test_generate_object_sets_structured_output() {
525        let mock = mock_for("test", "test-model");
526        mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
527
528        let params = ChatParams {
529            messages: vec![ChatMessage::user("Generate coords")],
530            ..Default::default()
531        };
532
533        let _result: GenerateObjectResult<Coord> =
534            generate_object(&mock, params, GenerateObjectConfig::default())
535                .await
536                .unwrap();
537
538        // Verify the provider received structured_output
539        let recorded = mock.recorded_calls();
540        assert!(recorded[0].structured_output.is_some());
541    }
542
543    #[tokio::test]
544    async fn test_generate_object_system_prompt_fallback() {
545        let mock = mock_for("test", "test-model");
546        mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
547
548        let params = ChatParams {
549            messages: vec![ChatMessage::user("Generate coords")],
550            ..Default::default()
551        };
552        let config = GenerateObjectConfig {
553            system_prompt_fallback: true,
554            ..Default::default()
555        };
556
557        let _result: GenerateObjectResult<Coord> =
558            generate_object(&mock, params, config).await.unwrap();
559
560        let recorded = mock.recorded_calls();
561        assert!(recorded[0].system.is_some());
562        assert!(recorded[0].system.as_ref().unwrap().contains("JSON Schema"));
563    }
564
565    #[tokio::test]
566    async fn test_generate_object_system_prompt_appends() {
567        let mock = mock_for("test", "test-model");
568        mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
569
570        let params = ChatParams {
571            messages: vec![ChatMessage::user("Generate coords")],
572            system: Some("You are a helpful assistant.".into()),
573            ..Default::default()
574        };
575        let config = GenerateObjectConfig {
576            system_prompt_fallback: true,
577            ..Default::default()
578        };
579
580        let _result: GenerateObjectResult<Coord> =
581            generate_object(&mock, params, config).await.unwrap();
582
583        let recorded = mock.recorded_calls();
584        let system = recorded[0].system.as_ref().unwrap();
585        assert!(system.starts_with("You are a helpful assistant."));
586        assert!(system.contains("JSON Schema"));
587    }
588
589    #[tokio::test]
590    async fn test_generate_object_retry_appends_feedback() {
591        let mock = mock_for("test", "test-model");
592        mock.queue_response(json_response(r#"{"name": "Alice"}"#));
593        mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
594
595        let params = ChatParams {
596            messages: vec![ChatMessage::user("Generate a person")],
597            ..Default::default()
598        };
599        let config = GenerateObjectConfig {
600            max_attempts: 2,
601            ..Default::default()
602        };
603
604        let _result: GenerateObjectResult<Person> =
605            generate_object(&mock, params, config).await.unwrap();
606
607        // Second call should have feedback messages
608        let recorded = mock.recorded_calls();
609        assert!(recorded[1].messages.len() > 1);
610        // Should contain the model's failed response + error feedback
611        let last_user_msg = recorded[1]
612            .messages
613            .iter()
614            .rfind(|m| m.role == crate::chat::ChatRole::User)
615            .unwrap();
616        let text = last_user_msg.content.iter().find_map(|b| match b {
617            ContentBlock::Text(t) => Some(t.as_str()),
618            _ => None,
619        });
620        assert!(text.unwrap().contains("did not pass validation"));
621    }
622
623    // ── collect_stream_object tests ────────────────────────────────
624
625    #[tokio::test]
626    async fn test_collect_stream_object_happy_path() {
627        let schema = JsonSchema::from_type::<Person>().unwrap();
628        let events = vec![
629            Ok(StreamEvent::TextDelta(r#"{"name":"#.into())),
630            Ok(StreamEvent::TextDelta(r#" "Alice", "age": 30}"#.into())),
631            Ok(StreamEvent::Usage(sample_usage())),
632            Ok(StreamEvent::Done {
633                stop_reason: StopReason::EndTurn,
634            }),
635        ];
636        let stream: ChatStream = Box::pin(futures::stream::iter(events));
637
638        let result: PartialObject<Person> = collect_stream_object(stream, &schema).await.unwrap();
639
640        assert_eq!(
641            result.complete.unwrap(),
642            Person {
643                name: "Alice".into(),
644                age: 30
645            }
646        );
647        assert_eq!(result.partial_json, r#"{"name": "Alice", "age": 30}"#);
648        assert_eq!(result.usage.input_tokens, 100);
649    }
650
651    #[tokio::test]
652    async fn test_collect_stream_object_invalid_json() {
653        let schema = JsonSchema::from_type::<Person>().unwrap();
654        let events = vec![
655            Ok(StreamEvent::TextDelta("not json".into())),
656            Ok(StreamEvent::Done {
657                stop_reason: StopReason::EndTurn,
658            }),
659        ];
660        let stream: ChatStream = Box::pin(futures::stream::iter(events));
661
662        let err = collect_stream_object::<Person>(stream, &schema)
663            .await
664            .unwrap_err();
665
666        assert!(matches!(err, LlmError::ResponseFormat { .. }));
667    }
668
669    #[tokio::test]
670    async fn test_collect_stream_object_schema_violation() {
671        let schema = JsonSchema::from_type::<Person>().unwrap();
672        let events = vec![
673            Ok(StreamEvent::TextDelta(r#"{"name": "Alice"}"#.into())),
674            Ok(StreamEvent::Done {
675                stop_reason: StopReason::EndTurn,
676            }),
677        ];
678        let stream: ChatStream = Box::pin(futures::stream::iter(events));
679
680        let err = collect_stream_object::<Person>(stream, &schema)
681            .await
682            .unwrap_err();
683
684        assert!(matches!(err, LlmError::SchemaValidation { .. }));
685    }
686
687    #[tokio::test]
688    async fn test_collect_stream_object_empty_stream() {
689        let schema = JsonSchema::from_type::<Person>().unwrap();
690        let events = vec![Ok(StreamEvent::Done {
691            stop_reason: StopReason::EndTurn,
692        })];
693        let stream: ChatStream = Box::pin(futures::stream::iter(events));
694
695        let err = collect_stream_object::<Person>(stream, &schema)
696            .await
697            .unwrap_err();
698
699        assert!(matches!(err, LlmError::ResponseFormat { .. }));
700    }
701
702    #[tokio::test]
703    async fn test_collect_stream_object_mid_stream_error() {
704        let schema = JsonSchema::from_type::<Person>().unwrap();
705        let events = vec![
706            Ok(StreamEvent::TextDelta(r#"{"name"#.into())),
707            Err(LlmError::Http {
708                status: None,
709                message: "connection lost".into(),
710                retryable: true,
711            }),
712        ];
713        let stream: ChatStream = Box::pin(futures::stream::iter(events));
714
715        let err = collect_stream_object::<Person>(stream, &schema)
716            .await
717            .unwrap_err();
718
719        assert!(matches!(err, LlmError::Http { .. }));
720    }
721
722    // ── stream_object_async tests ──────────────────────────────────
723
724    #[tokio::test]
725    async fn test_stream_object_async_sets_structured_output() {
726        let mock = mock_for("test", "test-model");
727        mock.queue_stream(vec![
728            StreamEvent::TextDelta(r#"{"x": 1.0, "y": 2.0}"#.into()),
729            StreamEvent::Done {
730                stop_reason: StopReason::EndTurn,
731            },
732        ]);
733
734        let params = ChatParams {
735            messages: vec![ChatMessage::user("Generate coords")],
736            ..Default::default()
737        };
738
739        let _stream = stream_object_async::<Coord>(&mock, params, GenerateObjectConfig::default())
740            .await
741            .unwrap();
742
743        let recorded = mock.recorded_calls();
744        assert!(recorded[0].structured_output.is_some());
745    }
746
747    #[tokio::test]
748    async fn test_stream_object_async_end_to_end() {
749        let mock = mock_for("test", "test-model");
750        mock.queue_stream(vec![
751            StreamEvent::TextDelta(r#"{"x": 1.23"#.into()),
752            StreamEvent::TextDelta(r#", "y": 4.56}"#.into()),
753            StreamEvent::Done {
754                stop_reason: StopReason::EndTurn,
755            },
756        ]);
757
758        let params = ChatParams {
759            messages: vec![ChatMessage::user("Generate coords")],
760            ..Default::default()
761        };
762
763        let schema = JsonSchema::from_type::<Coord>().unwrap();
764        let stream = stream_object_async::<Coord>(&mock, params, GenerateObjectConfig::default())
765            .await
766            .unwrap();
767
768        let result: PartialObject<Coord> = collect_stream_object(stream, &schema).await.unwrap();
769
770        let coord = result.complete.unwrap();
771        assert!((coord.x - 1.23).abs() < 0.001);
772        assert!((coord.y - 4.56).abs() < 0.001);
773    }
774
775    // ── Helper tests ───────────────────────────────────────────────
776
777    #[test]
778    fn test_inject_schema_prompt_new() {
779        let schema = JsonSchema::new(json!({"type": "object"}));
780        let mut params = ChatParams::default();
781        inject_schema_prompt(&mut params, &schema);
782
783        assert!(params.system.as_ref().unwrap().contains("JSON Schema"));
784        assert!(params.system.as_ref().unwrap().contains("ONLY"));
785    }
786
787    #[test]
788    fn test_inject_schema_prompt_appends() {
789        let schema = JsonSchema::new(json!({"type": "object"}));
790        let mut params = ChatParams {
791            system: Some("Be helpful.".into()),
792            ..Default::default()
793        };
794        inject_schema_prompt(&mut params, &schema);
795
796        let system = params.system.unwrap();
797        assert!(system.starts_with("Be helpful."));
798        assert!(system.contains("JSON Schema"));
799    }
800
801    #[test]
802    fn test_generate_object_config_default() {
803        let config = GenerateObjectConfig::default();
804        assert_eq!(config.max_attempts, 1);
805        assert!(!config.system_prompt_fallback);
806    }
807
808    #[test]
809    fn test_partial_object_debug() {
810        let po: PartialObject<Person> = PartialObject {
811            partial_json: "{}".into(),
812            complete: None,
813            usage: Usage::default(),
814        };
815        let debug = format!("{po:?}");
816        assert!(debug.contains("PartialObject"));
817    }
818
819    #[test]
820    fn test_generate_object_result_debug() {
821        let result = GenerateObjectResult {
822            value: Person {
823                name: "Alice".into(),
824                age: 30,
825            },
826            raw_json: "{}".into(),
827            usage: Usage::default(),
828            attempts: 1,
829        };
830        let debug = format!("{result:?}");
831        assert!(debug.contains("GenerateObjectResult"));
832    }
833
834    #[tokio::test]
835    async fn test_generate_object_zero_attempts_errors() {
836        let mock = mock_for("test", "test-model");
837        let params = ChatParams {
838            messages: vec![ChatMessage::user("Generate a person")],
839            ..Default::default()
840        };
841        let config = GenerateObjectConfig {
842            max_attempts: 0,
843            ..Default::default()
844        };
845
846        let err = generate_object::<Person>(&mock, params, config)
847            .await
848            .unwrap_err();
849
850        assert!(matches!(err, LlmError::InvalidRequest(_)));
851    }
852
853    #[tokio::test]
854    async fn test_generate_object_provider_error_propagates() {
855        let mock = mock_for("test", "test-model");
856        // First response fails validation, second is a provider error
857        mock.queue_response(json_response(r#"{"name": "Alice"}"#));
858        mock.queue_error(crate::mock::MockError::Http {
859            status: Some(http::StatusCode::SERVICE_UNAVAILABLE),
860            message: "service down".into(),
861            retryable: true,
862        });
863
864        let params = ChatParams {
865            messages: vec![ChatMessage::user("Generate a person")],
866            ..Default::default()
867        };
868        let config = GenerateObjectConfig {
869            max_attempts: 2,
870            ..Default::default()
871        };
872
873        let err = generate_object::<Person>(&mock, params, config)
874            .await
875            .unwrap_err();
876
877        // Provider errors propagate immediately, not wrapped
878        assert!(matches!(err, LlmError::Http { .. }));
879    }
880}