Skip to main content

bamboo_engine/skills/
access_control.rs

1//! Skill access control — authorization checks for skill loading and resource reads.
2
3use std::collections::{BTreeSet, HashMap, HashSet};
4
5use crate::skills::runtime_metadata::{
6    LAST_LOADED_SKILL_ID_METADATA_KEY, LAST_LOADED_SKILL_SUMMARY_METADATA_KEY,
7    LOADED_SKILL_IDS_METADATA_KEY, SELECTED_SKILL_IDS_METADATA_KEY,
8    SELECTED_SKILL_MODE_METADATA_KEY,
9};
10use crate::skills::selection::parse_selected_skill_ids_metadata;
11pub use crate::skills::session_port::SkillSessionPort;
12
13/// Error type for skill access control operations.
14#[derive(Debug, thiserror::Error)]
15pub enum SkillAccessError {
16    #[error("{0}")]
17    NotAllowed(String),
18    #[error("{0}")]
19    NotLoaded(String),
20    #[error("{0}")]
21    SessionRequired(String),
22    #[error("{0}")]
23    SessionNotFound(String),
24    #[error("{0}")]
25    PersistenceError(String),
26}
27
28// ---------------------------------------------------------------------------
29// Pure helpers
30// ---------------------------------------------------------------------------
31
32pub fn parse_loaded_skill_ids(raw: &str) -> HashSet<String> {
33    let trimmed = raw.trim();
34    if trimmed.is_empty() {
35        return HashSet::new();
36    }
37
38    if let Ok(ids) = serde_json::from_str::<Vec<String>>(trimmed) {
39        return ids
40            .into_iter()
41            .map(|id| id.trim().to_string())
42            .filter(|id| !id.is_empty())
43            .collect();
44    }
45
46    trimmed
47        .split(',')
48        .map(|id| id.trim().to_string())
49        .filter(|id| !id.is_empty())
50        .collect()
51}
52
53pub fn serialize_loaded_skill_ids(ids: &HashSet<String>) -> String {
54    let sorted: BTreeSet<String> = ids
55        .iter()
56        .map(|id| id.trim().to_string())
57        .filter(|id| !id.is_empty())
58        .collect();
59    serde_json::to_string(&sorted.into_iter().collect::<Vec<String>>()).unwrap_or("[]".to_string())
60}
61
62pub fn extract_skill_allowlist(metadata: &HashMap<String, String>) -> Option<HashSet<String>> {
63    metadata
64        .get(SELECTED_SKILL_IDS_METADATA_KEY)
65        .and_then(|raw| parse_selected_skill_ids_metadata(raw))
66        .map(|ids| ids.into_iter().collect())
67}
68
69pub fn extract_skill_mode(metadata: &HashMap<String, String>) -> Option<String> {
70    let mode = metadata
71        .get(SELECTED_SKILL_MODE_METADATA_KEY)
72        .or_else(|| metadata.get("mode"))?;
73    let trimmed = mode.trim();
74    if trimmed.is_empty() {
75        None
76    } else {
77        Some(trimmed.to_string())
78    }
79}
80
81fn extract_loaded_ids_from_metadata(metadata: &HashMap<String, String>) -> HashSet<String> {
82    metadata
83        .get(LOADED_SKILL_IDS_METADATA_KEY)
84        .map(|raw| parse_loaded_skill_ids(raw))
85        .unwrap_or_default()
86}
87
88// ---------------------------------------------------------------------------
89// Access control functions
90// ---------------------------------------------------------------------------
91
92pub async fn ensure_skill_allowed(
93    port: &dyn SkillSessionPort,
94    skill_id: &str,
95    session_id: Option<&str>,
96) -> Result<(), SkillAccessError> {
97    let disabled = port.disabled_skill_ids().await;
98    if disabled.contains(skill_id) {
99        return Err(SkillAccessError::NotAllowed(format!(
100            "Skill '{skill_id}' is globally disabled in Bamboo settings"
101        )));
102    }
103
104    let Some(session_id) = session_id else {
105        return Ok(());
106    };
107
108    let Some(metadata) = port.load_session_metadata(session_id).await else {
109        return Ok(());
110    };
111
112    let Some(allowlist) = extract_skill_allowlist(&metadata) else {
113        return Ok(());
114    };
115
116    if allowlist.contains(skill_id) {
117        return Ok(());
118    }
119
120    Err(SkillAccessError::NotAllowed(format!(
121        "Skill '{skill_id}' is not selected for this request"
122    )))
123}
124
125pub async fn ensure_skill_loaded(
126    port: &dyn SkillSessionPort,
127    skill_id: &str,
128    session_id: Option<&str>,
129) -> Result<(), SkillAccessError> {
130    let Some(session_id) = session_id else {
131        return Err(SkillAccessError::SessionRequired(
132            "read_skill_resource requires a session_id in tool context".to_string(),
133        ));
134    };
135
136    let Some(metadata) = port.load_session_metadata(session_id).await else {
137        return Err(SkillAccessError::SessionNotFound(format!(
138            "Session '{session_id}' was not found while verifying loaded skill state"
139        )));
140    };
141
142    let loaded_ids = extract_loaded_ids_from_metadata(&metadata);
143
144    if loaded_ids.contains(skill_id) {
145        return Ok(());
146    }
147
148    Err(SkillAccessError::NotLoaded(format!(
149        "Skill '{skill_id}' has not been loaded in this session. Call load_skill first."
150    )))
151}
152
153pub async fn mark_skill_loaded(
154    port: &dyn SkillSessionPort,
155    skill_id: &str,
156    session_id: Option<&str>,
157) -> Result<(), SkillAccessError> {
158    let Some(session_id) = session_id else {
159        return Err(SkillAccessError::SessionRequired(
160            "load_skill requires a session_id in tool context".to_string(),
161        ));
162    };
163
164    let metadata = port
165        .load_session_metadata(session_id)
166        .await
167        .ok_or_else(|| {
168            SkillAccessError::SessionNotFound(format!(
169                "Session '{session_id}' not found while persisting loaded skill state"
170            ))
171        })?;
172
173    let mut loaded_ids = extract_loaded_ids_from_metadata(&metadata);
174    loaded_ids.insert(skill_id.to_string());
175
176    let serialized_ids = serialize_loaded_skill_ids(&loaded_ids);
177    let summary = serde_json::json!({
178        "skill_id": skill_id,
179        "loaded_ids": loaded_ids.iter().cloned().collect::<BTreeSet<_>>(),
180        "selected_skill_mode": metadata.get(SELECTED_SKILL_MODE_METADATA_KEY).cloned(),
181        "loaded_count": loaded_ids.len()
182    })
183    .to_string();
184
185    let updates = vec![
186        (
187            LOADED_SKILL_IDS_METADATA_KEY.to_string(),
188            Some(serialized_ids),
189        ),
190        (
191            LAST_LOADED_SKILL_ID_METADATA_KEY.to_string(),
192            Some(skill_id.to_string()),
193        ),
194        (
195            LAST_LOADED_SKILL_SUMMARY_METADATA_KEY.to_string(),
196            Some(summary),
197        ),
198    ];
199
200    port.save_metadata_updates(session_id, &updates)
201        .await
202        .map_err(SkillAccessError::PersistenceError)?;
203
204    Ok(())
205}
206
207pub async fn selected_skill_allowlist(
208    port: &dyn SkillSessionPort,
209    session_id: Option<&str>,
210) -> Option<HashSet<String>> {
211    let session_id = session_id?;
212    let metadata = port.load_session_metadata(session_id).await?;
213    extract_skill_allowlist(&metadata)
214}
215
216pub async fn selected_skill_mode(
217    port: &dyn SkillSessionPort,
218    session_id: Option<&str>,
219) -> Option<String> {
220    let session_id = session_id?;
221    let metadata = port.load_session_metadata(session_id).await?;
222    extract_skill_mode(&metadata)
223}
224
225// ---------------------------------------------------------------------------
226// Tests
227// ---------------------------------------------------------------------------
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn parse_loaded_skill_ids_supports_json_and_csv() {
235        let from_json = parse_loaded_skill_ids(r#"["skill-b","skill-a","skill-a"]"#);
236        assert_eq!(from_json.len(), 2);
237        assert!(from_json.contains("skill-a"));
238        assert!(from_json.contains("skill-b"));
239
240        let from_csv = parse_loaded_skill_ids("skill-c, skill-d , skill-c");
241        assert_eq!(from_csv.len(), 2);
242        assert!(from_csv.contains("skill-c"));
243        assert!(from_csv.contains("skill-d"));
244    }
245
246    #[test]
247    fn serialize_loaded_skill_ids_is_stable_and_sorted() {
248        let mut ids = HashSet::new();
249        ids.insert("skill-b".to_string());
250        ids.insert("skill-a".to_string());
251
252        assert_eq!(serialize_loaded_skill_ids(&ids), r#"["skill-a","skill-b"]"#);
253    }
254
255    #[test]
256    fn extract_skill_allowlist_parses_metadata_json() {
257        let mut metadata = HashMap::new();
258        metadata.insert(
259            "selected_skill_ids".to_string(),
260            r#"["pdf","skill-creator"]"#.to_string(),
261        );
262
263        let allowlist = extract_skill_allowlist(&metadata).unwrap();
264        assert!(allowlist.contains("pdf"));
265        assert!(allowlist.contains("skill-creator"));
266    }
267
268    #[test]
269    fn extract_skill_mode_prefers_skill_mode_key() {
270        let mut metadata = HashMap::new();
271        metadata.insert("mode".to_string(), "ask".to_string());
272        metadata.insert("skill_mode".to_string(), "code".to_string());
273
274        assert_eq!(extract_skill_mode(&metadata).as_deref(), Some("code"));
275    }
276}