Skip to main content

composable_runtime/composition/
registry.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::any::{Any, TypeId};
4use std::collections::{HashMap, HashSet};
5use std::marker::PhantomData;
6use std::path::PathBuf;
7use std::sync::Arc;
8use wasmtime::component::{HasData, Linker};
9
10use super::composer::Composer;
11use super::graph::{ComponentGraph, Edge, Node};
12use super::wit::{ComponentMetadata, Parser};
13use crate::types::{CapabilityDefinition, ComponentDefinition, ComponentState, Function};
14
15/// Trait implemented by host capability instances.
16///
17/// An instance represents a configured capability (from one TOML block).
18/// Multiple TOML blocks with the same type create multiple instances.
19pub trait HostCapability: Send + Sync {
20    /// Fully qualified interfaces this capability provides (namespace:package/interface@version)
21    fn interfaces(&self) -> Vec<String>;
22
23    /// Add bindings to the linker. Called once per component instantiation.
24    fn link(&self, linker: &mut Linker<ComponentState>) -> wasmtime::Result<()>;
25
26    /// Create per-component-instance state. Called once per component instantiation.
27    /// Returns None if capability needs no per-instance state.
28    fn create_state_boxed(&self) -> Result<Option<(TypeId, Box<dyn Any + Send>)>> {
29        Ok(None)
30    }
31}
32
33/// `HasData` implementation that projects `ComponentState` to a capability's state type.
34///
35/// Use this in `link()` when adding bindings so host impls receive only their own state
36/// (the type created by `create_state_boxed`), not full `ComponentState`.
37///
38/// # Example
39///
40/// ```ignore
41/// fn link(&self, linker: &mut Linker<ComponentState>) -> wasmtime::Result<()> {
42///     my_capability::add_to_linker::<_, CapabilityStateHasData<MyState>>(
43///         linker,
44///         |state| state.get_extension_mut::<MyState>().expect("MyState not initialized"),
45///     )?;
46///     Ok(())
47/// }
48/// ```
49///
50/// The host impl must then be `my_capability::Host for MyState` (not `ComponentState`).
51pub struct CapabilityStateHasData<T>(PhantomData<T>);
52
53impl<T: Send + 'static> HasData for CapabilityStateHasData<T> {
54    type Data<'a> = &'a mut T;
55}
56
57/// Factory function that creates a HostCapability instance from TOML config.
58pub type HostCapabilityFactory =
59    Box<dyn Fn(serde_json::Value) -> Result<Box<dyn HostCapability>> + Send + Sync>;
60
61/// Macro for implementing `create_state_boxed()` with automatic TypeId inference.
62///
63/// The `$body` expression receives the capability instance via `$self` and can use `?`
64/// for fallible operations.
65///
66/// # Example
67///
68/// ```ignore
69/// impl HostCapability for MyCapability {
70///     // ...
71///     create_state!(this, MyState, {
72///         MyState {
73///             shared_resource: this.get_resource(),
74///             counter: 0,
75///         }
76///     });
77/// }
78/// ```
79#[macro_export]
80macro_rules! create_state {
81    ($self:ident, $type:ty, $body:expr) => {
82        fn create_state_boxed(
83            &self,
84        ) -> anyhow::Result<Option<(std::any::TypeId, Box<dyn std::any::Any + Send>)>> {
85            let $self = self;
86            let state: $type = $body;
87            Ok(Some((std::any::TypeId::of::<$type>(), Box::new(state))))
88        }
89    };
90}
91
92/// Macro for creating a `(&str, HostCapabilityFactory)` tuple with reduced boilerplate.
93///
94/// # Examples
95///
96/// ```ignore
97/// // Without config — capability is constructed directly
98/// create_capability!("greeting", GreetingCapability {
99///     message: self.message.clone(),
100/// })
101///
102/// // With config — closure receives the capability's config value
103/// create_capability!("greeting", |config| {
104///     let suffix = config.get("suffix").and_then(|v| v.as_str()).unwrap_or("!");
105///     GreetingCapability { message: self.message.clone(), suffix: suffix.to_string() }
106/// })
107/// ```
108#[macro_export]
109macro_rules! create_capability {
110    ($name:expr, |$config:ident| $body:expr) => {
111        (
112            $name,
113            Box::new(
114                move |$config: serde_json::Value| -> anyhow::Result<Box<dyn $crate::HostCapability>> {
115                    Ok(Box::new($body))
116                },
117            ) as $crate::HostCapabilityFactory,
118        )
119    };
120    ($name:expr, $body:expr) => {
121        (
122            $name,
123            Box::new(
124                move |_config: serde_json::Value| -> anyhow::Result<Box<dyn $crate::HostCapability>> {
125                    Ok(Box::new($body))
126                },
127            ) as $crate::HostCapabilityFactory,
128        )
129    };
130}
131
132// TODO: `properties` (wasi:* only) and `instance` (custom only)
133// should unify into an enum (e.g., Wasi vs Custom variants).
134#[derive(Serialize, Deserialize)]
135pub struct Capability {
136    pub kind: String,
137    pub scope: String,
138    pub interfaces: Vec<String>,
139    pub properties: HashMap<String, serde_json::Value>,
140    #[serde(skip)]
141    pub instance: Option<Box<dyn HostCapability>>,
142}
143
144impl std::fmt::Debug for Capability {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("Capability")
147            .field("kind", &self.kind)
148            .field("scope", &self.scope)
149            .field("interfaces", &self.interfaces)
150            .field(
151                "instance",
152                &self.instance.as_ref().map(|_| "<dyn HostCapability>"),
153            )
154            .finish()
155    }
156}
157
158#[derive(Debug, Clone)]
159pub struct ComponentSpec {
160    pub name: String,
161    pub namespace: Option<String>,
162    pub package: Option<String>,
163    pub bytes: Arc<[u8]>,
164    pub imports: Vec<String>,
165    pub exports: Vec<String>,
166    pub capabilities: Vec<String>,
167    pub functions: HashMap<String, Function>,
168}
169
170#[derive(Debug, Clone)]
171pub struct CapabilityRegistry {
172    pub capabilities: Arc<HashMap<String, Capability>>,
173}
174
175#[derive(Debug, Clone)]
176pub struct ComponentRegistry {
177    pub components: Arc<HashMap<String, ComponentSpec>>,
178}
179
180impl CapabilityRegistry {
181    pub fn new(capabilities: HashMap<String, Capability>) -> Self {
182        Self {
183            capabilities: Arc::new(capabilities),
184        }
185    }
186
187    pub fn get_capability(&self, name: &str) -> Option<&Capability> {
188        self.capabilities.get(name)
189    }
190
191    // TODO: replace hardcoded "any" with label selector evaluation
192    pub fn verify_importable(
193        &self,
194        candidate: &CapabilityDefinition,
195        requester: &ComponentDefinition,
196    ) -> Result<()> {
197        match candidate.scope.as_str() {
198            "any" => Ok(()),
199            scope => Err(anyhow::anyhow!(
200                "Component '{}' cannot import capability '{}' (scope: '{scope}')",
201                requester.name,
202                candidate.name
203            )),
204        }
205    }
206}
207
208impl ComponentRegistry {
209    pub fn empty() -> Self {
210        Self {
211            components: Arc::new(HashMap::new()),
212        }
213    }
214
215    pub fn get_components(&self) -> impl Iterator<Item = &ComponentSpec> {
216        self.components
217            .values()
218            .filter(|spec| !spec.name.starts_with('_'))
219    }
220
221    pub fn get_component(&self, name: &str) -> Option<&ComponentSpec> {
222        if name.starts_with('_') {
223            return None;
224        }
225        self.components.get(name)
226    }
227
228    // TODO: replace hardcoded "any" with label selector evaluation
229    pub fn get_required_import(
230        &self,
231        candidate: &ComponentDefinition,
232        requester: &ComponentDefinition,
233        _requester_metadata: &ComponentMetadata,
234    ) -> Result<&ComponentSpec> {
235        let component = self
236            .components
237            .get(&candidate.name)
238            .expect("component must exist in registry");
239        match candidate.scope.as_str() {
240            "any" => Ok(component),
241            scope => Err(anyhow::anyhow!(
242                "Component '{}' cannot import dependency '{}' (scope: '{scope}')",
243                requester.name,
244                candidate.name
245            )),
246        }
247    }
248}
249
250impl Default for ComponentRegistry {
251    fn default() -> Self {
252        Self::empty()
253    }
254}
255
256/// Build registries from definitions
257pub async fn build_registries(
258    component_graph: &ComponentGraph,
259    factories: HashMap<&'static str, HostCapabilityFactory>,
260) -> Result<(ComponentRegistry, CapabilityRegistry)> {
261    let mut capability_definitions = Vec::new();
262    for node in component_graph.nodes() {
263        if let Node::Capability(def) = &node.weight {
264            capability_definitions.push(def.clone());
265        }
266    }
267
268    let capability_registry = create_capability_registry(capability_definitions, factories)?;
269
270    let sorted_indices = component_graph.get_build_order();
271
272    let mut built_components = HashMap::new();
273
274    for node_index in sorted_indices {
275        if let Node::Component(definition) = &component_graph[node_index] {
276            let temp_component_registry = ComponentRegistry {
277                components: Arc::new(built_components.clone()),
278            };
279
280            let component_spec = process_component(
281                node_index,
282                component_graph,
283                &temp_component_registry,
284                &capability_registry,
285            )
286            .await?;
287
288            built_components.insert(definition.name.clone(), component_spec);
289        }
290    }
291
292    Ok((
293        ComponentRegistry {
294            components: Arc::new(built_components),
295        },
296        capability_registry,
297    ))
298}
299
300fn create_capability_registry(
301    capability_definitions: Vec<CapabilityDefinition>,
302    factories: HashMap<&'static str, HostCapabilityFactory>,
303) -> Result<CapabilityRegistry> {
304    let mut capabilities = HashMap::new();
305
306    for def in capability_definitions {
307        let (interfaces, capability_instance) = if def.kind.starts_with("wasi:") {
308            let interfaces = get_interfaces_for_capability(&def.kind);
309            if interfaces.is_empty() {
310                anyhow::bail!("Unknown capability type: '{}'", def.kind);
311            }
312            (interfaces, None)
313        } else {
314            // Custom capability
315            let factory = factories.get(def.kind.as_str()).ok_or_else(|| {
316                anyhow::anyhow!(
317                    "Capability type '{}' not registered. Use Runtime::builder().with_capability::<T>(\"{}\")",
318                    def.kind,
319                    def.kind
320                )
321            })?;
322
323            let config_value = serde_json::to_value(&def.properties)?;
324            let cap = factory(config_value).map_err(|e| {
325                anyhow::anyhow!(
326                    "Failed to create capability '{}' from TOML block '{}': {}",
327                    def.kind,
328                    def.name,
329                    e
330                )
331            })?;
332
333            (cap.interfaces(), Some(cap))
334        };
335
336        let capability = Capability {
337            kind: def.kind.clone(),
338            scope: def.scope.clone(),
339            interfaces,
340            properties: def.properties,
341            instance: capability_instance,
342        };
343        capabilities.insert(def.name, capability);
344    }
345
346    Ok(CapabilityRegistry::new(capabilities))
347}
348
349fn get_interfaces_for_capability(kind: &str) -> Vec<String> {
350    match kind {
351        "wasi:cli" => vec![
352            "wasi:cli/stdin@0.2.6".to_string(),
353            "wasi:cli/stdout@0.2.6".to_string(),
354            "wasi:cli/stderr@0.2.6".to_string(),
355            "wasi:cli/environment@0.2.6".to_string(),
356        ],
357        "wasi:clocks" => vec![
358            "wasi:clocks/monotonic-clock@0.2.6".to_string(),
359            "wasi:clocks/wall-clock@0.2.6".to_string(),
360        ],
361        "wasi:http" => vec![
362            "wasi:http/outgoing-handler@0.2.6".to_string(),
363            "wasi:http/types@0.2.6".to_string(),
364            // io is a transitive dep
365            "wasi:io/error@0.2.6".to_string(),
366            "wasi:io/poll@0.2.6".to_string(),
367            "wasi:io/streams@0.2.6".to_string(),
368        ],
369        "wasi:io" => vec![
370            "wasi:io/error@0.2.6".to_string(),
371            "wasi:io/poll@0.2.6".to_string(),
372            "wasi:io/streams@0.2.6".to_string(),
373        ],
374        "wasi:random" => vec![
375            "wasi:random/random@0.2.6".to_string(),
376            "wasi:random/insecure@0.2.6".to_string(),
377            "wasi:random/insecure-seed@0.2.6".to_string(),
378        ],
379        "wasi:sockets" => vec![
380            "wasi:sockets/tcp@0.2.6".to_string(),
381            "wasi:sockets/udp@0.2.6".to_string(),
382            "wasi:sockets/network@0.2.6".to_string(),
383            "wasi:sockets/instance-network@0.2.6".to_string(),
384            "wasi:sockets/ip-name-lookup@0.2.6".to_string(),
385            "wasi:sockets/tcp-create-socket@0.2.6".to_string(),
386            "wasi:sockets/udp-create-socket@0.2.6".to_string(),
387        ],
388        "wasi:p2" => vec![
389            "wasi:cli/environment@0.2.6".to_string(),
390            "wasi:cli/exit@0.2.6".to_string(),
391            "wasi:cli/stderr@0.2.6".to_string(),
392            "wasi:cli/stdin@0.2.6".to_string(),
393            "wasi:cli/stdout@0.2.6".to_string(),
394            "wasi:cli/terminal-input@0.2.6".to_string(),
395            "wasi:cli/terminal-output@0.2.6".to_string(),
396            "wasi:cli/terminal-stdin@0.2.6".to_string(),
397            "wasi:cli/terminal-stdout@0.2.6".to_string(),
398            "wasi:cli/terminal-stderr@0.2.6".to_string(),
399            "wasi:clocks/monotonic-clock@0.2.6".to_string(),
400            "wasi:clocks/wall-clock@0.2.6".to_string(),
401            "wasi:filesystem/preopens@0.2.6".to_string(),
402            "wasi:filesystem/types@0.2.6".to_string(),
403            "wasi:io/error@0.2.6".to_string(),
404            "wasi:io/poll@0.2.6".to_string(),
405            "wasi:io/streams@0.2.6".to_string(),
406            "wasi:random/random@0.2.6".to_string(),
407            "wasi:random/insecure@0.2.6".to_string(),
408            "wasi:random/insecure-seed@0.2.6".to_string(),
409            "wasi:sockets/tcp@0.2.6".to_string(),
410            "wasi:sockets/udp@0.2.6".to_string(),
411            "wasi:sockets/network@0.2.6".to_string(),
412            "wasi:sockets/instance-network@0.2.6".to_string(),
413            "wasi:sockets/ip-name-lookup@0.2.6".to_string(),
414            "wasi:sockets/tcp-create-socket@0.2.6".to_string(),
415            "wasi:sockets/udp-create-socket@0.2.6".to_string(),
416        ],
417        _ => {
418            vec![]
419        }
420    }
421}
422
423fn is_import_satisfied(import: &str, capability_interfaces: &HashSet<String>) -> bool {
424    // First try exact match for performance
425    if capability_interfaces.contains(import) {
426        return true;
427    }
428
429    if let Some((interface_name, requested_version)) = import.rsplit_once('@')
430        && let Some(requested_semver) = parse_semver(requested_version)
431    {
432        for available in capability_interfaces {
433            if let Some((available_name, available_version)) = available.rsplit_once('@')
434                && interface_name == available_name
435                && let Some(available_semver) = parse_semver(available_version)
436            {
437                // same major, same minor, patch >= requested
438                if available_semver.0 == requested_semver.0
439                    && available_semver.1 == requested_semver.1
440                    && available_semver.2 >= requested_semver.2
441                {
442                    return true;
443                }
444            }
445        }
446    }
447    false
448}
449
450fn parse_semver(version: &str) -> Option<(u32, u32, u32)> {
451    let parts: Vec<&str> = version.split('.').collect();
452    if parts.len() == 3
453        && let (Ok(major), Ok(minor), Ok(patch)) = (
454            parts[0].parse::<u32>(),
455            parts[1].parse::<u32>(),
456            parts[2].parse::<u32>(),
457        )
458    {
459        return Some((major, minor, patch));
460    }
461    None
462}
463
464async fn process_component(
465    node_index: petgraph::graph::NodeIndex,
466    component_graph: &ComponentGraph,
467    component_registry: &ComponentRegistry,
468    capability_registry: &CapabilityRegistry,
469) -> Result<ComponentSpec> {
470    let definition = if let Node::Component(def) = &component_graph[node_index] {
471        def
472    } else {
473        return Err(anyhow::anyhow!(
474            "Internal error: process_component called on a non-component node"
475        ));
476    };
477
478    let mut bytes = read_bytes(&definition.uri).await?;
479
480    let (metadata, mut imports, mut exports, mut functions) =
481        Parser::parse(&bytes).map_err(|e| anyhow::anyhow!("Failed to parse component: {e}"))?;
482
483    let imports_config = imports
484        .iter()
485        .any(|import| import.starts_with("wasi:config/store"));
486
487    if imports_config {
488        bytes = Composer::compose_with_config(&bytes, &definition.config).map_err(|e| {
489            anyhow::anyhow!(
490                "Failed to compose component '{}' with config: {}",
491                definition.name,
492                e
493            )
494        })?;
495
496        let config_keys: Vec<_> = definition.config.keys().collect();
497        tracing::info!(
498            "Composed component '{}' with config: {config_keys:?}",
499            definition.name
500        );
501
502        imports.retain(|import| !import.starts_with("wasi:config/store"));
503    } else if !definition.config.is_empty() {
504        tracing::warn!(
505            "Config provided for component '{}' but component doesn't import wasi:config/store",
506            definition.name
507        );
508    }
509
510    let mut all_capabilities = HashSet::new();
511
512    let dependencies: Vec<_> = component_graph.get_dependencies(node_index).collect();
513    for (dependency_node_index, edge) in &dependencies {
514        let dependency_node = &component_graph[*dependency_node_index];
515        match dependency_node {
516            Node::Component(dependency_def) => {
517                let component_spec = component_registry.get_required_import(
518                    dependency_def,
519                    definition,
520                    &metadata,
521                )?;
522
523                if matches!(edge, Edge::Interceptor(_)) && is_advice_component(&exports) {
524                    // Current component is advice; the dependency is the target.
525                    // Generate a wrapper from the target, plug in advice + target.
526                    let wrapper_bytes = composable_interceptor::create_from_component(
527                        &component_spec.bytes,
528                        &[],
529                    )
530                    .map_err(|e| {
531                        anyhow::anyhow!(
532                            "Failed to generate interceptor wrapper for '{}' targeting '{}': {e}",
533                            definition.name,
534                            dependency_def.name,
535                        )
536                    })?;
537                    let composed_wrapper = Composer::compose_components(&wrapper_bytes, &bytes)
538                        .map_err(|e| {
539                            anyhow::anyhow!(
540                                "Failed composing interceptor wrapper with advice '{}': {e}",
541                                definition.name,
542                            )
543                        })?;
544                    bytes = Composer::compose_components(&composed_wrapper, &component_spec.bytes)
545                        .map_err(|e| {
546                            anyhow::anyhow!(
547                                "Failed composing '{}' with target '{}': {e}",
548                                definition.name,
549                                dependency_def.name,
550                            )
551                        })?;
552
553                    // The composed result should be functionally equivalent to
554                    // the target: same exports/functions and remaining imports
555                    imports = component_spec.imports.clone();
556                    exports = component_spec.exports.clone();
557                    functions = component_spec.functions.clone();
558
559                    tracing::info!(
560                        "Composed advice '{}' with target '{}'",
561                        definition.name,
562                        dependency_def.name
563                    );
564                } else {
565                    bytes = Composer::compose_components(&bytes, &component_spec.bytes).map_err(
566                        |e| {
567                            anyhow::anyhow!(
568                                "Failed composing '{}' with dependency '{}': {e}",
569                                definition.name,
570                                dependency_def.name
571                            )
572                        },
573                    )?;
574                    tracing::info!(
575                        "Composed component '{}' with dependency '{}'",
576                        definition.name,
577                        dependency_def.name
578                    );
579                }
580
581                for export in &component_spec.exports {
582                    imports.retain(|import| import != export);
583                }
584                all_capabilities.extend(component_spec.capabilities.iter().cloned());
585            }
586            Node::Capability(capability_def) => {
587                capability_registry.verify_importable(capability_def, definition)?;
588                all_capabilities.insert(capability_def.name.clone());
589            }
590        }
591    }
592
593    let capability_interfaces: std::collections::HashSet<String> = all_capabilities
594        .iter()
595        .filter_map(|name| capability_registry.get_capability(name))
596        .flat_map(|cap| cap.interfaces.iter().cloned())
597        .collect();
598
599    // Check for imports not satisfied by capabilities
600    let unsatisfied: Vec<_> = imports
601        .iter()
602        .filter(|import| !is_import_satisfied(import, &capability_interfaces))
603        .cloned()
604        .collect();
605
606    if !unsatisfied.is_empty() {
607        return Err(anyhow::anyhow!(
608            "Component '{}' has unsatisfied imports: {:?}",
609            definition.name,
610            unsatisfied
611        ));
612    }
613
614    Ok(ComponentSpec {
615        name: definition.name.clone(),
616        namespace: metadata.namespace,
617        package: metadata.package,
618        bytes: Arc::from(bytes),
619        imports,
620        exports,
621        capabilities: all_capabilities.into_iter().collect(),
622        functions,
623    })
624}
625
626fn is_advice_component(exports: &[String]) -> bool {
627    exports
628        .iter()
629        .any(|e| e.starts_with("modulewise:interceptor/advice"))
630}
631
632async fn read_bytes(uri: &str) -> Result<Vec<u8>> {
633    if let Some(oci_ref) = uri.strip_prefix("oci://") {
634        let client = wasm_pkg_client::oci::client::Client::new(Default::default());
635        let image_ref = oci_ref.parse()?;
636        let auth = oci_client::secrets::RegistryAuth::Anonymous;
637        let media_types = vec!["application/wasm", "application/vnd.wasm.component"];
638
639        let image_data = client.pull(&image_ref, &auth, media_types).await?;
640
641        // Get the component bytes from the first layer
642        if let Some(layer) = image_data.layers.first() {
643            Ok(layer.data.to_vec())
644        } else {
645            Err(anyhow::anyhow!("No layers found in OCI image: {oci_ref}"))
646        }
647    } else {
648        // Handle both file:// and plain paths
649        let path = if let Some(path_str) = uri.strip_prefix("file://") {
650            PathBuf::from(path_str)
651        } else {
652            PathBuf::from(uri)
653        };
654        Ok(std::fs::read(path)?)
655    }
656}