1use std::collections::{BTreeMap, HashMap};
2
3use crate::cbor;
4
5#[derive(Debug, Clone)]
12pub enum LocalizedString {
13 Plain(String),
15 Localized(HashMap<String, String>),
17}
18
19impl Default for LocalizedString {
20 fn default() -> Self {
21 Self::Plain(String::new())
22 }
23}
24
25impl LocalizedString {
26 pub fn plain(text: impl Into<String>) -> Self {
28 Self::Plain(text.into())
29 }
30
31 pub fn new(lang: impl Into<String>, text: impl Into<String>) -> Self {
33 let mut map = HashMap::new();
34 map.insert(lang.into(), text.into());
35 Self::Localized(map)
36 }
37
38 pub fn get(&self, lang: &str) -> Option<&str> {
43 match self {
44 Self::Plain(text) => Some(text.as_str()),
45 Self::Localized(map) => map.get(lang).map(|s| s.as_str()),
46 }
47 }
48
49 pub fn resolve(&self, lang: &str) -> &str {
54 match self {
55 Self::Plain(text) => text.as_str(),
56 Self::Localized(map) => {
57 if let Some(text) = map.get(lang) {
59 return text.as_str();
60 }
61 if let Some(text) = map
63 .iter()
64 .find(|(tag, _)| tag.starts_with(lang) || lang.starts_with(tag.as_str()))
65 .map(|(_, text)| text.as_str())
66 {
67 return text;
68 }
69 map.values().next().map(|s| s.as_str()).unwrap_or("")
71 }
72 }
73 }
74
75 pub fn any_text(&self) -> &str {
78 match self {
79 Self::Plain(text) => text.as_str(),
80 Self::Localized(map) => map.values().next().map(|s| s.as_str()).unwrap_or(""),
81 }
82 }
83}
84
85impl From<String> for LocalizedString {
86 fn from(s: String) -> Self {
87 Self::Plain(s)
88 }
89}
90
91impl From<&str> for LocalizedString {
92 fn from(s: &str) -> Self {
93 Self::Plain(s.to_string())
94 }
95}
96
97impl From<Vec<(String, String)>> for LocalizedString {
98 fn from(v: Vec<(String, String)>) -> Self {
99 Self::Localized(v.into_iter().collect())
100 }
101}
102
103impl From<HashMap<String, String>> for LocalizedString {
104 fn from(map: HashMap<String, String>) -> Self {
105 Self::Localized(map)
106 }
107}
108
109#[derive(Debug, Clone, Default)]
115pub struct Metadata(HashMap<String, serde_json::Value>);
116
117impl Metadata {
118 pub fn new() -> Self {
119 Self(HashMap::new())
120 }
121
122 pub fn insert(&mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) {
124 self.0.insert(key.into(), value.into());
125 }
126
127 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
129 self.0.get(key)
130 }
131
132 pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
134 self.0
135 .get(key)
136 .and_then(|v| serde_json::from_value(v.clone()).ok())
137 }
138
139 pub fn contains_key(&self, key: &str) -> bool {
141 self.0.contains_key(key)
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.0.is_empty()
147 }
148
149 pub fn iter(&self) -> impl Iterator<Item = (&String, &serde_json::Value)> {
151 self.0.iter()
152 }
153
154 pub fn len(&self) -> usize {
156 self.0.len()
157 }
158
159 pub fn extend(&mut self, other: Metadata) {
161 self.0.extend(other.0);
162 }
163}
164
165impl From<serde_json::Value> for Metadata {
167 fn from(value: serde_json::Value) -> Self {
168 match value {
169 serde_json::Value::Object(map) => Self(map.into_iter().collect()),
170 _ => Self::new(),
171 }
172 }
173}
174
175impl From<Metadata> for serde_json::Value {
177 fn from(m: Metadata) -> Self {
178 serde_json::Value::Object(m.0.into_iter().collect())
179 }
180}
181
182impl From<Vec<(String, Vec<u8>)>> for Metadata {
184 fn from(v: Vec<(String, Vec<u8>)>) -> Self {
185 Self(
186 v.into_iter()
187 .filter_map(|(k, cbor_bytes)| {
188 let val = cbor::cbor_to_json(&cbor_bytes).ok()?;
189 Some((k, val))
190 })
191 .collect(),
192 )
193 }
194}
195
196impl From<Metadata> for Vec<(String, Vec<u8>)> {
198 fn from(m: Metadata) -> Self {
199 m.0.into_iter()
200 .map(|(k, v)| (k, cbor::to_cbor(&v)))
201 .collect()
202 }
203}
204
205use crate::constants::*;
206
207#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
211pub struct FilesystemCap {
212 #[serde(
214 rename = "mount-root",
215 default,
216 skip_serializing_if = "Option::is_none"
217 )]
218 pub mount_root: Option<String>,
219
220 #[serde(default, skip_serializing_if = "Vec::is_empty")]
224 pub allow: Vec<FilesystemAllow>,
225}
226
227#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
230pub struct FilesystemAllow {
231 pub path: String,
233 pub mode: FsMode,
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
239#[serde(rename_all = "lowercase")]
240pub enum FsMode {
241 Ro,
243 Rw,
245}
246
247#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
249pub struct HttpCap {
250 #[serde(default, skip_serializing_if = "Vec::is_empty")]
254 pub allow: Vec<HttpAllow>,
255}
256
257#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
263pub struct HttpAllow {
264 pub host: String,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub scheme: Option<String>,
267 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub methods: Option<Vec<String>>,
269 #[serde(default, skip_serializing_if = "Option::is_none")]
270 pub ports: Option<Vec<u16>>,
271}
272
273#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
275pub struct SocketsCap {
276 #[serde(default, skip_serializing_if = "Vec::is_empty")]
279 pub allow: Vec<SocketsAllow>,
280}
281
282#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
288pub struct SocketsAllow {
289 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub host: Option<String>,
293 #[serde(default, skip_serializing_if = "Option::is_none")]
295 pub cidr: Option<String>,
296 pub ports: Vec<u16>,
298 #[serde(
300 default = "default_socket_protocols",
301 skip_serializing_if = "is_default_protocols"
302 )]
303 pub protocols: Vec<SocketProtocol>,
304}
305
306fn default_socket_protocols() -> Vec<SocketProtocol> {
307 vec![SocketProtocol::Tcp, SocketProtocol::Udp]
308}
309
310fn is_default_protocols(v: &[SocketProtocol]) -> bool {
311 v == [SocketProtocol::Tcp, SocketProtocol::Udp]
312}
313
314#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
316#[serde(rename_all = "lowercase")]
317pub enum SocketProtocol {
318 Tcp,
319 Udp,
320}
321
322#[serde_with::skip_serializing_none]
327#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
328#[serde(default)]
329pub struct Capabilities {
330 #[serde(rename = "wasi:filesystem")]
332 pub filesystem: Option<FilesystemCap>,
333 #[serde(rename = "wasi:http")]
335 pub http: Option<HttpCap>,
336 #[serde(rename = "wasi:sockets")]
338 pub sockets: Option<SocketsCap>,
339 #[serde(flatten)]
341 pub other: BTreeMap<String, serde_json::Value>,
342}
343
344impl Capabilities {
345 pub fn is_empty(&self) -> bool {
347 self.http.is_none()
348 && self.filesystem.is_none()
349 && self.sockets.is_none()
350 && self.other.is_empty()
351 }
352
353 pub fn has(&self, id: &str) -> bool {
355 match id {
356 CAP_HTTP => self.http.is_some(),
357 CAP_FILESYSTEM => self.filesystem.is_some(),
358 CAP_SOCKETS => self.sockets.is_some(),
359 other => self.other.contains_key(other),
360 }
361 }
362
363 pub fn fs_mount_root(&self) -> Option<&str> {
365 self.filesystem.as_ref()?.mount_root.as_deref()
366 }
367}
368
369#[non_exhaustive]
376#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
377pub struct ComponentInfo {
378 #[serde(default)]
380 pub std: StdComponentInfo,
381 #[serde(flatten, default, skip_serializing_if = "HashMap::is_empty")]
383 pub extra: HashMap<String, serde_json::Value>,
384}
385
386#[non_exhaustive]
388#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
389pub struct StdComponentInfo {
390 #[serde(default)]
391 pub name: String,
392 #[serde(default)]
393 pub version: String,
394 #[serde(default)]
395 pub description: String,
396 #[serde(
397 rename = "default-language",
398 default,
399 skip_serializing_if = "Option::is_none"
400 )]
401 pub default_language: Option<String>,
402 #[serde(default, skip_serializing_if = "Capabilities::is_empty")]
403 pub capabilities: Capabilities,
404}
405
406impl ComponentInfo {
407 pub fn new(
408 name: impl Into<String>,
409 version: impl Into<String>,
410 description: impl Into<String>,
411 ) -> Self {
412 Self {
413 std: StdComponentInfo {
414 name: name.into(),
415 version: version.into(),
416 description: description.into(),
417 ..Default::default()
418 },
419 ..Default::default()
420 }
421 }
422
423 pub fn name(&self) -> &str {
425 &self.std.name
426 }
427 pub fn version(&self) -> &str {
428 &self.std.version
429 }
430 pub fn description(&self) -> &str {
431 &self.std.description
432 }
433}
434
435#[derive(Debug, Clone)]
439pub struct ActError {
440 pub kind: String,
441 pub message: String,
442}
443
444impl ActError {
445 pub fn new(kind: impl Into<String>, message: impl Into<String>) -> Self {
446 Self {
447 kind: kind.into(),
448 message: message.into(),
449 }
450 }
451
452 pub fn not_found(message: impl Into<String>) -> Self {
453 Self::new(ERR_NOT_FOUND, message)
454 }
455
456 pub fn invalid_args(message: impl Into<String>) -> Self {
457 Self::new(ERR_INVALID_ARGS, message)
458 }
459
460 pub fn internal(message: impl Into<String>) -> Self {
461 Self::new(ERR_INTERNAL, message)
462 }
463
464 pub fn timeout(message: impl Into<String>) -> Self {
465 Self::new(ERR_TIMEOUT, message)
466 }
467
468 pub fn capability_denied(message: impl Into<String>) -> Self {
469 Self::new(ERR_CAPABILITY_DENIED, message)
470 }
471
472 pub fn session_not_found(message: impl Into<String>) -> Self {
473 Self::new(ERR_SESSION_NOT_FOUND, message)
474 }
475}
476
477impl std::fmt::Display for ActError {
478 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479 write!(f, "{}: {}", self.kind, self.message)
480 }
481}
482
483impl std::error::Error for ActError {}
484
485pub type ActResult<T> = Result<T, ActError>;
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use serde_json::json;
492
493 #[test]
494 fn localized_string_plain() {
495 let ls = LocalizedString::plain("hello");
496 assert_eq!(ls.resolve("en"), "hello");
497 assert_eq!(ls.any_text(), "hello");
498 }
499
500 #[test]
501 fn localized_string_from_str() {
502 let ls = LocalizedString::from("hello");
503 assert_eq!(ls.any_text(), "hello");
504 }
505
506 #[test]
507 fn localized_string_default() {
508 let ls = LocalizedString::default();
509 assert_eq!(ls.any_text(), "");
510 }
511
512 #[test]
513 fn localized_string_resolve_by_lang() {
514 let mut map = std::collections::HashMap::new();
515 map.insert("en".to_string(), "hello".to_string());
516 map.insert("ru".to_string(), "привет".to_string());
517 let ls = LocalizedString::Localized(map);
518 assert_eq!(ls.resolve("ru"), "привет");
519 assert_eq!(ls.resolve("en"), "hello");
520 assert!(!ls.resolve("fr").is_empty());
522 }
523
524 #[test]
525 fn localized_string_resolve_prefix() {
526 let mut map = HashMap::new();
527 map.insert("zh-Hans".to_string(), "你好".to_string());
528 map.insert("en".to_string(), "hello".to_string());
529 let ls = LocalizedString::Localized(map);
530 assert_eq!(ls.resolve("zh"), "你好");
531 }
532
533 #[test]
534 fn localized_string_get() {
535 let ls = LocalizedString::new("en", "hello");
536 assert_eq!(ls.get("en"), Some("hello"));
537 assert_eq!(ls.get("ru"), None);
538 }
539
540 #[test]
541 fn localized_string_from_vec() {
542 let v = vec![("en".to_string(), "hi".to_string())];
543 let ls = LocalizedString::from(v);
544 assert_eq!(ls.resolve("en"), "hi");
545 }
546
547 #[test]
548 fn metadata_insert_and_get() {
549 let mut m = Metadata::new();
550 m.insert("std:read-only", true);
551 assert_eq!(m.get("std:read-only"), Some(&json!(true)));
552 assert_eq!(m.get_as::<bool>("std:read-only"), Some(true));
553 }
554
555 #[test]
556 fn metadata_to_json_empty() {
557 let json: serde_json::Value = Metadata::new().into();
558 assert_eq!(json, json!({}));
559 }
560
561 #[test]
562 fn metadata_to_json_with_values() {
563 let mut m = Metadata::new();
564 m.insert("std:read-only", true);
565 let json: serde_json::Value = m.into();
566 assert_eq!(json["std:read-only"], json!(true));
567 }
568
569 #[test]
570 fn metadata_from_vec() {
571 let v = vec![("key".to_string(), cbor::to_cbor(&42u32))];
572 let m = Metadata::from(v);
573 assert_eq!(m.get("key"), Some(&json!(42)));
574 assert_eq!(m.get_as::<u32>("key"), Some(42));
575 }
576
577 #[test]
578 fn capabilities_cbor_roundtrip() {
579 let mut info = ComponentInfo::new("test", "0.1.0", "test component");
580 info.std.capabilities.http = Some(HttpCap::default());
581 info.std.capabilities.filesystem = Some(FilesystemCap {
582 mount_root: Some("/data".to_string()),
583 ..Default::default()
584 });
585
586 let mut buf = Vec::new();
587 ciborium::into_writer(&info, &mut buf).unwrap();
588
589 let decoded: ComponentInfo = ciborium::from_reader(&buf[..]).unwrap();
590 assert!(decoded.std.capabilities.http.is_some());
591 assert!(decoded.std.capabilities.filesystem.is_some());
592 assert!(decoded.std.capabilities.sockets.is_none());
593 assert_eq!(decoded.std.capabilities.fs_mount_root(), Some("/data"));
594 }
595
596 #[test]
597 fn capabilities_empty_roundtrip() {
598 let info = ComponentInfo::new("test", "0.1.0", "test");
599
600 let mut buf = Vec::new();
601 ciborium::into_writer(&info, &mut buf).unwrap();
602
603 let decoded: ComponentInfo = ciborium::from_reader(&buf[..]).unwrap();
604 assert!(decoded.std.capabilities.is_empty());
605 }
606
607 #[test]
608 fn capabilities_fs_no_params_roundtrip() {
609 let mut info = ComponentInfo::new("test", "0.1.0", "test");
610 info.std.capabilities.filesystem = Some(FilesystemCap::default());
611
612 let mut buf = Vec::new();
613 ciborium::into_writer(&info, &mut buf).unwrap();
614
615 let decoded: ComponentInfo = ciborium::from_reader(&buf[..]).unwrap();
616 assert!(decoded.std.capabilities.filesystem.is_some());
617 assert_eq!(decoded.std.capabilities.fs_mount_root(), None);
618 }
619
620 #[test]
621 fn capabilities_unknown_preserved() {
622 let mut info = ComponentInfo::new("test", "0.1.0", "test");
623 info.std
624 .capabilities
625 .other
626 .insert("acme:gpu".to_string(), json!({"cores": 8}));
627
628 let mut buf = Vec::new();
629 ciborium::into_writer(&info, &mut buf).unwrap();
630
631 let decoded: ComponentInfo = ciborium::from_reader(&buf[..]).unwrap();
632 assert!(decoded.std.capabilities.has("acme:gpu"));
633 assert_eq!(decoded.std.capabilities.other["acme:gpu"]["cores"], 8);
634 }
635
636 #[test]
637 fn filesystem_cap_with_allow_roundtrips() {
638 let toml_input = r#"
639[std.capabilities."wasi:filesystem"]
640description = "test"
641
642[[std.capabilities."wasi:filesystem".allow]]
643path = "/etc/**"
644mode = "ro"
645
646[[std.capabilities."wasi:filesystem".allow]]
647path = "/tmp/**"
648mode = "rw"
649"#;
650 #[derive(serde::Deserialize)]
651 struct Wrap {
652 std: Std,
653 }
654 #[derive(serde::Deserialize)]
655 struct Std {
656 capabilities: Capabilities,
657 }
658 let w: Wrap = toml::from_str(toml_input).expect("parses");
659 let fs = w.std.capabilities.filesystem.expect("fs declared");
660 assert_eq!(fs.allow.len(), 2);
661 assert_eq!(fs.allow[0].path, "/etc/**");
662 assert!(matches!(fs.allow[0].mode, FsMode::Ro));
663 assert_eq!(fs.allow[1].path, "/tmp/**");
664 assert!(matches!(fs.allow[1].mode, FsMode::Rw));
665 }
666
667 #[test]
668 fn filesystem_cap_requires_path_and_mode_on_each_entry() {
669 let bad = r#"
671[[std.capabilities."wasi:filesystem".allow]]
672path = "/tmp/**"
673"#;
674 #[derive(serde::Deserialize)]
675 struct Wrap {
676 std: Std,
677 }
678 #[derive(serde::Deserialize)]
679 struct Std {
680 capabilities: Capabilities,
681 }
682 assert!(
683 toml::from_str::<Wrap>(bad).is_err(),
684 "missing mode must fail"
685 );
686 }
687
688 #[test]
689 fn http_cap_with_allow_roundtrips() {
690 let toml_input = r#"
691[std.capabilities."wasi:http"]
692description = "Calls OpenAI + GitHub"
693
694[[std.capabilities."wasi:http".allow]]
695host = "api.openai.com"
696scheme = "https"
697methods = ["GET", "POST"]
698
699[[std.capabilities."wasi:http".allow]]
700host = "*.github.com"
701scheme = "https"
702"#;
703 #[derive(serde::Deserialize)]
704 struct Wrap {
705 std: Std,
706 }
707 #[derive(serde::Deserialize)]
708 struct Std {
709 capabilities: Capabilities,
710 }
711 let w: Wrap = toml::from_str(toml_input).expect("parses");
712 let http = w.std.capabilities.http.expect("http declared");
713 assert_eq!(http.allow.len(), 2);
714 assert_eq!(http.allow[0].host, "api.openai.com");
715 assert_eq!(http.allow[0].scheme.as_deref(), Some("https"));
716 assert_eq!(
717 http.allow[0].methods.as_deref(),
718 Some(&["GET".to_string(), "POST".to_string()][..])
719 );
720 assert_eq!(http.allow[1].host, "*.github.com");
721 }
722
723 #[test]
724 fn http_cap_requires_host_on_each_entry() {
725 let bad = r#"
727[[std.capabilities."wasi:http".allow]]
728scheme = "https"
729"#;
730 #[derive(serde::Deserialize)]
731 struct Wrap {
732 std: Std,
733 }
734 #[derive(serde::Deserialize)]
735 struct Std {
736 capabilities: Capabilities,
737 }
738 assert!(
739 toml::from_str::<Wrap>(bad).is_err(),
740 "missing host must fail"
741 );
742 }
743
744 #[test]
745 fn http_cap_wildcard_host() {
746 let toml_input = r#"
748[[std.capabilities."wasi:http".allow]]
749host = "*"
750"#;
751 #[derive(serde::Deserialize)]
752 struct Wrap {
753 std: Std,
754 }
755 #[derive(serde::Deserialize)]
756 struct Std {
757 capabilities: Capabilities,
758 }
759 let w: Wrap = toml::from_str(toml_input).expect("parses");
760 assert_eq!(w.std.capabilities.http.unwrap().allow[0].host, "*");
761 }
762
763 #[test]
764 fn sockets_allow_round_trip() {
765 let toml_input = r#"
766[std.capabilities."wasi:sockets"]
767allow = [
768 { host = "vnc.example.com", ports = [5900], protocols = ["tcp"] },
769 { cidr = "10.0.0.0/8", ports = [80, 443] },
770]
771"#;
772 #[derive(serde::Deserialize)]
773 struct W {
774 std: Wrap,
775 }
776 #[derive(serde::Deserialize)]
777 struct Wrap {
778 capabilities: Capabilities,
779 }
780 let w: W = toml::from_str(toml_input).unwrap();
781 let sockets = w.std.capabilities.sockets.expect("sockets declared");
782 assert_eq!(sockets.allow.len(), 2);
783
784 let a = &sockets.allow[0];
785 assert_eq!(a.host.as_deref(), Some("vnc.example.com"));
786 assert_eq!(a.cidr, None);
787 assert_eq!(a.ports, vec![5900]);
788 assert_eq!(a.protocols, vec![SocketProtocol::Tcp]);
789
790 let b = &sockets.allow[1];
791 assert_eq!(b.host, None);
792 assert_eq!(b.cidr.as_deref(), Some("10.0.0.0/8"));
793 assert_eq!(b.ports, vec![80, 443]);
794 assert_eq!(b.protocols, vec![SocketProtocol::Tcp, SocketProtocol::Udp]);
795 }
796
797 #[test]
798 fn sockets_cap_has_string() {
799 let mut c = Capabilities::default();
800 assert!(!c.has(crate::constants::CAP_SOCKETS));
801 c.sockets = Some(SocketsCap::default());
802 assert!(c.has(crate::constants::CAP_SOCKETS));
803 }
804
805 #[test]
806 fn sockets_allow_default_protocols_not_emitted() {
807 let toml_input = r#"
811[[allow]]
812host = "vnc.example.com"
813ports = [5900]
814"#;
815 #[derive(serde::Serialize, serde::Deserialize)]
816 struct W {
817 allow: Vec<SocketsAllow>,
818 }
819 let w: W = toml::from_str(toml_input).unwrap();
820 assert_eq!(
821 w.allow[0].protocols,
822 vec![SocketProtocol::Tcp, SocketProtocol::Udp]
823 );
824
825 let re = toml::to_string(&w).unwrap();
826 assert!(
827 !re.contains("protocols"),
828 "default protocols leaked into re-serialized output: {re}"
829 );
830
831 let w2: W = toml::from_str(&re).unwrap();
833 assert_eq!(
834 w2.allow[0].protocols,
835 vec![SocketProtocol::Tcp, SocketProtocol::Udp]
836 );
837 }
838}