Skip to main content

api_testing_core/websocket/
schema.rs

1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::Context;
5
6use crate::Result;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct WebsocketExpect {
10    pub jq: Option<String>,
11    pub text_contains: Option<String>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum WebsocketStep {
16    Send {
17        text: String,
18    },
19    Receive {
20        timeout_seconds: Option<u64>,
21        expect: Option<WebsocketExpect>,
22    },
23    Close,
24}
25
26#[derive(Debug, Clone, PartialEq)]
27pub struct WebsocketRequest {
28    pub url: Option<String>,
29    pub headers: Vec<(String, String)>,
30    pub connect_timeout_seconds: Option<u64>,
31    pub steps: Vec<WebsocketStep>,
32    pub expect: Option<WebsocketExpect>,
33    pub raw: serde_json::Value,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct WebsocketRequestFile {
38    pub path: PathBuf,
39    pub request: WebsocketRequest,
40}
41
42impl WebsocketRequestFile {
43    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
44        let path = path.as_ref();
45        let bytes = std::fs::read(path)
46            .with_context(|| format!("read websocket request file: {}", path.display()))?;
47        let raw: serde_json::Value = serde_json::from_slice(&bytes).map_err(|_| {
48            anyhow::anyhow!(
49                "websocket request file is not valid JSON: {}",
50                path.display()
51            )
52        })?;
53        let request = parse_websocket_request_json(raw)?;
54        Ok(Self {
55            path: std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()),
56            request,
57        })
58    }
59}
60
61fn scalar_to_string(value: &serde_json::Value) -> Result<String> {
62    match value {
63        serde_json::Value::String(s) => Ok(s.clone()),
64        serde_json::Value::Number(n) => Ok(n.to_string()),
65        serde_json::Value::Bool(b) => Ok(b.to_string()),
66        serde_json::Value::Null => Ok(String::new()),
67        serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
68            anyhow::bail!("headers values must be scalar")
69        }
70    }
71}
72
73fn parse_optional_u64(path_label: &str, raw: Option<&serde_json::Value>) -> Result<Option<u64>> {
74    match raw {
75        None | Some(serde_json::Value::Null) => Ok(None),
76        Some(serde_json::Value::Number(n)) => n
77            .as_u64()
78            .map(Some)
79            .ok_or_else(|| anyhow::anyhow!("{path_label} must be a positive integer")),
80        Some(serde_json::Value::String(s)) => {
81            let s = s.trim();
82            if s.is_empty() {
83                Ok(None)
84            } else {
85                Ok(Some(s.parse::<u64>().with_context(|| {
86                    format!("{path_label} is not a positive integer: {s}")
87                })?))
88            }
89        }
90        _ => anyhow::bail!("{path_label} must be a positive integer"),
91    }
92}
93
94fn parse_expect(
95    raw: Option<&serde_json::Value>,
96    path_label: &str,
97) -> Result<Option<WebsocketExpect>> {
98    let Some(raw) = raw else {
99        return Ok(None);
100    };
101    if raw.is_null() {
102        return Ok(None);
103    }
104
105    let obj = raw
106        .as_object()
107        .with_context(|| format!("{path_label} must be an object"))?;
108
109    let jq = obj
110        .get("jq")
111        .and_then(|v| v.as_str())
112        .map(str::trim)
113        .filter(|s| !s.is_empty())
114        .map(ToString::to_string);
115    let text_contains = obj
116        .get("textContains")
117        .or_else(|| obj.get("contains"))
118        .and_then(|v| v.as_str())
119        .map(str::trim)
120        .filter(|s| !s.is_empty())
121        .map(ToString::to_string);
122
123    if jq.is_none() && text_contains.is_none() {
124        return Ok(None);
125    }
126
127    Ok(Some(WebsocketExpect { jq, text_contains }))
128}
129
130fn parse_send_text(raw: &serde_json::Value) -> Result<String> {
131    match raw {
132        serde_json::Value::String(s) => Ok(s.clone()),
133        serde_json::Value::Object(_)
134        | serde_json::Value::Array(_)
135        | serde_json::Value::Number(_)
136        | serde_json::Value::Bool(_)
137        | serde_json::Value::Null => {
138            serde_json::to_string(raw).context("failed to serialize websocket send payload to text")
139        }
140    }
141}
142
143fn parse_steps(raw_steps: Option<&serde_json::Value>) -> Result<Vec<WebsocketStep>> {
144    let Some(raw_steps) = raw_steps else {
145        return Ok(Vec::new());
146    };
147    let arr = raw_steps
148        .as_array()
149        .context("websocket request .steps must be an array")?;
150
151    let mut out = Vec::new();
152    for (idx, raw_step) in arr.iter().enumerate() {
153        let obj = raw_step
154            .as_object()
155            .with_context(|| format!("websocket request .steps[{idx}] must be an object"))?;
156
157        let step_type = obj
158            .get("type")
159            .and_then(|v| v.as_str())
160            .unwrap_or_default()
161            .trim()
162            .to_ascii_lowercase();
163
164        match step_type.as_str() {
165            "send" => {
166                let send_raw = obj
167                    .get("text")
168                    .or_else(|| obj.get("json"))
169                    .or_else(|| obj.get("payload"))
170                    .with_context(|| {
171                        format!(
172                            "websocket request .steps[{idx}] send step requires text/json/payload"
173                        )
174                    })?;
175                out.push(WebsocketStep::Send {
176                    text: parse_send_text(send_raw)?,
177                });
178            }
179            "receive" => {
180                let timeout_seconds = parse_optional_u64(
181                    &format!("websocket request .steps[{idx}].timeoutSeconds"),
182                    obj.get("timeoutSeconds"),
183                )?;
184                let expect = parse_expect(
185                    obj.get("expect"),
186                    &format!("websocket request .steps[{idx}].expect"),
187                )?;
188                out.push(WebsocketStep::Receive {
189                    timeout_seconds,
190                    expect,
191                });
192            }
193            "close" => out.push(WebsocketStep::Close),
194            _ => {
195                anyhow::bail!(
196                    "websocket request .steps[{idx}] has unsupported type '{}'",
197                    step_type
198                );
199            }
200        }
201    }
202
203    Ok(out)
204}
205
206pub fn parse_websocket_request_json(raw: serde_json::Value) -> Result<WebsocketRequest> {
207    let obj = raw
208        .as_object()
209        .context("websocket request file must be a JSON object")?;
210
211    let url = obj
212        .get("url")
213        .and_then(|v| v.as_str())
214        .map(str::trim)
215        .filter(|s| !s.is_empty())
216        .map(ToString::to_string);
217
218    let mut headers: Vec<(String, String)> = Vec::new();
219    if let Some(v) = obj.get("headers")
220        && !v.is_null()
221    {
222        let m = v
223            .as_object()
224            .context("websocket request .headers must be an object")?;
225        let mut sorted: BTreeMap<String, String> = BTreeMap::new();
226        for (k, raw_v) in m {
227            let key = k.trim();
228            if key.is_empty() {
229                continue;
230            }
231            let value = scalar_to_string(raw_v)?;
232            if !value.trim().is_empty() {
233                sorted.insert(key.to_string(), value);
234            }
235        }
236        headers.extend(sorted);
237    }
238
239    let connect_timeout_seconds = parse_optional_u64(
240        "websocket request .connectTimeoutSeconds",
241        obj.get("connectTimeoutSeconds"),
242    )?;
243
244    let mut steps = parse_steps(obj.get("steps"))?;
245    if steps.is_empty()
246        && let Some(send_raw) = obj.get("send")
247    {
248        steps.push(WebsocketStep::Send {
249            text: parse_send_text(send_raw)?,
250        });
251        let timeout_seconds = parse_optional_u64(
252            "websocket request .receiveTimeoutSeconds",
253            obj.get("receiveTimeoutSeconds"),
254        )?;
255        let expect = parse_expect(obj.get("expect"), "websocket request .expect")?;
256        steps.push(WebsocketStep::Receive {
257            timeout_seconds,
258            expect,
259        });
260    }
261
262    if steps.is_empty() {
263        anyhow::bail!("websocket request requires at least one step (or top-level send)");
264    }
265
266    let expect = parse_expect(obj.get("expect"), "websocket request .expect")?;
267
268    Ok(WebsocketRequest {
269        url,
270        headers,
271        connect_timeout_seconds,
272        steps,
273        expect,
274        raw,
275    })
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use pretty_assertions::assert_eq;
282    use tempfile::TempDir;
283
284    #[test]
285    fn websocket_schema_parses_steps_request() {
286        let req = parse_websocket_request_json(serde_json::json!({
287            "url": "ws://127.0.0.1:9001/ws",
288            "steps": [
289                { "type": "send", "text": "{\"ping\":true}" },
290                { "type": "receive", "timeoutSeconds": 2, "expect": { "jq": ".ok == true" } },
291                { "type": "close" }
292            ]
293        }))
294        .unwrap();
295
296        assert_eq!(req.steps.len(), 3);
297        assert_eq!(req.url.as_deref(), Some("ws://127.0.0.1:9001/ws"));
298    }
299
300    #[test]
301    fn websocket_schema_legacy_send_builds_receive_step() {
302        let req = parse_websocket_request_json(serde_json::json!({
303            "send": {"ping": true},
304            "expect": {"jq": ".ok == true"}
305        }))
306        .unwrap();
307        assert_eq!(req.steps.len(), 2);
308    }
309
310    #[test]
311    fn websocket_schema_rejects_empty_steps() {
312        let err = parse_websocket_request_json(serde_json::json!({})).unwrap_err();
313        assert!(format!("{err:#}").contains("requires at least one step"));
314    }
315
316    #[test]
317    fn websocket_schema_load_reads_file() {
318        let tmp = TempDir::new().unwrap();
319        let path = tmp.path().join("health.ws.json");
320        std::fs::write(
321            &path,
322            serde_json::to_vec_pretty(&serde_json::json!({
323                "url": "ws://127.0.0.1:9001/ws",
324                "steps": [
325                    { "type": "send", "text": "ping" },
326                    { "type": "receive", "expect": {"textContains": "ok"} }
327                ]
328            }))
329            .unwrap(),
330        )
331        .unwrap();
332
333        let loaded = WebsocketRequestFile::load(&path).unwrap();
334        assert_eq!(loaded.request.steps.len(), 2);
335    }
336}