Skip to main content

api_testing_core/websocket/
runner.rs

1use anyhow::Context;
2use serde::Serialize;
3use tungstenite::client::IntoClientRequest;
4use tungstenite::http::{HeaderName, HeaderValue};
5use tungstenite::{Message, connect};
6
7use crate::Result;
8use crate::websocket::schema::{WebsocketExpect, WebsocketRequestFile, WebsocketStep};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
11pub struct WebsocketTranscriptEntry {
12    pub direction: String,
13    pub payload: String,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct WebsocketExecutedRequest {
18    pub target: String,
19    pub transcript: Vec<WebsocketTranscriptEntry>,
20    pub last_received: Option<String>,
21}
22
23fn parse_message_text(message: Message) -> String {
24    match message {
25        Message::Text(t) => t.to_string(),
26        Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
27        Message::Ping(b) => format!("<PING:{}>", String::from_utf8_lossy(&b)),
28        Message::Pong(b) => format!("<PONG:{}>", String::from_utf8_lossy(&b)),
29        Message::Close(frame) => match frame {
30            Some(f) => format!("<CLOSE:{}:{}>", f.code, f.reason),
31            None => "<CLOSE>".to_string(),
32        },
33        Message::Frame(_) => "<FRAME>".to_string(),
34    }
35}
36
37fn apply_expect(expect: Option<&WebsocketExpect>, text: &str, path: &str) -> Result<()> {
38    if let Some(expect) = expect {
39        crate::websocket::expect::evaluate_text_expect(expect, text, path)?;
40    }
41    Ok(())
42}
43
44pub fn execute_websocket_request(
45    request_file: &WebsocketRequestFile,
46    target_override: &str,
47    bearer_token: Option<&str>,
48) -> Result<WebsocketExecutedRequest> {
49    let target = if !target_override.trim().is_empty() {
50        target_override.trim().to_string()
51    } else if let Some(url) = request_file.request.url.as_deref() {
52        url.to_string()
53    } else {
54        anyhow::bail!("websocket target URL is empty (set request.url or pass --url/--env)");
55    };
56
57    let mut request = target
58        .as_str()
59        .into_client_request()
60        .context("invalid websocket target URL")?;
61
62    for (key, value) in &request_file.request.headers {
63        let header_name = HeaderName::from_bytes(key.as_bytes())
64            .with_context(|| format!("invalid websocket header name: {key}"))?;
65        let header_value = HeaderValue::from_str(value)
66            .with_context(|| format!("invalid websocket header value for {key}"))?;
67        request.headers_mut().insert(header_name, header_value);
68    }
69
70    if let Some(token) = bearer_token {
71        let has_auth = request_file
72            .request
73            .headers
74            .iter()
75            .any(|(k, _)| k.eq_ignore_ascii_case("authorization"));
76        if !has_auth {
77            request.headers_mut().insert(
78                HeaderName::from_static("authorization"),
79                HeaderValue::from_str(&format!("Bearer {token}"))
80                    .context("invalid bearer token for Authorization header")?,
81            );
82        }
83    }
84
85    let (mut socket, _resp) = connect(request)
86        .with_context(|| format!("failed to connect websocket target '{target}'"))?;
87
88    if let Some(connect_timeout_seconds) = request_file.request.connect_timeout_seconds {
89        let _ = connect_timeout_seconds;
90    }
91
92    let mut transcript: Vec<WebsocketTranscriptEntry> = Vec::new();
93    let mut last_received: Option<String> = None;
94
95    for (idx, step) in request_file.request.steps.iter().enumerate() {
96        match step {
97            WebsocketStep::Send { text } => {
98                socket
99                    .send(Message::Text(text.clone().into()))
100                    .with_context(|| format!("websocket send failed at step {idx}"))?;
101                transcript.push(WebsocketTranscriptEntry {
102                    direction: "send".to_string(),
103                    payload: text.clone(),
104                });
105            }
106            WebsocketStep::Receive {
107                timeout_seconds,
108                expect,
109            } => {
110                let _ = timeout_seconds;
111                let message = socket
112                    .read()
113                    .with_context(|| format!("websocket receive failed at step {idx}"))?;
114                let text = parse_message_text(message);
115                apply_expect(
116                    expect.as_ref(),
117                    &text,
118                    &format!("websocket steps[{idx}].expect"),
119                )?;
120                transcript.push(WebsocketTranscriptEntry {
121                    direction: "receive".to_string(),
122                    payload: text.clone(),
123                });
124                last_received = Some(text);
125            }
126            WebsocketStep::Close => {
127                let _ = socket.close(None);
128                transcript.push(WebsocketTranscriptEntry {
129                    direction: "close".to_string(),
130                    payload: String::new(),
131                });
132            }
133        }
134    }
135
136    Ok(WebsocketExecutedRequest {
137        target,
138        transcript,
139        last_received,
140    })
141}
142
143#[cfg(test)]
144mod tests {
145    use std::net::TcpListener;
146    use std::thread;
147
148    use pretty_assertions::assert_eq;
149    use tempfile::TempDir;
150    use tungstenite::Message;
151
152    use super::*;
153    use crate::websocket::schema::WebsocketRequestFile;
154
155    fn spawn_echo_server() -> (String, thread::JoinHandle<()>) {
156        let listener = TcpListener::bind("127.0.0.1:0").expect("bind websocket listener");
157        let addr = listener.local_addr().expect("listener addr");
158
159        let handle = thread::spawn(move || {
160            let (stream, _) = listener.accept().expect("accept websocket stream");
161            let mut ws = tungstenite::accept(stream).expect("accept websocket handshake");
162            loop {
163                match ws.read() {
164                    Ok(Message::Text(text)) => {
165                        let response = if text.trim() == "ping" {
166                            "{\"ok\":true}".to_string()
167                        } else {
168                            text.to_string()
169                        };
170                        ws.send(Message::Text(response.into()))
171                            .expect("send response");
172                    }
173                    Ok(Message::Close(_)) => {
174                        let _ = ws.close(None);
175                        break;
176                    }
177                    Ok(_) => {}
178                    Err(_) => break,
179                }
180            }
181        });
182
183        (format!("ws://{addr}"), handle)
184    }
185
186    #[test]
187    fn websocket_runner_executes_send_receive_steps() {
188        let tmp = TempDir::new().expect("tmp");
189        let request_path = tmp.path().join("echo.ws.json");
190
191        let (url, handle) = spawn_echo_server();
192
193        std::fs::write(
194            &request_path,
195            serde_json::to_vec_pretty(&serde_json::json!({
196                "url": url,
197                "steps": [
198                    {"type": "send", "text": "ping"},
199                    {"type": "receive", "expect": {"jq": ".ok == true"}},
200                    {"type": "close"}
201                ]
202            }))
203            .expect("serialize request"),
204        )
205        .expect("write request");
206
207        let loaded = WebsocketRequestFile::load(&request_path).expect("load request");
208        let executed = execute_websocket_request(&loaded, "", None).expect("execute websocket");
209
210        assert_eq!(executed.transcript.len(), 3);
211        assert_eq!(executed.transcript[0].direction, "send");
212        assert_eq!(executed.transcript[1].direction, "receive");
213        assert_eq!(executed.last_received.as_deref(), Some("{\"ok\":true}"));
214
215        handle.join().expect("join websocket server");
216    }
217}