1use bytes::Bytes;
2use futures::{SinkExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use tokio_tungstenite::tungstenite::{self, Utf8Bytes};
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, warn};
7
8use crate::{
9 api_err,
10 error::ApiError,
11 model::{AgentId, ApiToken, ProcessId, SessionToken, USER_AGENT_VALUE},
12};
13
14#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
15pub struct CorrelationId(i64);
16
17#[derive(Clone, Default)]
18struct CorrelationIdGenerator {
19 v: std::sync::Arc<std::sync::atomic::AtomicI64>,
20}
21
22impl CorrelationIdGenerator {
23 fn next(&self) -> CorrelationId {
24 let id = self.v.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
25 CorrelationId(id)
26 }
27}
28
29#[derive(Debug, Serialize, Deserialize)]
30#[serde(tag = "messageType", rename_all = "SCREAMING_SNAKE_CASE")]
31enum Message {
32 CommandRequest {
33 #[serde(rename = "correlationId")]
34 correlation_id: CorrelationId,
35 #[serde(rename = "agentId")]
36 agent_id: AgentId,
37 },
38 CommandResponse {
39 #[serde(rename = "correlationId")]
40 correlation_id: CorrelationId,
41 },
43 ProcessRequest {
44 #[serde(rename = "correlationId")]
45 correlation_id: CorrelationId,
46 capabilities: serde_json::Value,
47 },
48 ProcessResponse {
49 #[serde(rename = "correlationId")]
50 correlation_id: CorrelationId,
51 #[serde(rename = "sessionToken")]
52 session_token: SessionToken,
53 #[serde(rename = "processId")]
54 process_id: ProcessId,
55 },
57}
58
59#[derive(Debug)]
60pub struct CommandResponse {
61 pub correlation_id: CorrelationId,
62 }
64
65#[derive(Debug)]
66pub struct ProcessResponse {
67 pub correlation_id: CorrelationId,
68 pub session_token: SessionToken,
69 pub process_id: ProcessId,
70}
71
72type Reply = Result<Message, ApiError>;
73
74struct Responder(CorrelationId, tokio::sync::oneshot::Sender<Reply>);
75
76#[derive(Default, Clone)]
77struct ResponseQueue(std::sync::Arc<tokio::sync::Mutex<std::collections::HashMap<CorrelationId, Responder>>>);
78
79impl ResponseQueue {
80 async fn lock(&self) -> tokio::sync::MutexGuard<std::collections::HashMap<CorrelationId, Responder>> {
81 self.0.lock().await
82 }
83}
84
85struct MessageToSend {
86 msg: tungstenite::Message,
87 resp: Option<Responder>,
88}
89
90impl MessageToSend {
91 fn ping() -> Self {
92 MessageToSend {
93 msg: tungstenite::Message::Ping(Bytes::new()),
94 resp: None,
95 }
96 }
97
98 fn pong() -> Self {
99 MessageToSend {
100 msg: tungstenite::Message::Pong(Bytes::new()),
101 resp: None,
102 }
103 }
104
105 fn text(text: String, resp: Responder) -> Self {
106 MessageToSend {
107 msg: tungstenite::Message::Text(text.into()),
108 resp: Some(resp),
109 }
110 }
111}
112
113pub struct Config {
114 pub agent_id: AgentId,
115 pub uri: http::Uri,
116 pub api_token: ApiToken,
117 pub capabilities: serde_json::Value,
118 pub ping_interval: std::time::Duration,
119}
120
121pub struct QueueClient {
122 agent_id: AgentId,
123 capabilities: serde_json::Value,
124 tx: tokio::sync::mpsc::Sender<MessageToSend>,
125 cancellation_token: CancellationToken,
126 correlation_id_gen: CorrelationIdGenerator,
127}
128
129impl QueueClient {
130 pub async fn connect(config: &Config) -> Result<Self, ApiError> {
131 let req = QueueClient::create_connect_request(config)?;
132 let (ws_stream, _) = tokio_tungstenite::connect_async(req).await?;
133
134 let (mut ws_write, mut ws_read) = ws_stream.split();
135
136 let (tx, mut rx) = tokio::sync::mpsc::channel::<MessageToSend>(1);
138
139 let response_queue = ResponseQueue::default();
141
142 let cancellation_token = CancellationToken::new();
143
144 let token = cancellation_token.clone();
146 let _ping_task = {
147 let tx = tx.clone();
148 let ping_interval = config.ping_interval;
149 tokio::spawn(async move {
150 let mut interval = tokio::time::interval(ping_interval);
151 loop {
152 tokio::select! {
153 _ = token.cancelled() => {
154 debug!("Stopping ping task...");
155 break
156 }
157 _ = interval.tick() => {
158 debug!("Sending a ping...");
159 if let Err(e) = tx.send(MessageToSend::ping()).await {
160 warn!("Failed to send a ping message to the server: {e}");
161 }
162 }
163 }
164 }
165 })
166 };
167
168 let token = cancellation_token.clone();
170 let _write_task = {
171 let response_queue = response_queue.clone();
172 tokio::spawn(async move {
173 loop {
174 tokio::select! {
175 _ = token.cancelled() => {
176 debug!("Stopping message sender...");
177 break
178 },
179 msg = rx.recv() => {
180 if let Some(MessageToSend { msg, resp }) = msg {
181 match ws_write.send(msg).await {
182 Ok(_) => {
183 if let Some(Responder(correlation_id, channel)) = resp {
185 let mut response_queue = response_queue.lock().await;
186 let responder = Responder(correlation_id, channel);
187 response_queue.insert(correlation_id, responder);
188 }
189 }
190 Err(e) => {
191 let err = format!("Write error: {e}");
193 warn!("{}", err);
194 if let Some(Responder(_, channel)) = resp {
195 if channel.send(Err(ApiError::simple(&err))).is_err() {
196 warn!("Responder error (most likely a bug)");
197 }
198 }
199 }
200 }
201 } else {
202 break
203 }
204 }
205 }
206 }
207 })
208 };
209
210 let token = cancellation_token.clone();
212 let _read_task = {
213 let tx = tx.clone();
214 tokio::spawn(async move {
215 loop {
216 tokio::select! {
217 _ = token.cancelled() => {
218 debug!("Stopping message received...");
219 break
220 }
221 msg = ws_read.next() => {
222 match msg {
223 Some(Ok(msg)) => match msg {
224 tungstenite::Message::Ping(_) => {
225 if let Err(e) = tx.send(MessageToSend::pong()).await {
227 warn!("Failed to send a pong response to the server: {e}");
228 }
229 }
230 tungstenite::Message::Pong(_) => {
231 debug!("Received a pong");
233 }
234 tungstenite::Message::Text(text) => {
235 debug!("Received message: {}", text);
236 QueueClient::handle_text_message(text, &response_queue).await;
237 }
238 _ => {
239 warn!("Unexpected message (possibly a bug): {msg:?}");
241 }
242 },
243 Some(Err(e)) => {
244 warn!("Read error: {e}");
246 token.cancel();
247 }
248 None => break,
249 }
250 }
251 }
252 }
253 })
254 };
255
256 Ok(QueueClient {
257 agent_id: config.agent_id,
258 capabilities: config.capabilities.clone(),
259 tx,
260 cancellation_token,
261 correlation_id_gen: Default::default(),
262 })
263 }
264
265 pub async fn next_command(&self) -> Result<CommandResponse, ApiError> {
266 let correlation_id = self.correlation_id_gen.next();
267
268 let msg = Message::CommandRequest {
269 correlation_id,
270 agent_id: self.agent_id,
271 };
272
273 match self.send_and_wait_for_reply(correlation_id, msg).await {
274 Ok(Message::CommandResponse {
275 correlation_id: reply_correlation_id,
276 }) => {
277 if correlation_id == reply_correlation_id {
278 Ok(CommandResponse { correlation_id })
279 } else {
280 api_err!("Unexpected correlation ID: {reply_correlation_id:?}")
281 }
282 }
283 Ok(msg) => api_err!("Unexpected message: {msg:?}"),
284 Err(e) => api_err!("Error while parsing message: {e}"),
285 }
286 }
287
288 pub async fn next_process(&self) -> Result<ProcessResponse, ApiError> {
289 let correlation_id = self.correlation_id_gen.next();
290
291 let msg = Message::ProcessRequest {
292 correlation_id,
293 capabilities: serde_json::json!(&self.capabilities),
294 };
295
296 match self.send_and_wait_for_reply(correlation_id, msg).await {
297 Ok(Message::ProcessResponse {
298 correlation_id: reply_correlation_id,
299 session_token,
300 process_id,
301 }) => {
302 if correlation_id == reply_correlation_id {
303 Ok(ProcessResponse {
304 correlation_id,
305 session_token,
306 process_id,
307 })
308 } else {
309 api_err!("Unexpected correlation ID: {reply_correlation_id:?}")
310 }
311 }
312 Ok(msg) => api_err!("Unexpected message: {msg:?}"),
313 Err(e) => api_err!("Error while parsing message: {e}"),
314 }
315 }
316
317 async fn send_and_wait_for_reply(&self, correlation_id: CorrelationId, msg: Message) -> Reply {
318 debug!("Sending message {msg:?} and waiting for reply");
319
320 let json = serde_json::to_string(&msg)?;
321
322 let (reply_sender, reply_receiver) = tokio::sync::oneshot::channel::<Reply>();
323
324 let msg = MessageToSend::text(json, Responder(correlation_id, reply_sender));
325 if let Err(e) = self.tx.send(msg).await {
326 return api_err!("Send error: {e}");
327 }
328
329 match reply_receiver.await {
330 Ok(reply) => reply,
331 Err(e) => api_err!("Error while receiving reply: {e}"),
332 }
333 }
334
335 async fn handle_text_message(text: Utf8Bytes, message_queue: &ResponseQueue) {
336 match serde_json::from_str::<Message>(&text) {
337 Ok(
338 cmd @ Message::CommandResponse { correlation_id, .. }
339 | cmd @ Message::ProcessResponse { correlation_id, .. },
340 ) => {
341 let mut message_queue = message_queue.lock().await;
342 if let Some(Responder(_, responder)) = message_queue.remove(&correlation_id) {
343 if responder.send(Ok(cmd)).is_err() {
344 warn!("Responder error (most likely a bug)");
345 }
346 } else {
347 warn!("No responder registered for correlation_id={correlation_id:?} (possibly a bug).")
348 }
349 }
350 Ok(msg) => {
351 warn!("Unexpected message body (most likely a bug): {:?}", msg);
353 }
354 Err(e) => {
355 warn!("Error while parsing message (possibly a bug): {e}");
357 }
358 }
359 }
360
361 fn create_connect_request(
362 Config {
363 uri,
364 api_token,
365 agent_id,
366 ..
367 }: &Config,
368 ) -> Result<http::Request<()>, ApiError> {
369 let host = format!(
370 "{}:{}",
371 uri.host().unwrap_or("localhost"),
372 uri.port_u16().unwrap_or(8001)
373 );
374
375 let ws_key = tungstenite::handshake::client::generate_key();
376
377 use http::{Request, header};
378 Request::builder()
379 .uri(uri.clone())
380 .header(header::HOST, host)
381 .header(header::AUTHORIZATION, api_token)
382 .header(header::CONNECTION, "Upgrade")
383 .header(header::UPGRADE, "websocket")
384 .header(header::SEC_WEBSOCKET_VERSION, "13")
385 .header(header::SEC_WEBSOCKET_KEY, ws_key)
386 .header(header::USER_AGENT, USER_AGENT_VALUE)
387 .header("X-Concord-Agent-Id", agent_id)
388 .header("X-Concord-Agent", USER_AGENT_VALUE)
389 .body(())
390 .map_err(|e| ApiError {
391 message: e.to_string(),
392 })
393 }
394}
395
396impl Drop for QueueClient {
397 fn drop(&mut self) {
398 self.cancellation_token.cancel();
399 }
400}