systemprompt_extension/
builder.rs1use std::marker::PhantomData;
8
9use crate::any::{AnyExtension, ApiExtensionWrapper, ExtensionWrapper, SchemaExtensionWrapper};
10use crate::error::LoaderError;
11use crate::hlist::{Subset, TypeList};
12use crate::typed::{ApiExtensionTypedDyn, SchemaExtensionTyped};
13use crate::typed_registry::TypedExtensionRegistry;
14use crate::types::{Dependencies, ExtensionType};
15
16pub struct ExtensionBuilder<Registered: TypeList = ()> {
17 extensions: Vec<Box<dyn AnyExtension>>,
18 _marker: PhantomData<Registered>,
19}
20
21impl<R: TypeList> std::fmt::Debug for ExtensionBuilder<R> {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("ExtensionBuilder")
24 .field("extension_count", &self.extensions.len())
25 .finish_non_exhaustive()
26 }
27}
28
29impl ExtensionBuilder<()> {
30 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 extensions: Vec::new(),
34 _marker: PhantomData,
35 }
36 }
37}
38
39impl Default for ExtensionBuilder<()> {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl<R: TypeList> ExtensionBuilder<R> {
46 pub fn extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
47 where
48 E: ExtensionType + Dependencies + std::fmt::Debug + 'static,
49 E::Deps: Subset<R>,
50 {
51 self.extensions.push(Box::new(ExtensionWrapper::new(ext)));
52 ExtensionBuilder {
53 extensions: self.extensions,
54 _marker: PhantomData,
55 }
56 }
57
58 pub fn schema_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
59 where
60 E: ExtensionType + Dependencies + SchemaExtensionTyped + std::fmt::Debug + 'static,
61 E::Deps: Subset<R>,
62 {
63 self.extensions
64 .push(Box::new(SchemaExtensionWrapper::new(ext)));
65 ExtensionBuilder {
66 extensions: self.extensions,
67 _marker: PhantomData,
68 }
69 }
70
71 pub fn api_extension<E>(mut self, ext: E) -> ExtensionBuilder<(E, R)>
72 where
73 E: ExtensionType + Dependencies + ApiExtensionTypedDyn + std::fmt::Debug + 'static,
74 E::Deps: Subset<R>,
75 {
76 self.extensions
77 .push(Box::new(ApiExtensionWrapper::new(ext)));
78 ExtensionBuilder {
79 extensions: self.extensions,
80 _marker: PhantomData,
81 }
82 }
83
84 pub fn build(self) -> Result<TypedExtensionRegistry, LoaderError> {
85 let mut registry = TypedExtensionRegistry::new();
86 let mut sorted = self.extensions;
87 sorted.sort_by_key(|e| e.priority());
88
89 for ext in sorted {
90 if registry.has(ext.id()) {
91 return Err(LoaderError::DuplicateExtension(ext.id().to_string()));
92 }
93
94 if let Some(api) = ext.as_api() {
95 registry.validate_api_path(ext.id(), api.base_path())?;
96 }
97
98 registry.add_boxed(ext);
99 }
100
101 Ok(registry)
102 }
103}