laburnum 1.17.0

An LSP framework for building language servers and compilers, powered by an incremental query tree with content-addressed storage, task-based dataflow, and parallel queries.
Documentation
// Copyright Two Neutron Stars Incorporated and contributors
// SPDX-License-Identifier: BlueOak-1.0.0

//! Version handshake protocol for IPC connections.
//!
//! On connect, both parties exchange version information. If the versions
//! don't match, the connection is rejected. This prevents issues from
//! same-binary deployments where client and server might be different versions.

use {
  crate::connect::lsp::{
    ClientId,
    ClientKind,
  },
  serde::{
    Deserialize,
    Serialize,
  },
  std::collections::HashMap,
};

/// Current protocol version. Increment when making breaking changes
/// to the handshake or framing format.
pub const PROTOCOL_VERSION: u32 = 3;

/// Handshake message exchanged on connection.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Handshake {
  /// Version string provided by the implementing crate (e.g., commit hash).
  pub version:          String,
  /// Protocol version for wire format compatibility.
  pub protocol_version: u32,
  /// Client kind (IDE or CLI). Optional for backwards compatibility with old
  /// clients.
  #[serde(default, skip_serializing_if = "Option::is_none")]
  pub client_kind:      Option<ClientKind>,
  /// Client metadata (e.g., editor name, version). Optional for backwards
  /// compatibility.
  #[serde(default, skip_serializing_if = "HashMap::is_empty")]
  pub metadata:         HashMap<String, String>,
  /// Client ID assigned by the server (only present in server response).
  /// Clients should include this ID in subsequent messages for proper routing.
  #[serde(default, skip_serializing_if = "Option::is_none")]
  pub client_id:        Option<ClientId>,
}

impl Handshake {
  /// Create a new handshake with the given version string.
  pub fn new(version: impl Into<String>) -> Self {
    Self {
      version:          version.into(),
      protocol_version: PROTOCOL_VERSION,
      client_kind:      None,
      metadata:         HashMap::new(),
      client_id:        None,
    }
  }

  /// Create a new handshake with specified client kind.
  pub fn with_client_kind(
    version: impl Into<String>,
    kind: ClientKind,
  ) -> Self {
    Self {
      version:          version.into(),
      protocol_version: PROTOCOL_VERSION,
      client_kind:      Some(kind),
      metadata:         HashMap::new(),
      client_id:        None,
    }
  }

  /// Create a new handshake with specified client kind and metadata.
  pub fn with_metadata(
    version: impl Into<String>,
    kind: ClientKind,
    metadata: HashMap<String, String>,
  ) -> Self {
    Self {
      version: version.into(),
      protocol_version: PROTOCOL_VERSION,
      client_kind: Some(kind),
      metadata,
      client_id: None,
    }
  }

  /// Create a new server response handshake with assigned client_id.
  pub fn server_response(
    version: impl Into<String>,
    client_id: ClientId,
  ) -> Self {
    Self {
      version:          version.into(),
      protocol_version: PROTOCOL_VERSION,
      client_kind:      None,
      metadata:         HashMap::new(),
      client_id:        Some(client_id),
    }
  }

  /// Get the metadata.
  pub fn metadata(&self) -> &HashMap<String, String> {
    &self.metadata
  }

  /// Take ownership of the metadata.
  pub fn take_metadata(&mut self) -> HashMap<String, String> {
    std::mem::take(&mut self.metadata)
  }

  /// Get the client kind, defaulting to CLI for old clients without the field.
  pub fn client_kind(&self) -> ClientKind {
    self.client_kind.unwrap_or(ClientKind::Cli)
  }

  /// Check if this handshake is compatible with another.
  ///
  /// Both version string and protocol version must match exactly.
  pub fn is_compatible(&self, other: &Handshake) -> bool {
    self.version == other.version
      && self.protocol_version == other.protocol_version
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  #[test]
  fn test_handshake_new() {
    let h = Handshake::new("abc123");
    assert_eq!(h.version, "abc123");
    assert_eq!(h.protocol_version, PROTOCOL_VERSION);
  }

  #[test]
  fn test_handshake_serialization_roundtrip() {
    let original = Handshake::new("v1.2.3-deadbeef");
    let json = serde_json::to_string(&original).unwrap();
    let parsed: Handshake = serde_json::from_str(&json).unwrap();
    assert_eq!(original, parsed);
  }

  #[test]
  fn test_handshake_is_compatible_same() {
    let h1 = Handshake::new("abc123");
    let h2 = Handshake::new("abc123");
    assert!(h1.is_compatible(&h2));
  }

  #[test]
  fn test_handshake_is_compatible_version_mismatch() {
    let h1 = Handshake::new("abc123");
    let h2 = Handshake::new("def456");
    assert!(!h1.is_compatible(&h2));
  }

  #[test]
  fn test_handshake_is_compatible_protocol_mismatch() {
    let h1 = Handshake::new("abc123");
    let mut h2 = Handshake::new("abc123");
    h2.protocol_version = 99;
    assert!(!h1.is_compatible(&h2));
  }

  #[test]
  fn test_handshake_extra_fields_ignored() {
    let json =
      r#"{"version":"v1.0","protocol_version":1,"unknown_field":true}"#;
    let h: Handshake = serde_json::from_str(json).unwrap();
    assert_eq!(h.version, "v1.0");
    assert_eq!(h.protocol_version, 1);
  }

  #[test]
  fn test_handshake_missing_fields_error() {
    let json = r#"{"version":"v1.0"}"#;
    let result: Result<Handshake, _> = serde_json::from_str(json);
    let err = result.expect_err("missing field should fail");
    assert!(
      err.to_string().contains("protocol_version"),
      "error should mention missing field: {}",
      err
    );
  }

  #[test]
  fn test_handshake_empty_version() {
    let h = Handshake::new("");
    assert_eq!(h.version, "");
    assert!(h.is_compatible(&Handshake::new("")));
  }

  #[test]
  fn test_handshake_unicode_version() {
    let h = Handshake::new("版本-1.0-🎉");
    let json = serde_json::to_string(&h).unwrap();
    let parsed: Handshake = serde_json::from_str(&json).unwrap();
    assert_eq!(h, parsed);
  }
}