Skip to main content

pulse_client/
duplex.rs

1//! B-114 — bidirectional duplex channel for synchronous decision agents.
2//!
3//! Opens ONE WebSocket to `/api/pulse/agents/{id}/duplex`: events are streamed
4//! IN and the agent's correlated outputs come back OUT on the same connection,
5//! matched by a correlation id. Eliminates the 2-connection publish-then-poll
6//! pattern for decision microservices (fraud, pricing, A/B assignment).
7//!
8//! The duplex endpoint runs on the Pulse WebSocket port (REST port + 1 by
9//! convention); [`derive_ws_url`] derives it from the client's `base_url`.
10//!
11//! # Example
12//!
13//! ```no_run
14//! use pulse_client::PulseClient;
15//! use serde_json::json;
16//!
17//! # async fn run(client: &PulseClient) -> Result<(), pulse_client::PulseError> {
18//! let mut ch = client.duplex("fraud-detector").await?;
19//! let cid = ch.send(&json!({ "amount": 5000 }), Some("tx-1")).await?;
20//! let output = ch.recv().await?;
21//! assert_eq!(output.correlation_id.as_deref(), Some("tx-1"));
22//! let _ = cid;
23//! ch.close().await?;
24//! # Ok(())
25//! # }
26//! ```
27
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::time::{SystemTime, UNIX_EPOCH};
30
31use futures_util::{SinkExt, StreamExt};
32use serde_json::{json, Value};
33use tokio::net::TcpStream;
34use tokio_tungstenite::tungstenite::Message;
35use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
36
37use crate::error::PulseError;
38
39/// Monotonic counter feeding generated correlation ids (combined with a
40/// millisecond timestamp so ids stay unique across channels in a process).
41static CORRELATION_COUNTER: AtomicU64 = AtomicU64::new(0);
42
43/// Builds the duplex WebSocket URL from the client's REST `base_url`.
44///
45/// `http`→`ws` / `https`→`wss`, host unchanged, port → REST port + 1 (the
46/// Pulse WebSocket server convention). The JWT, when set, rides as a `token`
47/// query param (the server reads it from the upgrade request line).
48///
49/// Mirrors `pulse_client._duplex.derive_ws_url` in the Python SDK.
50pub fn derive_ws_url(base_url: &str, agent_id: &str, token: Option<&str>) -> String {
51    // Split scheme.
52    let (scheme, rest) = match base_url.split_once("://") {
53        Some((s, r)) => (s, r),
54        None => ("http", base_url),
55    };
56    let ws_scheme = if scheme.eq_ignore_ascii_case("https") {
57        "wss"
58    } else {
59        "ws"
60    };
61
62    // The authority is everything up to the first '/' (path), '?' (query) or
63    // '#' (fragment) — Pulse base URLs are bare origins, but be defensive.
64    let authority_end = rest.find(['/', '?', '#']).unwrap_or(rest.len());
65    let authority = &rest[..authority_end];
66
67    // Strip userinfo if present (`user:pass@host:port`).
68    let host_port = authority
69        .rsplit_once('@')
70        .map(|(_, hp)| hp)
71        .unwrap_or(authority);
72
73    let netloc = match host_port.rsplit_once(':') {
74        // host:port → bump the port by one (the WS server convention).
75        Some((h, p)) if !h.is_empty() => match p.parse::<u32>() {
76            Ok(port) => format!("{h}:{}", port + 1),
77            // Not a numeric port (e.g. an IPv6 literal) → leave untouched.
78            Err(_) => host_port.to_string(),
79        },
80        // No explicit port → host unchanged (cannot bump an absent port).
81        _ if host_port.is_empty() => "localhost".to_string(),
82        _ => host_port.to_string(),
83    };
84
85    let path = format!("/api/pulse/agents/{}/duplex", encode_segment(agent_id));
86    match token {
87        Some(t) if !t.is_empty() => {
88            format!("{ws_scheme}://{netloc}{path}?token={}", encode_query(t))
89        }
90        _ => format!("{ws_scheme}://{netloc}{path}"),
91    }
92}
93
94/// An agent output event received over the duplex channel.
95///
96/// `event` is the agent's output event (`id` / `topic` / `type` / `key` /
97/// `payload`); `correlation_id` identifies the input that produced it (the id
98/// returned by the matching [`DuplexChannel::send`]).
99#[derive(Debug, Clone)]
100pub struct DuplexOutput {
101    /// The agent's output event JSON.
102    pub event: Value,
103    /// Correlation id matching the input that produced this output, if the
104    /// server supplied one.
105    pub correlation_id: Option<String>,
106}
107
108/// An open duplex session.
109///
110/// [`send`](Self::send) publishes an event to the agent's input topic and
111/// returns its correlation id; [`recv`](Self::recv) returns the next output
112/// event the agent produced. Ack / pong / connected frames are consumed
113/// transparently by [`recv`](Self::recv).
114pub struct DuplexChannel {
115    url: String,
116    ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
117}
118
119impl DuplexChannel {
120    /// Connect and complete the duplex handshake.
121    ///
122    /// The server sends a `connected` frame first (or `error` + close for an
123    /// unknown agent / disabled duplex). An `error` first frame surfaces as
124    /// [`PulseError::Validation`].
125    pub(crate) async fn connect(url: String) -> Result<Self, PulseError> {
126        let (mut ws, _resp) = connect_async(&url)
127            .await
128            .map_err(|e| PulseError::Duplex(format!("connect {url}: {e}")))?;
129
130        // Read the first frame — `connected` (proceed) or `error` (abort).
131        let first = read_json_frame(&mut ws, &url).await?;
132        if first.get("type").and_then(Value::as_str) == Some("error") {
133            // Best-effort close, then surface the server's error payload.
134            let _ = ws.close(None).await;
135            let body = first.get("error").cloned().or(Some(first));
136            return Err(PulseError::Validation { path: url, body });
137        }
138        Ok(Self { url, ws })
139    }
140
141    /// Publish `payload` to the agent's input topic.
142    ///
143    /// Returns the correlation id (generated when `correlation_id` is `None`)
144    /// that the matching output will carry.
145    pub async fn send(
146        &mut self,
147        payload: &Value,
148        correlation_id: Option<&str>,
149    ) -> Result<String, PulseError> {
150        let cid = match correlation_id {
151            Some(c) if !c.is_empty() => c.to_string(),
152            _ => generate_correlation_id(),
153        };
154        let frame = json!({
155            "type": "send",
156            "correlationId": cid,
157            "payload": payload,
158        });
159        let text = serde_json::to_string(&frame)?;
160        self.ws
161            .send(Message::text(text))
162            .await
163            .map_err(|e| PulseError::Duplex(format!("send on {}: {e}", self.url)))?;
164        Ok(cid)
165    }
166
167    /// Return the next agent output event.
168    ///
169    /// Skips `ack` / `pong` / `connected` frames transparently. An `error`
170    /// frame surfaces as [`PulseError::Validation`].
171    pub async fn recv(&mut self) -> Result<DuplexOutput, PulseError> {
172        loop {
173            let msg = read_json_frame(&mut self.ws, &self.url).await?;
174            match msg.get("type").and_then(Value::as_str) {
175                Some("output") => {
176                    let event = match msg.get("event") {
177                        Some(Value::Object(_)) => msg.get("event").cloned().unwrap_or(Value::Null),
178                        Some(other) => json!({ "value": other }),
179                        None => Value::Null,
180                    };
181                    let correlation_id = msg
182                        .get("correlationId")
183                        .and_then(Value::as_str)
184                        .map(str::to_string);
185                    return Ok(DuplexOutput {
186                        event,
187                        correlation_id,
188                    });
189                }
190                Some("error") => {
191                    let body = msg.get("error").cloned().or(Some(msg));
192                    return Err(PulseError::Validation {
193                        path: self.url.clone(),
194                        body,
195                    });
196                }
197                // ack / pong / connected / anything else → skip
198                _ => continue,
199            }
200        }
201    }
202
203    /// Close the channel cleanly.
204    pub async fn close(mut self) -> Result<(), PulseError> {
205        self.ws
206            .close(None)
207            .await
208            .map_err(|e| PulseError::Duplex(format!("close {}: {e}", self.url)))
209    }
210}
211
212impl std::fmt::Debug for DuplexChannel {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("DuplexChannel")
215            .field("url", &self.url)
216            .finish()
217    }
218}
219
220/// Read frames until a JSON text/binary frame arrives, skipping ping/pong and
221/// surfacing a clear error on close / transport failure.
222async fn read_json_frame(
223    ws: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
224    url: &str,
225) -> Result<Value, PulseError> {
226    loop {
227        match ws.next().await {
228            Some(Ok(Message::Text(text))) => {
229                return serde_json::from_str(&text).map_err(PulseError::Json);
230            }
231            Some(Ok(Message::Binary(bytes))) => {
232                return serde_json::from_slice(&bytes).map_err(PulseError::Json);
233            }
234            // Ping/pong are handled by the library's auto-pong; skip explicitly.
235            Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => continue,
236            Some(Ok(Message::Close(frame))) => {
237                return Err(PulseError::Duplex(format!(
238                    "{url} closed by server: {frame:?}"
239                )));
240            }
241            Some(Ok(Message::Frame(_))) => continue,
242            Some(Err(e)) => return Err(PulseError::Duplex(format!("{url}: {e}"))),
243            None => {
244                return Err(PulseError::Duplex(format!(
245                    "{url}: connection closed before a frame arrived"
246                )))
247            }
248        }
249    }
250}
251
252/// Generates a unique correlation id without pulling in the `uuid` crate:
253/// `<millis-since-epoch>-<process-monotonic-counter>`.
254fn generate_correlation_id() -> String {
255    let millis = SystemTime::now()
256        .duration_since(UNIX_EPOCH)
257        .map(|d| d.as_millis())
258        .unwrap_or(0);
259    let n = CORRELATION_COUNTER.fetch_add(1, Ordering::Relaxed);
260    format!("pulse-{millis:x}-{n:x}")
261}
262
263/// Percent-encode a path segment (same unreserved set as `resources::encode_path`).
264fn encode_segment(segment: &str) -> String {
265    let mut out = String::with_capacity(segment.len());
266    for b in segment.bytes() {
267        match b {
268            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
269                out.push(b as char)
270            }
271            _ => out.push_str(&format!("%{b:02X}")),
272        }
273    }
274    out
275}
276
277/// Percent-encode a query-param value (same unreserved set, encodes `&`/`=`/etc).
278fn encode_query(value: &str) -> String {
279    encode_segment(value)
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn derive_http_to_ws_bumps_port() {
288        let url = derive_ws_url("http://localhost:9090", "fraud", Some("ey.jwt"));
289        assert_eq!(
290            url,
291            "ws://localhost:9091/api/pulse/agents/fraud/duplex?token=ey.jwt"
292        );
293    }
294
295    #[test]
296    fn derive_https_to_wss() {
297        let url = derive_ws_url("https://pulse.example.com:443", "pricing", None);
298        assert_eq!(
299            url,
300            "wss://pulse.example.com:444/api/pulse/agents/pricing/duplex"
301        );
302    }
303
304    #[test]
305    fn derive_default_port_when_absent() {
306        // No explicit port → host unchanged (cannot bump an absent port).
307        let url = derive_ws_url("http://localhost", "ab", None);
308        assert_eq!(url, "ws://localhost/api/pulse/agents/ab/duplex");
309    }
310
311    #[test]
312    fn derive_encodes_agent_id_and_token() {
313        let url = derive_ws_url("http://h:1000", "a/b c", Some("a=b&c"));
314        assert_eq!(
315            url,
316            "ws://h:1001/api/pulse/agents/a%2Fb%20c/duplex?token=a%3Db%26c"
317        );
318    }
319
320    #[test]
321    fn generated_ids_are_unique() {
322        let a = generate_correlation_id();
323        let b = generate_correlation_id();
324        assert_ne!(a, b);
325        assert!(a.starts_with("pulse-"));
326    }
327}