concord_client/
queue_client.rs

1use bytes::Bytes;
2use futures::{SinkExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use tokio_tungstenite::tungstenite::{self, Utf8Bytes};
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, warn};
7
8use crate::{
9    api_err,
10    error::ApiError,
11    model::{AgentId, ApiToken, ProcessId, SessionToken, USER_AGENT_VALUE},
12};
13
14#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
15pub struct CorrelationId(i64);
16
17#[derive(Clone, Default)]
18struct CorrelationIdGenerator {
19    v: std::sync::Arc<std::sync::atomic::AtomicI64>,
20}
21
22impl CorrelationIdGenerator {
23    fn next(&self) -> CorrelationId {
24        let id = self.v.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
25        CorrelationId(id)
26    }
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30#[serde(tag = "messageType", rename_all = "SCREAMING_SNAKE_CASE")]
31enum Message {
32    CommandRequest {
33        #[serde(rename = "correlationId")]
34        correlation_id: CorrelationId,
35        #[serde(rename = "agentId")]
36        agent_id: AgentId,
37    },
38    CommandResponse {
39        #[serde(rename = "correlationId")]
40        correlation_id: CorrelationId,
41        // TODO
42    },
43    ProcessRequest {
44        #[serde(rename = "correlationId")]
45        correlation_id: CorrelationId,
46        capabilities: serde_json::Value,
47    },
48    ProcessResponse {
49        #[serde(rename = "correlationId")]
50        correlation_id: CorrelationId,
51        #[serde(rename = "sessionToken")]
52        session_token: SessionToken,
53        #[serde(rename = "processId")]
54        process_id: ProcessId,
55        // TODO imports
56    },
57}
58
59#[derive(Debug)]
60pub struct CommandResponse {
61    pub correlation_id: CorrelationId,
62    // TODO
63}
64
65#[derive(Debug)]
66pub struct ProcessResponse {
67    pub correlation_id: CorrelationId,
68    pub session_token: SessionToken,
69    pub process_id: ProcessId,
70}
71
72type Reply = Result<Message, ApiError>;
73
74struct Responder(CorrelationId, tokio::sync::oneshot::Sender<Reply>);
75
76#[derive(Default, Clone)]
77struct ResponseQueue(std::sync::Arc<tokio::sync::Mutex<std::collections::HashMap<CorrelationId, Responder>>>);
78
79impl ResponseQueue {
80    async fn lock(&self) -> tokio::sync::MutexGuard<std::collections::HashMap<CorrelationId, Responder>> {
81        self.0.lock().await
82    }
83}
84
85struct MessageToSend {
86    msg: tungstenite::Message,
87    resp: Option<Responder>,
88}
89
90impl MessageToSend {
91    fn ping() -> Self {
92        MessageToSend {
93            msg: tungstenite::Message::Ping(Bytes::new()),
94            resp: None,
95        }
96    }
97
98    fn pong() -> Self {
99        MessageToSend {
100            msg: tungstenite::Message::Pong(Bytes::new()),
101            resp: None,
102        }
103    }
104
105    fn text(text: String, resp: Responder) -> Self {
106        MessageToSend {
107            msg: tungstenite::Message::Text(text.into()),
108            resp: Some(resp),
109        }
110    }
111}
112
113pub struct Config {
114    pub agent_id: AgentId,
115    pub uri: http::Uri,
116    pub api_token: ApiToken,
117    pub capabilities: serde_json::Value,
118    pub ping_interval: std::time::Duration,
119}
120
121pub struct QueueClient {
122    agent_id: AgentId,
123    capabilities: serde_json::Value,
124    tx: tokio::sync::mpsc::Sender<MessageToSend>,
125    cancellation_token: CancellationToken,
126    correlation_id_gen: CorrelationIdGenerator,
127}
128
129impl QueueClient {
130    pub async fn connect(config: &Config) -> Result<Self, ApiError> {
131        let req = QueueClient::create_connect_request(config)?;
132        let (ws_stream, _) = tokio_tungstenite::connect_async(req).await?;
133
134        let (mut ws_write, mut ws_read) = ws_stream.split();
135
136        // a channel to communicate between tasks
137        let (tx, mut rx) = tokio::sync::mpsc::channel::<MessageToSend>(1);
138
139        // used to match responses to requests using correlation IDs
140        let response_queue = ResponseQueue::default();
141
142        let cancellation_token = CancellationToken::new();
143
144        // task to send pings
145        let token = cancellation_token.clone();
146        let _ping_task = {
147            let tx = tx.clone();
148            let ping_interval = config.ping_interval;
149            tokio::spawn(async move {
150                let mut interval = tokio::time::interval(ping_interval);
151                loop {
152                    tokio::select! {
153                        _ = token.cancelled() => {
154                            debug!("Stopping ping task...");
155                            break
156                        }
157                        _ = interval.tick() => {
158                            debug!("Sending a ping...");
159                            if let Err(e) = tx.send(MessageToSend::ping()).await {
160                                warn!("Failed to send a ping message to the server: {e}");
161                            }
162                        }
163                    }
164                }
165            })
166        };
167
168        // task to send messages to the server
169        let token = cancellation_token.clone();
170        let _write_task = {
171            let response_queue = response_queue.clone();
172            tokio::spawn(async move {
173                loop {
174                    tokio::select! {
175                        _ = token.cancelled() => {
176                            debug!("Stopping message sender...");
177                            break
178                        },
179                        msg = rx.recv() => {
180                            if let Some(MessageToSend { msg, resp }) = msg {
181                                match ws_write.send(msg).await {
182                                    Ok(_) => {
183                                        // message sent successfully, register the responder if needed
184                                        if let Some(Responder(correlation_id, channel)) = resp {
185                                            let mut response_queue = response_queue.lock().await;
186                                            let responder = Responder(correlation_id, channel);
187                                            response_queue.insert(correlation_id, responder);
188                                        }
189                                    }
190                                    Err(e) => {
191                                        // message failed to send, notify the responder if needed
192                                        let err = format!("Write error: {e}");
193                                        warn!("{}", err);
194                                        if let Some(Responder(_, channel)) = resp {
195                                            if channel.send(Err(ApiError::simple(&err))).is_err() {
196                                                warn!("Responder error (most likely a bug)");
197                                            }
198                                        }
199                                    }
200                                }
201                            } else {
202                                break
203                            }
204                        }
205                    }
206                }
207            })
208        };
209
210        // task to receive messages from the server
211        let token = cancellation_token.clone();
212        let _read_task = {
213            let tx = tx.clone();
214            tokio::spawn(async move {
215                loop {
216                    tokio::select! {
217                        _ = token.cancelled() => {
218                            debug!("Stopping message received...");
219                            break
220                        }
221                        msg = ws_read.next() => {
222                            match msg {
223                                Some(Ok(msg)) => match msg {
224                                    tungstenite::Message::Ping(_) => {
225                                        // respond to pings
226                                        if let Err(e) = tx.send(MessageToSend::pong()).await {
227                                            warn!("Failed to send a pong response to the server: {e}");
228                                        }
229                                    }
230                                    tungstenite::Message::Pong(_) => {
231                                        // log pongs
232                                        debug!("Received a pong");
233                                    }
234                                    tungstenite::Message::Text(text) => {
235                                        debug!("Received message: {}", text);
236                                        QueueClient::handle_text_message(text, &response_queue).await;
237                                    }
238                                    _ => {
239                                        // log and ignore bad messages
240                                        warn!("Unexpected message (possibly a bug): {msg:?}");
241                                    }
242                                },
243                                Some(Err(e)) => {
244                                    // complain about network errors and stop
245                                    warn!("Read error: {e}");
246                                    token.cancel();
247                                }
248                                None => break,
249                            }
250                        }
251                    }
252                }
253            })
254        };
255
256        Ok(QueueClient {
257            agent_id: config.agent_id,
258            capabilities: config.capabilities.clone(),
259            tx,
260            cancellation_token,
261            correlation_id_gen: Default::default(),
262        })
263    }
264
265    pub async fn next_command(&self) -> Result<CommandResponse, ApiError> {
266        let correlation_id = self.correlation_id_gen.next();
267
268        let msg = Message::CommandRequest {
269            correlation_id,
270            agent_id: self.agent_id,
271        };
272
273        match self.send_and_wait_for_reply(correlation_id, msg).await {
274            Ok(Message::CommandResponse {
275                correlation_id: reply_correlation_id,
276            }) => {
277                if correlation_id == reply_correlation_id {
278                    Ok(CommandResponse { correlation_id })
279                } else {
280                    api_err!("Unexpected correlation ID: {reply_correlation_id:?}")
281                }
282            }
283            Ok(msg) => api_err!("Unexpected message: {msg:?}"),
284            Err(e) => api_err!("Error while parsing message: {e}"),
285        }
286    }
287
288    pub async fn next_process(&self) -> Result<ProcessResponse, ApiError> {
289        let correlation_id = self.correlation_id_gen.next();
290
291        let msg = Message::ProcessRequest {
292            correlation_id,
293            capabilities: serde_json::json!(&self.capabilities),
294        };
295
296        match self.send_and_wait_for_reply(correlation_id, msg).await {
297            Ok(Message::ProcessResponse {
298                correlation_id: reply_correlation_id,
299                session_token,
300                process_id,
301            }) => {
302                if correlation_id == reply_correlation_id {
303                    Ok(ProcessResponse {
304                        correlation_id,
305                        session_token,
306                        process_id,
307                    })
308                } else {
309                    api_err!("Unexpected correlation ID: {reply_correlation_id:?}")
310                }
311            }
312            Ok(msg) => api_err!("Unexpected message: {msg:?}"),
313            Err(e) => api_err!("Error while parsing message: {e}"),
314        }
315    }
316
317    async fn send_and_wait_for_reply(&self, correlation_id: CorrelationId, msg: Message) -> Reply {
318        debug!("Sending message {msg:?} and waiting for reply");
319
320        let json = serde_json::to_string(&msg)?;
321
322        let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<Reply>();
323
324        let msg = MessageToSend::text(json, Responder(correlation_id, reply_sender));
325        if let Err(e) = self.tx.send(msg).await {
326            return api_err!("Send error: {e}");
327        }
328
329        match reply_receiver.await {
330            Ok(reply) => reply,
331            Err(e) => api_err!("Error while receiving reply: {e}"),
332        }
333    }
334
335    async fn handle_text_message(text: Utf8Bytes, message_queue: &ResponseQueue) {
336        match serde_json::from_str::<Message>(&text) {
337            Ok(
338                cmd @ Message::CommandResponse { correlation_id, .. }
339                | cmd @ Message::ProcessResponse { correlation_id, .. },
340            ) => {
341                let mut message_queue = message_queue.lock().await;
342                if let Some(Responder(_, responder)) = message_queue.remove(&correlation_id) {
343                    if responder.send(Ok(cmd)).is_err() {
344                        warn!("Responder error (most likely a bug)");
345                    }
346                } else {
347                    warn!("No responder registered for correlation_id={correlation_id:?} (possibly a bug).")
348                }
349            }
350            Ok(msg) => {
351                // log and ignore bad messages
352                warn!("Unexpected message body (most likely a bug): {:?}", msg);
353            }
354            Err(e) => {
355                // complain about parsing errors
356                warn!("Error while parsing message (possibly a bug): {e}");
357            }
358        }
359    }
360
361    fn create_connect_request(
362        Config {
363            uri,
364            api_token,
365            agent_id,
366            ..
367        }: &Config,
368    ) -> Result<http::Request<()>, ApiError> {
369        let host = format!(
370            "{}:{}",
371            uri.host().unwrap_or("localhost"),
372            uri.port_u16().unwrap_or(8001)
373        );
374
375        let ws_key = tungstenite::handshake::client::generate_key();
376
377        use http::{Request, header};
378        Request::builder()
379            .uri(uri.clone())
380            .header(header::HOST, host)
381            .header(header::AUTHORIZATION, api_token)
382            .header(header::CONNECTION, "Upgrade")
383            .header(header::UPGRADE, "websocket")
384            .header(header::SEC_WEBSOCKET_VERSION, "13")
385            .header(header::SEC_WEBSOCKET_KEY, ws_key)
386            .header(header::USER_AGENT, USER_AGENT_VALUE)
387            .header("X-Concord-Agent-Id", agent_id)
388            .header("X-Concord-Agent", USER_AGENT_VALUE)
389            .body(())
390            .map_err(|e| ApiError {
391                message: e.to_string(),
392            })
393    }
394}
395
396impl Drop for QueueClient {
397    fn drop(&mut self) {
398        self.cancellation_token.cancel();
399    }
400}