1use crate::{capabilities, Event};
2use anyhow::{anyhow, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use uuid::Uuid;
6
7pub const PROTOCOL_V0: u32 = 0;
9pub const PROTOCOL_V1: u32 = 1;
10pub const CURRENT_VERSION: u32 = PROTOCOL_V0; pub const SUPPORTED_VERSIONS: &[u32] = &[PROTOCOL_V1, PROTOCOL_V0];
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct NegotiationResult {
18 pub version: u32,
19 pub capabilities: Vec<String>,
20 pub fallback_reason: Option<String>,
21 pub service_info: HashMap<String, String>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ClientInfo {
27 pub supported_versions: Vec<u32>,
28 pub requested_capabilities: Vec<String>,
29 pub client_metadata: HashMap<String, String>,
30}
31
32pub struct ProtocolNegotiator {
34 service_id: Uuid,
35 available_capabilities: Vec<String>,
36 service_metadata: HashMap<String, String>,
37}
38
39impl ProtocolNegotiator {
40 pub fn new(service_id: Uuid) -> Self {
41 let mut service_metadata = HashMap::new();
42 service_metadata.insert(
43 "service_name".to_string(),
44 "claude-code-rs-core".to_string(),
45 );
46 service_metadata.insert("version".to_string(), env!("CARGO_PKG_VERSION").to_string());
47 service_metadata.insert(
48 "build_timestamp".to_string(),
49 std::env::var("BUILD_TIMESTAMP").unwrap_or_else(|_| "unknown".to_string()),
50 );
51 service_metadata.insert(
52 "git_commit".to_string(),
53 std::env::var("GIT_COMMIT").unwrap_or_else(|_| "unknown".to_string()),
54 );
55
56 let available_capabilities = vec![
58 capabilities::SHELL_EXEC.to_string(),
59 capabilities::REPLAY.to_string(),
60 capabilities::TRACING.to_string(),
61 #[cfg(feature = "hooks-quickjs")]
62 capabilities::HOOKS_JS.to_string(),
63 #[cfg(feature = "hooks-rust")]
64 capabilities::HOOKS_RUST.to_string(),
65 #[cfg(feature = "nats")]
66 capabilities::NATS.to_string(),
67 ];
68
69 #[cfg(feature = "protobuf")]
70 service_metadata.insert("protobuf_support".to_string(), "true".to_string());
71
72 Self {
73 service_id,
74 available_capabilities,
75 service_metadata,
76 }
77 }
78
79 pub fn negotiate(&self, client_info: ClientInfo) -> Result<NegotiationResult> {
81 let selected_version = self.select_version(&client_info.supported_versions)?;
83
84 let granted_capabilities = self.filter_capabilities(&client_info.requested_capabilities);
86
87 let fallback_reason =
89 self.check_fallback_conditions(selected_version, &granted_capabilities);
90
91 let mut service_info = self.service_metadata.clone();
93 service_info.insert(
94 "negotiated_version".to_string(),
95 selected_version.to_string(),
96 );
97 service_info.insert(
98 "granted_capabilities".to_string(),
99 granted_capabilities.len().to_string(),
100 );
101
102 Ok(NegotiationResult {
103 version: selected_version,
104 capabilities: granted_capabilities,
105 fallback_reason,
106 service_info,
107 })
108 }
109
110 fn select_version(&self, client_versions: &[u32]) -> Result<u32> {
112 for &server_version in SUPPORTED_VERSIONS {
114 if client_versions.contains(&server_version) {
115 return Ok(server_version);
116 }
117 }
118
119 Err(anyhow!(
120 "No compatible protocol version found. Server supports: {:?}, Client supports: {:?}",
121 SUPPORTED_VERSIONS,
122 client_versions
123 ))
124 }
125
126 fn filter_capabilities(&self, requested: &[String]) -> Vec<String> {
128 requested
129 .iter()
130 .filter(|cap| self.available_capabilities.contains(cap))
131 .cloned()
132 .collect()
133 }
134
135 fn check_fallback_conditions(&self, version: u32, _capabilities: &[String]) -> Option<String> {
137 match version {
138 PROTOCOL_V1 => {
139 #[cfg(not(feature = "protobuf"))]
141 {
142 Some("Protobuf support not compiled in, falling back to JSONL".to_string())
143 }
144
145 #[cfg(feature = "protobuf")]
146 {
147 None
149 }
150 }
151 PROTOCOL_V0 => None, _ => Some(format!("Unsupported version {}, using v0", version)),
153 }
154 }
155
156 pub fn create_ready_event(&self, result: &NegotiationResult) -> Event {
158 Event::Ready {
159 version: result.version,
160 capabilities: result.capabilities.clone(),
161 service_id: self.service_id,
162 }
163 }
164
165 pub fn get_available_capabilities(&self) -> &[String] {
167 &self.available_capabilities
168 }
169
170 pub fn supports_capability(&self, capability: &str) -> bool {
172 self.available_capabilities
173 .contains(&capability.to_string())
174 }
175
176 pub fn get_service_metadata(&self) -> &HashMap<String, String> {
178 &self.service_metadata
179 }
180}
181
182pub struct CapabilityChecker;
184
185impl CapabilityChecker {
186 pub fn check_compatibility(capabilities: &[String]) -> Result<Vec<String>> {
188 let mut warnings = Vec::new();
189
190 let has_js_hooks = capabilities.contains(&capabilities::HOOKS_JS.to_string());
192 let has_rust_hooks = capabilities.contains(&capabilities::HOOKS_RUST.to_string());
193
194 if has_js_hooks && has_rust_hooks {
195 warnings
196 .push("Both JS and Rust hooks enabled - performance may be impacted".to_string());
197 }
198
199 let has_nats = capabilities.contains(&capabilities::NATS.to_string());
201 let has_tracing = capabilities.contains(&capabilities::TRACING.to_string());
202
203 if has_nats && !has_tracing {
204 warnings.push("NATS enabled without tracing - reduced observability".to_string());
205 }
206
207 let has_replay = capabilities.contains(&capabilities::REPLAY.to_string());
209 if has_replay {
210 warnings.push(
211 "Replay enabled - performance overhead for recording all operations".to_string(),
212 );
213 }
214
215 Ok(warnings)
216 }
217
218 pub fn recommend_capabilities(use_case: &str) -> Vec<String> {
220 match use_case {
221 "development" => vec![
222 capabilities::SHELL_EXEC.to_string(),
223 capabilities::HOOKS_JS.to_string(),
224 capabilities::REPLAY.to_string(),
225 capabilities::TRACING.to_string(),
226 ],
227 "production" => vec![
228 capabilities::SHELL_EXEC.to_string(),
229 capabilities::HOOKS_RUST.to_string(),
230 capabilities::TRACING.to_string(),
231 capabilities::NATS.to_string(),
232 ],
233 "testing" => vec![
234 capabilities::SHELL_EXEC.to_string(),
235 capabilities::REPLAY.to_string(),
236 capabilities::TRACING.to_string(),
237 ],
238 "minimal" => vec![capabilities::SHELL_EXEC.to_string()],
239 _ => vec![
240 capabilities::SHELL_EXEC.to_string(),
241 capabilities::TRACING.to_string(),
242 ],
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_version_negotiation() {
253 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
254
255 let client_info = ClientInfo {
257 supported_versions: vec![0, 1],
258 requested_capabilities: vec![
259 capabilities::SHELL_EXEC.to_string(),
260 capabilities::TRACING.to_string(),
261 ],
262 client_metadata: HashMap::new(),
263 };
264
265 let result = negotiator.negotiate(client_info).unwrap();
266
267 assert!(result.version <= 1);
269 assert!(result
270 .capabilities
271 .contains(&capabilities::SHELL_EXEC.to_string()));
272 assert!(result
273 .capabilities
274 .contains(&capabilities::TRACING.to_string()));
275 }
276
277 #[test]
278 fn test_incompatible_versions() {
279 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
280
281 let client_info = ClientInfo {
282 supported_versions: vec![999], requested_capabilities: vec![],
284 client_metadata: HashMap::new(),
285 };
286
287 let result = negotiator.negotiate(client_info);
288 assert!(result.is_err());
289 }
290
291 #[test]
292 fn test_capability_filtering() {
293 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
294
295 let client_info = ClientInfo {
296 supported_versions: vec![0],
297 requested_capabilities: vec![
298 capabilities::SHELL_EXEC.to_string(),
299 "non_existent_capability".to_string(),
300 ],
301 client_metadata: HashMap::new(),
302 };
303
304 let result = negotiator.negotiate(client_info).unwrap();
305
306 assert!(result
308 .capabilities
309 .contains(&capabilities::SHELL_EXEC.to_string()));
310 assert!(!result
311 .capabilities
312 .contains(&"non_existent_capability".to_string()));
313 }
314
315 #[test]
316 fn test_capability_compatibility() {
317 let warnings = CapabilityChecker::check_compatibility(&[
318 capabilities::HOOKS_JS.to_string(),
319 capabilities::HOOKS_RUST.to_string(),
320 ])
321 .unwrap();
322
323 assert!(!warnings.is_empty());
325 assert!(warnings[0].contains("JS and Rust hooks"));
326 }
327
328 #[test]
329 fn test_use_case_recommendations() {
330 let dev_caps = CapabilityChecker::recommend_capabilities("development");
331 let prod_caps = CapabilityChecker::recommend_capabilities("production");
332
333 assert!(dev_caps.contains(&capabilities::REPLAY.to_string()));
335 assert!(dev_caps.contains(&capabilities::HOOKS_JS.to_string()));
336
337 assert!(!prod_caps.contains(&capabilities::REPLAY.to_string()));
339 assert!(prod_caps.contains(&capabilities::HOOKS_RUST.to_string()));
340 assert!(prod_caps.contains(&capabilities::NATS.to_string()));
341 }
342
343 #[test]
344 fn test_protocol_constants() {
345 assert_eq!(PROTOCOL_V0, 0);
347 assert_eq!(PROTOCOL_V1, 1);
348 assert_eq!(CURRENT_VERSION, PROTOCOL_V0);
349
350 assert!(SUPPORTED_VERSIONS.contains(&PROTOCOL_V0));
352 assert!(SUPPORTED_VERSIONS.contains(&PROTOCOL_V1));
353 assert!(SUPPORTED_VERSIONS.len() >= 2);
354 }
355
356 #[test]
357 fn test_negotiator_methods() {
358 let service_id = Uuid::new_v4();
359 let negotiator = ProtocolNegotiator::new(service_id);
360
361 assert!(negotiator.supports_capability(capabilities::SHELL_EXEC));
363 assert!(negotiator.supports_capability(capabilities::TRACING));
364 assert!(!negotiator.supports_capability("non_existent_capability"));
365
366 let caps = negotiator.get_available_capabilities();
368 assert!(!caps.is_empty());
369 assert!(caps.contains(&capabilities::SHELL_EXEC.to_string()));
370
371 let metadata = negotiator.get_service_metadata();
373 assert!(metadata.contains_key("service_name"));
374 assert!(metadata.contains_key("version"));
375 assert_eq!(metadata.get("service_name").unwrap(), "claude-code-rs-core");
376 }
377
378 #[test]
379 fn test_ready_event_creation() {
380 let service_id = Uuid::new_v4();
381 let negotiator = ProtocolNegotiator::new(service_id);
382
383 let negotiation_result = NegotiationResult {
384 version: PROTOCOL_V0,
385 capabilities: vec![capabilities::SHELL_EXEC.to_string()],
386 fallback_reason: None,
387 service_info: HashMap::new(),
388 };
389
390 let event = negotiator.create_ready_event(&negotiation_result);
391
392 match event {
393 Event::Ready {
394 version,
395 capabilities,
396 service_id: event_service_id,
397 } => {
398 assert_eq!(version, PROTOCOL_V0);
399 assert_eq!(capabilities, vec![capabilities::SHELL_EXEC.to_string()]);
400 assert_eq!(event_service_id, service_id);
401 }
402 _ => panic!("Expected Ready event"),
403 }
404 }
405
406 #[test]
407 fn test_fallback_conditions_v0() {
408 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
409 let fallback = negotiator.check_fallback_conditions(PROTOCOL_V0, &[]);
410 assert!(fallback.is_none()); }
412
413 #[test]
414 fn test_fallback_conditions_unsupported_version() {
415 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
416 let fallback = negotiator.check_fallback_conditions(999, &[]);
417 assert!(fallback.is_some());
418 assert!(fallback.unwrap().contains("Unsupported version 999"));
419 }
420
421 #[test]
422 fn test_capability_compatibility_nats_without_tracing() {
423 let warnings =
424 CapabilityChecker::check_compatibility(&[capabilities::NATS.to_string()]).unwrap();
425
426 assert!(!warnings.is_empty());
428 assert!(warnings
429 .iter()
430 .any(|w| w.contains("NATS enabled without tracing")));
431 }
432
433 #[test]
434 fn test_capability_compatibility_with_replay() {
435 let warnings =
436 CapabilityChecker::check_compatibility(&[capabilities::REPLAY.to_string()]).unwrap();
437
438 assert!(!warnings.is_empty());
440 assert!(warnings.iter().any(|w| w.contains("Replay enabled")));
441 }
442
443 #[test]
444 fn test_capability_compatibility_good_config() {
445 let warnings = CapabilityChecker::check_compatibility(&[
446 capabilities::SHELL_EXEC.to_string(),
447 capabilities::TRACING.to_string(),
448 ])
449 .unwrap();
450
451 assert!(warnings.is_empty());
453 }
454
455 #[test]
456 fn test_use_case_recommendations_all_variants() {
457 let test_cases = vec![
458 (
459 "development",
460 vec![
461 capabilities::SHELL_EXEC,
462 capabilities::HOOKS_JS,
463 capabilities::REPLAY,
464 capabilities::TRACING,
465 ],
466 ),
467 (
468 "production",
469 vec![
470 capabilities::SHELL_EXEC,
471 capabilities::HOOKS_RUST,
472 capabilities::TRACING,
473 capabilities::NATS,
474 ],
475 ),
476 (
477 "testing",
478 vec![
479 capabilities::SHELL_EXEC,
480 capabilities::REPLAY,
481 capabilities::TRACING,
482 ],
483 ),
484 ("minimal", vec![capabilities::SHELL_EXEC]),
485 (
486 "unknown_use_case",
487 vec![capabilities::SHELL_EXEC, capabilities::TRACING],
488 ),
489 ];
490
491 for (use_case, expected_caps) in test_cases {
492 let recommendations = CapabilityChecker::recommend_capabilities(use_case);
493 for cap in expected_caps {
494 assert!(
495 recommendations.contains(&cap.to_string()),
496 "Use case '{}' should include capability '{}'",
497 use_case,
498 cap
499 );
500 }
501 }
502 }
503
504 #[test]
505 fn test_client_info_and_negotiation_result_serialization() {
506 let mut metadata = HashMap::new();
508 metadata.insert("client_version".to_string(), "1.0.0".to_string());
509
510 let client_info = ClientInfo {
511 supported_versions: vec![0, 1],
512 requested_capabilities: vec![capabilities::SHELL_EXEC.to_string()],
513 client_metadata: metadata,
514 };
515
516 let json = serde_json::to_string(&client_info).unwrap();
517 let deserialized: ClientInfo = serde_json::from_str(&json).unwrap();
518 assert_eq!(deserialized.supported_versions, vec![0, 1]);
519 assert_eq!(
520 deserialized.requested_capabilities,
521 vec![capabilities::SHELL_EXEC.to_string()]
522 );
523
524 let mut service_info = HashMap::new();
526 service_info.insert("key".to_string(), "value".to_string());
527
528 let result = NegotiationResult {
529 version: PROTOCOL_V0,
530 capabilities: vec![capabilities::SHELL_EXEC.to_string()],
531 fallback_reason: Some("test fallback".to_string()),
532 service_info,
533 };
534
535 let json = serde_json::to_string(&result).unwrap();
536 let deserialized: NegotiationResult = serde_json::from_str(&json).unwrap();
537 assert_eq!(deserialized.version, PROTOCOL_V0);
538 assert_eq!(
539 deserialized.fallback_reason,
540 Some("test fallback".to_string())
541 );
542 }
543
544 #[test]
545 fn test_empty_client_versions_negotiation() {
546 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
547
548 let client_info = ClientInfo {
549 supported_versions: vec![], requested_capabilities: vec![],
551 client_metadata: HashMap::new(),
552 };
553
554 let result = negotiator.negotiate(client_info);
555 assert!(result.is_err());
556 assert!(result
557 .unwrap_err()
558 .to_string()
559 .contains("No compatible protocol version"));
560 }
561
562 #[test]
563 fn test_empty_requested_capabilities() {
564 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
565
566 let client_info = ClientInfo {
567 supported_versions: vec![PROTOCOL_V0],
568 requested_capabilities: vec![], client_metadata: HashMap::new(),
570 };
571
572 let result = negotiator.negotiate(client_info).unwrap();
573 assert_eq!(result.version, PROTOCOL_V0);
574 assert!(result.capabilities.is_empty()); }
576
577 #[test]
578 fn test_version_selection_priority() {
579 let negotiator = ProtocolNegotiator::new(Uuid::new_v4());
580
581 let client_info = ClientInfo {
583 supported_versions: vec![PROTOCOL_V0, PROTOCOL_V1],
584 requested_capabilities: vec![],
585 client_metadata: HashMap::new(),
586 };
587
588 let result = negotiator.negotiate(client_info).unwrap();
589 assert!(SUPPORTED_VERSIONS.contains(&result.version));
591 }
592}