1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use super::hooks::events::HookKind;
7use super::permissions::PermissionSet;
8
9pub const CURRENT_EXTENSION_PROTOCOL_VERSION: u32 = 1;
11
12fn default_protocol_version() -> u32 {
13 CURRENT_EXTENSION_PROTOCOL_VERSION
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ExtensionManifest {
19 #[serde(default = "default_protocol_version")]
21 pub protocol_version: u32,
22 pub runtime: ExtensionRuntime,
24 pub command: String,
26 #[serde(default)]
34 pub setup: Option<String>,
35 #[serde(default)]
44 pub prebuilt: std::collections::HashMap<String, PrebuiltAsset>,
45 #[serde(default)]
47 pub args: Vec<String>,
48 #[serde(default)]
50 pub permissions: Vec<String>,
51 #[serde(default)]
53 pub hooks: Vec<HookSubscription>,
54 #[serde(default)]
56 pub config: Vec<ExtensionConfigEntry>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub struct PrebuiltAsset {
69 pub url: String,
73 pub sha256: String,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
81#[serde(rename_all = "lowercase")]
82pub enum ExtensionConfigValueKind {
83 String,
84 Bool,
85 Number,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
89pub struct ExtensionConfigEntry {
90 pub key: String,
91 #[serde(default, rename = "type")]
92 pub value_type: Option<ExtensionConfigValueKind>,
93 #[serde(default)]
94 pub description: Option<String>,
95 #[serde(default)]
96 pub required: bool,
97 #[serde(default)]
98 pub default: Option<Value>,
99 #[serde(default)]
100 pub secret_env: Option<String>,
101}
102
103#[derive(Debug, Clone)]
105pub struct ValidatedExtensionManifest {
106 pub permissions: PermissionSet,
107 pub subscriptions: Vec<(HookKind, Option<String>, Option<HookMatcher>)>,
108}
109
110impl ExtensionManifest {
111 pub fn validate(&self, id: &str) -> Result<ValidatedExtensionManifest, String> {
113 if self.protocol_version != CURRENT_EXTENSION_PROTOCOL_VERSION {
114 return Err(format!(
115 "Extension '{}' uses unsupported protocol_version {} (supported: {})",
116 id, self.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION,
117 ));
118 }
119
120 if self.command.trim().is_empty() {
121 return Err(format!("Extension '{}' has empty command", id));
122 }
123
124 let has_capability_permission = self.permissions.iter().any(|permission| {
125 matches!(
126 permission.as_str(),
127 "tools.register" | "providers.register" | "memory.read" | "memory.write"
128 | "config.write" | "config.subscribe" | "audio.input" | "audio.output"
129 )
130 });
131 if self.hooks.is_empty() && !has_capability_permission {
132 return Err(format!("Extension '{}' must subscribe to at least one hook or request a registration permission", id));
133 }
134
135 let permissions = PermissionSet::try_from_strings(&self.permissions)?;
136 let mut subscriptions = Vec::with_capacity(self.hooks.len());
137 for sub in &self.hooks {
138 let kind = HookKind::from_str(&sub.hook).ok_or_else(|| {
139 format!("Unknown hook kind: '{}' in extension '{}'", sub.hook, id)
140 })?;
141 if !permissions.allows_hook(kind) {
142 return Err(format!(
143 "Extension '{}' lacks permission '{}' required for hook '{}'",
144 id,
145 kind.required_permission().as_str(),
146 kind.as_str(),
147 ));
148 }
149 if sub.tool.is_some() && !kind.allows_tool_filter() {
150 return Err(format!(
151 "Extension '{}' hook '{}' does not allow a tool filter",
152 id,
153 kind.as_str(),
154 ));
155 }
156 subscriptions.push((kind, sub.tool.clone(), sub.matcher.clone()));
157 }
158
159 Ok(ValidatedExtensionManifest {
160 permissions,
161 subscriptions,
162 })
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
168#[serde(rename_all = "lowercase")]
169pub enum ExtensionRuntime {
170 Process,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct HookSubscription {
176 pub hook: String,
178 #[serde(default)]
180 pub tool: Option<String>,
181 #[serde(default, rename = "match")]
183 pub matcher: Option<HookMatcher>,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187#[serde(deny_unknown_fields)]
188pub struct HookMatcher {
189 #[serde(default)]
190 pub input_contains: Option<String>,
191 #[serde(default)]
192 pub input_equals: Option<serde_json::Value>,
193}
194
195impl HookMatcher {
196 pub const SUPPORTED_KEYS: &'static [&'static str] = &["input_contains", "input_equals"];
197
198 pub fn matches(&self, event: &crate::extensions::hooks::events::HookEvent) -> bool {
199 let input = event.tool_input.as_ref().unwrap_or(&serde_json::Value::Null);
200 if let Some(expected) = &self.input_equals {
201 if input != expected {
202 return false;
203 }
204 }
205 if let Some(needle) = &self.input_contains {
206 let haystack = serde_json::to_string(input).unwrap_or_default();
207 if !haystack.contains(needle) {
208 return false;
209 }
210 }
211 true
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
222 fn deserialize_full_manifest() {
223 let json = r#"{
224 "protocol_version": 1,
225 "runtime": "process",
226 "command": "/usr/bin/my-ext",
227 "args": ["--port", "0"],
228 "permissions": ["tools.intercept", "session.lifecycle"],
229 "hooks": [
230 {"hook": "before_tool_call", "tool": "bash"},
231 {"hook": "on_session_start"}
232 ]
233 }"#;
234
235 let m: ExtensionManifest = serde_json::from_str(json).unwrap();
236 assert_eq!(m.protocol_version, 1);
237 assert_eq!(m.runtime, ExtensionRuntime::Process);
238 assert_eq!(m.command, "/usr/bin/my-ext");
239 assert_eq!(m.args, vec!["--port", "0"]);
240 assert_eq!(m.permissions, vec!["tools.intercept", "session.lifecycle"]);
241 assert_eq!(m.hooks.len(), 2);
242 assert_eq!(m.hooks[0].hook, "before_tool_call");
243 assert_eq!(m.hooks[0].tool.as_deref(), Some("bash"));
244 assert_eq!(m.hooks[1].hook, "on_session_start");
245 assert_eq!(m.hooks[1].tool, None);
246 }
247
248 #[test]
251 fn missing_optional_fields_get_defaults() {
252 let json = r#"{
253 "runtime": "process",
254 "command": "my-ext"
255 }"#;
256
257 let m: ExtensionManifest = serde_json::from_str(json).unwrap();
258 assert_eq!(m.protocol_version, CURRENT_EXTENSION_PROTOCOL_VERSION);
259 assert_eq!(m.runtime, ExtensionRuntime::Process);
260 assert_eq!(m.command, "my-ext");
261 assert!(m.args.is_empty(), "args should default to []");
262 assert!(m.permissions.is_empty(), "permissions should default to []");
263 assert!(m.hooks.is_empty(), "hooks should default to []");
264 }
265
266 #[test]
267 fn extension_config_entry_deserializes_optional_type() {
268 let json = r#"{
269 "key": "backend",
270 "type": "string",
271 "description": "Backend selector",
272 "default": "auto"
273 }"#;
274
275 let entry: ExtensionConfigEntry = serde_json::from_str(json).unwrap();
276 assert_eq!(entry.key, "backend");
277 assert_eq!(entry.value_type, Some(ExtensionConfigValueKind::String));
278 assert_eq!(entry.description.as_deref(), Some("Backend selector"));
279 assert_eq!(entry.default, Some(serde_json::Value::String("auto".to_string())));
280 }
281
282 #[test]
283 fn extension_config_entry_omitted_type_is_none() {
284 let json = r#"{"key": "backend"}"#;
285
286 let entry: ExtensionConfigEntry = serde_json::from_str(json).unwrap();
287 assert_eq!(entry.key, "backend");
288 assert_eq!(entry.value_type, None);
289 }
290
291 #[test]
292 fn hook_subscription_tool_defaults_to_none() {
293 let json = r#"{
294 "runtime": "process",
295 "command": "ext",
296 "hooks": [{"hook": "on_session_start"}]
297 }"#;
298
299 let m: ExtensionManifest = serde_json::from_str(json).unwrap();
300 assert_eq!(m.hooks[0].tool, None);
301 }
302
303 #[test]
306 fn missing_command_fails() {
307 let json = r#"{"runtime": "process"}"#;
308 let result: Result<ExtensionManifest, _> = serde_json::from_str(json);
309 assert!(result.is_err(), "command is required");
310 }
311
312 #[test]
313 fn missing_runtime_fails() {
314 let json = r#"{"command": "my-ext"}"#;
315 let result: Result<ExtensionManifest, _> = serde_json::from_str(json);
316 assert!(result.is_err(), "runtime is required");
317 }
318
319 #[test]
322 fn unknown_runtime_type_errors() {
323 let json = r#"{
324 "runtime": "wasm",
325 "command": "my-ext"
326 }"#;
327 let result: Result<ExtensionManifest, _> = serde_json::from_str(json);
328 assert!(result.is_err(), "unknown runtime 'wasm' should be rejected");
329 }
330
331 #[test]
332 fn runtime_is_case_sensitive() {
333 let json = r#"{"runtime": "Process", "command": "ext"}"#;
334 let result: Result<ExtensionManifest, _> = serde_json::from_str(json);
335 assert!(result.is_err(), "runtime matching is lowercase-only");
336 }
337
338 #[test]
339 fn validate_rejects_unsupported_protocol_version() {
340 let manifest = ExtensionManifest {
341 protocol_version: 999,
342 runtime: ExtensionRuntime::Process,
343 command: "ext".to_string(),
344 setup: None,
345 prebuilt: ::std::collections::HashMap::new(),
346 args: vec![],
347 permissions: vec!["tools.intercept".to_string()],
348 hooks: vec![HookSubscription {
349 hook: "before_tool_call".to_string(),
350 tool: None,
351 matcher: None,
352 }],
353 config: vec![],
354 };
355
356 let err = manifest.validate("bad-version").unwrap_err();
357 assert!(err.contains("unsupported protocol_version 999"));
358 }
359
360 #[test]
361 fn validate_allows_hookless_provider_registration_extensions() {
362 let manifest = ExtensionManifest {
363 protocol_version: 1,
364 runtime: ExtensionRuntime::Process,
365 command: "ext".to_string(),
366 setup: None,
367 prebuilt: ::std::collections::HashMap::new(),
368 args: vec![],
369 permissions: vec!["providers.register".to_string()],
370 hooks: vec![],
371 config: vec![],
372 };
373
374 manifest.validate("provider-only").unwrap();
375 }
376
377 #[test]
378 fn validate_rejects_tool_filter_on_non_tool_hook() {
379 let manifest = ExtensionManifest {
380 protocol_version: 1,
381 runtime: ExtensionRuntime::Process,
382 command: "ext".to_string(),
383 setup: None,
384 prebuilt: ::std::collections::HashMap::new(),
385 args: vec![],
386 permissions: vec!["session.lifecycle".to_string()],
387 hooks: vec![HookSubscription {
388 hook: "on_session_start".to_string(),
389 tool: Some("bash".to_string()),
390 matcher: None,
391 }],
392 config: vec![],
393 };
394
395 let err = manifest.validate("bad-filter").unwrap_err();
396 assert!(err.contains("does not allow a tool filter"));
397 }
398
399 #[test]
402 fn serialize_roundtrip() {
403 let original = ExtensionManifest {
404 protocol_version: 1,
405 runtime: ExtensionRuntime::Process,
406 command: "my-ext".to_string(),
407 setup: None,
408 prebuilt: ::std::collections::HashMap::new(),
409 args: vec!["--verbose".to_string()],
410 permissions: vec!["tools.intercept".to_string()],
411 hooks: vec![HookSubscription {
412 hook: "before_tool_call".to_string(),
413 tool: Some("bash".to_string()),
414 matcher: None,
415 }],
416 config: vec![],
417 };
418
419 let json = serde_json::to_string(&original).unwrap();
420 let restored: ExtensionManifest = serde_json::from_str(&json).unwrap();
421
422 assert_eq!(restored.protocol_version, original.protocol_version);
423 assert_eq!(restored.runtime, original.runtime);
424 assert_eq!(restored.command, original.command);
425 assert_eq!(restored.args, original.args);
426 assert_eq!(restored.permissions, original.permissions);
427 assert_eq!(restored.hooks[0].hook, original.hooks[0].hook);
428 assert_eq!(restored.hooks[0].tool, original.hooks[0].tool);
429 }
430
431 #[test]
434 fn matcher_input_equals_requires_exact_tool_input() {
435 let matcher = HookMatcher {
436 input_contains: None,
437 input_equals: Some(serde_json::json!({"command": "echo safe"})),
438 };
439
440 let matching = crate::extensions::hooks::events::HookEvent::before_tool_call(
441 "bash",
442 serde_json::json!({"command": "echo safe"}),
443 );
444 let different = crate::extensions::hooks::events::HookEvent::before_tool_call(
445 "bash",
446 serde_json::json!({"command": "echo safe", "extra": true}),
447 );
448
449 assert!(matcher.matches(&matching));
450 assert!(!matcher.matches(&different));
451 }
452
453 #[test]
454 fn matcher_conditions_are_combined_with_and() {
455 let matcher = HookMatcher {
456 input_contains: Some("safe".to_string()),
457 input_equals: Some(serde_json::json!({"command": "echo safe"})),
458 };
459
460 let matching = crate::extensions::hooks::events::HookEvent::before_tool_call(
461 "bash",
462 serde_json::json!({"command": "echo safe"}),
463 );
464 let equals_but_missing_contains = crate::extensions::hooks::events::HookEvent::before_tool_call(
465 "bash",
466 serde_json::json!({"command": "echo ok"}),
467 );
468
469 assert!(matcher.matches(&matching));
470 assert!(!matcher.matches(&equals_but_missing_contains));
471 }
472
473 #[test]
474 fn runtime_serializes_as_lowercase() {
475 let rt = ExtensionRuntime::Process;
476 let json = serde_json::to_string(&rt).unwrap();
477 assert_eq!(json, r#""process""#);
478 }
479
480 #[test]
481 fn extension_manifest_defaults_prebuilt_to_empty_when_absent() {
482 let json = r#"{
484 "runtime": "process",
485 "command": "bin/ext"
486 }"#;
487 let m: ExtensionManifest = serde_json::from_str(json).unwrap();
488 assert!(m.prebuilt.is_empty());
489 assert!(m.setup.is_none());
490 }
491
492 #[test]
493 fn extension_manifest_round_trips_prebuilt_assets() {
494 let json = r#"{
495 "runtime": "process",
496 "command": "bin/ext",
497 "prebuilt": {
498 "linux-x86_64": {
499 "url": "https://example.com/ext-linux-x86_64.tar.gz",
500 "sha256": "abc123"
501 },
502 "darwin-arm64": {
503 "url": "https://example.com/ext-darwin-arm64.tar.gz",
504 "sha256": "def456"
505 }
506 }
507 }"#;
508 let m: ExtensionManifest = serde_json::from_str(json).unwrap();
509 assert_eq!(m.prebuilt.len(), 2);
510 let linux = m.prebuilt.get("linux-x86_64").expect("linux entry");
511 assert_eq!(linux.url, "https://example.com/ext-linux-x86_64.tar.gz");
512 assert_eq!(linux.sha256, "abc123");
513 let back = serde_json::to_value(&m).unwrap();
515 assert_eq!(
516 back["prebuilt"]["darwin-arm64"]["sha256"],
517 serde_json::Value::String("def456".to_string())
518 );
519 }
520
521 #[test]
522 fn prebuilt_asset_requires_both_url_and_sha256() {
523 let json = r#"{ "url": "https://example.com/x.tar.gz" }"#;
525 let res: Result<PrebuiltAsset, _> = serde_json::from_str(json);
526 assert!(res.is_err(), "PrebuiltAsset without sha256 must fail to parse");
527 }
528}