api_testing_core/websocket/
runner.rs1use anyhow::Context;
2use serde::Serialize;
3use tungstenite::client::IntoClientRequest;
4use tungstenite::http::{HeaderName, HeaderValue};
5use tungstenite::{Message, connect};
6
7use crate::Result;
8use crate::websocket::schema::{WebsocketExpect, WebsocketRequestFile, WebsocketStep};
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
11pub struct WebsocketTranscriptEntry {
12 pub direction: String,
13 pub payload: String,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct WebsocketExecutedRequest {
18 pub target: String,
19 pub transcript: Vec<WebsocketTranscriptEntry>,
20 pub last_received: Option<String>,
21}
22
23fn parse_message_text(message: Message) -> String {
24 match message {
25 Message::Text(t) => t.to_string(),
26 Message::Binary(b) => String::from_utf8_lossy(&b).to_string(),
27 Message::Ping(b) => format!("<PING:{}>", String::from_utf8_lossy(&b)),
28 Message::Pong(b) => format!("<PONG:{}>", String::from_utf8_lossy(&b)),
29 Message::Close(frame) => match frame {
30 Some(f) => format!("<CLOSE:{}:{}>", f.code, f.reason),
31 None => "<CLOSE>".to_string(),
32 },
33 Message::Frame(_) => "<FRAME>".to_string(),
34 }
35}
36
37fn apply_expect(expect: Option<&WebsocketExpect>, text: &str, path: &str) -> Result<()> {
38 if let Some(expect) = expect {
39 crate::websocket::expect::evaluate_text_expect(expect, text, path)?;
40 }
41 Ok(())
42}
43
44pub fn execute_websocket_request(
45 request_file: &WebsocketRequestFile,
46 target_override: &str,
47 bearer_token: Option<&str>,
48) -> Result<WebsocketExecutedRequest> {
49 let target = if !target_override.trim().is_empty() {
50 target_override.trim().to_string()
51 } else if let Some(url) = request_file.request.url.as_deref() {
52 url.to_string()
53 } else {
54 anyhow::bail!("websocket target URL is empty (set request.url or pass --url/--env)");
55 };
56
57 let mut request = target
58 .as_str()
59 .into_client_request()
60 .context("invalid websocket target URL")?;
61
62 for (key, value) in &request_file.request.headers {
63 let header_name = HeaderName::from_bytes(key.as_bytes())
64 .with_context(|| format!("invalid websocket header name: {key}"))?;
65 let header_value = HeaderValue::from_str(value)
66 .with_context(|| format!("invalid websocket header value for {key}"))?;
67 request.headers_mut().insert(header_name, header_value);
68 }
69
70 if let Some(token) = bearer_token {
71 let has_auth = request_file
72 .request
73 .headers
74 .iter()
75 .any(|(k, _)| k.eq_ignore_ascii_case("authorization"));
76 if !has_auth {
77 request.headers_mut().insert(
78 HeaderName::from_static("authorization"),
79 HeaderValue::from_str(&format!("Bearer {token}"))
80 .context("invalid bearer token for Authorization header")?,
81 );
82 }
83 }
84
85 let (mut socket, _resp) = connect(request)
86 .with_context(|| format!("failed to connect websocket target '{target}'"))?;
87
88 if let Some(connect_timeout_seconds) = request_file.request.connect_timeout_seconds {
89 let _ = connect_timeout_seconds;
90 }
91
92 let mut transcript: Vec<WebsocketTranscriptEntry> = Vec::new();
93 let mut last_received: Option<String> = None;
94
95 for (idx, step) in request_file.request.steps.iter().enumerate() {
96 match step {
97 WebsocketStep::Send { text } => {
98 socket
99 .send(Message::Text(text.clone().into()))
100 .with_context(|| format!("websocket send failed at step {idx}"))?;
101 transcript.push(WebsocketTranscriptEntry {
102 direction: "send".to_string(),
103 payload: text.clone(),
104 });
105 }
106 WebsocketStep::Receive {
107 timeout_seconds,
108 expect,
109 } => {
110 let _ = timeout_seconds;
111 let message = socket
112 .read()
113 .with_context(|| format!("websocket receive failed at step {idx}"))?;
114 let text = parse_message_text(message);
115 apply_expect(
116 expect.as_ref(),
117 &text,
118 &format!("websocket steps[{idx}].expect"),
119 )?;
120 transcript.push(WebsocketTranscriptEntry {
121 direction: "receive".to_string(),
122 payload: text.clone(),
123 });
124 last_received = Some(text);
125 }
126 WebsocketStep::Close => {
127 let _ = socket.close(None);
128 transcript.push(WebsocketTranscriptEntry {
129 direction: "close".to_string(),
130 payload: String::new(),
131 });
132 }
133 }
134 }
135
136 Ok(WebsocketExecutedRequest {
137 target,
138 transcript,
139 last_received,
140 })
141}
142
143#[cfg(test)]
144mod tests {
145 use std::net::TcpListener;
146 use std::thread;
147
148 use pretty_assertions::assert_eq;
149 use tempfile::TempDir;
150 use tungstenite::Message;
151
152 use super::*;
153 use crate::websocket::schema::WebsocketRequestFile;
154
155 fn spawn_echo_server() -> (String, thread::JoinHandle<()>) {
156 let listener = TcpListener::bind("127.0.0.1:0").expect("bind websocket listener");
157 let addr = listener.local_addr().expect("listener addr");
158
159 let handle = thread::spawn(move || {
160 let (stream, _) = listener.accept().expect("accept websocket stream");
161 let mut ws = tungstenite::accept(stream).expect("accept websocket handshake");
162 loop {
163 match ws.read() {
164 Ok(Message::Text(text)) => {
165 let response = if text.trim() == "ping" {
166 "{\"ok\":true}".to_string()
167 } else {
168 text.to_string()
169 };
170 ws.send(Message::Text(response.into()))
171 .expect("send response");
172 }
173 Ok(Message::Close(_)) => {
174 let _ = ws.close(None);
175 break;
176 }
177 Ok(_) => {}
178 Err(_) => break,
179 }
180 }
181 });
182
183 (format!("ws://{addr}"), handle)
184 }
185
186 #[test]
187 fn websocket_runner_executes_send_receive_steps() {
188 let tmp = TempDir::new().expect("tmp");
189 let request_path = tmp.path().join("echo.ws.json");
190
191 let (url, handle) = spawn_echo_server();
192
193 std::fs::write(
194 &request_path,
195 serde_json::to_vec_pretty(&serde_json::json!({
196 "url": url,
197 "steps": [
198 {"type": "send", "text": "ping"},
199 {"type": "receive", "expect": {"jq": ".ok == true"}},
200 {"type": "close"}
201 ]
202 }))
203 .expect("serialize request"),
204 )
205 .expect("write request");
206
207 let loaded = WebsocketRequestFile::load(&request_path).expect("load request");
208 let executed = execute_websocket_request(&loaded, "", None).expect("execute websocket");
209
210 assert_eq!(executed.transcript.len(), 3);
211 assert_eq!(executed.transcript[0].direction, "send");
212 assert_eq!(executed.transcript[1].direction, "receive");
213 assert_eq!(executed.last_received.as_deref(), Some("{\"ok\":true}"));
214
215 handle.join().expect("join websocket server");
216 }
217}