use std::fmt;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::context::CapsuleContext;
use crate::error::{CapsuleError, CapsuleResult};
use crate::manifest::CapsuleManifest;
use crate::tool::CapsuleTool;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
pub struct CapsuleId(String);
impl<'de> Deserialize<'de> for CapsuleId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Self::new(s).map_err(serde::de::Error::custom)
}
}
impl CapsuleId {
pub fn new(id: impl Into<String>) -> CapsuleResult<Self> {
let id = id.into();
Self::validate(&id)?;
Ok(Self(id))
}
#[must_use]
pub fn from_static(id: &str) -> Self {
Self(id.to_string())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
fn validate(id: &str) -> CapsuleResult<()> {
if id.is_empty() {
return Err(CapsuleError::UnsupportedEntryPoint(
"capsule id must not be empty".into(),
));
}
if !id
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '-')
{
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"capsule id must contain only lowercase alphanumeric characters and hyphens, got: {id}"
)));
}
Ok(())
}
}
impl fmt::Display for CapsuleId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl AsRef<str> for CapsuleId {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CapsuleState {
Unloaded,
Loading,
Ready,
Failed(String),
Unloading,
}
#[async_trait]
pub trait Capsule: Send + Sync {
fn id(&self) -> &CapsuleId;
fn manifest(&self) -> &CapsuleManifest;
fn state(&self) -> CapsuleState;
async fn load(&mut self, ctx: &CapsuleContext) -> CapsuleResult<()>;
async fn unload(&mut self) -> CapsuleResult<()>;
fn tools(&self) -> &[std::sync::Arc<dyn CapsuleTool>] {
&[] }
fn take_inbound_rx(
&mut self,
) -> Option<tokio::sync::mpsc::Receiver<astrid_core::InboundMessage>> {
None
}
}
pub struct CompositeCapsule {
id: CapsuleId,
manifest: CapsuleManifest,
state: CapsuleState,
engines: Vec<Box<dyn crate::engine::ExecutionEngine>>,
tools: Vec<std::sync::Arc<dyn CapsuleTool>>,
}
impl CompositeCapsule {
pub fn new(manifest: CapsuleManifest) -> CapsuleResult<Self> {
let id = CapsuleId::new(manifest.package.name.clone())?;
Ok(Self {
id,
manifest,
state: CapsuleState::Unloaded,
engines: Vec::new(),
tools: Vec::new(),
})
}
pub fn add_engine(&mut self, engine: Box<dyn crate::engine::ExecutionEngine>) {
self.engines.push(engine);
}
}
#[async_trait]
impl Capsule for CompositeCapsule {
fn id(&self) -> &CapsuleId {
&self.id
}
fn manifest(&self) -> &CapsuleManifest {
&self.manifest
}
fn state(&self) -> CapsuleState {
self.state.clone()
}
async fn load(&mut self, ctx: &CapsuleContext) -> CapsuleResult<()> {
self.state = CapsuleState::Loading;
self.tools.clear();
for engine in &mut self.engines {
if let Err(e) = engine.load(ctx).await {
self.state = CapsuleState::Failed(e.to_string());
return Err(e);
}
self.tools.extend_from_slice(engine.tools());
}
self.state = CapsuleState::Ready;
Ok(())
}
async fn unload(&mut self) -> CapsuleResult<()> {
self.state = CapsuleState::Unloading;
for engine in &mut self.engines {
let _ = engine.unload().await;
}
self.tools.clear();
self.state = CapsuleState::Unloaded;
Ok(())
}
fn tools(&self) -> &[std::sync::Arc<dyn CapsuleTool>] {
&self.tools
}
fn take_inbound_rx(
&mut self,
) -> Option<tokio::sync::mpsc::Receiver<astrid_core::InboundMessage>> {
for engine in &mut self.engines {
if let Some(rx) = engine.take_inbound_rx() {
return Some(rx);
}
}
None
}
}