1use serde::de::DeserializeOwned;
41use serde_json::Value;
42
43use crate::chat::ChatResponse;
44use crate::error::LlmError;
45use crate::provider::{ChatParams, DynProvider, JsonSchema};
46use crate::stream::{ChatStream, StreamEvent};
47use crate::usage::Usage;
48
49#[derive(Debug, Clone)]
53pub struct GenerateObjectConfig {
54 pub max_attempts: u32,
57 pub system_prompt_fallback: bool,
61}
62
63impl Default for GenerateObjectConfig {
64 fn default() -> Self {
65 Self {
66 max_attempts: 1,
67 system_prompt_fallback: false,
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
76pub struct PartialObject<T> {
77 pub partial_json: String,
79 pub complete: Option<T>,
82 pub usage: Usage,
84}
85
86#[derive(Debug)]
90pub struct GenerateObjectResult<T> {
91 pub value: T,
93 pub raw_json: String,
95 pub usage: Usage,
97 pub attempts: u32,
99}
100
101#[cfg(feature = "schema")]
124pub async fn generate_object<T>(
125 provider: &dyn DynProvider,
126 mut params: ChatParams,
127 config: GenerateObjectConfig,
128) -> Result<GenerateObjectResult<T>, LlmError>
129where
130 T: DeserializeOwned + schemars::JsonSchema,
131{
132 if config.max_attempts == 0 {
133 return Err(LlmError::InvalidRequest(
134 "max_attempts must be at least 1".into(),
135 ));
136 }
137
138 let schema = JsonSchema::from_type::<T>()
139 .map_err(|e| LlmError::InvalidRequest(format!("failed to derive JSON schema: {e}")))?;
140
141 params.structured_output = Some(schema.clone());
142
143 if config.system_prompt_fallback {
144 inject_schema_prompt(&mut params, &schema);
145 }
146
147 let mut total_usage = Usage::default();
148 let mut last_error = None;
149
150 for attempt in 1..=config.max_attempts {
151 let response = provider.generate_boxed(¶ms).await?;
152 total_usage += response.usage.clone();
153
154 match extract_and_validate::<T>(&response, &schema) {
155 Ok((value, raw_json)) => {
156 return Ok(GenerateObjectResult {
157 value,
158 raw_json,
159 usage: total_usage,
160 attempts: attempt,
161 });
162 }
163 Err(e) => {
164 last_error = Some(e);
165 if attempt < config.max_attempts {
166 append_retry_feedback(
167 &mut params,
168 &response,
169 last_error.as_ref().expect("set on previous line"),
170 );
171 }
172 }
173 }
174 }
175
176 Err(last_error.expect("max_attempts >= 1 guarantees at least one iteration"))
177}
178
179#[cfg(feature = "schema")]
188pub async fn stream_object_async<T>(
189 provider: &dyn DynProvider,
190 mut params: ChatParams,
191 config: GenerateObjectConfig,
192) -> Result<ChatStream, LlmError>
193where
194 T: DeserializeOwned + schemars::JsonSchema,
195{
196 let schema = JsonSchema::from_type::<T>()
197 .map_err(|e| LlmError::InvalidRequest(format!("failed to derive JSON schema: {e}")))?;
198
199 params.structured_output = Some(schema.clone());
200
201 if config.system_prompt_fallback {
202 inject_schema_prompt(&mut params, &schema);
203 }
204
205 provider.stream_boxed(¶ms).await
206}
207
208#[cfg(feature = "schema")]
216pub async fn collect_stream_object<T>(
217 mut stream: ChatStream,
218 schema: &JsonSchema,
219) -> Result<PartialObject<T>, LlmError>
220where
221 T: DeserializeOwned,
222{
223 use futures::StreamExt;
224
225 let mut json_buf = String::new();
226 let mut usage = Usage::default();
227
228 while let Some(event) = stream.next().await {
229 match event? {
230 StreamEvent::TextDelta(text) => json_buf.push_str(&text),
231 StreamEvent::Usage(u) => usage += u,
232 StreamEvent::Done { .. } => break,
233 _ => {}
234 }
235 }
236
237 if json_buf.is_empty() {
238 return Err(LlmError::ResponseFormat {
239 message: "model returned no text content for structured output".into(),
240 raw: String::new(),
241 });
242 }
243
244 let value: Value = serde_json::from_str(&json_buf).map_err(|e| LlmError::ResponseFormat {
246 message: format!("invalid JSON in structured output: {e}"),
247 raw: json_buf.clone(),
248 })?;
249
250 schema.validate(&value)?;
251
252 let typed: T = serde_json::from_value(value).map_err(|e| LlmError::ResponseFormat {
253 message: format!("failed to deserialize structured output: {e}"),
254 raw: json_buf.clone(),
255 })?;
256
257 Ok(PartialObject {
258 partial_json: json_buf,
259 complete: Some(typed),
260 usage,
261 })
262}
263
264#[cfg(feature = "schema")]
269fn extract_and_validate<T: DeserializeOwned>(
270 response: &ChatResponse,
271 schema: &JsonSchema,
272) -> Result<(T, String), LlmError> {
273 let raw_json = response.text().ok_or_else(|| LlmError::ResponseFormat {
274 message: "model returned no text content for structured output".into(),
275 raw: String::new(),
276 })?;
277
278 let value: Value = serde_json::from_str(raw_json).map_err(|e| LlmError::ResponseFormat {
279 message: format!("invalid JSON in structured output: {e}"),
280 raw: raw_json.to_string(),
281 })?;
282
283 schema.validate(&value)?;
284
285 let typed: T = serde_json::from_value(value).map_err(|e| LlmError::ResponseFormat {
286 message: format!("failed to deserialize structured output: {e}"),
287 raw: raw_json.to_string(),
288 })?;
289
290 Ok((typed, raw_json.to_string()))
291}
292
293fn inject_schema_prompt(params: &mut ChatParams, schema: &JsonSchema) {
296 let schema_json = serde_json::to_string_pretty(schema.as_value())
297 .expect("serializing Value to JSON cannot fail");
298
299 let instruction = format!(
300 "You must respond with valid JSON that conforms to this JSON Schema:\n\
301 ```json\n{schema_json}\n```\n\
302 Respond ONLY with the JSON object. No markdown, no explanation."
303 );
304
305 match &mut params.system {
306 Some(existing) => {
307 existing.push_str("\n\n");
308 existing.push_str(&instruction);
309 }
310 None => params.system = Some(instruction),
311 }
312}
313
314fn append_retry_feedback(params: &mut ChatParams, response: &ChatResponse, error: &LlmError) {
317 use crate::chat::ChatMessage;
318
319 params
321 .messages
322 .push(ChatMessage::assistant(response.text().unwrap_or("")));
323
324 params.messages.push(ChatMessage::user(format!(
326 "Your response did not pass validation: {error}\n\
327 Please try again with valid JSON that conforms to the schema."
328 )));
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::chat::{ChatMessage, ContentBlock, StopReason};
335 use crate::test_helpers::{mock_for, sample_usage};
336 use serde::Deserialize;
337 use serde_json::json;
338 use std::collections::HashMap;
339
340 #[derive(Debug, Deserialize, PartialEq, schemars::JsonSchema)]
341 struct Person {
342 name: String,
343 age: u32,
344 }
345
346 #[derive(Debug, Deserialize, PartialEq, schemars::JsonSchema)]
347 struct Coord {
348 x: f64,
349 y: f64,
350 }
351
352 fn json_response(json_str: &str) -> ChatResponse {
353 ChatResponse {
354 content: vec![ContentBlock::Text(json_str.into())],
355 usage: sample_usage(),
356 stop_reason: StopReason::EndTurn,
357 model: "test-model".into(),
358 metadata: HashMap::new(),
359 }
360 }
361
362 #[tokio::test]
365 async fn test_generate_object_happy_path() {
366 let mock = mock_for("test", "test-model");
367 mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
368
369 let params = ChatParams {
370 messages: vec![ChatMessage::user("Generate a person")],
371 ..Default::default()
372 };
373
374 let result: GenerateObjectResult<Person> =
375 generate_object(&mock, params, GenerateObjectConfig::default())
376 .await
377 .unwrap();
378
379 assert_eq!(
380 result.value,
381 Person {
382 name: "Alice".into(),
383 age: 30
384 }
385 );
386 assert_eq!(result.attempts, 1);
387 assert_eq!(result.raw_json, r#"{"name": "Alice", "age": 30}"#);
388 }
389
390 #[tokio::test]
391 async fn test_generate_object_invalid_json() {
392 let mock = mock_for("test", "test-model");
393 mock.queue_response(json_response("not valid json"));
394
395 let params = ChatParams {
396 messages: vec![ChatMessage::user("Generate a person")],
397 ..Default::default()
398 };
399
400 let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
401 .await
402 .unwrap_err();
403
404 assert!(matches!(err, LlmError::ResponseFormat { .. }));
405 }
406
407 #[tokio::test]
408 async fn test_generate_object_schema_violation() {
409 let mock = mock_for("test", "test-model");
410 mock.queue_response(json_response(r#"{"name": "Alice"}"#));
412
413 let params = ChatParams {
414 messages: vec![ChatMessage::user("Generate a person")],
415 ..Default::default()
416 };
417
418 let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
419 .await
420 .unwrap_err();
421
422 assert!(matches!(err, LlmError::SchemaValidation { .. }));
423 }
424
425 #[tokio::test]
426 async fn test_generate_object_wrong_type() {
427 let mock = mock_for("test", "test-model");
428 mock.queue_response(json_response(r#"{"name": "Alice", "age": "thirty"}"#));
430
431 let params = ChatParams {
432 messages: vec![ChatMessage::user("Generate a person")],
433 ..Default::default()
434 };
435
436 let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
437 .await
438 .unwrap_err();
439
440 assert!(matches!(err, LlmError::SchemaValidation { .. }));
441 }
442
443 #[tokio::test]
444 async fn test_generate_object_retry_succeeds_on_second_attempt() {
445 let mock = mock_for("test", "test-model");
446 mock.queue_response(json_response(r#"{"name": "Alice"}"#));
448 mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
450
451 let params = ChatParams {
452 messages: vec![ChatMessage::user("Generate a person")],
453 ..Default::default()
454 };
455 let config = GenerateObjectConfig {
456 max_attempts: 2,
457 ..Default::default()
458 };
459
460 let result: GenerateObjectResult<Person> =
461 generate_object(&mock, params, config).await.unwrap();
462
463 assert_eq!(
464 result.value,
465 Person {
466 name: "Alice".into(),
467 age: 30
468 }
469 );
470 assert_eq!(result.attempts, 2);
471 assert_eq!(result.usage.input_tokens, 200);
473 }
474
475 #[tokio::test]
476 async fn test_generate_object_retry_exhausted() {
477 let mock = mock_for("test", "test-model");
478 mock.queue_response(json_response(r#"{"name": "Alice"}"#));
480 mock.queue_response(json_response(r#"{"name": "Bob"}"#));
481 mock.queue_response(json_response(r#"{"name": "Charlie"}"#));
482
483 let params = ChatParams {
484 messages: vec![ChatMessage::user("Generate a person")],
485 ..Default::default()
486 };
487 let config = GenerateObjectConfig {
488 max_attempts: 3,
489 ..Default::default()
490 };
491
492 let err = generate_object::<Person>(&mock, params, config)
493 .await
494 .unwrap_err();
495
496 assert!(matches!(err, LlmError::SchemaValidation { .. }));
497 }
498
499 #[tokio::test]
500 async fn test_generate_object_no_text_content() {
501 let mock = mock_for("test", "test-model");
502 mock.queue_response(ChatResponse {
504 content: vec![],
505 usage: sample_usage(),
506 stop_reason: StopReason::EndTurn,
507 model: "test-model".into(),
508 metadata: HashMap::new(),
509 });
510
511 let params = ChatParams {
512 messages: vec![ChatMessage::user("Generate a person")],
513 ..Default::default()
514 };
515
516 let err = generate_object::<Person>(&mock, params, GenerateObjectConfig::default())
517 .await
518 .unwrap_err();
519
520 assert!(matches!(err, LlmError::ResponseFormat { .. }));
521 }
522
523 #[tokio::test]
524 async fn test_generate_object_sets_structured_output() {
525 let mock = mock_for("test", "test-model");
526 mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
527
528 let params = ChatParams {
529 messages: vec![ChatMessage::user("Generate coords")],
530 ..Default::default()
531 };
532
533 let _result: GenerateObjectResult<Coord> =
534 generate_object(&mock, params, GenerateObjectConfig::default())
535 .await
536 .unwrap();
537
538 let recorded = mock.recorded_calls();
540 assert!(recorded[0].structured_output.is_some());
541 }
542
543 #[tokio::test]
544 async fn test_generate_object_system_prompt_fallback() {
545 let mock = mock_for("test", "test-model");
546 mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
547
548 let params = ChatParams {
549 messages: vec![ChatMessage::user("Generate coords")],
550 ..Default::default()
551 };
552 let config = GenerateObjectConfig {
553 system_prompt_fallback: true,
554 ..Default::default()
555 };
556
557 let _result: GenerateObjectResult<Coord> =
558 generate_object(&mock, params, config).await.unwrap();
559
560 let recorded = mock.recorded_calls();
561 assert!(recorded[0].system.is_some());
562 assert!(recorded[0].system.as_ref().unwrap().contains("JSON Schema"));
563 }
564
565 #[tokio::test]
566 async fn test_generate_object_system_prompt_appends() {
567 let mock = mock_for("test", "test-model");
568 mock.queue_response(json_response(r#"{"x": 1.0, "y": 2.0}"#));
569
570 let params = ChatParams {
571 messages: vec![ChatMessage::user("Generate coords")],
572 system: Some("You are a helpful assistant.".into()),
573 ..Default::default()
574 };
575 let config = GenerateObjectConfig {
576 system_prompt_fallback: true,
577 ..Default::default()
578 };
579
580 let _result: GenerateObjectResult<Coord> =
581 generate_object(&mock, params, config).await.unwrap();
582
583 let recorded = mock.recorded_calls();
584 let system = recorded[0].system.as_ref().unwrap();
585 assert!(system.starts_with("You are a helpful assistant."));
586 assert!(system.contains("JSON Schema"));
587 }
588
589 #[tokio::test]
590 async fn test_generate_object_retry_appends_feedback() {
591 let mock = mock_for("test", "test-model");
592 mock.queue_response(json_response(r#"{"name": "Alice"}"#));
593 mock.queue_response(json_response(r#"{"name": "Alice", "age": 30}"#));
594
595 let params = ChatParams {
596 messages: vec![ChatMessage::user("Generate a person")],
597 ..Default::default()
598 };
599 let config = GenerateObjectConfig {
600 max_attempts: 2,
601 ..Default::default()
602 };
603
604 let _result: GenerateObjectResult<Person> =
605 generate_object(&mock, params, config).await.unwrap();
606
607 let recorded = mock.recorded_calls();
609 assert!(recorded[1].messages.len() > 1);
610 let last_user_msg = recorded[1]
612 .messages
613 .iter()
614 .rfind(|m| m.role == crate::chat::ChatRole::User)
615 .unwrap();
616 let text = last_user_msg.content.iter().find_map(|b| match b {
617 ContentBlock::Text(t) => Some(t.as_str()),
618 _ => None,
619 });
620 assert!(text.unwrap().contains("did not pass validation"));
621 }
622
623 #[tokio::test]
626 async fn test_collect_stream_object_happy_path() {
627 let schema = JsonSchema::from_type::<Person>().unwrap();
628 let events = vec![
629 Ok(StreamEvent::TextDelta(r#"{"name":"#.into())),
630 Ok(StreamEvent::TextDelta(r#" "Alice", "age": 30}"#.into())),
631 Ok(StreamEvent::Usage(sample_usage())),
632 Ok(StreamEvent::Done {
633 stop_reason: StopReason::EndTurn,
634 }),
635 ];
636 let stream: ChatStream = Box::pin(futures::stream::iter(events));
637
638 let result: PartialObject<Person> = collect_stream_object(stream, &schema).await.unwrap();
639
640 assert_eq!(
641 result.complete.unwrap(),
642 Person {
643 name: "Alice".into(),
644 age: 30
645 }
646 );
647 assert_eq!(result.partial_json, r#"{"name": "Alice", "age": 30}"#);
648 assert_eq!(result.usage.input_tokens, 100);
649 }
650
651 #[tokio::test]
652 async fn test_collect_stream_object_invalid_json() {
653 let schema = JsonSchema::from_type::<Person>().unwrap();
654 let events = vec![
655 Ok(StreamEvent::TextDelta("not json".into())),
656 Ok(StreamEvent::Done {
657 stop_reason: StopReason::EndTurn,
658 }),
659 ];
660 let stream: ChatStream = Box::pin(futures::stream::iter(events));
661
662 let err = collect_stream_object::<Person>(stream, &schema)
663 .await
664 .unwrap_err();
665
666 assert!(matches!(err, LlmError::ResponseFormat { .. }));
667 }
668
669 #[tokio::test]
670 async fn test_collect_stream_object_schema_violation() {
671 let schema = JsonSchema::from_type::<Person>().unwrap();
672 let events = vec![
673 Ok(StreamEvent::TextDelta(r#"{"name": "Alice"}"#.into())),
674 Ok(StreamEvent::Done {
675 stop_reason: StopReason::EndTurn,
676 }),
677 ];
678 let stream: ChatStream = Box::pin(futures::stream::iter(events));
679
680 let err = collect_stream_object::<Person>(stream, &schema)
681 .await
682 .unwrap_err();
683
684 assert!(matches!(err, LlmError::SchemaValidation { .. }));
685 }
686
687 #[tokio::test]
688 async fn test_collect_stream_object_empty_stream() {
689 let schema = JsonSchema::from_type::<Person>().unwrap();
690 let events = vec![Ok(StreamEvent::Done {
691 stop_reason: StopReason::EndTurn,
692 })];
693 let stream: ChatStream = Box::pin(futures::stream::iter(events));
694
695 let err = collect_stream_object::<Person>(stream, &schema)
696 .await
697 .unwrap_err();
698
699 assert!(matches!(err, LlmError::ResponseFormat { .. }));
700 }
701
702 #[tokio::test]
703 async fn test_collect_stream_object_mid_stream_error() {
704 let schema = JsonSchema::from_type::<Person>().unwrap();
705 let events = vec![
706 Ok(StreamEvent::TextDelta(r#"{"name"#.into())),
707 Err(LlmError::Http {
708 status: None,
709 message: "connection lost".into(),
710 retryable: true,
711 }),
712 ];
713 let stream: ChatStream = Box::pin(futures::stream::iter(events));
714
715 let err = collect_stream_object::<Person>(stream, &schema)
716 .await
717 .unwrap_err();
718
719 assert!(matches!(err, LlmError::Http { .. }));
720 }
721
722 #[tokio::test]
725 async fn test_stream_object_async_sets_structured_output() {
726 let mock = mock_for("test", "test-model");
727 mock.queue_stream(vec![
728 StreamEvent::TextDelta(r#"{"x": 1.0, "y": 2.0}"#.into()),
729 StreamEvent::Done {
730 stop_reason: StopReason::EndTurn,
731 },
732 ]);
733
734 let params = ChatParams {
735 messages: vec![ChatMessage::user("Generate coords")],
736 ..Default::default()
737 };
738
739 let _stream = stream_object_async::<Coord>(&mock, params, GenerateObjectConfig::default())
740 .await
741 .unwrap();
742
743 let recorded = mock.recorded_calls();
744 assert!(recorded[0].structured_output.is_some());
745 }
746
747 #[tokio::test]
748 async fn test_stream_object_async_end_to_end() {
749 let mock = mock_for("test", "test-model");
750 mock.queue_stream(vec![
751 StreamEvent::TextDelta(r#"{"x": 1.23"#.into()),
752 StreamEvent::TextDelta(r#", "y": 4.56}"#.into()),
753 StreamEvent::Done {
754 stop_reason: StopReason::EndTurn,
755 },
756 ]);
757
758 let params = ChatParams {
759 messages: vec![ChatMessage::user("Generate coords")],
760 ..Default::default()
761 };
762
763 let schema = JsonSchema::from_type::<Coord>().unwrap();
764 let stream = stream_object_async::<Coord>(&mock, params, GenerateObjectConfig::default())
765 .await
766 .unwrap();
767
768 let result: PartialObject<Coord> = collect_stream_object(stream, &schema).await.unwrap();
769
770 let coord = result.complete.unwrap();
771 assert!((coord.x - 1.23).abs() < 0.001);
772 assert!((coord.y - 4.56).abs() < 0.001);
773 }
774
775 #[test]
778 fn test_inject_schema_prompt_new() {
779 let schema = JsonSchema::new(json!({"type": "object"}));
780 let mut params = ChatParams::default();
781 inject_schema_prompt(&mut params, &schema);
782
783 assert!(params.system.as_ref().unwrap().contains("JSON Schema"));
784 assert!(params.system.as_ref().unwrap().contains("ONLY"));
785 }
786
787 #[test]
788 fn test_inject_schema_prompt_appends() {
789 let schema = JsonSchema::new(json!({"type": "object"}));
790 let mut params = ChatParams {
791 system: Some("Be helpful.".into()),
792 ..Default::default()
793 };
794 inject_schema_prompt(&mut params, &schema);
795
796 let system = params.system.unwrap();
797 assert!(system.starts_with("Be helpful."));
798 assert!(system.contains("JSON Schema"));
799 }
800
801 #[test]
802 fn test_generate_object_config_default() {
803 let config = GenerateObjectConfig::default();
804 assert_eq!(config.max_attempts, 1);
805 assert!(!config.system_prompt_fallback);
806 }
807
808 #[test]
809 fn test_partial_object_debug() {
810 let po: PartialObject<Person> = PartialObject {
811 partial_json: "{}".into(),
812 complete: None,
813 usage: Usage::default(),
814 };
815 let debug = format!("{po:?}");
816 assert!(debug.contains("PartialObject"));
817 }
818
819 #[test]
820 fn test_generate_object_result_debug() {
821 let result = GenerateObjectResult {
822 value: Person {
823 name: "Alice".into(),
824 age: 30,
825 },
826 raw_json: "{}".into(),
827 usage: Usage::default(),
828 attempts: 1,
829 };
830 let debug = format!("{result:?}");
831 assert!(debug.contains("GenerateObjectResult"));
832 }
833
834 #[tokio::test]
835 async fn test_generate_object_zero_attempts_errors() {
836 let mock = mock_for("test", "test-model");
837 let params = ChatParams {
838 messages: vec![ChatMessage::user("Generate a person")],
839 ..Default::default()
840 };
841 let config = GenerateObjectConfig {
842 max_attempts: 0,
843 ..Default::default()
844 };
845
846 let err = generate_object::<Person>(&mock, params, config)
847 .await
848 .unwrap_err();
849
850 assert!(matches!(err, LlmError::InvalidRequest(_)));
851 }
852
853 #[tokio::test]
854 async fn test_generate_object_provider_error_propagates() {
855 let mock = mock_for("test", "test-model");
856 mock.queue_response(json_response(r#"{"name": "Alice"}"#));
858 mock.queue_error(crate::mock::MockError::Http {
859 status: Some(http::StatusCode::SERVICE_UNAVAILABLE),
860 message: "service down".into(),
861 retryable: true,
862 });
863
864 let params = ChatParams {
865 messages: vec![ChatMessage::user("Generate a person")],
866 ..Default::default()
867 };
868 let config = GenerateObjectConfig {
869 max_attempts: 2,
870 ..Default::default()
871 };
872
873 let err = generate_object::<Person>(&mock, params, config)
874 .await
875 .unwrap_err();
876
877 assert!(matches!(err, LlmError::Http { .. }));
879 }
880}