Skip to main content

shadi/
lib.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::process::Command;
6use std::sync::Mutex;
7
8use agent_secrets::{
9    AgentSecretAccess, AgentVerifier, SecretError, SecretPolicy, SecretResult, SecretStore,
10    SessionContext,
11};
12use pyo3::exceptions::PyRuntimeError;
13use pyo3::prelude::*;
14use pyo3::types::{PyBytes, PyModule};
15use shadi_memory::{MemoryEntry as ShadiMemoryEntry, SqlCipherStore};
16use shadi_sandbox::{spawn_sandboxed, SandboxError, SandboxPolicy};
17use tracing::{field, info_span};
18
19struct SessionFlagVerifier;
20
21impl AgentVerifier for SessionFlagVerifier {
22    fn verify(&self, session: &SessionContext) -> SecretResult<()> {
23        if session.verified {
24            Ok(())
25        } else {
26            Err(SecretError::NotAuthorized)
27        }
28    }
29}
30
31#[pyclass]
32pub struct ShadiStore {
33    store: Mutex<Box<dyn SecretStore>>,
34    verifier: SessionFlagVerifier,
35    didvc_verifier: Mutex<Option<Py<PyAny>>>,
36}
37
38#[pyclass]
39pub struct SqlCipherMemoryStore {
40    store: SqlCipherStore,
41}
42
43#[pyclass]
44#[derive(Clone)]
45pub struct MemoryEntry {
46    #[pyo3(get)]
47    id: i64,
48    #[pyo3(get)]
49    scope: String,
50    #[pyo3(get)]
51    entry_key: String,
52    #[pyo3(get)]
53    payload: String,
54    #[pyo3(get)]
55    created_at: String,
56}
57
58impl MemoryEntry {
59    fn from_native(entry: ShadiMemoryEntry) -> Self {
60        Self {
61            id: entry.id,
62            scope: entry.scope,
63            entry_key: entry.entry_key,
64            payload: entry.payload,
65            created_at: entry.created_at,
66        }
67    }
68}
69
70#[pyclass]
71pub struct SandboxPolicyHandle {
72    policy: SandboxPolicy,
73}
74
75#[pymethods]
76impl ShadiStore {
77    #[new]
78    fn new() -> Self {
79        Self {
80            store: Mutex::new(agent_secrets::default_store()),
81            verifier: SessionFlagVerifier,
82            didvc_verifier: Mutex::new(None),
83        }
84    }
85
86    fn set_verifier(&self, verifier: PyObject) -> PyResult<()> {
87        let mut guard = self
88            .didvc_verifier
89            .lock()
90            .map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
91        *guard = Some(verifier);
92        Ok(())
93    }
94
95    fn verify_session(
96        &self,
97        py: Python<'_>,
98        session: &Bound<'_, PySessionContext>,
99        presentation: &[u8],
100    ) -> PyResult<bool> {
101        let verifier = {
102            let guard = self
103                .didvc_verifier
104                .lock()
105                .map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
106            guard.clone().ok_or_else(|| PyRuntimeError::new_err("verifier not configured"))?
107        };
108
109        let (agent_id, session_id, claims) = {
110            let session_ref = session.borrow();
111            (
112                session_ref.agent_id.clone(),
113                session_ref.session_id.clone(),
114                session_ref.claims.clone(),
115            )
116        };
117
118        let payload = PyBytes::new_bound(py, presentation);
119        let result = verifier.call1(py, (agent_id, session_id, payload, claims))?;
120        let is_valid = result.is_truthy(py)?;
121
122        if is_valid {
123            let mut session_ref = session.borrow_mut();
124            session_ref.verified = true;
125        }
126
127        Ok(is_valid)
128    }
129
130    fn put(&self, session: &PySessionContext, key: &str, secret: &[u8]) -> PyResult<()> {
131        let span = info_span!("shadi.secret.put", secret.key = %key);
132        let _guard = span.enter();
133        let ctx = session.to_context();
134        let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
135        let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
136        access
137            .put_for_session(&ctx, key, secret, SecretPolicy::default())
138            .map_err(map_secret_error)
139    }
140
141    fn get<'py>(
142        &self,
143        py: Python<'py>,
144        session: &PySessionContext,
145        key: &str,
146    ) -> PyResult<Bound<'py, PyBytes>> {
147        let span = info_span!("shadi.secret.get", secret.key = %key);
148        let _guard = span.enter();
149        let ctx = session.to_context();
150        let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
151        let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
152        let secret = access.get_for_session(&ctx, key).map_err(map_secret_error)?;
153        let bytes = secret.expose(|data| data.to_vec());
154        Ok(PyBytes::new_bound(py, &bytes))
155    }
156
157    fn delete(&self, session: &PySessionContext, key: &str) -> PyResult<()> {
158        let span = info_span!("shadi.secret.delete", secret.key = %key);
159        let _guard = span.enter();
160        let ctx = session.to_context();
161        let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
162        let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
163        access
164            .delete_for_session(&ctx, key)
165            .map_err(map_secret_error)
166    }
167
168    fn list_keys(&self, session: &PySessionContext) -> PyResult<Vec<String>> {
169        let span = info_span!("shadi.secret.list_keys");
170        let _guard = span.enter();
171        let ctx = session.to_context();
172        AgentSecretAccess::require_verified(&ctx).map_err(map_secret_error)?;
173        let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
174        guard.list_keys().map_err(map_secret_error)
175    }
176}
177
178#[pymethods]
179impl SqlCipherMemoryStore {
180    #[new]
181    #[pyo3(signature = (db_path, key=None, key_name=None))]
182    fn new(db_path: String, key: Option<String>, key_name: Option<String>) -> PyResult<Self> {
183        let key = resolve_memory_key(key, key_name.as_deref())?;
184        let store = SqlCipherStore::open(db_path.as_ref(), &key)
185            .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
186        Ok(Self { store })
187    }
188
189    fn put(&self, scope: &str, entry_key: &str, payload: &str) -> PyResult<i64> {
190        let span = info_span!("shadi.memory.put", memory.scope = %scope, memory.entry_key = %entry_key);
191        let _guard = span.enter();
192        self.store
193            .put(scope, entry_key, payload)
194            .map_err(|err| PyRuntimeError::new_err(err.to_string()))
195    }
196
197    fn get_latest(&self, scope: &str, entry_key: &str) -> PyResult<Option<MemoryEntry>> {
198        let span = info_span!("shadi.memory.get_latest", memory.scope = %scope, memory.entry_key = %entry_key);
199        let _guard = span.enter();
200        let entry = self
201            .store
202            .get_latest(scope, entry_key)
203            .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
204        Ok(entry.map(MemoryEntry::from_native))
205    }
206
207    #[pyo3(signature = (query, scope=None, limit=10))]
208    fn search(&self, query: &str, scope: Option<String>, limit: usize) -> PyResult<Vec<MemoryEntry>> {
209        let span = info_span!(
210            "shadi.memory.search",
211            memory.query = %query,
212            memory.scope = %scope.as_deref().unwrap_or(""),
213            memory.limit = limit as i64,
214        );
215        let _guard = span.enter();
216        let entries = self
217            .store
218            .search(scope.as_deref(), query, limit)
219            .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
220        Ok(entries
221            .into_iter()
222            .map(MemoryEntry::from_native)
223            .collect())
224    }
225
226    #[pyo3(signature = (scope=None, limit=50))]
227    fn list(&self, scope: Option<String>, limit: usize) -> PyResult<Vec<MemoryEntry>> {
228        let span = info_span!(
229            "shadi.memory.list",
230            memory.scope = %scope.as_deref().unwrap_or(""),
231            memory.limit = limit as i64,
232        );
233        let _guard = span.enter();
234        let entries = self
235            .store
236            .list(scope.as_deref(), limit)
237            .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
238        Ok(entries
239            .into_iter()
240            .map(MemoryEntry::from_native)
241            .collect())
242    }
243
244    fn delete(&self, scope: &str, entry_key: &str) -> PyResult<usize> {
245        let span = info_span!("shadi.memory.delete", memory.scope = %scope, memory.entry_key = %entry_key);
246        let _guard = span.enter();
247        self.store
248            .delete(scope, entry_key)
249            .map_err(|err| PyRuntimeError::new_err(err.to_string()))
250    }
251}
252
253#[pymethods]
254impl SandboxPolicyHandle {
255    #[new]
256    fn new() -> Self {
257        Self {
258            policy: SandboxPolicy::new(),
259        }
260    }
261
262    fn allow_read_path(&mut self, path: &str) {
263        self.policy = self.policy.clone().allow_read_path(path);
264    }
265
266    fn allow_write_path(&mut self, path: &str) {
267        self.policy = self.policy.clone().allow_write_path(path);
268    }
269
270    fn block_network(&mut self, value: bool) {
271        self.policy = self.policy.clone().block_network(value);
272    }
273}
274
275#[pyclass]
276pub struct PySessionContext {
277    agent_id: String,
278    session_id: String,
279    verified: bool,
280    claims: Vec<String>,
281}
282
283#[pymethods]
284impl PySessionContext {
285    #[new]
286    fn new(agent_id: String, session_id: String) -> Self {
287        Self {
288            agent_id,
289            session_id,
290            verified: false,
291            claims: Vec::new(),
292        }
293    }
294
295    fn set_verified(&mut self, value: bool) {
296        self.verified = value;
297    }
298
299    fn add_claim(&mut self, claim: String) {
300        self.claims.push(claim);
301    }
302}
303
304impl PySessionContext {
305    fn to_context(&self) -> SessionContext {
306        SessionContext {
307            agent_id: self.agent_id.clone(),
308            session_id: self.session_id.clone(),
309            verified: self.verified,
310            claims: self.claims.clone(),
311        }
312    }
313}
314
315#[pymodule]
316fn shadi(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
317    shadi_telemetry::init("shadi-runtime");
318    m.add_class::<ShadiStore>()?;
319    m.add_class::<PySessionContext>()?;
320    m.add_class::<SqlCipherMemoryStore>()?;
321    m.add_class::<MemoryEntry>()?;
322    m.add_class::<SandboxPolicyHandle>()?;
323    m.add_function(wrap_pyfunction!(run_sandboxed, m)?)?;
324    Ok(())
325}
326
327fn map_secret_error(err: SecretError) -> PyErr {
328    PyRuntimeError::new_err(err.to_string())
329}
330
331fn map_sandbox_error(err: SandboxError) -> PyErr {
332    PyRuntimeError::new_err(err.to_string())
333}
334
335fn resolve_memory_key(key: Option<String>, key_name: Option<&str>) -> PyResult<String> {
336    if let Some(key) = key {
337        if key.is_empty() {
338            return Err(PyRuntimeError::new_err("SHADI_MEMORY_KEY is empty"));
339        }
340        return Ok(key);
341    }
342
343    let name = key_name.unwrap_or("shadi/memory/sqlcipher_key");
344    let store = agent_secrets::default_store();
345    let secret = store
346        .get(name)
347        .map_err(|_| PyRuntimeError::new_err(format!("missing SHADI key: {}", name)))?;
348    let raw = secret.expose(|bytes| bytes.to_vec());
349    String::from_utf8(raw).map_err(|_| PyRuntimeError::new_err("SHADI memory key is not utf-8"))
350}
351
352fn inject_keychain_with_store(
353    store: &dyn SecretStore,
354    command: &mut Command,
355    mappings: &[String],
356) -> Result<(), String> {
357    let span = info_span!("shadi.secrets.inject", secret.count = mappings.len() as i64);
358    let _guard = span.enter();
359    for mapping in mappings {
360        let (key, env) = parse_key_env(mapping)?;
361        let secret = store
362            .get(key)
363            .map_err(|_| format!("keychain lookup failed for {}", key))?;
364        let value = secret.expose(|bytes| bytes.to_vec());
365        let value = String::from_utf8(value).map_err(|_| "secret is not utf-8".to_string())?;
366        command.env(env, value);
367    }
368
369    Ok(())
370}
371
372fn parse_key_env(value: &str) -> Result<(&str, &str), String> {
373    let mut parts = value.splitn(2, '=');
374    let key = parts.next().unwrap_or("");
375    let env = parts.next().unwrap_or("");
376    if key.is_empty() || env.is_empty() {
377        return Err("inject-keychain must be in KEY=ENV format".to_string());
378    }
379    Ok((key, env))
380}
381
382#[pyfunction]
383#[pyo3(signature = (command, policy, cwd=None, env=None, inject_keychain=None))]
384fn run_sandboxed(
385    command: Vec<String>,
386    policy: &SandboxPolicyHandle,
387    cwd: Option<String>,
388    env: Option<HashMap<String, String>>,
389    inject_keychain: Option<Vec<String>>,
390) -> PyResult<i32> {
391    if command.is_empty() {
392        return Err(PyRuntimeError::new_err("command must not be empty"));
393    }
394
395    let cwd_value = cwd.as_deref().unwrap_or("");
396    let span = info_span!(
397        "shadi.sandbox.run",
398        command = %command[0],
399        cwd = %cwd_value,
400        exit.code = field::Empty,
401    );
402    let _guard = span.enter();
403
404    let mut cmd = Command::new(&command[0]);
405    if command.len() > 1 {
406        cmd.args(&command[1..]);
407    }
408    if let Some(cwd) = cwd {
409        cmd.current_dir(cwd);
410    }
411    if let Some(env_map) = env {
412        cmd.envs(env_map);
413    }
414    if let Some(mappings) = inject_keychain {
415        let store = agent_secrets::default_store();
416        inject_keychain_with_store(store.as_ref(), &mut cmd, &mappings)
417            .map_err(PyRuntimeError::new_err)?;
418    }
419
420    let mut child = spawn_sandboxed(&mut cmd, &policy.policy).map_err(map_sandbox_error)?;
421    let status = child
422        .wait()
423        .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
424    span.record("exit.code", &status.code().unwrap_or(-1));
425    Ok(status.code().unwrap_or(1))
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use std::sync::Once;
432    use std::time::{SystemTime, UNIX_EPOCH};
433
434    static PY_INIT: Once = Once::new();
435
436    fn ensure_python() {
437        PY_INIT.call_once(|| {
438            pyo3::prepare_freethreaded_python();
439        });
440    }
441
442    fn unique_key(prefix: &str) -> String {
443        let nanos = SystemTime::now()
444            .duration_since(UNIX_EPOCH)
445            .expect("time went backwards")
446            .as_nanos();
447        format!("{}-{}-{}", prefix, std::process::id(), nanos)
448    }
449
450    #[cfg(target_os = "macos")]
451    #[test]
452    fn verify_session_sets_verified_flag() {
453        ensure_python();
454        Python::with_gil(|py| {
455            let store = ShadiStore::new();
456            let module = PyModule::from_code_bound(
457                py,
458                "def verify(agent_id, session_id, presentation, claims):\n    return True\n",
459                "verifier.py",
460                "verifier",
461            )
462            .unwrap();
463            let verifier = module.getattr("verify").unwrap();
464            store.set_verifier(verifier.into_py(py)).unwrap();
465
466            let mut base_session = PySessionContext::new("agent".to_string(), "session".to_string());
467            base_session.add_claim("did:example:agent".to_string());
468            let session = Py::new(py, base_session).unwrap();
469            let session_bound = session.bind(py);
470
471            let ok = store.verify_session(py, session_bound, b"presentation").unwrap();
472            assert!(ok);
473            assert!(session_bound.borrow().verified);
474        });
475    }
476
477    #[cfg(target_os = "macos")]
478    #[test]
479    fn verify_session_requires_verifier() {
480        ensure_python();
481        Python::with_gil(|py| {
482            let store = ShadiStore::new();
483            let session = Py::new(py, PySessionContext::new("agent".to_string(), "session".to_string())).unwrap();
484            let session_bound = session.bind(py);
485
486            let err = store
487                .verify_session(py, session_bound, b"presentation")
488                .unwrap_err();
489            assert!(err.is_instance_of::<PyRuntimeError>(py));
490        });
491    }
492
493    #[cfg(target_os = "macos")]
494    #[test]
495    fn put_get_delete_roundtrip_requires_verified() {
496        ensure_python();
497        Python::with_gil(|py| {
498            let store = ShadiStore::new();
499            let mut session = PySessionContext::new("agent".to_string(), "session".to_string());
500            session.add_claim("role:tourist".to_string());
501
502            let err = store.put(&session, "key", b"value").unwrap_err();
503            assert!(err.is_instance_of::<PyRuntimeError>(py));
504
505            session.set_verified(true);
506            let key = unique_key("shadi-py");
507            let secret = b"secret-value";
508
509            store.put(&session, &key, secret).unwrap();
510            let bytes = store.get(py, &session, &key).unwrap();
511            assert_eq!(bytes.as_bytes(), secret);
512            store.delete(&session, &key).unwrap();
513        });
514    }
515
516    #[cfg(target_os = "macos")]
517    #[test]
518    fn list_keys_requires_verified() {
519        ensure_python();
520        Python::with_gil(|_py| {
521            let store = ShadiStore::new();
522            let mut session = PySessionContext::new("agent".to_string(), "session".to_string());
523            session.add_claim("role:secops".to_string());
524            session.set_verified(true);
525
526            let key = unique_key("shadi-py-list");
527            store.put(&session, &key, b"value").unwrap();
528            let keys = store.list_keys(&session).unwrap();
529            assert!(keys.iter().any(|item| item == &key));
530
531            store.delete(&session, &key).unwrap();
532        });
533    }
534}