Skip to main content

foundation_models/
schema.rs

1//! JSON-schema and dynamic schema builders for structured generation.
2
3use core::ffi::{c_char, c_void};
4use std::collections::BTreeMap;
5use std::ffi::CString;
6use std::sync::mpsc;
7
8use serde_json::{json, Map, Value};
9
10use crate::content::{FromGeneratedContent, GeneratedContent, ToGeneratedContent};
11use crate::error::FMError;
12use crate::ffi;
13
14/// A validated FoundationModels generation schema encoded as JSON Schema.
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub struct GenerationSchema {
17    json_schema: String,
18}
19
20impl GenerationSchema {
21    /// Validate and store a JSON schema definition.
22    ///
23    /// # Errors
24    ///
25    /// Returns an [`FMError`] if Apple's `GenerationSchema` rejects the schema.
26    pub fn from_json_schema(json_schema: impl Into<String>) -> Result<Self, FMError> {
27        let json_schema = json_schema.into();
28        let schema_c = CString::new(json_schema.as_str()).map_err(|error| {
29            FMError::InvalidArgument(format!(
30                "schema JSON contains an interior NUL byte: {error}"
31            ))
32        })?;
33        let mut error_ptr: *mut c_char = core::ptr::null_mut();
34        let status =
35            unsafe { ffi::fm_generation_schema_validate_json(schema_c.as_ptr(), &mut error_ptr) };
36        if status != ffi::status::OK {
37            return Err(crate::error::from_swift(status, error_ptr));
38        }
39        Ok(Self { json_schema })
40    }
41
42    /// Create a schema from a dynamic root schema plus optional dependencies.
43    ///
44    /// # Errors
45    ///
46    /// Returns an [`FMError`] if the dynamic schema is invalid.
47    pub fn from_dynamic(
48        root: DynamicGenerationSchema,
49        dependencies: impl IntoIterator<Item = DynamicGenerationSchema>,
50    ) -> Result<Self, FMError> {
51        let request = json!({
52            "root": root.to_json_value(),
53            "dependencies": dependencies
54                .into_iter()
55                .map(|schema| schema.to_json_value())
56                .collect::<Vec<_>>(),
57        });
58        let request_json = serde_json::to_string(&request).map_err(|error| {
59            FMError::InvalidArgument(format!(
60                "dynamic schema request is not JSON-serializable: {error}"
61            ))
62        })?;
63        let request_c = CString::new(request_json).map_err(|error| {
64            FMError::InvalidArgument(format!("dynamic schema JSON contains NUL byte: {error}"))
65        })?;
66        let (tx, rx) = mpsc::channel();
67        let tx_box: Box<mpsc::Sender<Result<String, FMError>>> = Box::new(tx);
68        let context = Box::into_raw(tx_box).cast::<c_void>();
69        unsafe {
70            ffi::fm_generation_schema_compile_json(
71                request_c.as_ptr(),
72                context,
73                schema_callback_trampoline,
74            );
75        }
76        let json_schema = rx.recv().map_err(|_| FMError::Unknown {
77            code: ffi::status::UNKNOWN,
78            message: "Swift bridge dropped the schema callback channel".into(),
79        })??;
80        Ok(Self { json_schema })
81    }
82
83    /// The JSON Schema payload accepted by Apple's `GenerationSchema`.
84    #[must_use]
85    pub fn json_schema(&self) -> &str {
86        &self.json_schema
87    }
88
89    /// Best-effort name (the schema's `title`).
90    #[must_use]
91    pub fn name(&self) -> Option<String> {
92        let value: Value = serde_json::from_str(&self.json_schema).ok()?;
93        value.get("title")?.as_str().map(ToOwned::to_owned)
94    }
95
96    /// A JSON string schema.
97    #[must_use]
98    pub fn string() -> Self {
99        Self::from_json_schema_unchecked(r#"{"type":"string"}"#.into())
100    }
101
102    /// A JSON integer schema.
103    #[must_use]
104    pub fn integer() -> Self {
105        Self::from_json_schema_unchecked(r#"{"type":"integer"}"#.into())
106    }
107
108    /// A JSON number schema.
109    #[must_use]
110    pub fn number() -> Self {
111        Self::from_json_schema_unchecked(r#"{"type":"number"}"#.into())
112    }
113
114    /// A JSON boolean schema.
115    #[must_use]
116    pub fn boolean() -> Self {
117        Self::from_json_schema_unchecked(r#"{"type":"boolean"}"#.into())
118    }
119
120    /// A schema for arbitrary JSON (`GeneratedContent`).
121    #[must_use]
122    pub fn generated_content() -> Self {
123        Self::from_json_schema_unchecked(
124            r##"{"title":"GeneratedContent","description":"Any legal JSON","anyOf":[{"type":"object","additionalProperties":{"$ref":"#"}},{"type":"array","items":{"$ref":"#"}},{"type":"boolean"},{"type":"number"},{"type":"string"}]}"##.into(),
125        )
126    }
127
128    pub(crate) fn from_json_schema_unchecked(json_schema: String) -> Self {
129        Self { json_schema }
130    }
131}
132
133/// A dynamic FoundationModels schema description.
134#[derive(Debug, Clone, PartialEq)]
135pub enum DynamicGenerationSchema {
136    Object {
137        name: String,
138        description: Option<String>,
139        properties: BTreeMap<String, DynamicGenerationProperty>,
140    },
141    Array {
142        item: Box<DynamicGenerationSchema>,
143        minimum_elements: Option<usize>,
144        maximum_elements: Option<usize>,
145    },
146    AnyOf {
147        name: String,
148        description: Option<String>,
149        choices: Vec<DynamicGenerationSchema>,
150    },
151    String {
152        description: Option<String>,
153        guides: Vec<GenerationGuide>,
154    },
155    Integer {
156        description: Option<String>,
157        guides: Vec<GenerationGuide>,
158    },
159    Float {
160        description: Option<String>,
161        guides: Vec<GenerationGuide>,
162    },
163    Number {
164        description: Option<String>,
165        guides: Vec<GenerationGuide>,
166    },
167    Decimal {
168        description: Option<String>,
169        guides: Vec<GenerationGuide>,
170    },
171    Boolean {
172        description: Option<String>,
173    },
174    GeneratedContent {
175        description: Option<String>,
176    },
177    Reference {
178        name: String,
179    },
180}
181
182impl DynamicGenerationSchema {
183    /// Create an object schema.
184    #[must_use]
185    pub fn object(name: impl Into<String>) -> Self {
186        Self::Object {
187            name: name.into(),
188            description: None,
189            properties: BTreeMap::new(),
190        }
191    }
192
193    /// Create a string schema.
194    #[must_use]
195    pub fn string() -> Self {
196        Self::String {
197            description: None,
198            guides: Vec::new(),
199        }
200    }
201
202    /// Create an integer schema.
203    #[must_use]
204    pub fn integer() -> Self {
205        Self::Integer {
206            description: None,
207            guides: Vec::new(),
208        }
209    }
210
211    /// Create a floating-point schema.
212    #[must_use]
213    pub fn float() -> Self {
214        Self::Float {
215            description: None,
216            guides: Vec::new(),
217        }
218    }
219
220    /// Create a number schema.
221    #[must_use]
222    pub fn number() -> Self {
223        Self::Number {
224            description: None,
225            guides: Vec::new(),
226        }
227    }
228
229    /// Create a decimal schema.
230    #[must_use]
231    pub fn decimal() -> Self {
232        Self::Decimal {
233            description: None,
234            guides: Vec::new(),
235        }
236    }
237
238    /// Create a boolean schema.
239    #[must_use]
240    pub fn boolean() -> Self {
241        Self::Boolean { description: None }
242    }
243
244    /// Create an arbitrary-JSON schema.
245    #[must_use]
246    pub fn generated_content() -> Self {
247        Self::GeneratedContent { description: None }
248    }
249
250    /// Create an array schema.
251    #[must_use]
252    pub fn array_of(item: Self) -> Self {
253        Self::Array {
254            item: Box::new(item),
255            minimum_elements: None,
256            maximum_elements: None,
257        }
258    }
259
260    /// Create a reference to a named dependency.
261    #[must_use]
262    pub fn reference(name: impl Into<String>) -> Self {
263        Self::Reference { name: name.into() }
264    }
265
266    /// Create a named union of schemas.
267    #[must_use]
268    pub fn any_of(name: impl Into<String>, choices: Vec<Self>) -> Self {
269        Self::AnyOf {
270            name: name.into(),
271            description: None,
272            choices,
273        }
274    }
275
276    /// Attach a description.
277    #[must_use]
278    pub fn with_description(mut self, description: impl Into<String>) -> Self {
279        match &mut self {
280            Self::Object {
281                description: slot, ..
282            }
283            | Self::AnyOf {
284                description: slot, ..
285            }
286            | Self::String {
287                description: slot, ..
288            }
289            | Self::Integer {
290                description: slot, ..
291            }
292            | Self::Float {
293                description: slot, ..
294            }
295            | Self::Number {
296                description: slot, ..
297            }
298            | Self::Decimal {
299                description: slot, ..
300            }
301            | Self::Boolean { description: slot }
302            | Self::GeneratedContent { description: slot } => *slot = Some(description.into()),
303            Self::Array { .. } | Self::Reference { .. } => {}
304        }
305        self
306    }
307
308    /// Add a property to an object schema.
309    #[must_use]
310    pub fn with_property(
311        mut self,
312        name: impl Into<String>,
313        property: DynamicGenerationProperty,
314    ) -> Self {
315        if let Self::Object { properties, .. } = &mut self {
316            properties.insert(name.into(), property);
317        }
318        self
319    }
320
321    /// Set the array bounds.
322    #[must_use]
323    pub fn with_element_bounds(mut self, minimum: Option<usize>, maximum: Option<usize>) -> Self {
324        if let Self::Array {
325            minimum_elements,
326            maximum_elements,
327            ..
328        } = &mut self
329        {
330            *minimum_elements = minimum;
331            *maximum_elements = maximum;
332        }
333        self
334    }
335
336    /// Attach primitive guides.
337    #[must_use]
338    pub fn with_guides(mut self, guides: impl IntoIterator<Item = GenerationGuide>) -> Self {
339        let guides: Vec<_> = guides.into_iter().collect();
340        match &mut self {
341            Self::String { guides: slot, .. }
342            | Self::Integer { guides: slot, .. }
343            | Self::Float { guides: slot, .. }
344            | Self::Number { guides: slot, .. }
345            | Self::Decimal { guides: slot, .. } => *slot = guides,
346            Self::Object { .. }
347            | Self::Array { .. }
348            | Self::AnyOf { .. }
349            | Self::Boolean { .. }
350            | Self::GeneratedContent { .. }
351            | Self::Reference { .. } => {}
352        }
353        self
354    }
355
356    fn to_json_value(&self) -> Value {
357        match self {
358            Self::Object {
359                name,
360                description,
361                properties,
362            } => {
363                let property_map = properties
364                    .iter()
365                    .map(|(property_name, property)| {
366                        (property_name.clone(), property.to_json_value())
367                    })
368                    .collect::<Map<String, Value>>();
369                let mut map = Map::new();
370                map.insert("type".into(), Value::String("object".into()));
371                map.insert("name".into(), Value::String(name.clone()));
372                if let Some(description) = description {
373                    map.insert("description".into(), Value::String(description.clone()));
374                }
375                map.insert("properties".into(), Value::Object(property_map));
376                Value::Object(map)
377            }
378            Self::Array {
379                item,
380                minimum_elements,
381                maximum_elements,
382            } => {
383                let mut map = Map::new();
384                map.insert("type".into(), Value::String("array".into()));
385                map.insert("items".into(), item.to_json_value());
386                if let Some(minimum_elements) = minimum_elements {
387                    map.insert("min".into(), Value::from(*minimum_elements));
388                }
389                if let Some(maximum_elements) = maximum_elements {
390                    map.insert("max".into(), Value::from(*maximum_elements));
391                }
392                Value::Object(map)
393            }
394            Self::AnyOf {
395                name,
396                description,
397                choices,
398            } => {
399                let mut map = Map::new();
400                map.insert("type".into(), Value::String("any_of".into()));
401                map.insert("name".into(), Value::String(name.clone()));
402                if let Some(description) = description {
403                    map.insert("description".into(), Value::String(description.clone()));
404                }
405                map.insert(
406                    "choices".into(),
407                    Value::Array(choices.iter().map(Self::to_json_value).collect()),
408                );
409                Value::Object(map)
410            }
411            Self::String {
412                description,
413                guides,
414            } => primitive_schema_json("string", description, guides),
415            Self::Integer {
416                description,
417                guides,
418            } => primitive_schema_json("integer", description, guides),
419            Self::Float {
420                description,
421                guides,
422            } => primitive_schema_json("float", description, guides),
423            Self::Number {
424                description,
425                guides,
426            } => primitive_schema_json("number", description, guides),
427            Self::Decimal {
428                description,
429                guides,
430            } => primitive_schema_json("decimal", description, guides),
431            Self::Boolean { description } => primitive_schema_json("boolean", description, &[]),
432            Self::GeneratedContent { description } => {
433                primitive_schema_json("generated_content", description, &[])
434            }
435            Self::Reference { name } => json!({ "$ref": name }),
436        }
437    }
438}
439
440/// A property in a dynamic object schema.
441#[derive(Debug, Clone, PartialEq)]
442pub struct DynamicGenerationProperty {
443    pub schema: DynamicGenerationSchema,
444    pub description: Option<String>,
445    pub optional: bool,
446}
447
448impl DynamicGenerationProperty {
449    /// Create a property from a nested schema.
450    #[must_use]
451    pub fn new(schema: DynamicGenerationSchema) -> Self {
452        Self {
453            schema,
454            description: None,
455            optional: false,
456        }
457    }
458
459    /// Mark the property as optional.
460    #[must_use]
461    pub const fn optional(mut self, optional: bool) -> Self {
462        self.optional = optional;
463        self
464    }
465
466    /// Attach a property description.
467    #[must_use]
468    pub fn with_description(mut self, description: impl Into<String>) -> Self {
469        self.description = Some(description.into());
470        self
471    }
472
473    fn to_json_value(&self) -> Value {
474        let mut value = self.schema.to_json_value();
475        if let Value::Object(map) = &mut value {
476            if let Some(description) = &self.description {
477                map.insert("description".into(), Value::String(description.clone()));
478            }
479            if self.optional {
480                map.insert("optional".into(), Value::Bool(true));
481            }
482        }
483        value
484    }
485}
486
487/// One of Apple's public `GenerationGuide` builders.
488#[derive(Debug, Clone, PartialEq)]
489pub enum GenerationGuide {
490    StringConstant(String),
491    StringAnyOf(Vec<String>),
492    StringPattern(String),
493    MinimumI64(i64),
494    MaximumI64(i64),
495    RangeI64(i64, i64),
496    MinimumF32(f32),
497    MaximumF32(f32),
498    RangeF32(f32, f32),
499    MinimumF64(f64),
500    MaximumF64(f64),
501    RangeF64(f64, f64),
502    MinimumDecimal(String),
503    MaximumDecimal(String),
504    RangeDecimal(String, String),
505}
506
507impl GenerationGuide {
508    #[must_use]
509    pub fn string_constant(value: impl Into<String>) -> Self {
510        Self::StringConstant(value.into())
511    }
512
513    #[must_use]
514    pub fn string_any_of(values: impl IntoIterator<Item = impl Into<String>>) -> Self {
515        Self::StringAnyOf(values.into_iter().map(Into::into).collect())
516    }
517
518    #[must_use]
519    pub fn string_pattern(pattern: impl Into<String>) -> Self {
520        Self::StringPattern(pattern.into())
521    }
522
523    #[must_use]
524    pub const fn minimum_i64(value: i64) -> Self {
525        Self::MinimumI64(value)
526    }
527
528    #[must_use]
529    pub const fn maximum_i64(value: i64) -> Self {
530        Self::MaximumI64(value)
531    }
532
533    #[must_use]
534    pub const fn range_i64(minimum: i64, maximum: i64) -> Self {
535        Self::RangeI64(minimum, maximum)
536    }
537
538    #[must_use]
539    pub const fn minimum_f32(value: f32) -> Self {
540        Self::MinimumF32(value)
541    }
542
543    #[must_use]
544    pub const fn maximum_f32(value: f32) -> Self {
545        Self::MaximumF32(value)
546    }
547
548    #[must_use]
549    pub const fn range_f32(minimum: f32, maximum: f32) -> Self {
550        Self::RangeF32(minimum, maximum)
551    }
552
553    #[must_use]
554    pub const fn minimum_f64(value: f64) -> Self {
555        Self::MinimumF64(value)
556    }
557
558    #[must_use]
559    pub const fn maximum_f64(value: f64) -> Self {
560        Self::MaximumF64(value)
561    }
562
563    #[must_use]
564    pub const fn range_f64(minimum: f64, maximum: f64) -> Self {
565        Self::RangeF64(minimum, maximum)
566    }
567
568    #[must_use]
569    pub fn minimum_decimal(value: impl Into<String>) -> Self {
570        Self::MinimumDecimal(value.into())
571    }
572
573    #[must_use]
574    pub fn maximum_decimal(value: impl Into<String>) -> Self {
575        Self::MaximumDecimal(value.into())
576    }
577
578    #[must_use]
579    pub fn range_decimal(minimum: impl Into<String>, maximum: impl Into<String>) -> Self {
580        Self::RangeDecimal(minimum.into(), maximum.into())
581    }
582
583    fn to_json_value(&self) -> Value {
584        match self {
585            Self::StringConstant(value) => json!({ "kind": "constant", "value": value }),
586            Self::StringAnyOf(values) => json!({ "kind": "any_of", "values": values }),
587            Self::StringPattern(pattern) => json!({ "kind": "pattern", "pattern": pattern }),
588            Self::MinimumI64(value) => json!({ "kind": "minimum", "value": value }),
589            Self::MaximumI64(value) => json!({ "kind": "maximum", "value": value }),
590            Self::RangeI64(minimum, maximum) => {
591                json!({ "kind": "range", "min": minimum, "max": maximum })
592            }
593            Self::MinimumF32(value) => json!({ "kind": "minimum", "value": value }),
594            Self::MaximumF32(value) => json!({ "kind": "maximum", "value": value }),
595            Self::RangeF32(minimum, maximum) => {
596                json!({ "kind": "range", "min": minimum, "max": maximum })
597            }
598            Self::MinimumF64(value) => json!({ "kind": "minimum", "value": value }),
599            Self::MaximumF64(value) => json!({ "kind": "maximum", "value": value }),
600            Self::RangeF64(minimum, maximum) => {
601                json!({ "kind": "range", "min": minimum, "max": maximum })
602            }
603            Self::MinimumDecimal(value) => json!({ "kind": "minimum", "value": value }),
604            Self::MaximumDecimal(value) => json!({ "kind": "maximum", "value": value }),
605            Self::RangeDecimal(minimum, maximum) => {
606                json!({ "kind": "range", "min": minimum, "max": maximum })
607            }
608        }
609    }
610}
611
612fn primitive_schema_json(
613    kind: &str,
614    description: &Option<String>,
615    guides: &[GenerationGuide],
616) -> Value {
617    let mut map = Map::new();
618    map.insert("type".into(), Value::String(kind.into()));
619    if let Some(description) = description {
620        map.insert("description".into(), Value::String(description.clone()));
621    }
622    if !guides.is_empty() {
623        map.insert(
624            "guides".into(),
625            Value::Array(guides.iter().map(GenerationGuide::to_json_value).collect()),
626        );
627    }
628    Value::Object(map)
629}
630
631/// Rust analogue of FoundationModels' `Generable` protocol.
632pub trait Generable: Sized + FromGeneratedContent + ToGeneratedContent {
633    /// Return the generation schema that describes `Self`.
634    fn generation_schema() -> Result<GenerationSchema, FMError>;
635}
636
637impl Generable for GeneratedContent {
638    fn generation_schema() -> Result<GenerationSchema, FMError> {
639        Ok(GenerationSchema::generated_content())
640    }
641}
642
643impl Generable for String {
644    fn generation_schema() -> Result<GenerationSchema, FMError> {
645        Ok(GenerationSchema::string())
646    }
647}
648
649impl Generable for bool {
650    fn generation_schema() -> Result<GenerationSchema, FMError> {
651        Ok(GenerationSchema::boolean())
652    }
653}
654
655macro_rules! impl_integer_generable {
656    ($($ty:ty),+ $(,)?) => {
657        $(
658            impl Generable for $ty {
659                fn generation_schema() -> Result<GenerationSchema, FMError> {
660                    Ok(GenerationSchema::integer())
661                }
662            }
663        )+
664    };
665}
666
667macro_rules! impl_number_generable {
668    ($($ty:ty),+ $(,)?) => {
669        $(
670            impl Generable for $ty {
671                fn generation_schema() -> Result<GenerationSchema, FMError> {
672                    Ok(GenerationSchema::number())
673                }
674            }
675        )+
676    };
677}
678
679impl_integer_generable!(i8, i16, i32, i64, u8, u16, u32, u64);
680impl_number_generable!(f32, f64);
681
682impl<T> Generable for Vec<T>
683where
684    T: Generable,
685{
686    fn generation_schema() -> Result<GenerationSchema, FMError> {
687        let item_schema: Value = serde_json::from_str(T::generation_schema()?.json_schema())
688            .map_err(|error| {
689                FMError::InvalidArgument(format!("element schema is not valid JSON: {error}"))
690            })?;
691        Ok(GenerationSchema::from_json_schema_unchecked(
692            json!({ "type": "array", "items": item_schema }).to_string(),
693        ))
694    }
695}
696
697impl<T> Generable for Option<T>
698where
699    T: Generable,
700{
701    fn generation_schema() -> Result<GenerationSchema, FMError> {
702        T::generation_schema()
703    }
704}
705
706unsafe extern "C" fn schema_callback_trampoline(
707    context: *mut c_void,
708    response: *mut c_char,
709    error: *mut c_char,
710    status: i32,
711) {
712    let tx = Box::from_raw(context.cast::<mpsc::Sender<Result<String, FMError>>>());
713    let result = if status == ffi::status::OK && !response.is_null() {
714        let value = core::ffi::CStr::from_ptr(response)
715            .to_string_lossy()
716            .into_owned();
717        ffi::fm_string_free(response);
718        Ok(value)
719    } else {
720        Err(crate::error::from_swift(status, error))
721    };
722    let _ = tx.send(result);
723}