laburnum 1.17.0

An LSP framework for building language servers and compilers, powered by an incremental query tree with content-addressed storage, task-based dataflow, and parallel queries.
Documentation
// Copyright Two Neutron Stars Incorporated and contributors
// SPDX-License-Identifier: BlueOak-1.0.0

use {
  super::{
    super::otel::TraceContext,
    Id,
    Version,
  },
  crate::{
    connect::lsp::ClientId,
    protocol::lsp::LSPAny,
  },
  serde::{
    Deserialize,
    Deserializer,
    Serialize,
  },
  std::{
    borrow::Cow,
    fmt::{
      self,
      Display,
      Formatter,
    },
    str::FromStr,
  },
};

fn deserialize_some<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
where
  T: Deserialize<'de>,
  D: Deserializer<'de>,
{
  T::deserialize(deserializer).map(Some)
}

/// A JSON-RPC request.
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct Request {
  jsonrpc:     Version,
  method:      Cow<'static, str>,
  #[serde(default, deserialize_with = "deserialize_some")]
  #[serde(skip_serializing_if = "Option::is_none")]
  params:      Option<LSPAny>,
  id:          Id,
  #[serde(default, skip_serializing_if = "Option::is_none")]
  traceparent: Option<String>,
  #[serde(default, skip_serializing_if = "Option::is_none")]
  tracestate:  Option<String>,
  #[serde(default, skip_serializing_if = "Option::is_none")]
  client_id:   Option<ClientId>,
}

impl Request {
  /// Starts building a JSON-RPC request.
  ///
  /// Returns a `RequestBuilder`, which allows setting the `params` field.
  pub fn build<M, I>(method: M, id: I) -> RequestBuilder
  where
    M: Into<Cow<'static, str>>,
    I: Into<Id>,
  {
    RequestBuilder {
      method:      method.into(),
      params:      None,
      id:          id.into(),
      traceparent: None,
      tracestate:  None,
      client_id:   None,
    }
  }

  /// Returns the name of the method to be invoked.
  #[must_use]
  pub fn method(&self) -> &str {
    self.method.as_ref()
  }

  /// Returns the unique ID of this request.
  #[must_use]
  pub const fn id(&self) -> &Id {
    &self.id
  }

  /// Returns the `params` field, if present.
  #[must_use]
  pub const fn params(&self) -> Option<&LSPAny> {
    self.params.as_ref()
  }

  /// Returns the `traceparent` field, if present.
  #[must_use]
  pub fn traceparent(&self) -> Option<&str> {
    self.traceparent.as_deref()
  }

  /// Returns the `tracestate` field, if present.
  #[must_use]
  pub fn tracestate(&self) -> Option<&str> {
    self.tracestate.as_deref()
  }

  /// Returns the trace context as a `TraceContext` struct.
  #[must_use]
  pub fn trace_context(&self) -> TraceContext {
    TraceContext::new(self.traceparent.clone(), self.tracestate.clone())
  }

  /// Returns the `client_id` field, if present.
  #[must_use]
  pub const fn client_id(&self) -> Option<ClientId> {
    self.client_id
  }

  /// Splits this request into the method name, request ID, and the `params`
  /// field, if present.
  #[must_use]
  pub fn into_parts(self) -> (Cow<'static, str>, Id, Option<LSPAny>) {
    (self.method, self.id, self.params)
  }
}

impl Display for Request {
  fn fmt(&self, f: &mut Formatter) -> fmt::Result {
    use std::{
      io,
      str,
    };

    struct WriterFormatter<'a, 'b: 'a> {
      inner: &'a mut Formatter<'b>,
    }

    impl io::Write for WriterFormatter<'_, '_> {
      fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        fn io_error<E>(_: E) -> io::Error {
          // Error value does not matter because fmt::Display impl below just
          // maps it to fmt::Error
          io::Error::other("fmt error")
        }
        let s = str::from_utf8(buf).map_err(io_error)?;
        self.inner.write_str(s).map_err(io_error)?;
        Ok(buf.len())
      }

      fn flush(&mut self) -> io::Result<()> {
        Ok(())
      }
    }

    let mut w = WriterFormatter { inner: f };
    serde_json::to_writer(&mut w, self).map_err(|_| fmt::Error)
  }
}

impl FromStr for Request {
  type Err = serde_json::Error;

  fn from_str(s: &str) -> Result<Self, Self::Err> {
    serde_json::from_str(s)
  }
}

/// A builder to construct the properties of a `Request`.
///
/// To construct a `RequestBuilder`, refer to [`Request::build`].
#[derive(Debug)]
pub struct RequestBuilder {
  method:      Cow<'static, str>,
  params:      Option<LSPAny>,
  id:          Id,
  traceparent: Option<String>,
  tracestate:  Option<String>,
  client_id:   Option<ClientId>,
}

impl RequestBuilder {
  /// Sets the `params` member of the request to the given value.
  ///
  /// This member is omitted from the request by default.
  #[must_use]
  pub fn params<V: Into<LSPAny>>(mut self, params: V) -> Self {
    self.params = Some(params.into());
    self
  }

  /// Sets the `traceparent` field for W3C Trace Context propagation.
  #[must_use]
  pub fn traceparent<S: Into<String>>(mut self, traceparent: S) -> Self {
    self.traceparent = Some(traceparent.into());
    self
  }

  /// Sets the `tracestate` field for W3C Trace Context propagation.
  #[must_use]
  pub fn tracestate<S: Into<String>>(mut self, tracestate: S) -> Self {
    self.tracestate = Some(tracestate.into());
    self
  }

  /// Sets trace context from a `TraceContext` struct.
  #[must_use]
  pub fn with_trace_context(mut self, ctx: TraceContext) -> Self {
    self.traceparent = ctx.traceparent;
    self.tracestate = ctx.tracestate;
    self
  }

  /// Sets the `client_id` field for client identification.
  #[must_use]
  pub fn client_id(mut self, id: ClientId) -> Self {
    self.client_id = Some(id);
    self
  }

  /// Constructs the JSON-RPC request and returns it.
  #[must_use]
  pub fn finish(self) -> Request {
    Request {
      jsonrpc:     Version,
      method:      self.method,
      params:      self.params,
      id:          self.id,
      traceparent: self.traceparent,
      tracestate:  self.tracestate,
      client_id:   self.client_id,
    }
  }
}

#[cfg(test)]
mod tests {
  use {
    super::*,
    serde_json::json,
  };

  #[test]
  fn request_serializes_with_trace_context() {
    let request = Request::build("test/method", 1)
      .params(json!({"key": "value"}))
      .traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
      .tracestate("rojo=00f067aa0ba902b7")
      .finish();

    let json = serde_json::to_value(&request).unwrap();
    assert_eq!(json["jsonrpc"], "2.0");
    assert_eq!(json["method"], "test/method");
    assert_eq!(json["id"], 1);
    assert_eq!(json["params"]["key"], "value");
    assert_eq!(
      json["traceparent"],
      "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
    );
    assert_eq!(json["tracestate"], "rojo=00f067aa0ba902b7");
    // Verify no _meta pollution
    assert!(json["params"].get("_meta").is_none());
  }

  #[test]
  fn request_deserializes_with_trace_context() {
    let json = json!({
      "jsonrpc": "2.0",
      "method": "test/method",
      "params": {"key": "value"},
      "id": 1,
      "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
      "tracestate": "rojo=00f067aa0ba902b7"
    });

    let request: Request = serde_json::from_value(json).unwrap();
    assert_eq!(request.method(), "test/method");
    assert_eq!(
      request.traceparent(),
      Some("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
    );
    assert_eq!(request.tracestate(), Some("rojo=00f067aa0ba902b7"));
  }

  #[test]
  fn request_without_trace_context() {
    let json = json!({
      "jsonrpc": "2.0",
      "method": "test/method",
      "id": 1
    });

    let request: Request = serde_json::from_value(json).unwrap();
    assert_eq!(request.traceparent(), None);
    assert_eq!(request.tracestate(), None);
  }

  #[test]
  fn request_null_params_with_trace_context() {
    let request = Request::build("test/method", 1)
      .traceparent("00-abc123-def456-01")
      .finish();

    let json = serde_json::to_value(&request).unwrap();
    // params should be absent (not null)
    assert!(json.get("params").is_none());
    assert_eq!(json["traceparent"], "00-abc123-def456-01");
  }

  #[test]
  fn request_with_trace_context_struct() {
    let ctx = TraceContext::new(
      Some("00-traceid-spanid-01".to_string()),
      Some("vendor=value".to_string()),
    );

    let request = Request::build("test/method", 1)
      .with_trace_context(ctx)
      .finish();

    assert_eq!(request.traceparent(), Some("00-traceid-spanid-01"));
    assert_eq!(request.tracestate(), Some("vendor=value"));
  }

  #[test]
  fn request_trace_context_accessor() {
    let request = Request::build("test/method", 1)
      .traceparent("00-abc-def-01")
      .tracestate("key=value")
      .finish();

    let ctx = request.trace_context();
    assert_eq!(ctx.traceparent, Some("00-abc-def-01".to_string()));
    assert_eq!(ctx.tracestate, Some("key=value".to_string()));
    assert!(!ctx.is_empty());
  }

  #[test]
  fn request_serialization_round_trip() {
    let original = Request::build("test/method", 42)
      .params(json!({"foo": "bar"}))
      .traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
      .tracestate("rojo=00f067aa0ba902b7,congo=t61rcWkgMzE")
      .finish();

    let serialized = serde_json::to_string(&original).unwrap();
    let deserialized: Request = serde_json::from_str(&serialized).unwrap();

    assert_eq!(original, deserialized);
  }
}