Skip to main content

packc/cli/
add_extension.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use clap::{Args, Subcommand};
6use greentic_types::provider::{ProviderDecl, ProviderRuntimeRef};
7use serde_yaml_bw::{self, Mapping, Sequence, Value as YamlValue};
8use walkdir::WalkDir;
9
10use crate::config::PackConfig;
11
12pub const PROVIDER_RUNTIME_WORLD: &str = "greentic:provider/schema-core@1.0.0";
13const PROVIDER_EXTENSION_KEY: &str = "greentic.provider-extension.v1";
14const PROVIDER_EXTENSION_PATH: [&str; 3] = ["greentic", "provider-extension", "v1"];
15
16#[derive(Debug, Subcommand)]
17pub enum AddExtensionCommand {
18    /// Add or update the provider extension entry.
19    Provider(ProviderArgs),
20}
21
22#[derive(Debug, Args)]
23pub struct ProviderArgs {
24    /// Path to a pack source directory containing pack.yaml.
25    #[arg(long = "pack-dir", value_name = "DIR")]
26    pub pack_dir: PathBuf,
27
28    /// Print what would change without writing files.
29    #[arg(long)]
30    pub dry_run: bool,
31
32    /// Provider identifier to add or update.
33    #[arg(long = "id", value_name = "PROVIDER_ID")]
34    pub provider_id: String,
35
36    /// Provider kind (e.g. messaging, events).
37    #[arg(long = "kind", value_name = "KIND")]
38    pub kind: String,
39
40    /// Optional provider title to store in metadata.
41    #[arg(long, value_name = "TITLE")]
42    pub title: Option<String>,
43
44    /// Optional description to store in metadata.
45    #[arg(long, value_name = "DESCRIPTION")]
46    pub description: Option<String>,
47    /// Optional validator reference for generated provider metadata.
48    #[arg(long = "validator-ref", value_name = "VALIDATOR_REF")]
49    pub validator_ref: Option<String>,
50    /// Optional validator digest for strict pinning.
51    #[arg(long = "validator-digest", value_name = "DIGEST")]
52    pub validator_digest: Option<String>,
53
54    /// Convenience route hint (if schema supports route<->flow binding).
55    #[arg(long = "route", value_name = "ROUTE")]
56    pub route: Option<String>,
57
58    /// Convenience flow hint (if schema supports route<->flow binding).
59    #[arg(long = "flow", value_name = "FLOW")]
60    pub flow: Option<String>,
61}
62
63pub fn handle(command: AddExtensionCommand) -> Result<()> {
64    match command {
65        AddExtensionCommand::Provider(args) => handle_provider(args),
66    }
67}
68
69fn handle_provider(args: ProviderArgs) -> Result<()> {
70    edit_pack_dir(&args.pack_dir, &args)?;
71    Ok(())
72}
73
74fn edit_pack_dir(pack_dir: &Path, args: &ProviderArgs) -> Result<()> {
75    let root = normalize_root(pack_dir)?;
76    let pack_yaml = root.join("pack.yaml");
77    let (pack_config, contents) = read_pack_yaml(&pack_yaml)?;
78    let metadata = ProviderMetadata::from_args(args);
79    let updated_yaml = inject_provider_entry(
80        &contents,
81        &build_provider_decl(args, &root)?,
82        metadata,
83        &pack_config.version,
84    )?;
85
86    if args.dry_run {
87        println!("--- dry-run: updated pack.yaml ---");
88        println!("{updated_yaml}");
89        return Ok(());
90    }
91
92    fs::write(&pack_yaml, updated_yaml)
93        .with_context(|| format!("write {}", pack_yaml.display()))?;
94    println!("provider extension updated in {}", pack_yaml.display());
95    Ok(())
96}
97
98fn normalize_root(path: &Path) -> Result<PathBuf> {
99    let canonical = if path.is_absolute() {
100        path.to_path_buf()
101    } else {
102        std::env::current_dir()?.join(path)
103    };
104    Ok(canonical)
105}
106
107fn read_pack_yaml(path: &Path) -> Result<(PackConfig, String)> {
108    let contents = fs::read_to_string(path).with_context(|| format!("read {}", path.display()))?;
109    let config: PackConfig = serde_yaml_bw::from_str(&contents)
110        .with_context(|| format!("{} is not a valid pack.yaml", path.display()))?;
111    Ok((config, contents))
112}
113
114#[derive(Default)]
115struct ProviderMetadata {
116    title: Option<String>,
117    description: Option<String>,
118    route: Option<String>,
119    flow: Option<String>,
120    validator_ref: Option<String>,
121    validator_digest: Option<String>,
122}
123
124impl ProviderMetadata {
125    fn from_args(args: &ProviderArgs) -> Self {
126        Self {
127            title: args.title.clone(),
128            description: args.description.clone(),
129            route: args.route.clone(),
130            flow: args.flow.clone(),
131            validator_ref: args.validator_ref.clone(),
132            validator_digest: args.validator_digest.clone(),
133        }
134    }
135}
136
137fn build_provider_decl(args: &ProviderArgs, root: &Path) -> Result<ProviderDecl> {
138    let config_ref = find_config_schema_ref(root, &args.kind, &args.provider_id);
139    let capabilities = vec![args.kind.clone()];
140    let ops = match args.kind.as_str() {
141        "messaging" => vec!["send".to_string(), "receive".to_string()],
142        "events" => vec!["emit".to_string(), "subscribe".to_string()],
143        _ => vec!["run".to_string()],
144    };
145
146    Ok(ProviderDecl {
147        provider_type: args.provider_id.clone(),
148        capabilities,
149        ops,
150        config_schema_ref: config_ref,
151        state_schema_ref: None,
152        runtime: ProviderRuntimeRef {
153            component_ref: args.provider_id.clone(),
154            export: "provider".to_string(),
155            world: PROVIDER_RUNTIME_WORLD.to_string(),
156        },
157        docs_ref: None,
158    })
159}
160
161fn find_config_schema_ref(root: &Path, kind: &str, provider_id: &str) -> String {
162    let schemas = root.join("schemas");
163    if schemas.exists() {
164        let provider_kw = provider_id.to_ascii_lowercase();
165        for entry in WalkDir::new(&schemas)
166            .into_iter()
167            .filter_map(Result::ok)
168            .filter(|entry| entry.file_type().is_file())
169        {
170            let name = entry.file_name().to_string_lossy().to_ascii_lowercase();
171            if name.contains(&provider_kw)
172                && name.contains("config.schema")
173                && let Ok(rel) = entry.path().strip_prefix(root)
174            {
175                return rel
176                    .components()
177                    .map(|comp| comp.as_os_str().to_string_lossy())
178                    .collect::<Vec<_>>()
179                    .join("/");
180            }
181        }
182    }
183
184    format!("schemas/{}/{}/config.schema.json", kind, provider_id)
185}
186
187fn inject_provider_entry(
188    contents: &str,
189    provider: &ProviderDecl,
190    metadata: ProviderMetadata,
191    version: &str,
192) -> Result<String> {
193    let mut document: YamlValue =
194        serde_yaml_bw::from_str(contents).context("parse pack.yaml for extension merge")?;
195    let mapping = document
196        .as_mapping_mut()
197        .ok_or_else(|| anyhow::anyhow!("pack.yaml root must be a mapping"))?;
198    let extensions = mapping
199        .entry(yaml_key("extensions"))
200        .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
201    let extensions_map = extensions
202        .as_mapping_mut()
203        .ok_or_else(|| anyhow::anyhow!("extensions must be a mapping"))?;
204
205    let location = detect_extension_location(extensions_map);
206    let extension_map = resolve_extension_map(extensions_map, &location)
207        .context("locate provider extension slot")?;
208    extension_map
209        .entry(yaml_key("kind"))
210        .or_insert_with(|| YamlValue::String(PROVIDER_EXTENSION_KEY.to_string(), None));
211    extension_map
212        .entry(yaml_key("version"))
213        .or_insert_with(|| YamlValue::String(version.to_string(), None));
214
215    let inline = extension_map
216        .entry(yaml_key("inline"))
217        .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
218    let inline_map = match inline {
219        YamlValue::Mapping(map) => map,
220        _ => {
221            *inline = YamlValue::Mapping(Mapping::new());
222            inline.as_mapping_mut().unwrap()
223        }
224    };
225
226    let providers_key = yaml_key("providers");
227    let providers_entry = inline_map
228        .entry(providers_key.clone())
229        .or_insert_with(|| YamlValue::Sequence(Sequence::default()));
230    let providers = match providers_entry {
231        YamlValue::Sequence(seq) => seq,
232        _ => {
233            *providers_entry = YamlValue::Sequence(Sequence::default());
234            providers_entry.as_sequence_mut().unwrap()
235        }
236    };
237
238    let mut provider_value =
239        serde_yaml_bw::to_value(provider).context("serialize provider declaration")?;
240    if let Some(map) = provider_value.as_mapping_mut() {
241        if let Some(title) = metadata.title {
242            map.insert(yaml_key("title"), YamlValue::String(title, None));
243        }
244        if let Some(desc) = metadata.description {
245            map.insert(yaml_key("description"), YamlValue::String(desc, None));
246        }
247        if let Some(route) = metadata.route {
248            map.insert(yaml_key("route"), YamlValue::String(route, None));
249        }
250        if let Some(flow) = metadata.flow {
251            map.insert(yaml_key("flow"), YamlValue::String(flow, None));
252        }
253        if let Some(validator_ref) = metadata.validator_ref {
254            map.insert(
255                yaml_key("validator_ref"),
256                YamlValue::String(validator_ref, None),
257            );
258        }
259        if let Some(validator_digest) = metadata.validator_digest {
260            map.insert(
261                yaml_key("validator_digest"),
262                YamlValue::String(validator_digest, None),
263            );
264        }
265    }
266    upsert_provider(providers, provider_value, &provider.provider_type);
267
268    serde_yaml_bw::to_string(&document).context("serialize updated pack.yaml")
269}
270
271fn upsert_provider(providers: &mut Vec<YamlValue>, provider: YamlValue, provider_id: &str) {
272    for entry in providers.iter_mut() {
273        if entry_matches_provider(entry, provider_id) {
274            *entry = provider;
275            return;
276        }
277    }
278    providers.push(provider);
279}
280
281fn entry_matches_provider(entry: &YamlValue, provider_id: &str) -> bool {
282    let provider_key = yaml_key("provider_type");
283    if let YamlValue::Mapping(map) = entry
284        && let Some(YamlValue::String(value, _)) = map.get(&provider_key)
285    {
286        return value == provider_id;
287    }
288    false
289}
290
291enum ExtensionLocation {
292    Flat,
293    Nested,
294}
295
296fn detect_extension_location(extensions: &Mapping) -> ExtensionLocation {
297    let provider_key = yaml_key(PROVIDER_EXTENSION_KEY);
298    if extensions.contains_key(&provider_key) {
299        return ExtensionLocation::Flat;
300    }
301    let mut current = extensions;
302    for segment in PROVIDER_EXTENSION_PATH
303        .iter()
304        .take(PROVIDER_EXTENSION_PATH.len() - 1)
305    {
306        let key = yaml_key(*segment);
307        if let Some(next) = current.get(&key).and_then(YamlValue::as_mapping) {
308            current = next;
309        } else {
310            return ExtensionLocation::Flat;
311        }
312    }
313    ExtensionLocation::Nested
314}
315
316fn resolve_extension_map<'a>(
317    extensions: &'a mut Mapping,
318    location: &ExtensionLocation,
319) -> Result<&'a mut Mapping> {
320    match location {
321        ExtensionLocation::Flat => {
322            let key = yaml_key(PROVIDER_EXTENSION_KEY);
323            let slot = extensions
324                .entry(key)
325                .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
326            slot.as_mapping_mut()
327                .ok_or_else(|| anyhow::anyhow!("extension slot must be a mapping"))
328        }
329        ExtensionLocation::Nested => {
330            let mut current_map = extensions;
331            for segment in PROVIDER_EXTENSION_PATH.iter() {
332                let key = yaml_key(*segment);
333                let entry = current_map
334                    .entry(key)
335                    .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
336                current_map = entry
337                    .as_mapping_mut()
338                    .ok_or_else(|| anyhow::anyhow!("nested extension value must be a mapping"))?;
339            }
340            Ok(current_map)
341        }
342    }
343}
344
345fn yaml_key(value: impl Into<String>) -> YamlValue {
346    YamlValue::String(value.into(), None)
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use serde_yaml_bw;
353
354    fn sample_flat_yaml() -> String {
355        r#"pack_id: demo
356version: 0.1.0
357extensions:
358  greentic.provider-extension.v1:
359    kind: greentic.provider-extension.v1
360    version: 0.1.0
361    inline:
362      providers:
363        - provider_type: existing
364          capabilities: [messaging]
365          ops: [send]
366          config_schema_ref: schemas/messaging/existing/config.schema.json
367          runtime:
368            component_ref: existing
369            export: provider
370            world: greentic:provider/schema-core@1.0.0
371"#
372        .to_string()
373    }
374
375    fn sample_nested_yaml() -> String {
376        r#"pack_id: demo
377version: 0.1.0
378extensions:
379  greentic:
380    provider-extension:
381      v1:
382        inline:
383          providers: []
384"#
385        .to_string()
386    }
387
388    fn provider_decl() -> ProviderDecl {
389        ProviderDecl {
390            provider_type: "demo.provider".to_string(),
391            capabilities: vec!["messaging".to_string()],
392            ops: vec!["send".to_string()],
393            config_schema_ref: "schemas/messaging/demo/config.schema.json".to_string(),
394            state_schema_ref: None,
395            runtime: ProviderRuntimeRef {
396                component_ref: "demo.provider".to_string(),
397                export: "provider".to_string(),
398                world: PROVIDER_RUNTIME_WORLD.to_string(),
399            },
400            docs_ref: None,
401        }
402    }
403
404    #[test]
405    fn inject_flat_extension() {
406        let contents = sample_flat_yaml();
407        let updated = inject_provider_entry(
408            &contents,
409            &provider_decl(),
410            ProviderMetadata::default(),
411            "0.1.0",
412        )
413        .unwrap();
414        let doc: YamlValue = serde_yaml_bw::from_str(&updated).unwrap();
415
416        let providers = doc["extensions"]["greentic.provider-extension.v1"]["inline"]["providers"]
417            .as_sequence()
418            .expect("providers list");
419        assert!(
420            providers
421                .iter()
422                .any(|entry| entry_matches_provider(entry, "demo.provider"))
423        );
424    }
425
426    #[test]
427    fn inject_nested_extension() {
428        let contents = sample_nested_yaml();
429        let updated = inject_provider_entry(
430            &contents,
431            &provider_decl(),
432            ProviderMetadata::default(),
433            "0.1.0",
434        )
435        .unwrap();
436        let doc: YamlValue = serde_yaml_bw::from_str(&updated).unwrap();
437
438        assert!(
439            doc["extensions"]["greentic"]["provider-extension"]["v1"]["inline"]["providers"]
440                .as_sequence()
441                .unwrap()
442                .iter()
443                .any(|entry| entry_matches_provider(entry, "demo.provider"))
444        );
445    }
446}