1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use juniper::Variables;
use serde::Deserialize;

use crate::util::default_for_null;

/// The payload for a client's "start" message. This triggers execution of a query, mutation, or
/// subscription.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: Deserialize<'de>"))]
#[serde(rename_all = "camelCase")]
pub struct SubscribePayload<S> {
    /// The document body.
    pub query: String,

    /// The optional variables.
    #[serde(default, deserialize_with = "default_for_null")]
    pub variables: Variables<S>,

    /// The optional operation name (required if the document contains multiple operations).
    pub operation_name: Option<String>,

    /// The optional extension data.
    #[serde(default, deserialize_with = "default_for_null")]
    pub extensions: Variables<S>,
}

/// ClientMessage defines the message types that clients can send.
#[derive(Debug, Deserialize, PartialEq)]
#[serde(bound(deserialize = "S: Deserialize<'de>"))]
#[serde(rename_all = "snake_case")]
#[serde(tag = "type")]
pub enum ClientMessage<S> {
    /// ConnectionInit is sent by the client upon connecting.
    ConnectionInit {
        /// Optional parameters of any type sent from the client. These are often used for
        /// authentication.
        #[serde(default, deserialize_with = "default_for_null")]
        payload: Variables<S>,
    },
    /// Ping is used for detecting failed connections, displaying latency metrics or other types of network probing.
    Ping {
        /// Optional parameters of any type used to transfer additional details about the ping.
        #[serde(default, deserialize_with = "default_for_null")]
        payload: Variables<S>,
    },
    /// The response to the `Ping` message.
    Pong {
        /// Optional parameters of any type used to transfer additional details about the pong.
        #[serde(default, deserialize_with = "default_for_null")]
        payload: Variables<S>,
    },
    /// Requests an operation specified in the message payload.
    Subscribe {
        /// The id of the operation. This can be anything, but must be unique. If there are other
        /// in-flight operations with the same id, the message will cause an error.
        id: String,

        /// The query, variables, and operation name.
        payload: SubscribePayload<S>,
    },
    /// Indicates that the client has stopped listening and wants to complete the subscription.
    Complete {
        /// The id of the operation to stop.
        id: String,
    },
}

#[cfg(test)]
mod test {
    use juniper::{graphql_vars, DefaultScalarValue};

    use super::*;

    #[test]
    fn test_deserialization() {
        type ClientMessage = super::ClientMessage<DefaultScalarValue>;

        assert_eq!(
            ClientMessage::ConnectionInit {
                payload: graphql_vars! {"foo": "bar"},
            },
            serde_json::from_str(r#"{"type": "connection_init", "payload": {"foo": "bar"}}"#)
                .unwrap(),
        );

        assert_eq!(
            ClientMessage::ConnectionInit {
                payload: graphql_vars! {},
            },
            serde_json::from_str(r#"{"type": "connection_init"}"#).unwrap(),
        );

        assert_eq!(
            ClientMessage::Subscribe {
                id: "foo".into(),
                payload: SubscribePayload {
                    query: "query MyQuery { __typename }".into(),
                    variables: graphql_vars! {"foo": "bar"},
                    operation_name: Some("MyQuery".into()),
                    extensions: Default::default(),
                },
            },
            serde_json::from_str(
                r#"{"type": "subscribe", "id": "foo", "payload": {
                "query": "query MyQuery { __typename }",
                "variables": {
                    "foo": "bar"
                },
                "operationName": "MyQuery"
            }}"#
            )
            .unwrap(),
        );

        assert_eq!(
            ClientMessage::Subscribe {
                id: "foo".into(),
                payload: SubscribePayload {
                    query: "query MyQuery { __typename }".into(),
                    variables: graphql_vars! {},
                    operation_name: None,
                    extensions: Default::default(),
                },
            },
            serde_json::from_str(
                r#"{"type": "subscribe", "id": "foo", "payload": {
                "query": "query MyQuery { __typename }"
            }}"#
            )
            .unwrap(),
        );

        assert_eq!(
            ClientMessage::Complete { id: "foo".into() },
            serde_json::from_str(r#"{"type": "complete", "id": "foo"}"#).unwrap(),
        );
    }

    #[test]
    fn test_deserialization_of_null() -> serde_json::Result<()> {
        let payload = r#"{"query":"query","variables":null}"#;
        let payload: SubscribePayload<DefaultScalarValue> = serde_json::from_str(payload)?;

        let expected = SubscribePayload {
            query: "query".into(),
            variables: graphql_vars! {},
            operation_name: None,
            extensions: Default::default(),
        };

        assert_eq!(expected, payload);

        Ok(())
    }
}