Skip to main content

foundation_models/
schema.rs

1//! JSON-schema and dynamic schema builders for structured generation.
2
3use core::ffi::{c_char, c_void};
4use std::collections::BTreeMap;
5use std::ffi::CString;
6use std::sync::mpsc;
7
8use serde_json::{json, Map, Value};
9
10use crate::content::{FromGeneratedContent, ToGeneratedContent};
11use crate::error::FMError;
12use crate::ffi;
13
14/// A validated FoundationModels generation schema encoded as JSON Schema.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct GenerationSchema {
17    json_schema: String,
18}
19
20impl GenerationSchema {
21    /// Validate and store a JSON schema definition.
22    ///
23    /// # Errors
24    ///
25    /// Returns an [`FMError`] if Apple's `GenerationSchema` rejects the schema.
26    pub fn from_json_schema(json_schema: impl Into<String>) -> Result<Self, FMError> {
27        let json_schema = json_schema.into();
28        let schema_c = CString::new(json_schema.as_str()).map_err(|error| {
29            FMError::InvalidArgument(format!(
30                "schema JSON contains an interior NUL byte: {error}"
31            ))
32        })?;
33        let mut error_ptr: *mut c_char = core::ptr::null_mut();
34        let status =
35            unsafe { ffi::fm_generation_schema_validate_json(schema_c.as_ptr(), &mut error_ptr) };
36        if status != ffi::status::OK {
37            return Err(crate::error::from_swift(status, error_ptr));
38        }
39        Ok(Self { json_schema })
40    }
41
42    /// Create a schema from a dynamic root schema plus optional dependencies.
43    ///
44    /// # Errors
45    ///
46    /// Returns an [`FMError`] if the dynamic schema is invalid.
47    pub fn from_dynamic(
48        root: DynamicGenerationSchema,
49        dependencies: impl IntoIterator<Item = DynamicGenerationSchema>,
50    ) -> Result<Self, FMError> {
51        let request = json!({
52            "root": root.to_json_value(),
53            "dependencies": dependencies
54                .into_iter()
55                .map(|schema| schema.to_json_value())
56                .collect::<Vec<_>>(),
57        });
58        let request_json = serde_json::to_string(&request).map_err(|error| {
59            FMError::InvalidArgument(format!(
60                "dynamic schema request is not JSON-serializable: {error}"
61            ))
62        })?;
63        let request_c = CString::new(request_json).map_err(|error| {
64            FMError::InvalidArgument(format!("dynamic schema JSON contains NUL byte: {error}"))
65        })?;
66        let (tx, rx) = mpsc::channel();
67        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
68        let context = Box::into_raw(tx_box).cast::<c_void>();
69        unsafe {
70            ffi::fm_generation_schema_compile_json(
71                request_c.as_ptr(),
72                context,
73                schema_callback_trampoline,
74            );
75        }
76        let json_schema = rx.recv().map_err(|_| FMError::Unknown {
77            code: ffi::status::UNKNOWN,
78            message: "Swift bridge dropped the schema callback channel".into(),
79        })??;
80        Ok(Self { json_schema })
81    }
82
83    /// The JSON Schema payload accepted by Apple's `GenerationSchema`.
84    #[must_use]
85    pub fn json_schema(&self) -> &str {
86        &self.json_schema
87    }
88
89    /// Best-effort name (the schema's `title`).
90    #[must_use]
91    pub fn name(&self) -> Option<String> {
92        let value: Value = serde_json::from_str(&self.json_schema).ok()?;
93        value.get("title")?.as_str().map(ToOwned::to_owned)
94    }
95
96    /// A JSON string schema.
97    #[must_use]
98    pub fn string() -> Self {
99        Self::from_json_schema_unchecked(r#"{"type":"string"}"#.into())
100    }
101
102    /// A JSON integer schema.
103    #[must_use]
104    pub fn integer() -> Self {
105        Self::from_json_schema_unchecked(r#"{"type":"integer"}"#.into())
106    }
107
108    /// A JSON number schema.
109    #[must_use]
110    pub fn number() -> Self {
111        Self::from_json_schema_unchecked(r#"{"type":"number"}"#.into())
112    }
113
114    /// A JSON boolean schema.
115    #[must_use]
116    pub fn boolean() -> Self {
117        Self::from_json_schema_unchecked(r#"{"type":"boolean"}"#.into())
118    }
119
120    /// A schema for arbitrary JSON (`GeneratedContent`).
121    #[must_use]
122    pub fn generated_content() -> Self {
123        Self::from_json_schema_unchecked(
124            r##"{"title":"GeneratedContent","description":"Any legal JSON","anyOf":[{"type":"object","additionalProperties":{"$ref":"#"}},{"type":"array","items":{"$ref":"#"}},{"type":"boolean"},{"type":"number"},{"type":"string"}]}"##.into(),
125        )
126    }
127
128    pub(crate) fn from_json_schema_unchecked(json_schema: String) -> Self {
129        Self { json_schema }
130    }
131}
132
133/// A dynamic FoundationModels schema description.
134#[derive(Debug, Clone, PartialEq)]
135pub enum DynamicGenerationSchema {
136    Object {
137        name: String,
138        description: Option<String>,
139        properties: BTreeMap<String, DynamicGenerationProperty>,
140    },
141    Array {
142        item: Box<DynamicGenerationSchema>,
143        minimum_elements: Option<usize>,
144        maximum_elements: Option<usize>,
145        guides: Vec<GenerationGuide>,
146    },
147    AnyOf {
148        name: String,
149        description: Option<String>,
150        choices: Vec<DynamicGenerationSchema>,
151    },
152    AnyOfStrings {
153        name: String,
154        description: Option<String>,
155        choices: Vec<String>,
156    },
157    String {
158        description: Option<String>,
159        guides: Vec<GenerationGuide>,
160    },
161    Integer {
162        description: Option<String>,
163        guides: Vec<GenerationGuide>,
164    },
165    Float {
166        description: Option<String>,
167        guides: Vec<GenerationGuide>,
168    },
169    Number {
170        description: Option<String>,
171        guides: Vec<GenerationGuide>,
172    },
173    Decimal {
174        description: Option<String>,
175        guides: Vec<GenerationGuide>,
176    },
177    Boolean {
178        description: Option<String>,
179    },
180    GeneratedContent {
181        description: Option<String>,
182    },
183    Reference {
184        name: String,
185    },
186}
187
188impl DynamicGenerationSchema {
189    /// Create an object schema.
190    #[must_use]
191    pub fn object(name: impl Into<String>) -> Self {
192        Self::Object {
193            name: name.into(),
194            description: None,
195            properties: BTreeMap::new(),
196        }
197    }
198
199    /// Create a string schema.
200    #[must_use]
201    pub fn string() -> Self {
202        Self::String {
203            description: None,
204            guides: Vec::new(),
205        }
206    }
207
208    /// Create an integer schema.
209    #[must_use]
210    pub fn integer() -> Self {
211        Self::Integer {
212            description: None,
213            guides: Vec::new(),
214        }
215    }
216
217    /// Create a floating-point schema.
218    #[must_use]
219    pub fn float() -> Self {
220        Self::Float {
221            description: None,
222            guides: Vec::new(),
223        }
224    }
225
226    /// Create a number schema.
227    #[must_use]
228    pub fn number() -> Self {
229        Self::Number {
230            description: None,
231            guides: Vec::new(),
232        }
233    }
234
235    /// Create a decimal schema.
236    #[must_use]
237    pub fn decimal() -> Self {
238        Self::Decimal {
239            description: None,
240            guides: Vec::new(),
241        }
242    }
243
244    /// Create a boolean schema.
245    #[must_use]
246    pub fn boolean() -> Self {
247        Self::Boolean { description: None }
248    }
249
250    /// Create an arbitrary-JSON schema.
251    #[must_use]
252    pub fn generated_content() -> Self {
253        Self::GeneratedContent { description: None }
254    }
255
256    /// Create an array schema.
257    #[must_use]
258    pub fn array_of(item: Self) -> Self {
259        Self::Array {
260            item: Box::new(item),
261            minimum_elements: None,
262            maximum_elements: None,
263            guides: Vec::new(),
264        }
265    }
266
267    /// Create a reference to a named dependency.
268    #[must_use]
269    pub fn reference(name: impl Into<String>) -> Self {
270        Self::Reference { name: name.into() }
271    }
272
273    /// Create a named union of schemas.
274    #[must_use]
275    pub fn any_of(name: impl Into<String>, choices: Vec<Self>) -> Self {
276        Self::AnyOf {
277            name: name.into(),
278            description: None,
279            choices,
280        }
281    }
282
283    /// Create a named union of constant string choices.
284    #[must_use]
285    pub fn any_of_strings(
286        name: impl Into<String>,
287        choices: impl IntoIterator<Item = impl Into<String>>,
288    ) -> Self {
289        Self::AnyOfStrings {
290            name: name.into(),
291            description: None,
292            choices: choices.into_iter().map(Into::into).collect(),
293        }
294    }
295
296    /// Attach a description.
297    #[must_use]
298    pub fn with_description(mut self, description: impl Into<String>) -> Self {
299        match &mut self {
300            Self::Object {
301                description: slot, ..
302            }
303            | Self::AnyOf {
304                description: slot, ..
305            }
306            | Self::AnyOfStrings {
307                description: slot, ..
308            }
309            | Self::String {
310                description: slot, ..
311            }
312            | Self::Integer {
313                description: slot, ..
314            }
315            | Self::Float {
316                description: slot, ..
317            }
318            | Self::Number {
319                description: slot, ..
320            }
321            | Self::Decimal {
322                description: slot, ..
323            }
324            | Self::Boolean { description: slot }
325            | Self::GeneratedContent { description: slot } => *slot = Some(description.into()),
326            Self::Array { .. } | Self::Reference { .. } => {}
327        }
328        self
329    }
330
331    /// Add a property to an object schema.
332    #[must_use]
333    pub fn with_property(
334        mut self,
335        name: impl Into<String>,
336        property: DynamicGenerationProperty,
337    ) -> Self {
338        if let Self::Object { properties, .. } = &mut self {
339            properties.insert(name.into(), property);
340        }
341        self
342    }
343
344    /// Set the array bounds.
345    #[must_use]
346    pub fn with_element_bounds(mut self, minimum: Option<usize>, maximum: Option<usize>) -> Self {
347        if let Self::Array {
348            minimum_elements,
349            maximum_elements,
350            ..
351        } = &mut self
352        {
353            *minimum_elements = minimum;
354            *maximum_elements = maximum;
355        }
356        self
357    }
358
359    /// Attach FoundationModels generation guides.
360    #[must_use]
361    pub fn with_guides(mut self, guides: impl IntoIterator<Item = GenerationGuide>) -> Self {
362        let guides: Vec<_> = guides.into_iter().collect();
363        match &mut self {
364            Self::String { guides: slot, .. }
365            | Self::Integer { guides: slot, .. }
366            | Self::Float { guides: slot, .. }
367            | Self::Number { guides: slot, .. }
368            | Self::Decimal { guides: slot, .. }
369            | Self::Array { guides: slot, .. } => *slot = guides,
370            Self::Object { .. }
371            | Self::AnyOf { .. }
372            | Self::AnyOfStrings { .. }
373            | Self::Boolean { .. }
374            | Self::GeneratedContent { .. }
375            | Self::Reference { .. } => {}
376        }
377        self
378    }
379
380    fn to_json_value(&self) -> Value {
381        match self {
382            Self::Object {
383                name,
384                description,
385                properties,
386            } => object_schema_json(name, description, properties),
387            Self::Array {
388                item,
389                minimum_elements,
390                maximum_elements,
391                guides,
392            } => array_schema_json(item, *minimum_elements, *maximum_elements, guides),
393            Self::AnyOf {
394                name,
395                description,
396                choices,
397            } => named_schema_json(
398                "any_of",
399                name,
400                description,
401                Value::Array(choices.iter().map(Self::to_json_value).collect()),
402            ),
403            Self::AnyOfStrings {
404                name,
405                description,
406                choices,
407            } => named_schema_json(
408                "any_of",
409                name,
410                description,
411                Value::Array(choices.iter().cloned().map(Value::String).collect()),
412            ),
413            Self::String {
414                description,
415                guides,
416            } => primitive_schema_json("string", description, guides),
417            Self::Integer {
418                description,
419                guides,
420            } => primitive_schema_json("integer", description, guides),
421            Self::Float {
422                description,
423                guides,
424            } => primitive_schema_json("float", description, guides),
425            Self::Number {
426                description,
427                guides,
428            } => primitive_schema_json("number", description, guides),
429            Self::Decimal {
430                description,
431                guides,
432            } => primitive_schema_json("decimal", description, guides),
433            Self::Boolean { description } => primitive_schema_json("boolean", description, &[]),
434            Self::GeneratedContent { description } => {
435                primitive_schema_json("generated_content", description, &[])
436            }
437            Self::Reference { name } => json!({ "$ref": name }),
438        }
439    }
440}
441
442fn named_schema_json(
443    kind: &str,
444    name: &str,
445    description: &Option<String>,
446    choices: Value,
447) -> Value {
448    let mut map = Map::new();
449    map.insert("type".into(), Value::String(kind.into()));
450    map.insert("name".into(), Value::String(name.to_string()));
451    if let Some(description) = description {
452        map.insert("description".into(), Value::String(description.clone()));
453    }
454    map.insert("choices".into(), choices);
455    Value::Object(map)
456}
457
458fn object_schema_json(
459    name: &str,
460    description: &Option<String>,
461    properties: &BTreeMap<String, DynamicGenerationProperty>,
462) -> Value {
463    let property_map = properties
464        .iter()
465        .map(|(property_name, property)| (property_name.clone(), property.to_json_value()))
466        .collect::<Map<String, Value>>();
467    let mut map = Map::new();
468    map.insert("type".into(), Value::String("object".into()));
469    map.insert("name".into(), Value::String(name.to_string()));
470    if let Some(description) = description {
471        map.insert("description".into(), Value::String(description.clone()));
472    }
473    map.insert("properties".into(), Value::Object(property_map));
474    Value::Object(map)
475}
476
477fn array_schema_json(
478    item: &DynamicGenerationSchema,
479    minimum_elements: Option<usize>,
480    maximum_elements: Option<usize>,
481    guides: &[GenerationGuide],
482) -> Value {
483    let mut map = Map::new();
484    map.insert("type".into(), Value::String("array".into()));
485    map.insert("items".into(), item.to_json_value());
486    if let Some(minimum_elements) = minimum_elements {
487        map.insert("min".into(), Value::from(minimum_elements));
488    }
489    if let Some(maximum_elements) = maximum_elements {
490        map.insert("max".into(), Value::from(maximum_elements));
491    }
492    if !guides.is_empty() {
493        map.insert(
494            "guides".into(),
495            Value::Array(guides.iter().map(GenerationGuide::to_json_value).collect()),
496        );
497    }
498    Value::Object(map)
499}
500
501/// A property in a dynamic object schema.
502#[derive(Debug, Clone, PartialEq)]
503pub struct DynamicGenerationProperty {
504    pub schema: DynamicGenerationSchema,
505    pub description: Option<String>,
506    pub optional: bool,
507}
508
509impl DynamicGenerationProperty {
510    /// Create a property from a nested schema.
511    #[must_use]
512    pub fn new(schema: DynamicGenerationSchema) -> Self {
513        Self {
514            schema,
515            description: None,
516            optional: false,
517        }
518    }
519
520    /// Mark the property as optional.
521    #[must_use]
522    pub const fn optional(mut self, optional: bool) -> Self {
523        self.optional = optional;
524        self
525    }
526
527    /// Attach a property description.
528    #[must_use]
529    pub fn with_description(mut self, description: impl Into<String>) -> Self {
530        self.description = Some(description.into());
531        self
532    }
533
534    fn to_json_value(&self) -> Value {
535        let mut value = self.schema.to_json_value();
536        if let Value::Object(map) = &mut value {
537            if let Some(description) = &self.description {
538                map.insert("description".into(), Value::String(description.clone()));
539            }
540            if self.optional {
541                map.insert("optional".into(), Value::Bool(true));
542            }
543        }
544        value
545    }
546}
547
548/// One of Apple's public `GenerationGuide` builders.
549#[derive(Debug, Clone, PartialEq)]
550pub enum GenerationGuide {
551    StringConstant(String),
552    StringAnyOf(Vec<String>),
553    StringPattern(String),
554    MinimumI64(i64),
555    MaximumI64(i64),
556    RangeI64(i64, i64),
557    MinimumF32(f32),
558    MaximumF32(f32),
559    RangeF32(f32, f32),
560    MinimumF64(f64),
561    MaximumF64(f64),
562    RangeF64(f64, f64),
563    MinimumDecimal(String),
564    MaximumDecimal(String),
565    RangeDecimal(String, String),
566    MinimumCount(usize),
567    MaximumCount(usize),
568    CountRange(usize, usize),
569    CountExact(usize),
570    Element(Box<GenerationGuide>),
571}
572
573impl GenerationGuide {
574    #[must_use]
575    pub fn string_constant(value: impl Into<String>) -> Self {
576        Self::StringConstant(value.into())
577    }
578
579    #[must_use]
580    pub fn string_any_of(values: impl IntoIterator<Item = impl Into<String>>) -> Self {
581        Self::StringAnyOf(values.into_iter().map(Into::into).collect())
582    }
583
584    #[must_use]
585    pub fn string_pattern(pattern: impl Into<String>) -> Self {
586        Self::StringPattern(pattern.into())
587    }
588
589    #[must_use]
590    pub const fn minimum_i64(value: i64) -> Self {
591        Self::MinimumI64(value)
592    }
593
594    #[must_use]
595    pub const fn maximum_i64(value: i64) -> Self {
596        Self::MaximumI64(value)
597    }
598
599    #[must_use]
600    pub const fn range_i64(minimum: i64, maximum: i64) -> Self {
601        Self::RangeI64(minimum, maximum)
602    }
603
604    #[must_use]
605    pub const fn minimum_f32(value: f32) -> Self {
606        Self::MinimumF32(value)
607    }
608
609    #[must_use]
610    pub const fn maximum_f32(value: f32) -> Self {
611        Self::MaximumF32(value)
612    }
613
614    #[must_use]
615    pub const fn range_f32(minimum: f32, maximum: f32) -> Self {
616        Self::RangeF32(minimum, maximum)
617    }
618
619    #[must_use]
620    pub const fn minimum_f64(value: f64) -> Self {
621        Self::MinimumF64(value)
622    }
623
624    #[must_use]
625    pub const fn maximum_f64(value: f64) -> Self {
626        Self::MaximumF64(value)
627    }
628
629    #[must_use]
630    pub const fn range_f64(minimum: f64, maximum: f64) -> Self {
631        Self::RangeF64(minimum, maximum)
632    }
633
634    #[must_use]
635    pub fn minimum_decimal(value: impl Into<String>) -> Self {
636        Self::MinimumDecimal(value.into())
637    }
638
639    #[must_use]
640    pub fn maximum_decimal(value: impl Into<String>) -> Self {
641        Self::MaximumDecimal(value.into())
642    }
643
644    #[must_use]
645    pub fn range_decimal(minimum: impl Into<String>, maximum: impl Into<String>) -> Self {
646        Self::RangeDecimal(minimum.into(), maximum.into())
647    }
648
649    #[must_use]
650    pub const fn minimum_count(count: usize) -> Self {
651        Self::MinimumCount(count)
652    }
653
654    #[must_use]
655    pub const fn maximum_count(count: usize) -> Self {
656        Self::MaximumCount(count)
657    }
658
659    #[must_use]
660    pub const fn count_range(minimum: usize, maximum: usize) -> Self {
661        Self::CountRange(minimum, maximum)
662    }
663
664    #[must_use]
665    pub const fn count(count: usize) -> Self {
666        Self::CountExact(count)
667    }
668
669    #[must_use]
670    pub fn element(guide: GenerationGuide) -> Self {
671        Self::Element(Box::new(guide))
672    }
673
674    fn to_json_value(&self) -> Value {
675        match self {
676            Self::StringConstant(value) => json!({ "kind": "constant", "value": value }),
677            Self::StringAnyOf(values) => json!({ "kind": "any_of", "values": values }),
678            Self::StringPattern(pattern) => json!({ "kind": "pattern", "pattern": pattern }),
679            Self::MinimumI64(value) => json!({ "kind": "minimum", "value": value }),
680            Self::MaximumI64(value) => json!({ "kind": "maximum", "value": value }),
681            Self::RangeI64(minimum, maximum) => {
682                json!({ "kind": "range", "min": minimum, "max": maximum })
683            }
684            Self::MinimumF32(value) => json!({ "kind": "minimum", "value": value }),
685            Self::MaximumF32(value) => json!({ "kind": "maximum", "value": value }),
686            Self::RangeF32(minimum, maximum) => {
687                json!({ "kind": "range", "min": minimum, "max": maximum })
688            }
689            Self::MinimumF64(value) => json!({ "kind": "minimum", "value": value }),
690            Self::MaximumF64(value) => json!({ "kind": "maximum", "value": value }),
691            Self::RangeF64(minimum, maximum) => {
692                json!({ "kind": "range", "min": minimum, "max": maximum })
693            }
694            Self::MinimumDecimal(value) => json!({ "kind": "minimum", "value": value }),
695            Self::MaximumDecimal(value) => json!({ "kind": "maximum", "value": value }),
696            Self::RangeDecimal(minimum, maximum) => {
697                json!({ "kind": "range", "min": minimum, "max": maximum })
698            }
699            Self::MinimumCount(count) => json!({ "kind": "minimum_count", "value": count }),
700            Self::MaximumCount(count) => json!({ "kind": "maximum_count", "value": count }),
701            Self::CountRange(minimum, maximum) => {
702                json!({ "kind": "count", "min": minimum, "max": maximum })
703            }
704            Self::CountExact(count) => json!({ "kind": "count", "value": count }),
705            Self::Element(guide) => json!({ "kind": "element", "guide": guide.to_json_value() }),
706        }
707    }
708}
709
710fn primitive_schema_json(
711    kind: &str,
712    description: &Option<String>,
713    guides: &[GenerationGuide],
714) -> Value {
715    let mut map = Map::new();
716    map.insert("type".into(), Value::String(kind.into()));
717    if let Some(description) = description {
718        map.insert("description".into(), Value::String(description.clone()));
719    }
720    if !guides.is_empty() {
721        map.insert(
722            "guides".into(),
723            Value::Array(guides.iter().map(GenerationGuide::to_json_value).collect()),
724        );
725    }
726    Value::Object(map)
727}
728
729/// Rust analogue of FoundationModels' `Generable` protocol.
730pub trait Generable: Sized + FromGeneratedContent + ToGeneratedContent {
731    /// Return the generation schema that describes `Self`.
732    fn generation_schema() -> Result<GenerationSchema, FMError>;
733}
734
735// SAFETY: `context` is a `Box<mpsc::Sender<...>>` raw pointer created by
736// `GenerationSchema::compile`. Swift calls this callback exactly once, so
737// there is no double-free risk. `response` and `error` are C strings owned
738// by the Swift bridge and only valid for this call.
739unsafe extern "C" fn schema_callback_trampoline(
740    context: *mut c_void,
741    response: *mut c_char,
742    error: *mut c_char,
743    status: i32,
744) {
745    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
746    let result = if status == ffi::status::OK && !response.is_null() {
747        let value = core::ffi::CStr::from_ptr(response)
748            .to_string_lossy()
749            .into_owned();
750        ffi::fm_string_free(response);
751        Ok(value)
752    } else {
753        Err(crate::error::from_swift(status, error))
754    };
755    let _ = tx.send(result);
756}