Skip to main content

systemprompt_extension/
builder.rs

1//! Fluent typed builder that enforces dependency ordering at compile time.
2//!
3//! Each `extension`/`schema_extension`/`api_extension` call advances the
4//! `Registered` typestate, and the `where E::Deps: Subset<R>` bound forces
5//! callers to register dependencies before their dependents.
6
7use std::marker::PhantomData;
8
9use crate::any::{AnyExtension, ApiExtensionWrapper, ExtensionWrapper, SchemaExtensionWrapper};
10use crate::error::LoaderError;
11use crate::hlist::{Subset, TypeList};
12use crate::typed::{ApiExtensionTypedDyn, SchemaExtensionTyped};
13use crate::typed_registry::TypedExtensionRegistry;
14use crate::types::{Dependencies, ExtensionType};
15
16pub struct ExtensionBuilder<Registered: TypeList = ()> {
17    extensions: Vec<Box<dyn AnyExtension>>,
18    _marker: PhantomData<Registered>,
19}
20
21impl<R: TypeList> std::fmt::Debug for ExtensionBuilder<R> {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("ExtensionBuilder")
24            .field("extension_count", &self.extensions.len())
25            .finish_non_exhaustive()
26    }
27}
28
29impl ExtensionBuilder<()> {
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            extensions: Vec::new(),
34            _marker: PhantomData,
35        }
36    }
37}
38
39impl Default for ExtensionBuilder<()> {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl<R: TypeList> ExtensionBuilder<R> {
46    pub fn extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
47    where
48        E: ExtensionType + Dependencies + std::fmt::Debug + 'static,
49        E::Deps: Subset<R>,
50    {
51        self.extensions.push(Box::new(ExtensionWrapper::new(ext)));
52        ExtensionBuilder {
53            extensions: self.extensions,
54            _marker: PhantomData,
55        }
56    }
57
58    pub fn schema_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
59    where
60        E: ExtensionType + Dependencies + SchemaExtensionTyped + std::fmt::Debug + 'static,
61        E::Deps: Subset<R>,
62    {
63        self.extensions
64            .push(Box::new(SchemaExtensionWrapper::new(ext)));
65        ExtensionBuilder {
66            extensions: self.extensions,
67            _marker: PhantomData,
68        }
69    }
70
71    pub fn api_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
72    where
73        E: ExtensionType + Dependencies + ApiExtensionTypedDyn + std::fmt::Debug + 'static,
74        E::Deps: Subset<R>,
75    {
76        self.extensions
77            .push(Box::new(ApiExtensionWrapper::new(ext)));
78        ExtensionBuilder {
79            extensions: self.extensions,
80            _marker: PhantomData,
81        }
82    }
83
84    pub fn build(self) -> Result<TypedExtensionRegistry, LoaderError> {
85        let mut registry = TypedExtensionRegistry::new();
86        let mut sorted = self.extensions;
87        sorted.sort_by_key(|e| e.priority());
88
89        for ext in sorted {
90            if registry.has(ext.id()) {
91                return Err(LoaderError::DuplicateExtension(ext.id().to_string()));
92            }
93
94            if let Some(api) = ext.as_api() {
95                registry.validate_api_path(ext.id(), api.base_path())?;
96            }
97
98            registry.add_boxed(ext);
99        }
100
101        Ok(registry)
102    }
103}