Skip to main content

api_testing_core/websocket/
schema.rs

1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use anyhow::Context;
5
6use crate::Result;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct WebsocketExpect {
10    pub jq: Option<String>,
11    pub text_contains: Option<String>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum WebsocketStep {
16    Send {
17        text: String,
18    },
19    Receive {
20        timeout_seconds: Option<u64>,
21        expect: Option<WebsocketExpect>,
22    },
23    Close,
24}
25
26#[derive(Debug, Clone, PartialEq)]
27pub struct WebsocketRequest {
28    pub url: Option<String>,
29    pub headers: Vec<(String, String)>,
30    pub connect_timeout_seconds: Option<u64>,
31    pub steps: Vec<WebsocketStep>,
32    pub expect: Option<WebsocketExpect>,
33    pub raw: serde_json::Value,
34}
35
36#[derive(Debug, Clone, PartialEq)]
37pub struct WebsocketRequestFile {
38    pub path: PathBuf,
39    pub request: WebsocketRequest,
40}
41
42impl WebsocketRequestFile {
43    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
44        let path = path.as_ref();
45        let bytes = std::fs::read(path)
46            .with_context(|| format!("read websocket request file: {}", path.display()))?;
47        let raw: serde_json::Value = serde_json::from_slice(&bytes).map_err(|_| {
48            anyhow::anyhow!(
49                "websocket request file is not valid JSON: {}",
50                path.display()
51            )
52        })?;
53        let request = parse_websocket_request_json(raw)?;
54        Ok(Self {
55            path: std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()),
56            request,
57        })
58    }
59}
60
61fn scalar_to_string(value: &serde_json::Value) -> Result<String> {
62    match value {
63        serde_json::Value::String(s) => Ok(s.clone()),
64        serde_json::Value::Number(n) => Ok(n.to_string()),
65        serde_json::Value::Bool(b) => Ok(b.to_string()),
66        serde_json::Value::Null => Ok(String::new()),
67        serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
68            anyhow::bail!("headers values must be scalar")
69        }
70    }
71}
72
73fn parse_optional_u64(path_label: &str, raw: Option<&serde_json::Value>) -> Result<Option<u64>> {
74    match raw {
75        None | Some(serde_json::Value::Null) => Ok(None),
76        Some(serde_json::Value::Number(n)) => n
77            .as_u64()
78            .map(Some)
79            .ok_or_else(|| anyhow::anyhow!("{path_label} must be a positive integer")),
80        Some(serde_json::Value::String(s)) => {
81            let s = s.trim();
82            if s.is_empty() {
83                Ok(None)
84            } else {
85                Ok(Some(s.parse::<u64>().with_context(|| {
86                    format!("{path_label} is not a positive integer: {s}")
87                })?))
88            }
89        }
90        _ => anyhow::bail!("{path_label} must be a positive integer"),
91    }
92}
93
94fn parse_expect(
95    raw: Option<&serde_json::Value>,
96    path_label: &str,
97) -> Result<Option<WebsocketExpect>> {
98    let Some(raw) = raw else {
99        return Ok(None);
100    };
101    if raw.is_null() {
102        return Ok(None);
103    }
104
105    let obj = raw
106        .as_object()
107        .with_context(|| format!("{path_label} must be an object"))?;
108
109    let jq = obj
110        .get("jq")
111        .and_then(|v| v.as_str())
112        .map(str::trim)
113        .filter(|s| !s.is_empty())
114        .map(ToString::to_string);
115    let text_contains = obj
116        .get("textContains")
117        .or_else(|| obj.get("contains"))
118        .and_then(|v| v.as_str())
119        .map(str::trim)
120        .filter(|s| !s.is_empty())
121        .map(ToString::to_string);
122
123    if jq.is_none() && text_contains.is_none() {
124        return Ok(None);
125    }
126
127    Ok(Some(WebsocketExpect { jq, text_contains }))
128}
129
130fn parse_send_text(raw: &serde_json::Value) -> Result<String> {
131    match raw {
132        serde_json::Value::String(s) => Ok(s.clone()),
133        serde_json::Value::Object(_)
134        | serde_json::Value::Array(_)
135        | serde_json::Value::Number(_)
136        | serde_json::Value::Bool(_)
137        | serde_json::Value::Null => {
138            serde_json::to_string(raw).context("failed to serialize websocket send payload to text")
139        }
140    }
141}
142
143fn parse_steps(raw_steps: Option<&serde_json::Value>) -> Result<Vec<WebsocketStep>> {
144    let raw_steps = raw_steps.context("websocket request .steps is required")?;
145    let arr = raw_steps
146        .as_array()
147        .context("websocket request .steps must be an array")?;
148    if arr.is_empty() {
149        anyhow::bail!("websocket request .steps must include at least one step");
150    }
151
152    let mut out = Vec::new();
153    for (idx, raw_step) in arr.iter().enumerate() {
154        let obj = raw_step
155            .as_object()
156            .with_context(|| format!("websocket request .steps[{idx}] must be an object"))?;
157
158        let step_type = obj
159            .get("type")
160            .and_then(|v| v.as_str())
161            .unwrap_or_default()
162            .trim()
163            .to_ascii_lowercase();
164
165        match step_type.as_str() {
166            "send" => {
167                let send_raw = obj
168                    .get("text")
169                    .or_else(|| obj.get("json"))
170                    .or_else(|| obj.get("payload"))
171                    .with_context(|| {
172                        format!(
173                            "websocket request .steps[{idx}] send step requires text/json/payload"
174                        )
175                    })?;
176                out.push(WebsocketStep::Send {
177                    text: parse_send_text(send_raw)?,
178                });
179            }
180            "receive" => {
181                let timeout_seconds = parse_optional_u64(
182                    &format!("websocket request .steps[{idx}].timeoutSeconds"),
183                    obj.get("timeoutSeconds"),
184                )?;
185                let expect = parse_expect(
186                    obj.get("expect"),
187                    &format!("websocket request .steps[{idx}].expect"),
188                )?;
189                out.push(WebsocketStep::Receive {
190                    timeout_seconds,
191                    expect,
192                });
193            }
194            "close" => out.push(WebsocketStep::Close),
195            _ => {
196                anyhow::bail!(
197                    "websocket request .steps[{idx}] has unsupported type '{}'",
198                    step_type
199                );
200            }
201        }
202    }
203
204    Ok(out)
205}
206
207pub fn parse_websocket_request_json(raw: serde_json::Value) -> Result<WebsocketRequest> {
208    let obj = raw
209        .as_object()
210        .context("websocket request file must be a JSON object")?;
211
212    let url = obj
213        .get("url")
214        .and_then(|v| v.as_str())
215        .map(str::trim)
216        .filter(|s| !s.is_empty())
217        .map(ToString::to_string);
218
219    let mut headers: Vec<(String, String)> = Vec::new();
220    if let Some(v) = obj.get("headers")
221        && !v.is_null()
222    {
223        let m = v
224            .as_object()
225            .context("websocket request .headers must be an object")?;
226        let mut sorted: BTreeMap<String, String> = BTreeMap::new();
227        for (k, raw_v) in m {
228            let key = k.trim();
229            if key.is_empty() {
230                continue;
231            }
232            let value = scalar_to_string(raw_v)?;
233            if !value.trim().is_empty() {
234                sorted.insert(key.to_string(), value);
235            }
236        }
237        headers.extend(sorted);
238    }
239
240    let connect_timeout_seconds = parse_optional_u64(
241        "websocket request .connectTimeoutSeconds",
242        obj.get("connectTimeoutSeconds"),
243    )?;
244
245    let steps = parse_steps(obj.get("steps"))?;
246
247    let expect = parse_expect(obj.get("expect"), "websocket request .expect")?;
248
249    Ok(WebsocketRequest {
250        url,
251        headers,
252        connect_timeout_seconds,
253        steps,
254        expect,
255        raw,
256    })
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use pretty_assertions::assert_eq;
263    use tempfile::TempDir;
264
265    #[test]
266    fn websocket_schema_parses_steps_request() {
267        let req = parse_websocket_request_json(serde_json::json!({
268            "url": "ws://127.0.0.1:9001/ws",
269            "steps": [
270                { "type": "send", "text": "{\"ping\":true}" },
271                { "type": "receive", "timeoutSeconds": 2, "expect": { "jq": ".ok == true" } },
272                { "type": "close" }
273            ]
274        }))
275        .unwrap();
276
277        assert_eq!(req.steps.len(), 3);
278        assert_eq!(req.url.as_deref(), Some("ws://127.0.0.1:9001/ws"));
279    }
280
281    #[test]
282    fn websocket_schema_rejects_missing_steps() {
283        let err = parse_websocket_request_json(serde_json::json!({})).unwrap_err();
284        assert!(format!("{err:#}").contains(".steps is required"));
285    }
286
287    #[test]
288    fn websocket_schema_rejects_empty_steps() {
289        let err = parse_websocket_request_json(serde_json::json!({"steps": []})).unwrap_err();
290        assert!(format!("{err:#}").contains("must include at least one step"));
291    }
292
293    #[test]
294    fn websocket_schema_load_reads_file() {
295        let tmp = TempDir::new().unwrap();
296        let path = tmp.path().join("health.ws.json");
297        std::fs::write(
298            &path,
299            serde_json::to_vec_pretty(&serde_json::json!({
300                "url": "ws://127.0.0.1:9001/ws",
301                "steps": [
302                    { "type": "send", "text": "ping" },
303                    { "type": "receive", "expect": {"textContains": "ok"} }
304                ]
305            }))
306            .unwrap(),
307        )
308        .unwrap();
309
310        let loaded = WebsocketRequestFile::load(&path).unwrap();
311        assert_eq!(loaded.request.steps.len(), 2);
312    }
313}