systemprompt_extension/registry/
mod.rs1mod discovery;
9mod queries;
10mod validation;
11
12use crate::Extension;
13use crate::error::LoaderError;
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16use tracing::warn;
17
18pub use validation::RESERVED_PATHS;
19
20#[derive(Default)]
21pub struct ExtensionRegistry {
22 pub(crate) extensions: HashMap<String, Arc<dyn Extension>>,
23 pub(crate) sorted_extensions: Vec<Arc<dyn Extension>>,
24}
25
26impl std::fmt::Debug for ExtensionRegistry {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("ExtensionRegistry")
29 .field("extension_count", &self.extensions.len())
30 .finish_non_exhaustive()
31 }
32}
33
34impl ExtensionRegistry {
35 #[must_use]
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub(crate) fn sort_by_priority(&mut self) -> Result<(), LoaderError> {
48 let ids: Vec<String> = self
49 .sorted_extensions
50 .iter()
51 .map(|e| e.id().to_string())
52 .collect();
53 let id_set: HashSet<&str> = ids.iter().map(String::as_str).collect();
54
55 let mut by_id: HashMap<String, Arc<dyn Extension>> = HashMap::new();
56 for ext in self.sorted_extensions.drain(..) {
57 by_id.insert(ext.id().to_string(), ext);
58 }
59
60 for (owner, ext) in &by_id {
61 for dep in ext.dependencies() {
62 if !id_set.contains(dep) {
63 warn!(
64 extension = %owner,
65 missing_dependency = %dep,
66 "Extension declares dependency that is not loaded; treating as optional \
67 and ignoring for ordering"
68 );
69 }
70 }
71 }
72
73 let order = topo_sort(&ids, &by_id)?;
74
75 self.sorted_extensions = order
76 .into_iter()
77 .filter_map(|id| by_id.remove(&id))
78 .collect();
79 Ok(())
80 }
81
82 pub fn register(&mut self, ext: Arc<dyn Extension>) -> Result<(), LoaderError> {
83 let id = ext.id().to_string();
84 if self.extensions.contains_key(&id) {
85 return Err(LoaderError::DuplicateExtension(id));
86 }
87 self.extensions.insert(id, Arc::clone(&ext));
88 self.sorted_extensions.push(ext);
89 self.sort_by_priority()?;
90 Ok(())
91 }
92
93 pub fn merge(&mut self, extensions: Vec<Arc<dyn Extension>>) -> Result<(), LoaderError> {
94 for ext in extensions {
95 self.register(ext)?;
96 }
97 Ok(())
98 }
99
100 pub fn validate(&self) -> Result<(), LoaderError> {
101 self.validate_dependencies()?;
102 Ok(())
103 }
104
105 #[must_use]
106 pub fn len(&self) -> usize {
107 self.extensions.len()
108 }
109
110 #[must_use]
111 pub fn is_empty(&self) -> bool {
112 self.extensions.is_empty()
113 }
114}
115
116fn topo_sort(
117 ids: &[String],
118 by_id: &HashMap<String, Arc<dyn Extension>>,
119) -> Result<Vec<String>, LoaderError> {
120 const WHITE: u8 = 0;
121 const GRAY: u8 = 1;
122 const BLACK: u8 = 2;
123
124 fn visit(
125 node: &str,
126 by_id: &HashMap<String, Arc<dyn Extension>>,
127 color: &mut HashMap<String, u8>,
128 path: &mut Vec<String>,
129 out: &mut Vec<String>,
130 ) -> Result<(), LoaderError> {
131 let state = color.get(node).copied().unwrap_or(WHITE);
132 if state == BLACK {
133 return Ok(());
134 }
135 if state == GRAY {
136 let cycle_start = path.iter().position(|p| p == node).unwrap_or(0);
137 let mut chain: Vec<String> = path[cycle_start..].to_vec();
138 chain.push(node.to_string());
139 return Err(LoaderError::DependencyCycle {
140 chain: chain.join(" -> "),
141 });
142 }
143 color.insert(node.to_string(), GRAY);
144 path.push(node.to_string());
145
146 if let Some(ext) = by_id.get(node) {
147 let mut deps: Vec<&'static str> = ext
148 .dependencies()
149 .into_iter()
150 .filter(|d| by_id.contains_key(*d))
151 .collect();
152 deps.sort_by_key(|d| {
153 by_id.get(*d).map_or((u32::MAX, String::new()), |e| {
154 (e.priority(), e.id().to_string())
155 })
156 });
157 for dep in deps {
158 visit(dep, by_id, color, path, out)?;
159 }
160 }
161
162 path.pop();
163 color.insert(node.to_string(), BLACK);
164 out.push(node.to_string());
165 Ok(())
166 }
167
168 let mut roots: Vec<&String> = ids.iter().collect();
169 roots.sort_by_key(|id| {
170 by_id.get(*id).map_or((u32::MAX, String::new()), |e| {
171 (e.priority(), e.id().to_string())
172 })
173 });
174
175 let mut color: HashMap<String, u8> = HashMap::with_capacity(ids.len());
176 let mut path: Vec<String> = Vec::new();
177 let mut out: Vec<String> = Vec::with_capacity(ids.len());
178 for id in roots {
179 visit(id, by_id, &mut color, &mut path, &mut out)?;
180 }
181 Ok(out)
182}
183
184#[derive(Debug, Clone, Copy)]
185pub struct ExtensionRegistration {
186 pub factory: fn() -> Arc<dyn Extension>,
187}
188
189inventory::collect!(ExtensionRegistration);