1use std::collections::HashMap;
5use std::process::Command;
6use std::sync::Mutex;
7
8use agent_secrets::{
9 AgentSecretAccess, AgentVerifier, SecretError, SecretPolicy, SecretResult, SecretStore,
10 SessionContext,
11};
12use pyo3::exceptions::PyRuntimeError;
13use pyo3::prelude::*;
14use pyo3::types::{PyBytes, PyModule};
15use shadi_memory::{MemoryEntry as ShadiMemoryEntry, SqlCipherStore};
16use shadi_sandbox::{spawn_sandboxed, SandboxError, SandboxPolicy};
17use tracing::{field, info_span};
18
19struct SessionFlagVerifier;
20
21impl AgentVerifier for SessionFlagVerifier {
22 fn verify(&self, session: &SessionContext) -> SecretResult<()> {
23 if session.verified {
24 Ok(())
25 } else {
26 Err(SecretError::NotAuthorized)
27 }
28 }
29}
30
31#[pyclass]
32pub struct ShadiStore {
33 store: Mutex<Box<dyn SecretStore>>,
34 verifier: SessionFlagVerifier,
35 didvc_verifier: Mutex<Option<Py<PyAny>>>,
36}
37
38#[pyclass]
39pub struct SqlCipherMemoryStore {
40 store: SqlCipherStore,
41}
42
43#[pyclass]
44#[derive(Clone)]
45pub struct MemoryEntry {
46 #[pyo3(get)]
47 id: i64,
48 #[pyo3(get)]
49 scope: String,
50 #[pyo3(get)]
51 entry_key: String,
52 #[pyo3(get)]
53 payload: String,
54 #[pyo3(get)]
55 created_at: String,
56}
57
58impl MemoryEntry {
59 fn from_native(entry: ShadiMemoryEntry) -> Self {
60 Self {
61 id: entry.id,
62 scope: entry.scope,
63 entry_key: entry.entry_key,
64 payload: entry.payload,
65 created_at: entry.created_at,
66 }
67 }
68}
69
70#[pyclass]
71pub struct SandboxPolicyHandle {
72 policy: SandboxPolicy,
73}
74
75#[pymethods]
76impl ShadiStore {
77 #[new]
78 fn new() -> Self {
79 Self {
80 store: Mutex::new(agent_secrets::default_store()),
81 verifier: SessionFlagVerifier,
82 didvc_verifier: Mutex::new(None),
83 }
84 }
85
86 fn set_verifier(&self, verifier: PyObject) -> PyResult<()> {
87 let mut guard = self
88 .didvc_verifier
89 .lock()
90 .map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
91 *guard = Some(verifier);
92 Ok(())
93 }
94
95 fn verify_session(
96 &self,
97 py: Python<'_>,
98 session: &Bound<'_, PySessionContext>,
99 presentation: &[u8],
100 ) -> PyResult<bool> {
101 let verifier = {
102 let guard = self
103 .didvc_verifier
104 .lock()
105 .map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
106 guard.clone().ok_or_else(|| PyRuntimeError::new_err("verifier not configured"))?
107 };
108
109 let (agent_id, session_id, claims) = {
110 let session_ref = session.borrow();
111 (
112 session_ref.agent_id.clone(),
113 session_ref.session_id.clone(),
114 session_ref.claims.clone(),
115 )
116 };
117
118 let payload = PyBytes::new_bound(py, presentation);
119 let result = verifier.call1(py, (agent_id, session_id, payload, claims))?;
120 let is_valid = result.is_truthy(py)?;
121
122 if is_valid {
123 let mut session_ref = session.borrow_mut();
124 session_ref.verified = true;
125 }
126
127 Ok(is_valid)
128 }
129
130 fn put(&self, session: &PySessionContext, key: &str, secret: &[u8]) -> PyResult<()> {
131 let span = info_span!("shadi.secret.put", secret.key = %key);
132 let _guard = span.enter();
133 let ctx = session.to_context();
134 let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
135 let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
136 access
137 .put_for_session(&ctx, key, secret, SecretPolicy::default())
138 .map_err(map_secret_error)
139 }
140
141 fn get<'py>(
142 &self,
143 py: Python<'py>,
144 session: &PySessionContext,
145 key: &str,
146 ) -> PyResult<Bound<'py, PyBytes>> {
147 let span = info_span!("shadi.secret.get", secret.key = %key);
148 let _guard = span.enter();
149 let ctx = session.to_context();
150 let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
151 let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
152 let secret = access.get_for_session(&ctx, key).map_err(map_secret_error)?;
153 let bytes = secret.expose(|data| data.to_vec());
154 Ok(PyBytes::new_bound(py, &bytes))
155 }
156
157 fn delete(&self, session: &PySessionContext, key: &str) -> PyResult<()> {
158 let span = info_span!("shadi.secret.delete", secret.key = %key);
159 let _guard = span.enter();
160 let ctx = session.to_context();
161 let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
162 let access = AgentSecretAccess::new(guard.as_ref(), &self.verifier);
163 access
164 .delete_for_session(&ctx, key)
165 .map_err(map_secret_error)
166 }
167
168 fn list_keys(&self, session: &PySessionContext) -> PyResult<Vec<String>> {
169 let span = info_span!("shadi.secret.list_keys");
170 let _guard = span.enter();
171 let ctx = session.to_context();
172 AgentSecretAccess::require_verified(&ctx).map_err(map_secret_error)?;
173 let guard = self.store.lock().map_err(|_| PyRuntimeError::new_err("lock poisoned"))?;
174 guard.list_keys().map_err(map_secret_error)
175 }
176}
177
178#[pymethods]
179impl SqlCipherMemoryStore {
180 #[new]
181 #[pyo3(signature = (db_path, key=None, key_name=None))]
182 fn new(db_path: String, key: Option<String>, key_name: Option<String>) -> PyResult<Self> {
183 let key = resolve_memory_key(key, key_name.as_deref())?;
184 let store = SqlCipherStore::open(db_path.as_ref(), &key)
185 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
186 Ok(Self { store })
187 }
188
189 fn put(&self, scope: &str, entry_key: &str, payload: &str) -> PyResult<i64> {
190 let span = info_span!("shadi.memory.put", memory.scope = %scope, memory.entry_key = %entry_key);
191 let _guard = span.enter();
192 self.store
193 .put(scope, entry_key, payload)
194 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
195 }
196
197 fn get_latest(&self, scope: &str, entry_key: &str) -> PyResult<Option<MemoryEntry>> {
198 let span = info_span!("shadi.memory.get_latest", memory.scope = %scope, memory.entry_key = %entry_key);
199 let _guard = span.enter();
200 let entry = self
201 .store
202 .get_latest(scope, entry_key)
203 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
204 Ok(entry.map(MemoryEntry::from_native))
205 }
206
207 #[pyo3(signature = (query, scope=None, limit=10))]
208 fn search(&self, query: &str, scope: Option<String>, limit: usize) -> PyResult<Vec<MemoryEntry>> {
209 let span = info_span!(
210 "shadi.memory.search",
211 memory.query = %query,
212 memory.scope = %scope.as_deref().unwrap_or(""),
213 memory.limit = limit as i64,
214 );
215 let _guard = span.enter();
216 let entries = self
217 .store
218 .search(scope.as_deref(), query, limit)
219 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
220 Ok(entries
221 .into_iter()
222 .map(MemoryEntry::from_native)
223 .collect())
224 }
225
226 #[pyo3(signature = (scope=None, limit=50))]
227 fn list(&self, scope: Option<String>, limit: usize) -> PyResult<Vec<MemoryEntry>> {
228 let span = info_span!(
229 "shadi.memory.list",
230 memory.scope = %scope.as_deref().unwrap_or(""),
231 memory.limit = limit as i64,
232 );
233 let _guard = span.enter();
234 let entries = self
235 .store
236 .list(scope.as_deref(), limit)
237 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
238 Ok(entries
239 .into_iter()
240 .map(MemoryEntry::from_native)
241 .collect())
242 }
243
244 fn delete(&self, scope: &str, entry_key: &str) -> PyResult<usize> {
245 let span = info_span!("shadi.memory.delete", memory.scope = %scope, memory.entry_key = %entry_key);
246 let _guard = span.enter();
247 self.store
248 .delete(scope, entry_key)
249 .map_err(|err| PyRuntimeError::new_err(err.to_string()))
250 }
251}
252
253#[pymethods]
254impl SandboxPolicyHandle {
255 #[new]
256 fn new() -> Self {
257 Self {
258 policy: SandboxPolicy::new(),
259 }
260 }
261
262 fn allow_read_path(&mut self, path: &str) {
263 self.policy = self.policy.clone().allow_read_path(path);
264 }
265
266 fn allow_write_path(&mut self, path: &str) {
267 self.policy = self.policy.clone().allow_write_path(path);
268 }
269
270 fn block_network(&mut self, value: bool) {
271 self.policy = self.policy.clone().block_network(value);
272 }
273}
274
275#[pyclass]
276pub struct PySessionContext {
277 agent_id: String,
278 session_id: String,
279 verified: bool,
280 claims: Vec<String>,
281}
282
283#[pymethods]
284impl PySessionContext {
285 #[new]
286 fn new(agent_id: String, session_id: String) -> Self {
287 Self {
288 agent_id,
289 session_id,
290 verified: false,
291 claims: Vec::new(),
292 }
293 }
294
295 fn set_verified(&mut self, value: bool) {
296 self.verified = value;
297 }
298
299 fn add_claim(&mut self, claim: String) {
300 self.claims.push(claim);
301 }
302}
303
304impl PySessionContext {
305 fn to_context(&self) -> SessionContext {
306 SessionContext {
307 agent_id: self.agent_id.clone(),
308 session_id: self.session_id.clone(),
309 verified: self.verified,
310 claims: self.claims.clone(),
311 }
312 }
313}
314
315#[pymodule]
316fn shadi(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
317 shadi_telemetry::init("shadi-runtime");
318 m.add_class::<ShadiStore>()?;
319 m.add_class::<PySessionContext>()?;
320 m.add_class::<SqlCipherMemoryStore>()?;
321 m.add_class::<MemoryEntry>()?;
322 m.add_class::<SandboxPolicyHandle>()?;
323 m.add_function(wrap_pyfunction!(run_sandboxed, m)?)?;
324 Ok(())
325}
326
327fn map_secret_error(err: SecretError) -> PyErr {
328 PyRuntimeError::new_err(err.to_string())
329}
330
331fn map_sandbox_error(err: SandboxError) -> PyErr {
332 PyRuntimeError::new_err(err.to_string())
333}
334
335fn resolve_memory_key(key: Option<String>, key_name: Option<&str>) -> PyResult<String> {
336 if let Some(key) = key {
337 if key.is_empty() {
338 return Err(PyRuntimeError::new_err("SHADI_MEMORY_KEY is empty"));
339 }
340 return Ok(key);
341 }
342
343 let name = key_name.unwrap_or("shadi/memory/sqlcipher_key");
344 let store = agent_secrets::default_store();
345 let secret = store
346 .get(name)
347 .map_err(|_| PyRuntimeError::new_err(format!("missing SHADI key: {}", name)))?;
348 let raw = secret.expose(|bytes| bytes.to_vec());
349 String::from_utf8(raw).map_err(|_| PyRuntimeError::new_err("SHADI memory key is not utf-8"))
350}
351
352fn inject_keychain_with_store(
353 store: &dyn SecretStore,
354 command: &mut Command,
355 mappings: &[String],
356) -> Result<(), String> {
357 let span = info_span!("shadi.secrets.inject", secret.count = mappings.len() as i64);
358 let _guard = span.enter();
359 for mapping in mappings {
360 let (key, env) = parse_key_env(mapping)?;
361 let secret = store
362 .get(key)
363 .map_err(|_| format!("keychain lookup failed for {}", key))?;
364 let value = secret.expose(|bytes| bytes.to_vec());
365 let value = String::from_utf8(value).map_err(|_| "secret is not utf-8".to_string())?;
366 command.env(env, value);
367 }
368
369 Ok(())
370}
371
372fn parse_key_env(value: &str) -> Result<(&str, &str), String> {
373 let mut parts = value.splitn(2, '=');
374 let key = parts.next().unwrap_or("");
375 let env = parts.next().unwrap_or("");
376 if key.is_empty() || env.is_empty() {
377 return Err("inject-keychain must be in KEY=ENV format".to_string());
378 }
379 Ok((key, env))
380}
381
382#[pyfunction]
383#[pyo3(signature = (command, policy, cwd=None, env=None, inject_keychain=None))]
384fn run_sandboxed(
385 command: Vec<String>,
386 policy: &SandboxPolicyHandle,
387 cwd: Option<String>,
388 env: Option<HashMap<String, String>>,
389 inject_keychain: Option<Vec<String>>,
390) -> PyResult<i32> {
391 if command.is_empty() {
392 return Err(PyRuntimeError::new_err("command must not be empty"));
393 }
394
395 let cwd_value = cwd.as_deref().unwrap_or("");
396 let span = info_span!(
397 "shadi.sandbox.run",
398 command = %command[0],
399 cwd = %cwd_value,
400 exit.code = field::Empty,
401 );
402 let _guard = span.enter();
403
404 let mut cmd = Command::new(&command[0]);
405 if command.len() > 1 {
406 cmd.args(&command[1..]);
407 }
408 if let Some(cwd) = cwd {
409 cmd.current_dir(cwd);
410 }
411 if let Some(env_map) = env {
412 cmd.envs(env_map);
413 }
414 if let Some(mappings) = inject_keychain {
415 let store = agent_secrets::default_store();
416 inject_keychain_with_store(store.as_ref(), &mut cmd, &mappings)
417 .map_err(PyRuntimeError::new_err)?;
418 }
419
420 let mut child = spawn_sandboxed(&mut cmd, &policy.policy).map_err(map_sandbox_error)?;
421 let status = child
422 .wait()
423 .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
424 span.record("exit.code", &status.code().unwrap_or(-1));
425 Ok(status.code().unwrap_or(1))
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use std::sync::Once;
432 use std::time::{SystemTime, UNIX_EPOCH};
433
434 static PY_INIT: Once = Once::new();
435
436 fn ensure_python() {
437 PY_INIT.call_once(|| {
438 pyo3::prepare_freethreaded_python();
439 });
440 }
441
442 fn unique_key(prefix: &str) -> String {
443 let nanos = SystemTime::now()
444 .duration_since(UNIX_EPOCH)
445 .expect("time went backwards")
446 .as_nanos();
447 format!("{}-{}-{}", prefix, std::process::id(), nanos)
448 }
449
450 #[cfg(target_os = "macos")]
451 #[test]
452 fn verify_session_sets_verified_flag() {
453 ensure_python();
454 Python::with_gil(|py| {
455 let store = ShadiStore::new();
456 let module = PyModule::from_code_bound(
457 py,
458 "def verify(agent_id, session_id, presentation, claims):\n return True\n",
459 "verifier.py",
460 "verifier",
461 )
462 .unwrap();
463 let verifier = module.getattr("verify").unwrap();
464 store.set_verifier(verifier.into_py(py)).unwrap();
465
466 let mut base_session = PySessionContext::new("agent".to_string(), "session".to_string());
467 base_session.add_claim("did:example:agent".to_string());
468 let session = Py::new(py, base_session).unwrap();
469 let session_bound = session.bind(py);
470
471 let ok = store.verify_session(py, session_bound, b"presentation").unwrap();
472 assert!(ok);
473 assert!(session_bound.borrow().verified);
474 });
475 }
476
477 #[cfg(target_os = "macos")]
478 #[test]
479 fn verify_session_requires_verifier() {
480 ensure_python();
481 Python::with_gil(|py| {
482 let store = ShadiStore::new();
483 let session = Py::new(py, PySessionContext::new("agent".to_string(), "session".to_string())).unwrap();
484 let session_bound = session.bind(py);
485
486 let err = store
487 .verify_session(py, session_bound, b"presentation")
488 .unwrap_err();
489 assert!(err.is_instance_of::<PyRuntimeError>(py));
490 });
491 }
492
493 #[cfg(target_os = "macos")]
494 #[test]
495 fn put_get_delete_roundtrip_requires_verified() {
496 ensure_python();
497 Python::with_gil(|py| {
498 let store = ShadiStore::new();
499 let mut session = PySessionContext::new("agent".to_string(), "session".to_string());
500 session.add_claim("role:tourist".to_string());
501
502 let err = store.put(&session, "key", b"value").unwrap_err();
503 assert!(err.is_instance_of::<PyRuntimeError>(py));
504
505 session.set_verified(true);
506 let key = unique_key("shadi-py");
507 let secret = b"secret-value";
508
509 store.put(&session, &key, secret).unwrap();
510 let bytes = store.get(py, &session, &key).unwrap();
511 assert_eq!(bytes.as_bytes(), secret);
512 store.delete(&session, &key).unwrap();
513 });
514 }
515
516 #[cfg(target_os = "macos")]
517 #[test]
518 fn list_keys_requires_verified() {
519 ensure_python();
520 Python::with_gil(|_py| {
521 let store = ShadiStore::new();
522 let mut session = PySessionContext::new("agent".to_string(), "session".to_string());
523 session.add_claim("role:secops".to_string());
524 session.set_verified(true);
525
526 let key = unique_key("shadi-py-list");
527 store.put(&session, &key, b"value").unwrap();
528 let keys = store.list_keys(&session).unwrap();
529 assert!(keys.iter().any(|item| item == &key));
530
531 store.delete(&session, &key).unwrap();
532 });
533 }
534}