use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::capsule::CapsuleId;
use astrid_core::connector::{ConnectorDescriptor, InboundMessage, MAX_CONNECTORS_PER_PLUGIN};
use astrid_storage::ScopedKvStore;
use crate::security::CapsuleSecurityGate;
pub struct HostState {
pub capsule_id: CapsuleId,
pub capsule_uuid: uuid::Uuid,
pub workspace_root: PathBuf,
pub vfs: Arc<dyn astrid_vfs::Vfs>,
pub vfs_root_handle: astrid_capabilities::DirHandle,
pub upper_dir: Option<Arc<tempfile::TempDir>>,
pub kv: ScopedKvStore,
pub event_bus: astrid_events::EventBus,
pub ipc_limiter: astrid_events::ipc::IpcRateLimiter,
pub subscriptions: HashMap<u64, astrid_events::EventReceiver>,
pub next_subscription_id: u64,
pub config: HashMap<String, serde_json::Value>,
pub security: Option<Arc<dyn CapsuleSecurityGate>>,
pub runtime_handle: tokio::runtime::Handle,
pub has_connector_capability: bool,
pub inbound_tx: Option<mpsc::Sender<InboundMessage>>,
pub registered_connectors: Vec<ConnectorDescriptor>,
}
impl HostState {
pub fn register_connector(
&mut self,
descriptor: ConnectorDescriptor,
) -> Result<(), &'static str> {
if self.registered_connectors.len() >= MAX_CONNECTORS_PER_PLUGIN {
return Err("connector registration limit reached");
}
let duplicate = self
.registered_connectors
.iter()
.any(|c| c.name == descriptor.name && c.frontend_type == descriptor.frontend_type);
if duplicate {
return Err("duplicate connector name and platform");
}
self.registered_connectors.push(descriptor);
Ok(())
}
#[must_use]
pub fn connectors(&self) -> &[ConnectorDescriptor] {
&self.registered_connectors
}
pub fn set_inbound_tx(&mut self, tx: mpsc::Sender<InboundMessage>) {
self.inbound_tx = Some(tx);
}
}
impl std::fmt::Debug for HostState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HostState")
.field("capsule_id", &self.capsule_id)
.field("workspace_root", &self.workspace_root)
.field("vfs_root_handle", &self.vfs_root_handle)
.field("has_security", &self.security.is_some())
.field("has_connector_capability", &self.has_connector_capability)
.field("has_inbound_tx", &self.inbound_tx.is_some())
.field("registered_connectors", &self.registered_connectors.len())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_state_debug_format() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "plugin:test").unwrap();
let state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: std::sync::Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
security: None,
runtime_handle: rt.handle().clone(),
has_connector_capability: false,
inbound_tx: None,
registered_connectors: Vec::new(),
};
let debug = format!("{state:?}");
assert!(debug.contains("test"));
assert!(debug.contains("has_security"));
assert!(debug.contains("has_inbound_tx"));
assert!(debug.contains("registered_connectors"));
}
#[test]
fn register_connector_accumulates() {
use crate::capsule::CapsuleId;
use astrid_core::connector::{ConnectorCapabilities, ConnectorProfile, ConnectorSource};
use astrid_core::identity::FrontendType;
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "plugin:test").unwrap();
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: std::sync::Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
security: None,
runtime_handle: rt.handle().clone(),
has_connector_capability: true,
inbound_tx: None,
registered_connectors: Vec::new(),
};
assert!(state.connectors().is_empty());
let desc = ConnectorDescriptor::builder("test-conn", FrontendType::Discord)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
state.register_connector(desc).unwrap();
assert_eq!(state.connectors().len(), 1);
assert_eq!(state.connectors()[0].name, "test-conn");
}
#[test]
fn set_inbound_tx_stores_sender() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "plugin:test").unwrap();
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: std::sync::Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
security: None,
runtime_handle: rt.handle().clone(),
has_connector_capability: false,
inbound_tx: None,
registered_connectors: Vec::new(),
};
assert!(state.inbound_tx.is_none());
let (tx, _rx) = mpsc::channel(256);
state.set_inbound_tx(tx);
assert!(state.inbound_tx.is_some());
}
#[test]
fn register_connector_rejects_at_limit() {
use crate::capsule::CapsuleId;
use astrid_core::connector::{ConnectorCapabilities, ConnectorProfile, ConnectorSource};
use astrid_core::identity::FrontendType;
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "plugin:test").unwrap();
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: std::sync::Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
security: None,
runtime_handle: rt.handle().clone(),
has_connector_capability: true,
inbound_tx: None,
registered_connectors: Vec::new(),
};
for i in 0..MAX_CONNECTORS_PER_PLUGIN {
let desc = ConnectorDescriptor::builder(format!("conn-{i}"), FrontendType::Discord)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
assert!(state.register_connector(desc).is_ok());
}
assert_eq!(state.connectors().len(), MAX_CONNECTORS_PER_PLUGIN);
let extra = ConnectorDescriptor::builder("over-limit", FrontendType::Discord)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
assert!(state.register_connector(extra).is_err());
assert_eq!(state.connectors().len(), MAX_CONNECTORS_PER_PLUGIN);
}
#[test]
fn register_connector_rejects_duplicate_name_and_platform() {
use crate::capsule::CapsuleId;
use astrid_core::connector::{ConnectorCapabilities, ConnectorProfile, ConnectorSource};
use astrid_core::identity::FrontendType;
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
let store = Arc::new(astrid_storage::MemoryKvStore::new());
let kv = ScopedKvStore::new(store, "plugin:test").unwrap();
let mut state = HostState {
capsule_uuid: uuid::Uuid::new_v4(),
capsule_id: CapsuleId::from_static("test"),
workspace_root: PathBuf::from("/tmp"),
vfs: std::sync::Arc::new(astrid_vfs::HostVfs::new()),
vfs_root_handle: astrid_capabilities::DirHandle::new(),
upper_dir: None,
kv,
event_bus: astrid_events::EventBus::with_capacity(128),
ipc_limiter: astrid_events::ipc::IpcRateLimiter::new(),
subscriptions: HashMap::new(),
next_subscription_id: 1,
config: HashMap::new(),
security: None,
runtime_handle: rt.handle().clone(),
has_connector_capability: true,
inbound_tx: None,
registered_connectors: Vec::new(),
};
let desc1 = ConnectorDescriptor::builder("my-conn", FrontendType::Discord)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
assert!(state.register_connector(desc1).is_ok());
let desc2 = ConnectorDescriptor::builder("my-conn", FrontendType::Discord)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
let err = state.register_connector(desc2).unwrap_err();
assert!(err.contains("duplicate"), "expected duplicate error: {err}");
let desc3 = ConnectorDescriptor::builder("my-conn", FrontendType::Telegram)
.source(ConnectorSource::Wasm {
capsule_id: "test".into(),
})
.capabilities(ConnectorCapabilities::receive_only())
.profile(ConnectorProfile::Chat)
.build();
assert!(state.register_connector(desc3).is_ok());
assert_eq!(state.connectors().len(), 2);
}
}