1use std::sync::atomic::{AtomicU64, Ordering};
29use std::time::{SystemTime, UNIX_EPOCH};
30
31use futures_util::{SinkExt, StreamExt};
32use serde_json::{json, Value};
33use tokio::net::TcpStream;
34use tokio_tungstenite::tungstenite::Message;
35use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
36
37use crate::error::PulseError;
38
39static CORRELATION_COUNTER: AtomicU64 = AtomicU64::new(0);
42
43pub fn derive_ws_url(base_url: &str, agent_id: &str, token: Option<&str>) -> String {
51 let (scheme, rest) = match base_url.split_once("://") {
53 Some((s, r)) => (s, r),
54 None => ("http", base_url),
55 };
56 let ws_scheme = if scheme.eq_ignore_ascii_case("https") {
57 "wss"
58 } else {
59 "ws"
60 };
61
62 let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
65 let authority = &rest[..authority_end];
66
67 let host_port = authority
69 .rsplit_once('@')
70 .map(|(_, hp)| hp)
71 .unwrap_or(authority);
72
73 let netloc = match host_port.rsplit_once(':') {
74 Some((h, p)) if !h.is_empty() => match p.parse::<u32>() {
76 Ok(port) => format!("{h}:{}", port + 1),
77 Err(_) => host_port.to_string(),
79 },
80 _ if host_port.is_empty() => "localhost".to_string(),
82 _ => host_port.to_string(),
83 };
84
85 let path = format!("/api/pulse/agents/{}/duplex", encode_segment(agent_id));
86 match token {
87 Some(t) if !t.is_empty() => {
88 format!("{ws_scheme}://{netloc}{path}?token={}", encode_query(t))
89 }
90 _ => format!("{ws_scheme}://{netloc}{path}"),
91 }
92}
93
94#[derive(Debug, Clone)]
100pub struct DuplexOutput {
101 pub event: Value,
103 pub correlation_id: Option<String>,
106}
107
108pub struct DuplexChannel {
115 url: String,
116 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
117}
118
119impl DuplexChannel {
120 pub(crate) async fn connect(url: String) -> Result<Self, PulseError> {
126 let (mut ws, _resp) = connect_async(&url)
127 .await
128 .map_err(|e| PulseError::Duplex(format!("connect {url}: {e}")))?;
129
130 let first = read_json_frame(&mut ws, &url).await?;
132 if first.get("type").and_then(Value::as_str) == Some("error") {
133 let _ = ws.close(None).await;
135 let body = first.get("error").cloned().or(Some(first));
136 return Err(PulseError::Validation { path: url, body });
137 }
138 Ok(Self { url, ws })
139 }
140
141 pub async fn send(
146 &mut self,
147 payload: &Value,
148 correlation_id: Option<&str>,
149 ) -> Result<String, PulseError> {
150 let cid = match correlation_id {
151 Some(c) if !c.is_empty() => c.to_string(),
152 _ => generate_correlation_id(),
153 };
154 let frame = json!({
155 "type": "send",
156 "correlationId": cid,
157 "payload": payload,
158 });
159 let text = serde_json::to_string(&frame)?;
160 self.ws
161 .send(Message::text(text))
162 .await
163 .map_err(|e| PulseError::Duplex(format!("send on {}: {e}", self.url)))?;
164 Ok(cid)
165 }
166
167 pub async fn recv(&mut self) -> Result<DuplexOutput, PulseError> {
172 loop {
173 let msg = read_json_frame(&mut self.ws, &self.url).await?;
174 match msg.get("type").and_then(Value::as_str) {
175 Some("output") => {
176 let event = match msg.get("event") {
177 Some(Value::Object(_)) => msg.get("event").cloned().unwrap_or(Value::Null),
178 Some(other) => json!({ "value": other }),
179 None => Value::Null,
180 };
181 let correlation_id = msg
182 .get("correlationId")
183 .and_then(Value::as_str)
184 .map(str::to_string);
185 return Ok(DuplexOutput {
186 event,
187 correlation_id,
188 });
189 }
190 Some("error") => {
191 let body = msg.get("error").cloned().or(Some(msg));
192 return Err(PulseError::Validation {
193 path: self.url.clone(),
194 body,
195 });
196 }
197 _ => continue,
199 }
200 }
201 }
202
203 pub async fn close(mut self) -> Result<(), PulseError> {
205 self.ws
206 .close(None)
207 .await
208 .map_err(|e| PulseError::Duplex(format!("close {}: {e}", self.url)))
209 }
210}
211
212impl std::fmt::Debug for DuplexChannel {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("DuplexChannel")
215 .field("url", &self.url)
216 .finish()
217 }
218}
219
220async fn read_json_frame(
223 ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
224 url: &str,
225) -> Result<Value, PulseError> {
226 loop {
227 match ws.next().await {
228 Some(Ok(Message::Text(text))) => {
229 return serde_json::from_str(&text).map_err(PulseError::Json);
230 }
231 Some(Ok(Message::Binary(bytes))) => {
232 return serde_json::from_slice(&bytes).map_err(PulseError::Json);
233 }
234 Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => continue,
236 Some(Ok(Message::Close(frame))) => {
237 return Err(PulseError::Duplex(format!(
238 "{url} closed by server: {frame:?}"
239 )));
240 }
241 Some(Ok(Message::Frame(_))) => continue,
242 Some(Err(e)) => return Err(PulseError::Duplex(format!("{url}: {e}"))),
243 None => {
244 return Err(PulseError::Duplex(format!(
245 "{url}: connection closed before a frame arrived"
246 )))
247 }
248 }
249 }
250}
251
252fn generate_correlation_id() -> String {
255 let millis = SystemTime::now()
256 .duration_since(UNIX_EPOCH)
257 .map(|d| d.as_millis())
258 .unwrap_or(0);
259 let n = CORRELATION_COUNTER.fetch_add(1, Ordering::Relaxed);
260 format!("pulse-{millis:x}-{n:x}")
261}
262
263fn encode_segment(segment: &str) -> String {
265 let mut out = String::with_capacity(segment.len());
266 for b in segment.bytes() {
267 match b {
268 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
269 out.push(b as char)
270 }
271 _ => out.push_str(&format!("%{b:02X}")),
272 }
273 }
274 out
275}
276
277fn encode_query(value: &str) -> String {
279 encode_segment(value)
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn derive_http_to_ws_bumps_port() {
288 let url = derive_ws_url("http://localhost:9090", "fraud", Some("ey.jwt"));
289 assert_eq!(
290 url,
291 "ws://localhost:9091/api/pulse/agents/fraud/duplex?token=ey.jwt"
292 );
293 }
294
295 #[test]
296 fn derive_https_to_wss() {
297 let url = derive_ws_url("https://pulse.example.com:443", "pricing", None);
298 assert_eq!(
299 url,
300 "wss://pulse.example.com:444/api/pulse/agents/pricing/duplex"
301 );
302 }
303
304 #[test]
305 fn derive_default_port_when_absent() {
306 let url = derive_ws_url("http://localhost", "ab", None);
308 assert_eq!(url, "ws://localhost/api/pulse/agents/ab/duplex");
309 }
310
311 #[test]
312 fn derive_encodes_agent_id_and_token() {
313 let url = derive_ws_url("http://h:1000", "a/b c", Some("a=b&c"));
314 assert_eq!(
315 url,
316 "ws://h:1001/api/pulse/agents/a%2Fb%20c/duplex?token=a%3Db%26c"
317 );
318 }
319
320 #[test]
321 fn generated_ids_are_unique() {
322 let a = generate_correlation_id();
323 let b = generate_correlation_id();
324 assert_ne!(a, b);
325 assert!(a.starts_with("pulse-"));
326 }
327}