hehe_core/capability/
mod.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum Capability {
7 TextInput,
8 ImageInput,
9 AudioInput,
10 VideoInput,
11 FileInput,
12 TextOutput,
13 ImageOutput,
14 AudioOutput,
15 ToolUse,
16 Streaming,
17 JsonMode,
18 SystemPrompt,
19 MultiTurn,
20 CodeExecution,
21 WebBrowsing,
22 FunctionCalling,
23 Vision,
24 Custom(String),
25}
26
27#[derive(Clone, Debug, Default, Serialize, Deserialize)]
28pub struct Capabilities {
29 #[serde(default)]
30 inner: HashSet<Capability>,
31}
32
33impl Capabilities {
34 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn with(mut self, cap: Capability) -> Self {
39 self.inner.insert(cap);
40 self
41 }
42
43 pub fn add(&mut self, cap: Capability) -> bool {
44 self.inner.insert(cap)
45 }
46
47 pub fn remove(&mut self, cap: &Capability) -> bool {
48 self.inner.remove(cap)
49 }
50
51 pub fn has(&self, cap: &Capability) -> bool {
52 self.inner.contains(cap)
53 }
54
55 pub fn has_all(&self, caps: &[Capability]) -> bool {
56 caps.iter().all(|c| self.has(c))
57 }
58
59 pub fn has_any(&self, caps: &[Capability]) -> bool {
60 caps.iter().any(|c| self.has(c))
61 }
62
63 pub fn iter(&self) -> impl Iterator<Item = &Capability> {
64 self.inner.iter()
65 }
66
67 pub fn len(&self) -> usize {
68 self.inner.len()
69 }
70
71 pub fn is_empty(&self) -> bool {
72 self.inner.is_empty()
73 }
74
75 pub fn text_basic() -> Self {
76 Self::new()
77 .with(Capability::TextInput)
78 .with(Capability::TextOutput)
79 .with(Capability::Streaming)
80 .with(Capability::SystemPrompt)
81 .with(Capability::MultiTurn)
82 }
83
84 pub fn vision() -> Self {
85 Self::text_basic()
86 .with(Capability::ImageInput)
87 .with(Capability::Vision)
88 }
89
90 pub fn multimodal() -> Self {
91 Self::vision()
92 .with(Capability::AudioInput)
93 .with(Capability::FileInput)
94 }
95
96 pub fn tool_capable() -> Self {
97 Self::text_basic()
98 .with(Capability::ToolUse)
99 .with(Capability::FunctionCalling)
100 }
101
102 pub fn full_agent() -> Self {
103 Self::multimodal()
104 .with(Capability::ToolUse)
105 .with(Capability::FunctionCalling)
106 .with(Capability::JsonMode)
107 }
108
109 pub fn merge(&mut self, other: &Capabilities) {
110 self.inner.extend(other.inner.iter().cloned());
111 }
112
113 pub fn intersection(&self, other: &Capabilities) -> Capabilities {
114 Capabilities {
115 inner: self.inner.intersection(&other.inner).cloned().collect(),
116 }
117 }
118}
119
120pub trait CapabilityProvider {
121 fn capabilities(&self) -> &Capabilities;
122
123 fn supports(&self, cap: &Capability) -> bool {
124 self.capabilities().has(cap)
125 }
126
127 fn supports_all(&self, caps: &[Capability]) -> bool {
128 self.capabilities().has_all(caps)
129 }
130
131 fn supports_any(&self, caps: &[Capability]) -> bool {
132 self.capabilities().has_any(caps)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_capabilities() {
142 let caps = Capabilities::text_basic();
143 assert!(caps.has(&Capability::TextInput));
144 assert!(caps.has(&Capability::TextOutput));
145 assert!(!caps.has(&Capability::ImageInput));
146 }
147
148 #[test]
149 fn test_full_agent() {
150 let caps = Capabilities::full_agent();
151 assert!(caps.has(&Capability::ToolUse));
152 assert!(caps.has(&Capability::ImageInput));
153 assert!(caps.has(&Capability::Streaming));
154 }
155}