Skip to main content

wire/
capabilities.rs

1// SPDX-License-Identifier: Apache-2.0
2//! Capability negotiation for Heddle protocol.
3//!
4//! Capabilities allow clients and servers to negotiate features and
5//! protocol extensions during the handshake phase.
6
7use std::collections::HashSet;
8
9use serde::{Deserialize, Serialize};
10
11pub const CAPABILITY_CHUNKED_TRANSFER: &str = "chunked-transfer";
12pub const CAPABILITY_RESUMABLE_TRANSFER: &str = "resumable-transfer";
13pub const CAPABILITY_PACK_TRANSFER: &str = "pack-transfer";
14pub const CAPABILITY_PARTIAL_FETCH: &str = "partial-fetch";
15
16/// Set of protocol capabilities.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Capabilities {
19    /// Protocol version.
20    pub version: u32,
21    /// Supported capability flags.
22    pub flags: HashSet<String>,
23    /// Maximum object size supported (in bytes).
24    pub max_object_size: u64,
25    /// Preferred chunk size for streaming (in bytes).
26    pub chunk_size: u32,
27    /// Whether delta compression is supported.
28    pub delta_compression: bool,
29    /// Supported compression algorithms.
30    pub compression: Vec<String>,
31}
32
33impl Capabilities {
34    /// Create default capabilities.
35    pub fn new(version: u32) -> Self {
36        let mut flags = HashSet::new();
37        flags.insert("baseline".to_string());
38
39        Self {
40            version,
41            flags,
42            max_object_size: 128 * 1024 * 1024,
43            chunk_size: 64 * 1024,
44            delta_compression: true,
45            compression: vec!["none".to_string()],
46        }
47    }
48
49    pub fn with_flag(mut self, flag: impl Into<String>) -> Self {
50        self.flags.insert(flag.into());
51        self
52    }
53
54    pub fn with_chunked_transfer(mut self, enabled: bool) -> Self {
55        if enabled {
56            self.flags.insert(CAPABILITY_CHUNKED_TRANSFER.to_string());
57        } else {
58            self.flags.remove(CAPABILITY_CHUNKED_TRANSFER);
59        }
60        self
61    }
62
63    pub fn with_resumable_transfer(mut self, enabled: bool) -> Self {
64        if enabled {
65            self.flags.insert(CAPABILITY_RESUMABLE_TRANSFER.to_string());
66        } else {
67            self.flags.remove(CAPABILITY_RESUMABLE_TRANSFER);
68        }
69        self
70    }
71
72    pub fn with_pack_transfer(mut self, enabled: bool) -> Self {
73        if enabled {
74            self.flags.insert(CAPABILITY_PACK_TRANSFER.to_string());
75        } else {
76            self.flags.remove(CAPABILITY_PACK_TRANSFER);
77        }
78        self
79    }
80
81    pub fn with_partial_fetch(mut self, enabled: bool) -> Self {
82        if enabled {
83            self.flags.insert(CAPABILITY_PARTIAL_FETCH.to_string());
84        } else {
85            self.flags.remove(CAPABILITY_PARTIAL_FETCH);
86        }
87        self
88    }
89
90    pub fn has_flag(&self, flag: &str) -> bool {
91        self.flags.contains(flag)
92    }
93
94    pub fn supports_chunked_transfer(&self) -> bool {
95        self.has_flag(CAPABILITY_CHUNKED_TRANSFER)
96    }
97
98    pub fn supports_resumable_transfer(&self) -> bool {
99        self.has_flag(CAPABILITY_RESUMABLE_TRANSFER)
100    }
101
102    pub fn supports_pack_transfer(&self) -> bool {
103        self.has_flag(CAPABILITY_PACK_TRANSFER)
104    }
105
106    pub fn supports_partial_fetch(&self) -> bool {
107        self.has_flag(CAPABILITY_PARTIAL_FETCH)
108    }
109
110    pub fn with_delta(mut self, enabled: bool) -> Self {
111        self.delta_compression = enabled;
112        self
113    }
114
115    pub fn with_compression(mut self, algo: impl Into<String>) -> Self {
116        let algo = algo.into();
117        if !self.compression.contains(&algo) {
118            self.compression.push(algo);
119        }
120        self
121    }
122
123    pub fn with_chunk_size(mut self, size: u32) -> Self {
124        self.chunk_size = size;
125        self
126    }
127
128    pub fn with_max_object_size(mut self, size: u64) -> Self {
129        self.max_object_size = size;
130        self
131    }
132
133    pub fn negotiate(&self, other: &Capabilities) -> Capabilities {
134        let version = self.version.min(other.version);
135        let flags: HashSet<_> = self.flags.intersection(&other.flags).cloned().collect();
136        let max_object_size = self.max_object_size.min(other.max_object_size);
137        let chunk_size = self.chunk_size.min(other.chunk_size);
138        let delta_compression = self.delta_compression && other.delta_compression;
139        let compression: Vec<_> = self
140            .compression
141            .iter()
142            .filter(|candidate| other.compression.contains(*candidate))
143            .cloned()
144            .collect();
145
146        Capabilities {
147            version,
148            flags,
149            max_object_size,
150            chunk_size,
151            delta_compression,
152            compression,
153        }
154    }
155
156    pub fn validate(&self) -> Result<(), String> {
157        if !self.has_flag("baseline") {
158            return Err("missing baseline capability".to_string());
159        }
160        if self.version == 0 {
161            return Err("invalid protocol version".to_string());
162        }
163        if self.chunk_size == 0 {
164            return Err("invalid chunk size".to_string());
165        }
166        if self.max_object_size == 0 {
167            return Err("invalid max object size".to_string());
168        }
169        if self.compression.is_empty() {
170            return Err("no common compression algorithms".to_string());
171        }
172        Ok(())
173    }
174
175    pub fn validate_with_required(&self, required_flags: &[&str]) -> Result<(), String> {
176        self.validate()?;
177
178        for flag in required_flags {
179            if !self.has_flag(flag) {
180                return Err(format!("missing required capability: {flag}"));
181            }
182        }
183
184        Ok(())
185    }
186}
187
188impl Default for Capabilities {
189    fn default() -> Self {
190        Self::new(1)
191    }
192}
193
194/// A set of capabilities that have been negotiated.
195#[derive(Debug, Clone)]
196pub struct CapabilitySet {
197    pub caps: Capabilities,
198    pub valid: bool,
199    pub error: Option<String>,
200}
201
202impl CapabilitySet {
203    pub fn new(client: &Capabilities, server: &Capabilities) -> Self {
204        let caps = client.negotiate(server);
205
206        match caps.validate() {
207            Ok(()) => Self {
208                caps,
209                valid: true,
210                error: None,
211            },
212            Err(error) => Self {
213                caps,
214                valid: false,
215                error: Some(error),
216            },
217        }
218    }
219
220    pub fn has_flag(&self, flag: &str) -> bool {
221        self.valid && self.caps.has_flag(flag)
222    }
223
224    pub fn delta_enabled(&self) -> bool {
225        self.valid && self.caps.delta_compression
226    }
227
228    pub fn chunk_size(&self) -> usize {
229        self.caps.chunk_size as usize
230    }
231
232    pub fn max_object_size(&self) -> usize {
233        self.caps.max_object_size.min(usize::MAX as u64) as usize
234    }
235
236    pub fn chunked_transfer_enabled(&self) -> bool {
237        self.has_flag(CAPABILITY_CHUNKED_TRANSFER)
238    }
239
240    pub fn resumable_transfer_enabled(&self) -> bool {
241        self.has_flag(CAPABILITY_RESUMABLE_TRANSFER)
242    }
243
244    pub fn pack_transfer_enabled(&self) -> bool {
245        self.has_flag(CAPABILITY_PACK_TRANSFER)
246    }
247
248    pub fn partial_fetch_enabled(&self) -> bool {
249        self.has_flag(CAPABILITY_PARTIAL_FETCH)
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_capabilities_default() {
259        let caps = Capabilities::default();
260        assert!(caps.has_flag("baseline"));
261        assert!(caps.delta_compression);
262        assert_eq!(caps.version, 1);
263    }
264
265    #[test]
266    fn test_capabilities_negotiate() {
267        let client = Capabilities::new(1)
268            .with_flag("fast-import")
269            .with_delta(true);
270        let server = Capabilities::new(1)
271            .with_flag("fast-import")
272            .with_flag("server-side-merging")
273            .with_delta(true);
274
275        let negotiated = client.negotiate(&server);
276
277        assert!(negotiated.has_flag("baseline"));
278        assert!(negotiated.has_flag("fast-import"));
279        assert!(!negotiated.has_flag("server-side-merging"));
280        assert!(negotiated.delta_compression);
281    }
282
283    #[test]
284    fn test_capabilities_version_negotiate() {
285        let client = Capabilities::new(1);
286        let server = Capabilities::new(2);
287        let negotiated = client.negotiate(&server);
288        assert_eq!(negotiated.version, 1);
289    }
290
291    #[test]
292    fn test_capability_set() {
293        let client = Capabilities::new(1).with_flag("test-feature");
294        let server = Capabilities::new(1).with_flag("test-feature");
295        let set = CapabilitySet::new(&client, &server);
296        assert!(set.valid);
297        assert!(set.has_flag("test-feature"));
298        assert!(set.has_flag("baseline"));
299    }
300
301    #[test]
302    fn test_capability_set_invalid() {
303        let mut client = Capabilities::new(1);
304        client.flags.clear();
305        let server = Capabilities::new(1);
306        let set = CapabilitySet::new(&client, &server);
307        assert!(!set.valid);
308        assert!(set.error.is_some());
309    }
310
311    #[test]
312    fn test_capabilities_validate_required_flags() {
313        let caps = Capabilities::new(1).with_flag("refs");
314        assert!(caps.validate_with_required(&["refs"]).is_ok());
315        assert!(caps.validate_with_required(&["objects"]).is_err());
316    }
317
318    #[test]
319    fn test_capabilities_validate_limits() {
320        let caps = Capabilities::new(1).with_chunk_size(0);
321        assert!(caps.validate().is_err());
322    }
323
324    #[test]
325    fn test_transport_capability_helpers_round_trip() {
326        let caps = Capabilities::new(1)
327            .with_chunked_transfer(true)
328            .with_resumable_transfer(true)
329            .with_pack_transfer(true)
330            .with_partial_fetch(true);
331
332        assert!(caps.supports_chunked_transfer());
333        assert!(caps.supports_resumable_transfer());
334        assert!(caps.supports_pack_transfer());
335        assert!(caps.supports_partial_fetch());
336    }
337
338    #[test]
339    fn test_transport_capability_helpers_toggle_off() {
340        let caps = Capabilities::new(1)
341            .with_chunked_transfer(true)
342            .with_chunked_transfer(false)
343            .with_resumable_transfer(true)
344            .with_resumable_transfer(false);
345
346        assert!(!caps.supports_chunked_transfer());
347        assert!(!caps.supports_resumable_transfer());
348    }
349}