use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{Semaphore, mpsc, watch};
use tokio_util::sync::CancellationToken;
use crate::capsule::CapsuleId;
use astrid_core::uplink::{InboundMessage, MAX_UPLINKS_PER_CAPSULE, UplinkDescriptor};
use astrid_storage::ScopedKvStore;
use astrid_storage::secret::SecretStore;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LifecyclePhase {
Install,
Upgrade,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct InterceptorHandle {
pub handle_id: u64,
pub action: String,
pub topic: String,
}
use crate::engine::wasm::host::process::ProcessTracker;
use crate::security::CapsuleSecurityGate;
pub struct HostState {
pub capsule_id: CapsuleId,
pub caller_context: Option<astrid_events::ipc::IpcMessage>,
pub capsule_uuid: uuid::Uuid,
pub workspace_root: PathBuf,
pub vfs: Arc<dyn astrid_vfs::Vfs>,
pub vfs_root_handle: astrid_capabilities::DirHandle,
pub global_root: Option<PathBuf>,
pub global_vfs: Option<Arc<dyn astrid_vfs::Vfs>>,
pub global_vfs_root_handle: Option<astrid_capabilities::DirHandle>,
pub overlay_vfs: Option<Arc<astrid_vfs::OverlayVfs>>,
pub upper_dir: Option<Arc<tempfile::TempDir>>,
pub kv: ScopedKvStore,
pub event_bus: astrid_events::EventBus,
pub ipc_limiter: astrid_events::ipc::IpcRateLimiter,
pub subscriptions: HashMap<u64, astrid_events::EventReceiver>,
pub next_subscription_id: u64,
pub config: HashMap<String, serde_json::Value>,
pub ipc_publish_patterns: Vec<String>,
pub ipc_subscribe_patterns: Vec<String>,
pub security: Option<Arc<dyn CapsuleSecurityGate>>,
pub hook_manager: Option<Arc<dyn std::any::Any + Send + Sync>>,
pub capsule_registry: Option<Arc<tokio::sync::RwLock<crate::registry::CapsuleRegistry>>>,
pub runtime_handle: tokio::runtime::Handle,
pub has_uplink_capability: bool,
pub inbound_tx: Option<mpsc::Sender<InboundMessage>>,
pub registered_uplinks: Vec<UplinkDescriptor>,
pub cli_socket_listener: Option<Arc<tokio::sync::Mutex<tokio::net::UnixListener>>>,
pub active_streams:
std::collections::HashMap<u64, Arc<tokio::sync::Mutex<tokio::net::UnixStream>>>,
pub next_stream_id: u64,
pub lifecycle_phase: Option<LifecyclePhase>,
pub secret_store: Arc<dyn SecretStore>,
pub ready_tx: Option<watch::Sender<bool>>,
pub host_semaphore: Arc<Semaphore>,
pub cancel_token: CancellationToken,
pub session_token: Option<std::sync::Arc<astrid_core::session_token::SessionToken>>,
pub interceptor_handles: Vec<InterceptorHandle>,
pub allowance_store: Option<std::sync::Arc<astrid_approval::AllowanceStore>>,
pub identity_store: Option<std::sync::Arc<dyn astrid_storage::IdentityStore>>,
pub background_processes: HashMap<u64, crate::engine::wasm::host::process::ManagedProcess>,
pub next_process_id: u64,
pub process_tracker: Arc<ProcessTracker>,
}
impl HostState {
pub fn register_uplink(&mut self, descriptor: UplinkDescriptor) -> Result<(), &'static str> {
if self.registered_uplinks.len() >= MAX_UPLINKS_PER_CAPSULE {
return Err("uplink registration limit reached");
}
let duplicate = self
.registered_uplinks
.iter()
.any(|c| c.name == descriptor.name && c.platform == descriptor.platform);
if duplicate {
return Err("duplicate uplink name and platform");
}
self.registered_uplinks.push(descriptor);
Ok(())
}
#[must_use]
pub fn uplinks(&self) -> &[UplinkDescriptor] {
&self.registered_uplinks
}
#[must_use]
pub fn default_host_semaphore() -> Arc<Semaphore> {
Arc::new(Semaphore::new(
std::thread::available_parallelism()
.map(|n| n.get().saturating_sub(2).max(2))
.unwrap_or(2),
))
}
pub fn set_inbound_tx(&mut self, tx: mpsc::Sender<InboundMessage>) {
self.inbound_tx = Some(tx);
}
}
impl std::fmt::Debug for HostState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostState")
.field("capsule_id", &self.capsule_id)
.field("workspace_root", &self.workspace_root)
.field("vfs_root_handle", &self.vfs_root_handle)
.field("has_global_root", &self.global_root.is_some())
.field("has_security", &self.security.is_some())
.field("has_uplink_capability", &self.has_uplink_capability)
.field("has_inbound_tx", &self.inbound_tx.is_some())
.field("registered_uplinks", &self.registered_uplinks.len())
.field(
"host_semaphore_permits",
&self.host_semaphore.available_permits(),
)
.field("cancel_token_cancelled", &self.cancel_token.is_cancelled())
.field("has_identity_store", &self.identity_store.is_some())
.field("process_tracker", &self.process_tracker)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_state_debug_format() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "capsule:test").unwrap();
let secret_store: Arc<dyn SecretStore> = Arc::new(astrid_storage::KvSecretStore::new(
kv.clone(),
rt.handle().clone(),
));
let state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
caller_context: None,
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
global_root: None,
global_vfs: None,
global_vfs_root_handle: None,
overlay_vfs: None,
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
ipc_publish_patterns: Vec::new(),
ipc_subscribe_patterns: Vec::new(),
security: None,
hook_manager: None,
capsule_registry: None,
runtime_handle: rt.handle().clone(),
has_uplink_capability: false,
inbound_tx: None,
registered_uplinks: Vec::new(),
cli_socket_listener: None,
active_streams: std::collections::HashMap::new(),
next_stream_id: 1,
lifecycle_phase: None,
secret_store: secret_store.clone(),
ready_tx: None,
host_semaphore: Arc::new(Semaphore::new(2)),
cancel_token: CancellationToken::new(),
session_token: None,
interceptor_handles: Vec::new(),
allowance_store: None,
identity_store: None,
background_processes: HashMap::new(),
next_process_id: 1,
process_tracker: Arc::new(ProcessTracker::new()),
};
let debug = format!("{state:?}");
assert!(debug.contains("test"));
assert!(debug.contains("has_security"));
assert!(debug.contains("has_inbound_tx"));
assert!(debug.contains("registered_uplinks"));
}
#[test]
fn register_uplink_accumulates() {
use crate::capsule::CapsuleId;
use astrid_core::uplink::{UplinkCapabilities, UplinkProfile, UplinkSource};
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "capsule:test").unwrap();
let secret_store: Arc<dyn SecretStore> = Arc::new(astrid_storage::KvSecretStore::new(
kv.clone(),
rt.handle().clone(),
));
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
caller_context: None,
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
global_root: None,
global_vfs: None,
global_vfs_root_handle: None,
overlay_vfs: None,
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
ipc_publish_patterns: Vec::new(),
ipc_subscribe_patterns: Vec::new(),
security: None,
hook_manager: None,
capsule_registry: None,
runtime_handle: rt.handle().clone(),
has_uplink_capability: true,
inbound_tx: None,
registered_uplinks: Vec::new(),
cli_socket_listener: None,
active_streams: std::collections::HashMap::new(),
next_stream_id: 1,
lifecycle_phase: None,
secret_store: secret_store.clone(),
ready_tx: None,
host_semaphore: Arc::new(Semaphore::new(2)),
cancel_token: CancellationToken::new(),
session_token: None,
interceptor_handles: Vec::new(),
allowance_store: None,
identity_store: None,
background_processes: HashMap::new(),
next_process_id: 1,
process_tracker: Arc::new(ProcessTracker::new()),
};
assert!(state.uplinks().is_empty());
let desc = UplinkDescriptor::builder("test-conn", "discord")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
state.register_uplink(desc).unwrap();
assert_eq!(state.uplinks().len(), 1);
assert_eq!(state.uplinks()[0].name, "test-conn");
}
#[test]
fn set_inbound_tx_stores_sender() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "capsule:test").unwrap();
let secret_store: Arc<dyn SecretStore> = Arc::new(astrid_storage::KvSecretStore::new(
kv.clone(),
rt.handle().clone(),
));
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
caller_context: None,
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
global_root: None,
global_vfs: None,
global_vfs_root_handle: None,
overlay_vfs: None,
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
ipc_publish_patterns: Vec::new(),
ipc_subscribe_patterns: Vec::new(),
security: None,
hook_manager: None,
capsule_registry: None,
runtime_handle: rt.handle().clone(),
has_uplink_capability: false,
inbound_tx: None,
registered_uplinks: Vec::new(),
cli_socket_listener: None,
active_streams: std::collections::HashMap::new(),
next_stream_id: 1,
lifecycle_phase: None,
secret_store: secret_store.clone(),
ready_tx: None,
host_semaphore: Arc::new(Semaphore::new(2)),
cancel_token: CancellationToken::new(),
session_token: None,
interceptor_handles: Vec::new(),
allowance_store: None,
identity_store: None,
background_processes: HashMap::new(),
next_process_id: 1,
process_tracker: Arc::new(ProcessTracker::new()),
};
assert!(state.inbound_tx.is_none());
let (tx, _rx) = mpsc::channel(256);
state.set_inbound_tx(tx);
assert!(state.inbound_tx.is_some());
}
#[test]
fn register_uplink_rejects_at_limit() {
use crate::capsule::CapsuleId;
use astrid_core::uplink::{UplinkCapabilities, UplinkProfile, UplinkSource};
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "capsule:test").unwrap();
let secret_store: Arc<dyn SecretStore> = Arc::new(astrid_storage::KvSecretStore::new(
kv.clone(),
rt.handle().clone(),
));
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
caller_context: None,
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
global_root: None,
global_vfs: None,
global_vfs_root_handle: None,
overlay_vfs: None,
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
ipc_publish_patterns: Vec::new(),
ipc_subscribe_patterns: Vec::new(),
security: None,
hook_manager: None,
capsule_registry: None,
runtime_handle: rt.handle().clone(),
has_uplink_capability: true,
inbound_tx: None,
registered_uplinks: Vec::new(),
cli_socket_listener: None,
active_streams: std::collections::HashMap::new(),
next_stream_id: 1,
lifecycle_phase: None,
secret_store: secret_store.clone(),
ready_tx: None,
host_semaphore: Arc::new(Semaphore::new(2)),
cancel_token: CancellationToken::new(),
session_token: None,
interceptor_handles: Vec::new(),
allowance_store: None,
identity_store: None,
background_processes: HashMap::new(),
next_process_id: 1,
process_tracker: Arc::new(ProcessTracker::new()),
};
for i in 0..MAX_UPLINKS_PER_CAPSULE {
let desc = UplinkDescriptor::builder(format!("conn-{i}"), "discord")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
assert!(state.register_uplink(desc).is_ok());
}
assert_eq!(state.uplinks().len(), MAX_UPLINKS_PER_CAPSULE);
let extra = UplinkDescriptor::builder("over-limit", "discord")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
assert!(state.register_uplink(extra).is_err());
assert_eq!(state.uplinks().len(), MAX_UPLINKS_PER_CAPSULE);
}
#[test]
fn register_uplink_rejects_duplicate_name_and_platform() {
use crate::capsule::CapsuleId;
use astrid_core::uplink::{UplinkCapabilities, UplinkProfile, UplinkSource};
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "capsule:test").unwrap();
let secret_store: Arc<dyn SecretStore> = Arc::new(astrid_storage::KvSecretStore::new(
kv.clone(),
rt.handle().clone(),
));
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
caller_context: None,
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
global_root: None,
global_vfs: None,
global_vfs_root_handle: None,
overlay_vfs: None,
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
ipc_publish_patterns: Vec::new(),
ipc_subscribe_patterns: Vec::new(),
security: None,
hook_manager: None,
capsule_registry: None,
runtime_handle: rt.handle().clone(),
has_uplink_capability: true,
inbound_tx: None,
registered_uplinks: Vec::new(),
cli_socket_listener: None,
active_streams: std::collections::HashMap::new(),
next_stream_id: 1,
lifecycle_phase: None,
secret_store: secret_store.clone(),
ready_tx: None,
host_semaphore: Arc::new(Semaphore::new(2)),
cancel_token: CancellationToken::new(),
session_token: None,
interceptor_handles: Vec::new(),
allowance_store: None,
identity_store: None,
background_processes: HashMap::new(),
next_process_id: 1,
process_tracker: Arc::new(ProcessTracker::new()),
};
let desc1 = UplinkDescriptor::builder("my-conn", "discord")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
assert!(state.register_uplink(desc1).is_ok());
let desc2 = UplinkDescriptor::builder("my-conn", "discord")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
let err = state.register_uplink(desc2).unwrap_err();
assert!(err.contains("duplicate"), "expected duplicate error: {err}");
let desc3 = UplinkDescriptor::builder("my-conn", "telegram")
.source(UplinkSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(UplinkCapabilities::receive_only())
.profile(UplinkProfile::Chat)
.build();
assert!(state.register_uplink(desc3).is_ok());
assert_eq!(state.uplinks().len(), 2);
}
}