use {
crate::connect::lsp::{
ClientId,
ClientKind,
},
serde::{
Deserialize,
Serialize,
},
std::collections::HashMap,
};
pub const PROTOCOL_VERSION: u32 = 3;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Handshake {
pub version: String,
pub protocol_version: u32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_kind: Option<ClientKind>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub client_id: Option<ClientId>,
}
impl Handshake {
pub fn new(version: impl Into<String>) -> Self {
Self {
version: version.into(),
protocol_version: PROTOCOL_VERSION,
client_kind: None,
metadata: HashMap::new(),
client_id: None,
}
}
pub fn with_client_kind(
version: impl Into<String>,
kind: ClientKind,
) -> Self {
Self {
version: version.into(),
protocol_version: PROTOCOL_VERSION,
client_kind: Some(kind),
metadata: HashMap::new(),
client_id: None,
}
}
pub fn with_metadata(
version: impl Into<String>,
kind: ClientKind,
metadata: HashMap<String, String>,
) -> Self {
Self {
version: version.into(),
protocol_version: PROTOCOL_VERSION,
client_kind: Some(kind),
metadata,
client_id: None,
}
}
pub fn server_response(
version: impl Into<String>,
client_id: ClientId,
) -> Self {
Self {
version: version.into(),
protocol_version: PROTOCOL_VERSION,
client_kind: None,
metadata: HashMap::new(),
client_id: Some(client_id),
}
}
pub fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
pub fn take_metadata(&mut self) -> HashMap<String, String> {
std::mem::take(&mut self.metadata)
}
pub fn client_kind(&self) -> ClientKind {
self.client_kind.unwrap_or(ClientKind::Cli)
}
pub fn is_compatible(&self, other: &Handshake) -> bool {
self.version == other.version
&& self.protocol_version == other.protocol_version
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_new() {
let h = Handshake::new("abc123");
assert_eq!(h.version, "abc123");
assert_eq!(h.protocol_version, PROTOCOL_VERSION);
}
#[test]
fn test_handshake_serialization_roundtrip() {
let original = Handshake::new("v1.2.3-deadbeef");
let json = serde_json::to_string(&original).unwrap();
let parsed: Handshake = serde_json::from_str(&json).unwrap();
assert_eq!(original, parsed);
}
#[test]
fn test_handshake_is_compatible_same() {
let h1 = Handshake::new("abc123");
let h2 = Handshake::new("abc123");
assert!(h1.is_compatible(&h2));
}
#[test]
fn test_handshake_is_compatible_version_mismatch() {
let h1 = Handshake::new("abc123");
let h2 = Handshake::new("def456");
assert!(!h1.is_compatible(&h2));
}
#[test]
fn test_handshake_is_compatible_protocol_mismatch() {
let h1 = Handshake::new("abc123");
let mut h2 = Handshake::new("abc123");
h2.protocol_version = 99;
assert!(!h1.is_compatible(&h2));
}
#[test]
fn test_handshake_extra_fields_ignored() {
let json =
r#"{"version":"v1.0","protocol_version":1,"unknown_field":true}"#;
let h: Handshake = serde_json::from_str(json).unwrap();
assert_eq!(h.version, "v1.0");
assert_eq!(h.protocol_version, 1);
}
#[test]
fn test_handshake_missing_fields_error() {
let json = r#"{"version":"v1.0"}"#;
let result: Result<Handshake, _> = serde_json::from_str(json);
let err = result.expect_err("missing field should fail");
assert!(
err.to_string().contains("protocol_version"),
"error should mention missing field: {}",
err
);
}
#[test]
fn test_handshake_empty_version() {
let h = Handshake::new("");
assert_eq!(h.version, "");
assert!(h.is_compatible(&Handshake::new("")));
}
#[test]
fn test_handshake_unicode_version() {
let h = Handshake::new("版本-1.0-🎉");
let json = serde_json::to_string(&h).unwrap();
let parsed: Handshake = serde_json::from_str(&json).unwrap();
assert_eq!(h, parsed);
}
}