use astrid_core::capsule_abi::LogLevel;
use extism::{CurrentPlugin, Error, UserData, Val};
use crate::engine::wasm::host::util;
use crate::engine::wasm::host_state::HostState;
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_log_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
_outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let level_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], 64)?;
let message_bytes: Vec<u8> =
util::get_safe_bytes(plugin, &inputs[1], util::MAX_LOG_MESSAGE_LEN)?;
let level = String::from_utf8_lossy(&level_bytes).to_string();
let message = String::from_utf8_lossy(&message_bytes).to_string();
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let capsule_id = state.capsule_id.as_str().to_owned();
drop(state);
let parsed_level: LogLevel = match level.to_lowercase().as_str() {
"trace" => LogLevel::Trace,
"debug" => LogLevel::Debug,
"warn" | "warning" => LogLevel::Warn,
"error" | "err" => LogLevel::Error,
_ => LogLevel::Info,
};
match parsed_level {
LogLevel::Trace => tracing::trace!(plugin = %capsule_id, "{message}"),
LogLevel::Debug => tracing::debug!(plugin = %capsule_id, "{message}"),
LogLevel::Info => tracing::info!(plugin = %capsule_id, "{message}"),
LogLevel::Warn => tracing::warn!(plugin = %capsule_id, "{message}"),
LogLevel::Error => tracing::error!(plugin = %capsule_id, "{message}"),
}
Ok(())
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_get_config_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let key_bytes: Vec<u8> = util::get_safe_bytes(plugin, &inputs[0], util::MAX_KEY_LEN)?;
let key = String::from_utf8_lossy(&key_bytes).to_string();
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let value = state.config.get(&key).cloned();
drop(state);
let result = match value {
Some(v) => serde_json::to_string(&v).unwrap_or_default(),
None => String::new(),
};
let mem = plugin.memory_new(&result)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_get_caller_impl(
plugin: &mut CurrentPlugin,
_inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let result = if let Some(_msg) = &state.caller_context {
let session_id = None::<String>; let user_id = None::<String>; serde_json::json!({
"session_id": session_id,
"user_id": user_id
})
.to_string()
} else {
String::from("{}")
};
drop(state);
let mem = plugin.memory_new(&result)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_signal_ready_impl(
_plugin: &mut CurrentPlugin,
_inputs: &[Val],
_outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
if let Some(tx) = &state.ready_tx {
let _ = tx.send(true);
tracing::debug!(
capsule = %state.capsule_id,
"Capsule signaled ready"
);
}
Ok(())
}
pub(crate) fn astrid_clock_ms_impl(
plugin: &mut CurrentPlugin,
_inputs: &[Val],
outputs: &mut [Val],
_user_data: UserData<HostState>,
) -> Result<(), Error> {
let ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0u64, |d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX));
let s = ms.to_string();
let mem = plugin.memory_new(&s)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[derive(serde::Deserialize)]
struct TriggerRequest {
hook: String,
payload: serde_json::Value,
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_trigger_hook_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let event_bytes = util::get_safe_bytes(plugin, &inputs[0], 1024 * 1024)?;
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let caller_id = state.capsule_id.clone();
let registry = state.capsule_registry.clone();
let rt_handle = state.runtime_handle.clone();
let host_semaphore = state.host_semaphore.clone();
drop(state);
let result_bytes = if let Some(registry) = registry {
let request: TriggerRequest = serde_json::from_slice(&event_bytes)
.map_err(|e| Error::msg(format!("invalid trigger request: {e}")))?;
let payload_bytes = serde_json::to_vec(&request.payload).unwrap_or_default();
let matches: Vec<(std::sync::Arc<dyn crate::capsule::Capsule>, String)> =
util::bounded_block_on(&rt_handle, &host_semaphore, async {
let registry = registry.read().await;
let mut matches = Vec::new();
for capsule_id in registry.list() {
if *capsule_id == caller_id {
continue;
}
if let Some(capsule) = registry.get(capsule_id) {
if !matches!(capsule.state(), crate::capsule::CapsuleState::Ready) {
continue;
}
for interceptor in &capsule.manifest().interceptors {
if crate::dispatcher::topic_matches(&request.hook, &interceptor.event) {
matches.push((
std::sync::Arc::clone(&capsule),
interceptor.action.clone(),
));
}
}
}
}
matches
});
let responses: Vec<serde_json::Value> =
util::bounded_block_on(&rt_handle, &host_semaphore, async {
let mut join_set = tokio::task::JoinSet::new();
for (capsule, action) in matches {
let payload = payload_bytes.clone();
let hook = request.hook.clone();
join_set.spawn(async move {
match capsule.invoke_interceptor(&action, &payload) {
Ok(bytes) if bytes.is_empty() => None,
Ok(bytes) => {
match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(val) => Some(val),
Err(_) => {
tracing::warn!(
capsule_id = %capsule.id(),
action = %action,
"interceptor returned non-JSON response, skipping"
);
None
},
}
},
Err(e) => {
tracing::warn!(
capsule_id = %capsule.id(),
action = %action,
hook = %hook,
error = %e,
"interceptor invocation failed during hook trigger"
);
None
},
}
});
}
let mut responses = Vec::new();
while let Some(result) = join_set.join_next().await {
if let Ok(Some(val)) = result {
responses.push(val);
}
}
responses
});
match serde_json::to_vec(&responses) {
Ok(bytes) => bytes,
Err(e) => {
tracing::warn!(error = %e, "failed to serialize hook responses");
b"[]".to_vec()
},
}
} else {
b"[]".to_vec()
};
let mem = plugin.memory_new(&result_bytes)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}
#[derive(serde::Deserialize)]
struct CapabilityCheckRequest {
source_uuid: String,
capability: String,
}
#[expect(clippy::needless_pass_by_value)]
pub(crate) fn astrid_check_capsule_capability_impl(
plugin: &mut CurrentPlugin,
inputs: &[Val],
outputs: &mut [Val],
user_data: UserData<HostState>,
) -> Result<(), Error> {
let request_bytes = util::get_safe_bytes(plugin, &inputs[0], 1024)?;
let request: CapabilityCheckRequest = serde_json::from_slice(&request_bytes)
.map_err(|e| Error::msg(format!("invalid capability check request: {e}")))?;
let ud = user_data.get()?;
let state = ud
.lock()
.map_err(|e| Error::msg(format!("host state lock poisoned: {e}")))?;
let registry = state.capsule_registry.clone();
let rt_handle = state.runtime_handle.clone();
let host_semaphore = state.host_semaphore.clone();
drop(state);
let allowed = if let Some(registry) = registry {
if let Ok(source_uuid) = uuid::Uuid::parse_str(&request.source_uuid) {
util::bounded_block_on(&rt_handle, &host_semaphore, async {
let reg = registry.read().await;
let Some(capsule_id) = reg.find_by_uuid(&source_uuid) else {
tracing::debug!(
uuid = %source_uuid,
capability = %request.capability,
"UUID not found in registry, denying capability"
);
return false;
};
let Some(capsule) = reg.get(capsule_id) else {
return false;
};
match request.capability.as_str() {
"allow_prompt_injection" => {
capsule.manifest().capabilities.allow_prompt_injection
},
other => {
tracing::warn!(
capability = %other,
"Unknown capability requested, denying"
);
false
},
}
})
} else {
tracing::debug!(
uuid = %request.source_uuid,
"Malformed UUID in capability check, denying"
);
false
}
} else {
false
};
let result = serde_json::json!({"allowed": allowed}).to_string();
let mem = plugin.memory_new(&result)?;
outputs[0] = plugin.memory_to_val(mem);
Ok(())
}