1use std::fs;
2use std::path::{Path, PathBuf};
3
4use anyhow::{Context, Result};
5use clap::{Args, Subcommand};
6use greentic_types::provider::{ProviderDecl, ProviderRuntimeRef};
7use serde_yaml_bw::{self, Mapping, Sequence, Value as YamlValue};
8use walkdir::WalkDir;
9
10use crate::config::PackConfig;
11
12pub const PROVIDER_RUNTIME_WORLD: &str = "greentic:provider/schema-core@1.0.0";
13const PROVIDER_EXTENSION_KEY: &str = "greentic.provider-extension.v1";
14const PROVIDER_EXTENSION_PATH: [&str; 3] = ["greentic", "provider-extension", "v1"];
15
16#[derive(Debug, Subcommand)]
17pub enum AddExtensionCommand {
18 Provider(ProviderArgs),
20}
21
22#[derive(Debug, Args)]
23pub struct ProviderArgs {
24 #[arg(long = "pack-dir", value_name = "DIR")]
26 pub pack_dir: PathBuf,
27
28 #[arg(long)]
30 pub dry_run: bool,
31
32 #[arg(long = "id", value_name = "PROVIDER_ID")]
34 pub provider_id: String,
35
36 #[arg(long = "kind", value_name = "KIND")]
38 pub kind: String,
39
40 #[arg(long, value_name = "TITLE")]
42 pub title: Option<String>,
43
44 #[arg(long, value_name = "DESCRIPTION")]
46 pub description: Option<String>,
47 #[arg(long = "validator-ref", value_name = "VALIDATOR_REF")]
49 pub validator_ref: Option<String>,
50 #[arg(long = "validator-digest", value_name = "DIGEST")]
52 pub validator_digest: Option<String>,
53
54 #[arg(long = "route", value_name = "ROUTE")]
56 pub route: Option<String>,
57
58 #[arg(long = "flow", value_name = "FLOW")]
60 pub flow: Option<String>,
61}
62
63pub fn handle(command: AddExtensionCommand) -> Result<()> {
64 match command {
65 AddExtensionCommand::Provider(args) => handle_provider(args),
66 }
67}
68
69fn handle_provider(args: ProviderArgs) -> Result<()> {
70 edit_pack_dir(&args.pack_dir, &args)?;
71 Ok(())
72}
73
74fn edit_pack_dir(pack_dir: &Path, args: &ProviderArgs) -> Result<()> {
75 let root = normalize_root(pack_dir)?;
76 let pack_yaml = root.join("pack.yaml");
77 let (pack_config, contents) = read_pack_yaml(&pack_yaml)?;
78 let metadata = ProviderMetadata::from_args(args);
79 let updated_yaml = inject_provider_entry(
80 &contents,
81 &build_provider_decl(args, &root)?,
82 metadata,
83 &pack_config.version,
84 )?;
85
86 if args.dry_run {
87 println!("--- dry-run: updated pack.yaml ---");
88 println!("{updated_yaml}");
89 return Ok(());
90 }
91
92 fs::write(&pack_yaml, updated_yaml)
93 .with_context(|| format!("write {}", pack_yaml.display()))?;
94 println!("provider extension updated in {}", pack_yaml.display());
95 Ok(())
96}
97
98fn normalize_root(path: &Path) -> Result<PathBuf> {
99 let canonical = if path.is_absolute() {
100 path.to_path_buf()
101 } else {
102 std::env::current_dir()?.join(path)
103 };
104 Ok(canonical)
105}
106
107fn read_pack_yaml(path: &Path) -> Result<(PackConfig, String)> {
108 let contents = fs::read_to_string(path).with_context(|| format!("read {}", path.display()))?;
109 let config: PackConfig = serde_yaml_bw::from_str(&contents)
110 .with_context(|| format!("{} is not a valid pack.yaml", path.display()))?;
111 Ok((config, contents))
112}
113
114#[derive(Default)]
115struct ProviderMetadata {
116 title: Option<String>,
117 description: Option<String>,
118 route: Option<String>,
119 flow: Option<String>,
120 validator_ref: Option<String>,
121 validator_digest: Option<String>,
122}
123
124impl ProviderMetadata {
125 fn from_args(args: &ProviderArgs) -> Self {
126 Self {
127 title: args.title.clone(),
128 description: args.description.clone(),
129 route: args.route.clone(),
130 flow: args.flow.clone(),
131 validator_ref: args.validator_ref.clone(),
132 validator_digest: args.validator_digest.clone(),
133 }
134 }
135}
136
137fn build_provider_decl(args: &ProviderArgs, root: &Path) -> Result<ProviderDecl> {
138 let config_ref = find_config_schema_ref(root, &args.kind, &args.provider_id);
139 let capabilities = vec![args.kind.clone()];
140 let ops = match args.kind.as_str() {
141 "messaging" => vec!["send".to_string(), "receive".to_string()],
142 "events" => vec!["emit".to_string(), "subscribe".to_string()],
143 _ => vec!["run".to_string()],
144 };
145
146 Ok(ProviderDecl {
147 provider_type: args.provider_id.clone(),
148 capabilities,
149 ops,
150 config_schema_ref: config_ref,
151 state_schema_ref: None,
152 runtime: ProviderRuntimeRef {
153 component_ref: args.provider_id.clone(),
154 export: "provider".to_string(),
155 world: PROVIDER_RUNTIME_WORLD.to_string(),
156 },
157 docs_ref: None,
158 })
159}
160
161fn find_config_schema_ref(root: &Path, kind: &str, provider_id: &str) -> String {
162 let schemas = root.join("schemas");
163 if schemas.exists() {
164 let provider_kw = provider_id.to_ascii_lowercase();
165 for entry in WalkDir::new(&schemas)
166 .into_iter()
167 .filter_map(Result::ok)
168 .filter(|entry| entry.file_type().is_file())
169 {
170 let name = entry.file_name().to_string_lossy().to_ascii_lowercase();
171 if name.contains(&provider_kw)
172 && name.contains("config.schema")
173 && let Ok(rel) = entry.path().strip_prefix(root)
174 {
175 return rel
176 .components()
177 .map(|comp| comp.as_os_str().to_string_lossy())
178 .collect::<Vec<_>>()
179 .join("/");
180 }
181 }
182 }
183
184 format!("schemas/{}/{}/config.schema.json", kind, provider_id)
185}
186
187fn inject_provider_entry(
188 contents: &str,
189 provider: &ProviderDecl,
190 metadata: ProviderMetadata,
191 version: &str,
192) -> Result<String> {
193 let mut document: YamlValue =
194 serde_yaml_bw::from_str(contents).context("parse pack.yaml for extension merge")?;
195 let mapping = document
196 .as_mapping_mut()
197 .ok_or_else(|| anyhow::anyhow!("pack.yaml root must be a mapping"))?;
198 let extensions = mapping
199 .entry(yaml_key("extensions"))
200 .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
201 let extensions_map = extensions
202 .as_mapping_mut()
203 .ok_or_else(|| anyhow::anyhow!("extensions must be a mapping"))?;
204
205 let location = detect_extension_location(extensions_map);
206 let extension_map = resolve_extension_map(extensions_map, &location)
207 .context("locate provider extension slot")?;
208 extension_map
209 .entry(yaml_key("kind"))
210 .or_insert_with(|| YamlValue::String(PROVIDER_EXTENSION_KEY.to_string(), None));
211 extension_map
212 .entry(yaml_key("version"))
213 .or_insert_with(|| YamlValue::String(version.to_string(), None));
214
215 let inline = extension_map
216 .entry(yaml_key("inline"))
217 .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
218 let inline_map = match inline {
219 YamlValue::Mapping(map) => map,
220 _ => {
221 *inline = YamlValue::Mapping(Mapping::new());
222 inline.as_mapping_mut().unwrap()
223 }
224 };
225
226 let providers_key = yaml_key("providers");
227 let providers_entry = inline_map
228 .entry(providers_key.clone())
229 .or_insert_with(|| YamlValue::Sequence(Sequence::default()));
230 let providers = match providers_entry {
231 YamlValue::Sequence(seq) => seq,
232 _ => {
233 *providers_entry = YamlValue::Sequence(Sequence::default());
234 providers_entry.as_sequence_mut().unwrap()
235 }
236 };
237
238 let mut provider_value =
239 serde_yaml_bw::to_value(provider).context("serialize provider declaration")?;
240 if let Some(map) = provider_value.as_mapping_mut() {
241 if let Some(title) = metadata.title {
242 map.insert(yaml_key("title"), YamlValue::String(title, None));
243 }
244 if let Some(desc) = metadata.description {
245 map.insert(yaml_key("description"), YamlValue::String(desc, None));
246 }
247 if let Some(route) = metadata.route {
248 map.insert(yaml_key("route"), YamlValue::String(route, None));
249 }
250 if let Some(flow) = metadata.flow {
251 map.insert(yaml_key("flow"), YamlValue::String(flow, None));
252 }
253 if let Some(validator_ref) = metadata.validator_ref {
254 map.insert(
255 yaml_key("validator_ref"),
256 YamlValue::String(validator_ref, None),
257 );
258 }
259 if let Some(validator_digest) = metadata.validator_digest {
260 map.insert(
261 yaml_key("validator_digest"),
262 YamlValue::String(validator_digest, None),
263 );
264 }
265 }
266 upsert_provider(providers, provider_value, &provider.provider_type);
267
268 serde_yaml_bw::to_string(&document).context("serialize updated pack.yaml")
269}
270
271fn upsert_provider(providers: &mut Vec<YamlValue>, provider: YamlValue, provider_id: &str) {
272 for entry in providers.iter_mut() {
273 if entry_matches_provider(entry, provider_id) {
274 *entry = provider;
275 return;
276 }
277 }
278 providers.push(provider);
279}
280
281fn entry_matches_provider(entry: &YamlValue, provider_id: &str) -> bool {
282 let provider_key = yaml_key("provider_type");
283 if let YamlValue::Mapping(map) = entry
284 && let Some(YamlValue::String(value, _)) = map.get(&provider_key)
285 {
286 return value == provider_id;
287 }
288 false
289}
290
291enum ExtensionLocation {
292 Flat,
293 Nested,
294}
295
296fn detect_extension_location(extensions: &Mapping) -> ExtensionLocation {
297 let provider_key = yaml_key(PROVIDER_EXTENSION_KEY);
298 if extensions.contains_key(&provider_key) {
299 return ExtensionLocation::Flat;
300 }
301 let mut current = extensions;
302 for segment in PROVIDER_EXTENSION_PATH
303 .iter()
304 .take(PROVIDER_EXTENSION_PATH.len() - 1)
305 {
306 let key = yaml_key(*segment);
307 if let Some(next) = current.get(&key).and_then(YamlValue::as_mapping) {
308 current = next;
309 } else {
310 return ExtensionLocation::Flat;
311 }
312 }
313 ExtensionLocation::Nested
314}
315
316fn resolve_extension_map<'a>(
317 extensions: &'a mut Mapping,
318 location: &ExtensionLocation,
319) -> Result<&'a mut Mapping> {
320 match location {
321 ExtensionLocation::Flat => {
322 let key = yaml_key(PROVIDER_EXTENSION_KEY);
323 let slot = extensions
324 .entry(key)
325 .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
326 slot.as_mapping_mut()
327 .ok_or_else(|| anyhow::anyhow!("extension slot must be a mapping"))
328 }
329 ExtensionLocation::Nested => {
330 let mut current_map = extensions;
331 for segment in PROVIDER_EXTENSION_PATH.iter() {
332 let key = yaml_key(*segment);
333 let entry = current_map
334 .entry(key)
335 .or_insert_with(|| YamlValue::Mapping(Mapping::new()));
336 current_map = entry
337 .as_mapping_mut()
338 .ok_or_else(|| anyhow::anyhow!("nested extension value must be a mapping"))?;
339 }
340 Ok(current_map)
341 }
342 }
343}
344
345fn yaml_key(value: impl Into<String>) -> YamlValue {
346 YamlValue::String(value.into(), None)
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use serde_yaml_bw;
353
354 fn sample_flat_yaml() -> String {
355 r#"pack_id: demo
356version: 0.1.0
357extensions:
358 greentic.provider-extension.v1:
359 kind: greentic.provider-extension.v1
360 version: 0.1.0
361 inline:
362 providers:
363 - provider_type: existing
364 capabilities: [messaging]
365 ops: [send]
366 config_schema_ref: schemas/messaging/existing/config.schema.json
367 runtime:
368 component_ref: existing
369 export: provider
370 world: greentic:provider/schema-core@1.0.0
371"#
372 .to_string()
373 }
374
375 fn sample_nested_yaml() -> String {
376 r#"pack_id: demo
377version: 0.1.0
378extensions:
379 greentic:
380 provider-extension:
381 v1:
382 inline:
383 providers: []
384"#
385 .to_string()
386 }
387
388 fn provider_decl() -> ProviderDecl {
389 ProviderDecl {
390 provider_type: "demo.provider".to_string(),
391 capabilities: vec!["messaging".to_string()],
392 ops: vec!["send".to_string()],
393 config_schema_ref: "schemas/messaging/demo/config.schema.json".to_string(),
394 state_schema_ref: None,
395 runtime: ProviderRuntimeRef {
396 component_ref: "demo.provider".to_string(),
397 export: "provider".to_string(),
398 world: PROVIDER_RUNTIME_WORLD.to_string(),
399 },
400 docs_ref: None,
401 }
402 }
403
404 #[test]
405 fn inject_flat_extension() {
406 let contents = sample_flat_yaml();
407 let updated = inject_provider_entry(
408 &contents,
409 &provider_decl(),
410 ProviderMetadata::default(),
411 "0.1.0",
412 )
413 .unwrap();
414 let doc: YamlValue = serde_yaml_bw::from_str(&updated).unwrap();
415
416 let providers = doc["extensions"]["greentic.provider-extension.v1"]["inline"]["providers"]
417 .as_sequence()
418 .expect("providers list");
419 assert!(
420 providers
421 .iter()
422 .any(|entry| entry_matches_provider(entry, "demo.provider"))
423 );
424 }
425
426 #[test]
427 fn inject_nested_extension() {
428 let contents = sample_nested_yaml();
429 let updated = inject_provider_entry(
430 &contents,
431 &provider_decl(),
432 ProviderMetadata::default(),
433 "0.1.0",
434 )
435 .unwrap();
436 let doc: YamlValue = serde_yaml_bw::from_str(&updated).unwrap();
437
438 assert!(
439 doc["extensions"]["greentic"]["provider-extension"]["v1"]["inline"]["providers"]
440 .as_sequence()
441 .unwrap()
442 .iter()
443 .any(|entry| entry_matches_provider(entry, "demo.provider"))
444 );
445 }
446}