astrid-capsule 0.5.0

Core runtime management for User-Space Capsules in Astrid OS
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
use astrid_core::capsule_abi::LogLevel;
use extism::{CurrentPlugin, Error, UserData, Val};

use crate::engine::wasm::host::util;
use crate::engine::wasm::host_state::HostState;

#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_log_impl(
    plugin: &mut CurrentPlugin,
    inputs: &[Val],
    _outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let level_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], 64)?;
    let message_bytes: Vec<u8> =
        util::get_safe_bytes(plugin, &inputs[1], util::MAX_LOG_MESSAGE_LEN)?;

    let level = String::from_utf8_lossy(&level_bytes).to_string();
    let message = String::from_utf8_lossy(&message_bytes).to_string();

    let parsed_level: LogLevel = match level.to_lowercase().as_str() {
        "trace" => LogLevel::Trace,
        "debug" => LogLevel::Debug,
        "warn" | "warning" => LogLevel::Warn,
        "error" | "err" => LogLevel::Error,
        _ => LogLevel::Info,
    };

    // Single lock acquisition: extract everything we need, then drop
    // the guard BEFORE any filesystem I/O (critical for cross-principal
    // path which does create_dir_all + open).
    let ud = user_data.get()?;
    let (capsule_id, invocation_principal, log_file) = {
        let state = ud
            .lock()
            .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
        let inv_p = state
            .caller_context
            .as_ref()
            .and_then(|msg| msg.principal.as_deref())
            .and_then(|p| astrid_core::PrincipalId::new(p).ok())
            .filter(|p| *p != state.principal);
        (
            state.capsule_id.as_str().to_owned(),
            inv_p,
            state.capsule_log.clone(),
        )
    };
    // Guard dropped — safe to do filesystem I/O.

    let level_str = match parsed_level {
        LogLevel::Trace => "TRACE",
        LogLevel::Debug => "DEBUG",
        LogLevel::Info => "INFO",
        LogLevel::Warn => "WARN",
        LogLevel::Error => "ERROR",
    };
    let timestamp = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .map_or_else(|_| "0".to_string(), |d| format!("{:.3}", d.as_secs_f64()));

    if let Some(ref inv_principal) = invocation_principal {
        // Cross-principal: open target log file, write, close.
        // No FD caching — append-mode open() is cheap, avoids leaks
        // in 1000-user deployments. OS filesystem cache handles the inode.
        if let Ok(home) = astrid_core::dirs::AstridHome::resolve() {
            let ph = home.principal_home(inv_principal);
            let log_dir = ph.log_dir().join(&capsule_id);
            let _ = std::fs::create_dir_all(&log_dir);
            let today = crate::engine::wasm::today_date_string();
            if let Ok(mut f) = std::fs::OpenOptions::new()
                .create(true)
                .append(true)
                .open(log_dir.join(format!("{today}.log")))
            {
                use std::io::Write;
                let _ = writeln!(f, "{timestamp} {level_str} [{capsule_id}] {message}");
            }
        }
    } else if let Some(ref log_file) = log_file {
        // Default principal: use pre-opened log file (fast path).
        use std::io::Write;
        if let Ok(mut f) = log_file.lock() {
            let _ = writeln!(f, "{timestamp} {level_str} [{capsule_id}] {message}");
        }
    } else {
        match parsed_level {
            LogLevel::Trace => tracing::trace!(plugin = %capsule_id, "{message}"),
            LogLevel::Debug => tracing::debug!(plugin = %capsule_id, "{message}"),
            LogLevel::Info => tracing::info!(plugin = %capsule_id, "{message}"),
            LogLevel::Warn => tracing::warn!(plugin = %capsule_id, "{message}"),
            LogLevel::Error => tracing::error!(plugin = %capsule_id, "{message}"),
        }
    }

    Ok(())
}

#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_get_config_impl(
    plugin: &mut CurrentPlugin,
    inputs: &[Val],
    outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let key_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], util::MAX_KEY_LEN)?;
    let key = String::from_utf8_lossy(&key_bytes).to_string();

    let ud = user_data.get()?;
    let state = ud
        .lock()
        .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
    let value = state.config.get(&key).cloned();
    drop(state);

    // Return the raw string value, not JSON-encoded.
    // serde_json::to_string wraps strings in quotes ("\"value\""),
    // causing double-encoding when the SDK's env::var reads it.
    let result = match value {
        Some(serde_json::Value::String(s)) => s,
        Some(v) => serde_json::to_string(&v).unwrap_or_default(),
        None => String::new(),
    };
    let mem = plugin.memory_new(&result)?;
    outputs[0] = plugin.memory_to_val(mem);
    Ok(())
}

#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_get_caller_impl(
    plugin: &mut CurrentPlugin,
    _inputs: &[Val],
    outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let ud = user_data.get()?;
    let state = ud
        .lock()
        .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;

    let result = if let Some(ref msg) = state.caller_context {
        serde_json::json!({
            "principal": msg.principal,
            "source_id": msg.source_id.to_string(),
            "timestamp": msg.timestamp.to_rfc3339(),
        })
        .to_string()
    } else {
        String::from("{}")
    };
    drop(state);

    let mem = plugin.memory_new(&result)?;
    outputs[0] = plugin.memory_to_val(mem);
    Ok(())
}

/// Signal that the capsule's run loop is ready (subscriptions are active).
///
/// Called by the WASM guest after setting up IPC subscriptions. Sends `true`
/// on the readiness watch channel so the kernel can proceed with loading
/// dependent capsules.
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_signal_ready_impl(
    _plugin: &mut CurrentPlugin,
    _inputs: &[Val],
    _outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let ud = user_data.get()?;
    let state = ud
        .lock()
        .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;

    if let Some(tx) = &state.ready_tx {
        let _ = tx.send(true);
        tracing::debug!(
            capsule = %state.capsule_id,
            "Capsule signaled ready"
        );
    }

    Ok(())
}

/// Returns the current wall-clock time as milliseconds since the UNIX epoch.
///
/// No inputs required. Returns the timestamp as a UTF-8 decimal string.
pub(crate) fn astrid_clock_ms_impl(
    plugin: &mut CurrentPlugin,
    _inputs: &[Val],
    outputs: &mut [Val],
    _user_data: UserData<HostState>,
) -> Result<(), Error> {
    let ms = std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .map_or(0u64, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX));
    let s = ms.to_string();
    let mem = plugin.memory_new(&s)?;
    outputs[0] = plugin.memory_to_val(mem);
    Ok(())
}

/// Trigger request sent by WASM capsules via `hooks::trigger`.
#[derive(serde::Deserialize)]
struct TriggerRequest {
    /// The hook/interceptor topic to fan out (e.g. `before_tool_call`).
    hook: String,
    /// Opaque JSON payload forwarded to each matching interceptor.
    payload: serde_json::Value,
}

#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_trigger_hook_impl(
    plugin: &mut CurrentPlugin,
    inputs: &[Val],
    outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let event_bytes = util::get_safe_bytes(plugin, &inputs[0], 1024 * 1024)?; // 1MB max payload

    let ud = user_data.get()?;
    let state = ud
        .lock()
        .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;

    let caller_id = state.capsule_id.clone();
    let registry = state.capsule_registry.clone();
    let rt_handle = state.runtime_handle.clone();
    let host_semaphore = state.host_semaphore.clone();
    drop(state);

    let result_bytes = if let Some(registry) = registry {
        // Deserialize the trigger request from the WASM guest.
        let request: TriggerRequest = serde_json::from_slice(&event_bytes)
            .map_err(|e| Error::msg(format!("invalid trigger request: {e}")))?;

        let payload_bytes = serde_json::to_vec(&request.payload).unwrap_or_default();

        // Fan out: find all capsules with interceptors matching the hook topic,
        // invoke each (skipping the caller to prevent infinite recursion),
        // and collect their responses.
        //
        // Step 1: Collect matching capsules under the registry read lock.
        // This happens inside block_in_place → block_on so we can acquire
        // the async RwLock, but we do NOT call invoke_interceptor here
        // (which itself does block_in_place and would panic if nested).
        let matches: Vec<(std::sync::Arc<dyn crate::capsule::Capsule>, String)> =
            util::bounded_block_on(&rt_handle, &host_semaphore, async {
                let registry = registry.read().await;
                let mut matches = Vec::new();

                for capsule_id in registry.list() {
                    // Skip the calling capsule to prevent recursion.
                    if *capsule_id == caller_id {
                        continue;
                    }
                    if let Some(capsule) = registry.get(capsule_id) {
                        if !matches!(capsule.state(), crate::capsule::CapsuleState::Ready) {
                            continue;
                        }
                        for interceptor in &capsule.manifest().interceptors {
                            if crate::topic::topic_matches(&request.hook, &interceptor.event) {
                                matches.push((
                                    std::sync::Arc::clone(&capsule),
                                    interceptor.action.clone(),
                                ));
                            }
                        }
                    }
                }
                matches
                // Read lock dropped here.
            });

        // Step 2: Dispatch each interceptor via spawned tasks and collect
        // results. Each invoke_interceptor call may use block_in_place
        // internally, which is safe because it runs in its own spawned task
        // (not nested inside our block_on).
        let responses: Vec<serde_json::Value> =
            util::bounded_block_on(&rt_handle, &host_semaphore, async {
                let mut join_set = tokio::task::JoinSet::new();

                for (capsule, action) in matches {
                    let payload = payload_bytes.clone();
                    let hook = request.hook.clone();
                    join_set.spawn(async move {
                        match capsule.invoke_interceptor(&action, &payload, None) {
                            Ok(crate::capsule::InterceptResult::Continue(bytes))
                                if bytes.is_empty() =>
                            {
                                None
                            },
                            Ok(
                                crate::capsule::InterceptResult::Continue(bytes)
                                | crate::capsule::InterceptResult::Final(bytes),
                            ) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
                                Ok(val) => Some(val),
                                Err(_) => {
                                    tracing::warn!(
                                        capsule_id = %capsule.id(),
                                        action = %action,
                                        "interceptor returned non-JSON response, skipping"
                                    );
                                    None
                                },
                            },
                            Ok(crate::capsule::InterceptResult::Deny { reason }) => {
                                tracing::warn!(
                                    capsule_id = %capsule.id(),
                                    action = %action,
                                    hook = %hook,
                                    reason = %reason,
                                    "interceptor denied during hook trigger"
                                );
                                None
                            },
                            Err(e) => {
                                tracing::warn!(
                                    capsule_id = %capsule.id(),
                                    action = %action,
                                    hook = %hook,
                                    error = %e,
                                    "interceptor invocation failed during hook trigger"
                                );
                                None
                            },
                        }
                    });
                }

                let mut responses = Vec::new();
                while let Some(result) = join_set.join_next().await {
                    if let Ok(Some(val)) = result {
                        responses.push(val);
                    }
                }
                responses
            });

        match serde_json::to_vec(&responses) {
            Ok(bytes) => bytes,
            Err(e) => {
                tracing::warn!(error = %e, "failed to serialize hook responses");
                b"[]".to_vec()
            },
        }
    } else {
        // No registry available — return empty array (no subscribers).
        b"[]".to_vec()
    };

    let mem = plugin.memory_new(&result_bytes)?;
    outputs[0] = plugin.memory_to_val(mem);
    Ok(())
}

/// Request payload for cross-capsule capability checks.
#[derive(serde::Deserialize)]
struct CapabilityCheckRequest {
    /// The UUID of the capsule whose capability is being queried.
    source_uuid: String,
    /// The capability to check (e.g. `"allow_prompt_injection"`).
    capability: String,
}

/// Check whether a capsule (identified by its session UUID) has a specific
/// manifest capability.
///
/// Input: JSON `{"source_uuid": "...", "capability": "allow_prompt_injection"}`
/// Output: JSON `{"allowed": true/false}`
///
/// Returns `{"allowed": false}` for unknown UUIDs, unknown capabilities, or
/// if the registry is unavailable (fail-closed).
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_check_capsule_capability_impl(
    plugin: &mut CurrentPlugin,
    inputs: &[Val],
    outputs: &mut [Val],
    user_data: UserData<HostState>,
) -> Result<(), Error> {
    let request_bytes = util::get_safe_bytes(plugin, &inputs[0], 1024)?;
    let request: CapabilityCheckRequest = serde_json::from_slice(&request_bytes)
        .map_err(|e| Error::msg(format!("invalid capability check request: {e}")))?;

    let ud = user_data.get()?;
    let state = ud
        .lock()
        .map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;

    let registry = state.capsule_registry.clone();
    let rt_handle = state.runtime_handle.clone();
    let host_semaphore = state.host_semaphore.clone();
    drop(state);

    let allowed = if let Some(registry) = registry {
        if let Ok(source_uuid) = uuid::Uuid::parse_str(&request.source_uuid) {
            util::bounded_block_on(&rt_handle, &host_semaphore, async {
                let reg = registry.read().await;
                let Some(capsule_id) = reg.find_by_uuid(&source_uuid) else {
                    tracing::debug!(
                        uuid = %source_uuid,
                        capability = %request.capability,
                        "UUID not found in registry, denying capability"
                    );
                    return false;
                };
                let Some(capsule) = reg.get(capsule_id) else {
                    return false;
                };
                match request.capability.as_str() {
                    "allow_prompt_injection" => {
                        capsule.manifest().capabilities.allow_prompt_injection
                    },
                    other => {
                        tracing::warn!(
                            capability = %other,
                            "Unknown capability requested, denying"
                        );
                        false
                    },
                }
            })
        } else {
            tracing::debug!(
                uuid = %request.source_uuid,
                "Malformed UUID in capability check, denying"
            );
            false
        }
    } else {
        false
    };

    let result = serde_json::json!({"allowed": allowed}).to_string();
    let mem = plugin.memory_new(&result)?;
    outputs[0] = plugin.memory_to_val(mem);
    Ok(())
}