hehe_core/capability/
mod.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashSet;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum Capability {
7    TextInput,
8    ImageInput,
9    AudioInput,
10    VideoInput,
11    FileInput,
12    TextOutput,
13    ImageOutput,
14    AudioOutput,
15    ToolUse,
16    Streaming,
17    JsonMode,
18    SystemPrompt,
19    MultiTurn,
20    CodeExecution,
21    WebBrowsing,
22    FunctionCalling,
23    Vision,
24    Custom(String),
25}
26
27#[derive(Clone, Debug, Default, Serialize, Deserialize)]
28pub struct Capabilities {
29    #[serde(default)]
30    inner: HashSet<Capability>,
31}
32
33impl Capabilities {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    pub fn with(mut self, cap: Capability) -> Self {
39        self.inner.insert(cap);
40        self
41    }
42
43    pub fn add(&mut self, cap: Capability) -> bool {
44        self.inner.insert(cap)
45    }
46
47    pub fn remove(&mut self, cap: &Capability) -> bool {
48        self.inner.remove(cap)
49    }
50
51    pub fn has(&self, cap: &Capability) -> bool {
52        self.inner.contains(cap)
53    }
54
55    pub fn has_all(&self, caps: &[Capability]) -> bool {
56        caps.iter().all(|c| self.has(c))
57    }
58
59    pub fn has_any(&self, caps: &[Capability]) -> bool {
60        caps.iter().any(|c| self.has(c))
61    }
62
63    pub fn iter(&self) -> impl Iterator<Item = &Capability> {
64        self.inner.iter()
65    }
66
67    pub fn len(&self) -> usize {
68        self.inner.len()
69    }
70
71    pub fn is_empty(&self) -> bool {
72        self.inner.is_empty()
73    }
74
75    pub fn text_basic() -> Self {
76        Self::new()
77            .with(Capability::TextInput)
78            .with(Capability::TextOutput)
79            .with(Capability::Streaming)
80            .with(Capability::SystemPrompt)
81            .with(Capability::MultiTurn)
82    }
83
84    pub fn vision() -> Self {
85        Self::text_basic()
86            .with(Capability::ImageInput)
87            .with(Capability::Vision)
88    }
89
90    pub fn multimodal() -> Self {
91        Self::vision()
92            .with(Capability::AudioInput)
93            .with(Capability::FileInput)
94    }
95
96    pub fn tool_capable() -> Self {
97        Self::text_basic()
98            .with(Capability::ToolUse)
99            .with(Capability::FunctionCalling)
100    }
101
102    pub fn full_agent() -> Self {
103        Self::multimodal()
104            .with(Capability::ToolUse)
105            .with(Capability::FunctionCalling)
106            .with(Capability::JsonMode)
107    }
108
109    pub fn merge(&mut self, other: &Capabilities) {
110        self.inner.extend(other.inner.iter().cloned());
111    }
112
113    pub fn intersection(&self, other: &Capabilities) -> Capabilities {
114        Capabilities {
115            inner: self.inner.intersection(&other.inner).cloned().collect(),
116        }
117    }
118}
119
120pub trait CapabilityProvider {
121    fn capabilities(&self) -> &Capabilities;
122
123    fn supports(&self, cap: &Capability) -> bool {
124        self.capabilities().has(cap)
125    }
126
127    fn supports_all(&self, caps: &[Capability]) -> bool {
128        self.capabilities().has_all(caps)
129    }
130
131    fn supports_any(&self, caps: &[Capability]) -> bool {
132        self.capabilities().has_any(caps)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_capabilities() {
142        let caps = Capabilities::text_basic();
143        assert!(caps.has(&Capability::TextInput));
144        assert!(caps.has(&Capability::TextOutput));
145        assert!(!caps.has(&Capability::ImageInput));
146    }
147
148    #[test]
149    fn test_full_agent() {
150        let caps = Capabilities::full_agent();
151        assert!(caps.has(&Capability::ToolUse));
152        assert!(caps.has(&Capability::ImageInput));
153        assert!(caps.has(&Capability::Streaming));
154    }
155}