use {
super::{
super::otel::TraceContext,
Version,
},
crate::{
connect::lsp::ClientId,
protocol::lsp::LSPAny,
},
serde::{
Deserialize,
Serialize,
},
std::{
borrow::Cow,
fmt::{
self,
Display,
Formatter,
},
str::FromStr,
},
};
#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
pub struct Notification {
jsonrpc: Version,
method: Cow<'static, str>,
#[serde(default, skip_serializing_if = "Option::is_none")]
params: Option<LSPAny>,
#[serde(default, skip_serializing_if = "Option::is_none")]
traceparent: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
tracestate: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
client_id: Option<ClientId>,
}
impl Notification {
pub fn build<M>(method: M) -> NotificationBuilder
where
M: Into<Cow<'static, str>>,
{
NotificationBuilder {
method: method.into(),
params: None,
traceparent: None,
tracestate: None,
client_id: None,
}
}
#[must_use]
pub fn method(&self) -> &str {
self.method.as_ref()
}
#[must_use]
pub const fn params(&self) -> Option<&LSPAny> {
self.params.as_ref()
}
#[must_use]
pub fn traceparent(&self) -> Option<&str> {
self.traceparent.as_deref()
}
#[must_use]
pub fn tracestate(&self) -> Option<&str> {
self.tracestate.as_deref()
}
#[must_use]
pub fn trace_context(&self) -> TraceContext {
TraceContext::new(self.traceparent.clone(), self.tracestate.clone())
}
#[must_use]
pub const fn client_id(&self) -> Option<ClientId> {
self.client_id
}
#[must_use]
pub fn into_parts(self) -> (Cow<'static, str>, Option<LSPAny>) {
(self.method, self.params)
}
}
impl Display for Notification {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
use std::{
io,
str,
};
struct WriterFormatter<'a, 'b: 'a> {
inner: &'a mut Formatter<'b>,
}
impl io::Write for WriterFormatter<'_, '_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn io_error<E>(_: E) -> io::Error {
io::Error::other("fmt error")
}
let s = str::from_utf8(buf).map_err(io_error)?;
self.inner.write_str(s).map_err(io_error)?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
let mut w = WriterFormatter { inner: f };
serde_json::to_writer(&mut w, self).map_err(|_| fmt::Error)
}
}
impl FromStr for Notification {
type Err = serde_json::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_str(s)
}
}
#[derive(Debug)]
pub struct NotificationBuilder {
method: Cow<'static, str>,
params: Option<LSPAny>,
traceparent: Option<String>,
tracestate: Option<String>,
client_id: Option<ClientId>,
}
impl NotificationBuilder {
#[must_use]
pub fn params<V: Into<LSPAny>>(mut self, params: V) -> Self {
self.params = Some(params.into());
self
}
#[must_use]
pub fn traceparent<S: Into<String>>(mut self, traceparent: S) -> Self {
self.traceparent = Some(traceparent.into());
self
}
#[must_use]
pub fn tracestate<S: Into<String>>(mut self, tracestate: S) -> Self {
self.tracestate = Some(tracestate.into());
self
}
#[must_use]
pub fn with_trace_context(mut self, ctx: TraceContext) -> Self {
self.traceparent = ctx.traceparent;
self.tracestate = ctx.tracestate;
self
}
#[must_use]
pub fn client_id(mut self, id: ClientId) -> Self {
self.client_id = Some(id);
self
}
#[must_use]
pub fn finish(self) -> Notification {
Notification {
jsonrpc: Version,
method: self.method,
params: self.params,
traceparent: self.traceparent,
tracestate: self.tracestate,
client_id: self.client_id,
}
}
}
#[cfg(test)]
mod tests {
use {
super::*,
serde_json::json,
};
#[test]
fn notification_serializes_with_trace_context() {
let notification = Notification::build("test/notification")
.params(json!({"key": "value"}))
.traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
.tracestate("rojo=00f067aa0ba902b7")
.finish();
let json = serde_json::to_value(¬ification).unwrap();
assert_eq!(json["jsonrpc"], "2.0");
assert_eq!(json["method"], "test/notification");
assert_eq!(json["params"]["key"], "value");
assert_eq!(
json["traceparent"],
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
);
assert_eq!(json["tracestate"], "rojo=00f067aa0ba902b7");
assert!(json.get("id").is_none());
assert!(json["params"].get("_meta").is_none());
}
#[test]
fn notification_deserializes_with_trace_context() {
let json = json!({
"jsonrpc": "2.0",
"method": "test/notification",
"params": {"key": "value"},
"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
"tracestate": "rojo=00f067aa0ba902b7"
});
let notification: Notification = serde_json::from_value(json).unwrap();
assert_eq!(notification.method(), "test/notification");
assert_eq!(
notification.traceparent(),
Some("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
);
assert_eq!(notification.tracestate(), Some("rojo=00f067aa0ba902b7"));
}
#[test]
fn notification_without_trace_context() {
let json = json!({
"jsonrpc": "2.0",
"method": "test/notification"
});
let notification: Notification = serde_json::from_value(json).unwrap();
assert_eq!(notification.traceparent(), None);
assert_eq!(notification.tracestate(), None);
}
#[test]
fn notification_null_params_with_trace_context() {
let notification = Notification::build("test/notification")
.traceparent("00-abc123-def456-01")
.finish();
let json = serde_json::to_value(¬ification).unwrap();
assert!(json.get("params").is_none());
assert_eq!(json["traceparent"], "00-abc123-def456-01");
}
#[test]
fn notification_with_trace_context_struct() {
let ctx = TraceContext::new(
Some("00-traceid-spanid-01".to_string()),
Some("vendor=value".to_string()),
);
let notification = Notification::build("test/notification")
.with_trace_context(ctx)
.finish();
assert_eq!(notification.traceparent(), Some("00-traceid-spanid-01"));
assert_eq!(notification.tracestate(), Some("vendor=value"));
}
#[test]
fn notification_serialization_round_trip() {
let original = Notification::build("test/notification")
.params(json!({"foo": "bar"}))
.traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
.tracestate("rojo=00f067aa0ba902b7,congo=t61rcWkgMzE")
.finish();
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: Notification = serde_json::from_str(&serialized).unwrap();
assert_eq!(original, deserialized);
}
}