bamboo_engine/skills/
access_control.rs1use std::collections::{BTreeSet, HashMap, HashSet};
4
5use crate::skills::runtime_metadata::{
6 LAST_LOADED_SKILL_ID_METADATA_KEY, LAST_LOADED_SKILL_SUMMARY_METADATA_KEY,
7 LOADED_SKILL_IDS_METADATA_KEY, SELECTED_SKILL_IDS_METADATA_KEY,
8 SELECTED_SKILL_MODE_METADATA_KEY,
9};
10use crate::skills::selection::parse_selected_skill_ids_metadata;
11pub use crate::skills::session_port::SkillSessionPort;
12
13#[derive(Debug, thiserror::Error)]
15pub enum SkillAccessError {
16 #[error("{0}")]
17 NotAllowed(String),
18 #[error("{0}")]
19 NotLoaded(String),
20 #[error("{0}")]
21 SessionRequired(String),
22 #[error("{0}")]
23 SessionNotFound(String),
24 #[error("{0}")]
25 PersistenceError(String),
26}
27
28pub fn parse_loaded_skill_ids(raw: &str) -> HashSet<String> {
33 let trimmed = raw.trim();
34 if trimmed.is_empty() {
35 return HashSet::new();
36 }
37
38 if let Ok(ids) = serde_json::from_str::<Vec<String>>(trimmed) {
39 return ids
40 .into_iter()
41 .map(|id| id.trim().to_string())
42 .filter(|id| !id.is_empty())
43 .collect();
44 }
45
46 trimmed
47 .split(',')
48 .map(|id| id.trim().to_string())
49 .filter(|id| !id.is_empty())
50 .collect()
51}
52
53pub fn serialize_loaded_skill_ids(ids: &HashSet<String>) -> String {
54 let sorted: BTreeSet<String> = ids
55 .iter()
56 .map(|id| id.trim().to_string())
57 .filter(|id| !id.is_empty())
58 .collect();
59 serde_json::to_string(&sorted.into_iter().collect::<Vec<String>>()).unwrap_or("[]".to_string())
60}
61
62pub fn extract_skill_allowlist(metadata: &HashMap<String, String>) -> Option<HashSet<String>> {
63 metadata
64 .get(SELECTED_SKILL_IDS_METADATA_KEY)
65 .and_then(|raw| parse_selected_skill_ids_metadata(raw))
66 .map(|ids| ids.into_iter().collect())
67}
68
69pub fn extract_skill_mode(metadata: &HashMap<String, String>) -> Option<String> {
70 let mode = metadata
71 .get(SELECTED_SKILL_MODE_METADATA_KEY)
72 .or_else(|| metadata.get("mode"))?;
73 let trimmed = mode.trim();
74 if trimmed.is_empty() {
75 None
76 } else {
77 Some(trimmed.to_string())
78 }
79}
80
81fn extract_loaded_ids_from_metadata(metadata: &HashMap<String, String>) -> HashSet<String> {
82 metadata
83 .get(LOADED_SKILL_IDS_METADATA_KEY)
84 .map(|raw| parse_loaded_skill_ids(raw))
85 .unwrap_or_default()
86}
87
88pub async fn ensure_skill_allowed(
93 port: &dyn SkillSessionPort,
94 skill_id: &str,
95 session_id: Option<&str>,
96) -> Result<(), SkillAccessError> {
97 let disabled = port.disabled_skill_ids().await;
98 if disabled.contains(skill_id) {
99 return Err(SkillAccessError::NotAllowed(format!(
100 "Skill '{skill_id}' is globally disabled in Bamboo settings"
101 )));
102 }
103
104 let Some(session_id) = session_id else {
105 return Ok(());
106 };
107
108 let Some(metadata) = port.load_session_metadata(session_id).await else {
109 return Ok(());
110 };
111
112 let Some(allowlist) = extract_skill_allowlist(&metadata) else {
113 return Ok(());
114 };
115
116 if allowlist.contains(skill_id) {
117 return Ok(());
118 }
119
120 Err(SkillAccessError::NotAllowed(format!(
121 "Skill '{skill_id}' is not selected for this request"
122 )))
123}
124
125pub async fn ensure_skill_loaded(
126 port: &dyn SkillSessionPort,
127 skill_id: &str,
128 session_id: Option<&str>,
129) -> Result<(), SkillAccessError> {
130 let Some(session_id) = session_id else {
131 return Err(SkillAccessError::SessionRequired(
132 "read_skill_resource requires a session_id in tool context".to_string(),
133 ));
134 };
135
136 let Some(metadata) = port.load_session_metadata(session_id).await else {
137 return Err(SkillAccessError::SessionNotFound(format!(
138 "Session '{session_id}' was not found while verifying loaded skill state"
139 )));
140 };
141
142 let loaded_ids = extract_loaded_ids_from_metadata(&metadata);
143
144 if loaded_ids.contains(skill_id) {
145 return Ok(());
146 }
147
148 Err(SkillAccessError::NotLoaded(format!(
149 "Skill '{skill_id}' has not been loaded in this session. Call load_skill first."
150 )))
151}
152
153pub async fn mark_skill_loaded(
154 port: &dyn SkillSessionPort,
155 skill_id: &str,
156 session_id: Option<&str>,
157) -> Result<(), SkillAccessError> {
158 let Some(session_id) = session_id else {
159 return Err(SkillAccessError::SessionRequired(
160 "load_skill requires a session_id in tool context".to_string(),
161 ));
162 };
163
164 let metadata = port
165 .load_session_metadata(session_id)
166 .await
167 .ok_or_else(|| {
168 SkillAccessError::SessionNotFound(format!(
169 "Session '{session_id}' not found while persisting loaded skill state"
170 ))
171 })?;
172
173 let mut loaded_ids = extract_loaded_ids_from_metadata(&metadata);
174 loaded_ids.insert(skill_id.to_string());
175
176 let serialized_ids = serialize_loaded_skill_ids(&loaded_ids);
177 let summary = serde_json::json!({
178 "skill_id": skill_id,
179 "loaded_ids": loaded_ids.iter().cloned().collect::<BTreeSet<_>>(),
180 "selected_skill_mode": metadata.get(SELECTED_SKILL_MODE_METADATA_KEY).cloned(),
181 "loaded_count": loaded_ids.len()
182 })
183 .to_string();
184
185 let updates = vec![
186 (
187 LOADED_SKILL_IDS_METADATA_KEY.to_string(),
188 Some(serialized_ids),
189 ),
190 (
191 LAST_LOADED_SKILL_ID_METADATA_KEY.to_string(),
192 Some(skill_id.to_string()),
193 ),
194 (
195 LAST_LOADED_SKILL_SUMMARY_METADATA_KEY.to_string(),
196 Some(summary),
197 ),
198 ];
199
200 port.save_metadata_updates(session_id, &updates)
201 .await
202 .map_err(SkillAccessError::PersistenceError)?;
203
204 Ok(())
205}
206
207pub async fn selected_skill_allowlist(
208 port: &dyn SkillSessionPort,
209 session_id: Option<&str>,
210) -> Option<HashSet<String>> {
211 let session_id = session_id?;
212 let metadata = port.load_session_metadata(session_id).await?;
213 extract_skill_allowlist(&metadata)
214}
215
216pub async fn selected_skill_mode(
217 port: &dyn SkillSessionPort,
218 session_id: Option<&str>,
219) -> Option<String> {
220 let session_id = session_id?;
221 let metadata = port.load_session_metadata(session_id).await?;
222 extract_skill_mode(&metadata)
223}
224
225#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn parse_loaded_skill_ids_supports_json_and_csv() {
235 let from_json = parse_loaded_skill_ids(r#"["skill-b","skill-a","skill-a"]"#);
236 assert_eq!(from_json.len(), 2);
237 assert!(from_json.contains("skill-a"));
238 assert!(from_json.contains("skill-b"));
239
240 let from_csv = parse_loaded_skill_ids("skill-c, skill-d , skill-c");
241 assert_eq!(from_csv.len(), 2);
242 assert!(from_csv.contains("skill-c"));
243 assert!(from_csv.contains("skill-d"));
244 }
245
246 #[test]
247 fn serialize_loaded_skill_ids_is_stable_and_sorted() {
248 let mut ids = HashSet::new();
249 ids.insert("skill-b".to_string());
250 ids.insert("skill-a".to_string());
251
252 assert_eq!(serialize_loaded_skill_ids(&ids), r#"["skill-a","skill-b"]"#);
253 }
254
255 #[test]
256 fn extract_skill_allowlist_parses_metadata_json() {
257 let mut metadata = HashMap::new();
258 metadata.insert(
259 "selected_skill_ids".to_string(),
260 r#"["pdf","skill-creator"]"#.to_string(),
261 );
262
263 let allowlist = extract_skill_allowlist(&metadata).unwrap();
264 assert!(allowlist.contains("pdf"));
265 assert!(allowlist.contains("skill-creator"));
266 }
267
268 #[test]
269 fn extract_skill_mode_prefers_skill_mode_key() {
270 let mut metadata = HashMap::new();
271 metadata.insert("mode".to_string(), "ask".to_string());
272 metadata.insert("skill_mode".to_string(), "code".to_string());
273
274 assert_eq!(extract_skill_mode(&metadata).as_deref(), Some("code"));
275 }
276}