1use crate::{two_hours_from_now, Artifact, BackendIds, GitHubClient};
5use anyhow::{anyhow, Result};
6use azure_core::Etag;
7use azure_storage_blobs::prelude::BlobClient;
8use futures::stream::StreamExt as _;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashSet, VecDeque};
11use std::future::Future;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15#[allow(async_fn_in_trait)]
16pub trait QueueConnection {
17 type Blob: QueueBlob;
18
19 async fn get_blob(&self, backend_ids: BackendIds, key: &str) -> Result<Self::Blob>;
20 async fn create_blob(&self, key: &str) -> Result<Self::Blob>;
21 async fn list(&self) -> Result<Vec<Artifact>>;
22}
23
24pub enum ReadResponse {
25 Data { data: Vec<u8>, etag: Etag },
26 NoData,
27 AuthenticationFailed,
28}
29
30#[allow(async_fn_in_trait)]
31pub trait QueueBlob: Send + Sync + 'static {
32 async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse>;
33 fn write(&self, data: Vec<u8>) -> impl Future<Output = Result<()>> + Send;
34}
35
36impl QueueConnection for GitHubClient {
37 type Blob = BlobClient;
38
39 async fn get_blob(&self, backend_ids: BackendIds, key: &str) -> Result<Self::Blob> {
40 self.start_download(backend_ids, key).await
41 }
42
43 async fn create_blob(&self, key: &str) -> Result<Self::Blob> {
44 let blob = self.start_upload(key, Some(two_hours_from_now())).await?;
45 blob.put_append_blob().await?;
46 self.finish_upload(key, 0).await?;
47 Ok(blob)
48 }
49
50 async fn list(&self) -> Result<Vec<Artifact>> {
51 GitHubClient::list(self).await
52 }
53}
54
55impl QueueBlob for BlobClient {
56 async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse> {
57 let mut builder = self.get().range(index..);
58
59 if let Some(etag) = etag {
60 builder = builder.if_match(azure_core::request_options::IfMatchCondition::NotMatch(
61 etag.to_string(),
62 ));
63 }
64
65 let mut stream = builder.into_stream();
66 let resp = stream
67 .next()
68 .await
69 .ok_or_else(|| anyhow!("missing read response"))?;
70 match resp {
71 Ok(resp) => {
72 let msg = resp.data.collect().await?;
73 Ok(ReadResponse::Data {
74 data: msg.to_vec(),
75 etag: resp.blob.properties.etag,
76 })
77 }
78 Err(err) => {
79 use azure_core::{error::ErrorKind, StatusCode};
80
81 match err.kind() {
82 ErrorKind::HttpResponse {
83 status: StatusCode::NotModified,
84 error_code: Some(error_code),
85 } if error_code == "ConditionNotMet" => {
86 return Ok(ReadResponse::NoData);
87 }
88 ErrorKind::HttpResponse {
89 status: StatusCode::RequestedRangeNotSatisfiable,
90 error_code: Some(error_code),
91 } if error_code == "InvalidRange" => {
92 return Ok(ReadResponse::NoData);
93 }
94 ErrorKind::HttpResponse {
95 status: StatusCode::Forbidden,
96 error_code: Some(error_code),
97 } if error_code == "AuthenticationFailed" => {
98 return Ok(ReadResponse::AuthenticationFailed);
99 }
100 _ => {}
101 }
102 Err(err.into())
103 }
104 }
105 }
106
107 async fn write(&self, to_send: Vec<u8>) -> Result<()> {
108 self.append_block(to_send).await?;
109 Ok(())
110 }
111}
112
113#[derive(Serialize, Deserialize, PartialEq, Eq, Copy, Clone, Debug)]
114enum MessageHeader {
115 KeepAlive,
116 Payload { size: usize },
117 Shutdown,
118}
119
120const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(60);
121
122pub struct GitHubReadQueue<ConnT: QueueConnection = GitHubClient> {
123 conn: Arc<ConnT>,
124 blob: ConnT::Blob,
125 index: usize,
126 etag: Option<Etag>,
127 pending: VecDeque<Option<Vec<u8>>>,
128 read_timeout: Duration,
129 backend_ids: BackendIds,
130 key: String,
131}
132
133impl<ConnT> GitHubReadQueue<ConnT>
134where
135 ConnT: QueueConnection,
136{
137 async fn new(
138 conn: Arc<ConnT>,
139 read_timeout: Duration,
140 backend_ids: BackendIds,
141 key: &str,
142 ) -> Result<Self> {
143 let blob = conn.get_blob(backend_ids.clone(), key).await?;
144 Ok(Self {
145 conn,
146 blob,
147 index: 0,
148 etag: None,
149 pending: Default::default(),
150 read_timeout,
151 backend_ids,
152 key: key.into(),
153 })
154 }
155
156 async fn maybe_read_msg(&mut self) -> Result<Option<Vec<u8>>> {
157 let (msg, etag) = match self.blob.read(self.index, &self.etag).await? {
158 ReadResponse::Data { data, etag } => (data, etag),
159 ReadResponse::NoData => return Ok(None),
160 ReadResponse::AuthenticationFailed => {
161 self.blob = self
162 .conn
163 .get_blob(self.backend_ids.clone(), &self.key)
164 .await?;
165 return Ok(None);
166 }
167 };
168
169 self.etag = Some(etag);
170 self.index += msg.len();
171 Ok(Some(msg))
172 }
173
174 pub async fn read_msg(&mut self) -> Result<Option<Vec<u8>>> {
175 if let Some(msg) = self.pending.pop_front() {
176 return Ok(msg);
177 }
178
179 let mut read_start = Instant::now();
180 loop {
181 if let Some(res) = self.maybe_read_msg().await? {
182 let mut r = &res[..];
183 while !r.is_empty() {
184 let header: MessageHeader = bincode::deserialize_from(&mut r)?;
185 match header {
186 MessageHeader::KeepAlive => {
187 read_start = Instant::now();
188 }
189 MessageHeader::Payload { size } => {
190 let payload = r[..size].to_vec();
191 r = &r[size..];
192 self.pending.push_back(Some(payload));
193 }
194 MessageHeader::Shutdown => {
195 self.pending.push_back(None);
196 }
197 }
198 }
199 }
200
201 if let Some(msg) = self.pending.pop_front() {
202 return Ok(msg);
203 }
204
205 if read_start.elapsed() > self.read_timeout {
206 return Err(anyhow!("GitHub queue read timeout"));
207 }
208 }
209 }
210}
211
212async fn send_keep_alive(duration: Duration, blob: Arc<impl QueueBlob>) {
213 loop {
214 tokio::time::sleep(duration).await;
215 let _ = blob
216 .write(bincode::serialize(&MessageHeader::KeepAlive).unwrap())
217 .await;
218 }
219}
220
221pub struct GitHubWriteQueue<BlobT = BlobClient> {
222 blob: Arc<BlobT>,
223 keep_alive: tokio::task::AbortHandle,
224 keep_alive_duration: Duration,
225}
226
227impl<BlobT: QueueBlob> GitHubWriteQueue<BlobT> {
228 async fn new<ConnT>(conn: &ConnT, keep_alive_duration: Duration, key: &str) -> Result<Self>
229 where
230 ConnT: QueueConnection<Blob = BlobT>,
231 {
232 let blob = Arc::new(conn.create_blob(key).await?);
233 let keep_alive =
234 tokio::task::spawn(send_keep_alive(keep_alive_duration, blob.clone())).abort_handle();
235 Ok(Self {
236 blob,
237 keep_alive,
238 keep_alive_duration,
239 })
240 }
241
242 pub async fn write_msg(&mut self, data: &[u8]) -> Result<()> {
243 let mut to_send = bincode::serialize(&MessageHeader::Payload { size: data.len() }).unwrap();
244 to_send.extend(data);
245 self.blob.write(to_send).await?;
246
247 self.keep_alive.abort();
248 self.keep_alive =
249 tokio::task::spawn(send_keep_alive(self.keep_alive_duration, self.blob.clone()))
250 .abort_handle();
251
252 Ok(())
253 }
254
255 pub async fn write_many_msgs(&mut self, messages: &[Vec<u8>]) -> Result<()> {
256 let mut to_send = vec![];
257 for data in messages {
258 to_send
259 .extend(bincode::serialize(&MessageHeader::Payload { size: data.len() }).unwrap());
260 to_send.extend(data);
261 }
262 self.blob.write(to_send).await?;
263
264 self.keep_alive.abort();
265 self.keep_alive =
266 tokio::task::spawn(send_keep_alive(self.keep_alive_duration, self.blob.clone()))
267 .abort_handle();
268
269 Ok(())
270 }
271
272 pub async fn shut_down(&mut self) -> Result<()> {
273 self.keep_alive.abort();
274 self.blob
275 .write(bincode::serialize(&MessageHeader::Shutdown).unwrap())
276 .await?;
277 Ok(())
278 }
279}
280
281impl<BlobT> Drop for GitHubWriteQueue<BlobT> {
282 fn drop(&mut self) {
283 self.keep_alive.abort();
284 }
285}
286
287async fn wait_for_artifact(conn: &impl QueueConnection, key: &str) -> Result<()> {
288 while !conn.list().await?.iter().any(|a| a.name == key) {}
289 Ok(())
290}
291
292pub struct GitHubQueue<ConnT: QueueConnection = GitHubClient> {
293 read: GitHubReadQueue<ConnT>,
294 write: GitHubWriteQueue<ConnT::Blob>,
295}
296
297impl<ConnT> GitHubQueue<ConnT>
298where
299 ConnT: QueueConnection,
300{
301 async fn new(
302 conn: Arc<ConnT>,
303 read_timeout: Duration,
304 read_backend_ids: BackendIds,
305 read_key: &str,
306 write_key: &str,
307 ) -> Result<Self> {
308 Ok(Self {
309 write: GitHubWriteQueue::new(&*conn, read_timeout / 4, write_key).await?,
310 read: GitHubReadQueue::new(conn, read_timeout, read_backend_ids, read_key).await?,
311 })
312 }
313
314 async fn maybe_connect(conn: Arc<ConnT>, id: &str) -> Result<Option<Self>> {
315 let artifacts = conn.list().await?;
316 if let Some(listener) = artifacts.iter().find(|a| a.name == format!("{id}-listen")) {
317 let Artifact {
318 name, backend_ids, ..
319 } = listener;
320 let key = name.strip_suffix("-listen").unwrap();
321 let self_id = uuid::Uuid::new_v4().to_string();
322
323 let write_key = format!("{self_id}-{key}-up");
324 let write = GitHubWriteQueue::new(&*conn, DEFAULT_READ_TIMEOUT / 4, &write_key).await?;
325
326 let read_key = format!("{self_id}-{key}-down");
327 wait_for_artifact(&*conn, &read_key).await?;
328 let read =
329 GitHubReadQueue::new(conn, DEFAULT_READ_TIMEOUT, backend_ids.clone(), &read_key)
330 .await?;
331
332 Ok(Some(Self { write, read }))
333 } else {
334 Ok(None)
335 }
336 }
337
338 pub async fn connect(conn: ConnT, id: &str) -> Result<Self> {
339 let conn = Arc::new(conn);
340 loop {
341 if let Some(socket) = Self::maybe_connect(conn.clone(), id).await? {
342 return Ok(socket);
343 }
344 }
345 }
346
347 pub async fn read_msg(&mut self) -> Result<Option<Vec<u8>>> {
348 self.read.read_msg().await
349 }
350
351 pub async fn write_msg(&mut self, data: &[u8]) -> Result<()> {
352 self.write.write_msg(data).await
353 }
354
355 pub async fn shut_down(&mut self) -> Result<()> {
356 self.write.shut_down().await
357 }
358
359 pub fn into_split(self) -> (GitHubReadQueue<ConnT>, GitHubWriteQueue<ConnT::Blob>) {
360 (self.read, self.write)
361 }
362}
363
364pub struct GitHubQueueAcceptor<ConnT = GitHubClient> {
365 id: String,
366 accepted: HashSet<String>,
367 conn: Arc<ConnT>,
368}
369
370impl<ConnT> GitHubQueueAcceptor<ConnT>
371where
372 ConnT: QueueConnection,
373{
374 pub async fn new(conn: ConnT, id: &str) -> Result<Self> {
375 let key = format!("{id}-listen");
376 let conn = Arc::new(conn);
377 conn.create_blob(&key).await?;
378 Ok(Self {
379 id: id.into(),
380 accepted: HashSet::new(),
381 conn,
382 })
383 }
384
385 async fn maybe_accept_one(&mut self) -> Result<Option<GitHubQueue<ConnT>>> {
386 let artifacts = self.conn.list().await?;
387 if let Some(connected) = artifacts.iter().find(|a| {
388 a.name.ends_with(&format!("{}-up", self.id)) && !self.accepted.contains(&a.name)
389 }) {
390 let Artifact {
391 name, backend_ids, ..
392 } = connected;
393 let key = name.strip_suffix("-up").unwrap();
394 let socket = GitHubQueue::new(
395 self.conn.clone(),
396 DEFAULT_READ_TIMEOUT,
397 backend_ids.clone(),
398 &format!("{key}-up"),
399 &format!("{key}-down"),
400 )
401 .await?;
402 self.accepted.insert(name.into());
403 Ok(Some(socket))
404 } else {
405 Ok(None)
406 }
407 }
408
409 pub async fn accept_one(&mut self) -> Result<GitHubQueue<ConnT>> {
410 loop {
411 if let Some(socket) = self.maybe_accept_one().await? {
412 return Ok(socket);
413 }
414 }
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use anyhow::bail;
422 use std::collections::HashMap;
423 use std::sync::Mutex;
424
425 #[derive(Clone, Default)]
426 struct FakeConnection {
427 blobs: Arc<Mutex<HashMap<String, FakeBlob>>>,
428 }
429
430 #[derive(Clone, Default)]
431 struct FakeBlob {
432 data: Arc<Mutex<Vec<u8>>>,
433 }
434
435 impl FakeBlob {
436 fn len(&self) -> usize {
437 self.data.lock().unwrap().len()
438 }
439
440 fn data(&self) -> Vec<u8> {
441 self.data.lock().unwrap().clone()
442 }
443 }
444
445 fn b_ids() -> BackendIds {
446 BackendIds {
447 workflow_run_backend_id: "b1".into(),
448 workflow_job_run_backend_id: "b2".into(),
449 }
450 }
451
452 impl QueueConnection for FakeConnection {
453 type Blob = FakeBlob;
454
455 async fn get_blob(&self, backend_ids: BackendIds, key: &str) -> Result<Self::Blob> {
456 tokio::task::yield_now().await;
457
458 assert_eq!(backend_ids, b_ids());
459 Ok(self
460 .blobs
461 .lock()
462 .unwrap()
463 .get(key)
464 .ok_or_else(|| anyhow!("blob not found"))?
465 .clone())
466 }
467
468 async fn create_blob(&self, key: &str) -> Result<Self::Blob> {
469 tokio::task::yield_now().await;
470
471 let mut blobs = self.blobs.lock().unwrap();
472
473 if blobs.contains_key(key) {
474 bail!("blob already exists");
475 }
476 let new_blob = FakeBlob::default();
477 blobs.insert(key.into(), new_blob.clone());
478 Ok(new_blob)
479 }
480
481 async fn list(&self) -> Result<Vec<Artifact>> {
482 tokio::task::yield_now().await;
483
484 Ok(self
485 .blobs
486 .lock()
487 .unwrap()
488 .iter()
489 .map(|(name, blob)| Artifact {
490 name: name.clone(),
491 backend_ids: b_ids(),
492 size: blob.len().try_into().unwrap(),
493 database_id: 1.into(),
494 })
495 .collect())
496 }
497 }
498
499 impl QueueBlob for FakeBlob {
500 async fn read(&self, index: usize, etag: &Option<Etag>) -> Result<ReadResponse> {
501 use sha2::Digest as _;
502
503 tokio::task::yield_now().await;
504
505 let data = self.data.lock().unwrap();
506
507 let mut hasher = sha2::Sha256::new();
508 hasher.update(&data[..]);
509 let actual_etag: Etag = maelstrom_base::Sha256Digest::new(hasher.finalize().into())
510 .to_string()
511 .into();
512
513 if let Some(not_etag) = etag {
514 if not_etag == &actual_etag {
515 return Ok(ReadResponse::NoData);
516 }
517 }
518
519 if !data.is_empty() {
520 assert!(index < data.len());
521 }
522 Ok(ReadResponse::Data {
523 data: data[index..].to_vec(),
524 etag: actual_etag,
525 })
526 }
527
528 async fn write(&self, data: Vec<u8>) -> Result<()> {
529 tokio::task::yield_now().await;
530
531 self.data.lock().unwrap().extend(data);
532 Ok(())
533 }
534 }
535
536 const SHORT_DURATION: Duration = Duration::from_millis(100);
537 const FOREVER: Duration = Duration::from_secs(u64::MAX);
538
539 #[tokio::test]
540 async fn read_single_msg() {
541 let conn = FakeConnection::default();
542 let b = conn.create_blob("foo").await.unwrap();
543 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
544 .await
545 .unwrap();
546
547 b.write(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap())
548 .await
549 .unwrap();
550 let sent_msg = vec![1, 2, 3, 4, 5];
551 b.write(sent_msg.clone()).await.unwrap();
552
553 let read_msg = queue.read_msg().await.unwrap().unwrap();
554 assert_eq!(read_msg, sent_msg);
555 }
556
557 #[tokio::test]
558 async fn read_multiple_msgs() {
559 let conn = FakeConnection::default();
560 let b = conn.create_blob("foo").await.unwrap();
561 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
562 .await
563 .unwrap();
564
565 const SHORT_DURATION: Duration = Duration::from_millis(100);
566
567 let sent_msg = vec![1, 2, 3, 4, 5];
568 for _ in 0..3 {
569 b.write(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap())
570 .await
571 .unwrap();
572 b.write(sent_msg.clone()).await.unwrap();
573 }
574
575 for _ in 0..3 {
576 let read_msg = queue.read_msg().await.unwrap().unwrap();
577 assert_eq!(read_msg, sent_msg);
578 }
579 }
580
581 #[tokio::test]
582 async fn read_multiple_msgs_interleaved() {
583 let conn = FakeConnection::default();
584 let b = conn.create_blob("foo").await.unwrap();
585 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
586 .await
587 .unwrap();
588
589 let sent_msg = vec![1, 2, 3, 4, 5];
590 for _ in 0..3 {
591 b.write(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap())
592 .await
593 .unwrap();
594 b.write(sent_msg.clone()).await.unwrap();
595
596 let read_msg = queue.read_msg().await.unwrap().unwrap();
597 assert_eq!(read_msg, sent_msg);
598 }
599 }
600
601 #[tokio::test]
602 async fn read_ignores_keep_alive_msgs() {
603 let conn = FakeConnection::default();
604 let b = conn.create_blob("foo").await.unwrap();
605 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
606 .await
607 .unwrap();
608
609 let sent_msg = vec![1, 2, 3, 4, 5];
610 for _ in 0..3 {
611 b.write(bincode::serialize(&MessageHeader::KeepAlive).unwrap())
612 .await
613 .unwrap();
614 b.write(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap())
615 .await
616 .unwrap();
617 b.write(sent_msg.clone()).await.unwrap();
618 }
619
620 for _ in 0..3 {
621 let read_msg = queue.read_msg().await.unwrap().unwrap();
622 assert_eq!(read_msg, sent_msg);
623 }
624 }
625
626 #[tokio::test]
627 async fn read_with_shutdown() {
628 let conn = FakeConnection::default();
629 let b = conn.create_blob("foo").await.unwrap();
630 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
631 .await
632 .unwrap();
633
634 let sent_msg = vec![1, 2, 3, 4, 5];
635 b.write(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap())
636 .await
637 .unwrap();
638 b.write(sent_msg.clone()).await.unwrap();
639 b.write(bincode::serialize(&MessageHeader::Shutdown).unwrap())
640 .await
641 .unwrap();
642
643 let read_msg = queue.read_msg().await.unwrap().unwrap();
644 assert_eq!(read_msg, sent_msg);
645 assert_eq!(queue.read_msg().await.unwrap(), None);
646 }
647
648 #[tokio::test]
649 async fn read_timeout() {
650 let conn = FakeConnection::default();
651 let _ = conn.create_blob("foo").await.unwrap();
652 let mut queue = GitHubReadQueue::new(Arc::new(conn), SHORT_DURATION, b_ids(), "foo")
653 .await
654 .unwrap();
655
656 queue.read_msg().await.unwrap_err();
657 }
658
659 #[tokio::test]
660 async fn write_msg() {
661 let conn = FakeConnection::default();
662 let mut queue = GitHubWriteQueue::new(&conn, FOREVER, "foo").await.unwrap();
663 let sent = [1, 2, 3, 4, 5];
664 queue.write_msg(&sent[..]).await.unwrap();
665
666 let mut expected = bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap();
667 expected.extend(sent);
668
669 let b = conn.get_blob(b_ids(), "foo").await.unwrap();
670 assert_eq!(b.data(), expected);
671 }
672
673 #[tokio::test]
674 async fn write_many_msgs() {
675 let conn = FakeConnection::default();
676 let mut queue = GitHubWriteQueue::new(&conn, FOREVER, "foo").await.unwrap();
677 let sent = vec![1, 2, 3, 4, 5];
678 queue.write_many_msgs(&vec![sent.clone(); 3]).await.unwrap();
679
680 let mut expected = vec![];
681 for _ in 0..3 {
682 expected.extend(bincode::serialize(&MessageHeader::Payload { size: 5 }).unwrap());
683 expected.extend(sent.clone());
684 }
685
686 let b = conn.get_blob(b_ids(), "foo").await.unwrap();
687 assert_eq!(b.data(), expected);
688 }
689
690 #[tokio::test]
691 async fn write_shutdown() {
692 let conn = FakeConnection::default();
693 let mut queue = GitHubWriteQueue::new(&conn, FOREVER, "foo").await.unwrap();
694 queue.shut_down().await.unwrap();
695
696 let expected = bincode::serialize(&MessageHeader::Shutdown).unwrap();
697
698 let b = conn.get_blob(b_ids(), "foo").await.unwrap();
699 assert_eq!(b.data(), expected);
700 }
701
702 #[tokio::test]
703 async fn keep_alive() {
704 let conn = FakeConnection::default();
705 let queue = GitHubWriteQueue::new(&conn, Duration::from_micros(1), "foo")
706 .await
707 .unwrap();
708 tokio::time::sleep(Duration::from_millis(150)).await;
709 drop(queue);
710
711 let b = conn.get_blob(b_ids(), "foo").await.unwrap();
712 let data = b.data();
713 let mut cursor = &data[..];
714
715 let mut keep_alive_count = 0;
716 while !cursor.is_empty() {
717 let header: MessageHeader = bincode::deserialize_from(&mut cursor).unwrap();
718 assert_eq!(header, MessageHeader::KeepAlive);
719 keep_alive_count += 1;
720 }
721
722 assert!(keep_alive_count > 50, "{keep_alive_count}");
723 }
724
725 #[tokio::test]
726 async fn accept_and_connect() {
727 let conn = FakeConnection::default();
728
729 let their_conn = conn.clone();
730 tokio::task::spawn(async move {
731 let mut acceptor = GitHubQueueAcceptor::new(their_conn, "foo").await.unwrap();
732 let mut queue_b = acceptor.accept_one().await.unwrap();
733 queue_b.write_msg(&b"hello"[..]).await.unwrap();
734 });
735
736 let mut queue_a = GitHubQueue::connect(conn, "foo").await.unwrap();
737 let msg = queue_a.read_msg().await.unwrap().unwrap();
738 assert_eq!(msg, b"hello");
739 }
740
741 async fn acceptor(client: GitHubClient) {
742 let mut acceptor = GitHubQueueAcceptor::new(client, "foo").await.unwrap();
743
744 let mut handles = vec![];
745 for _ in 0..2 {
746 let mut queue = acceptor.accept_one().await.unwrap();
747 handles.push(tokio::task::spawn(async move {
748 for _ in 0..3 {
749 queue.write_msg(&b"ping"[..]).await.unwrap();
750 let msg = queue.read_msg().await.unwrap().unwrap();
751 assert_eq!(msg, b"pong");
752 }
753 queue.shut_down().await.unwrap();
754 }));
755 }
756
757 for h in handles {
758 h.await.unwrap();
759 }
760 }
761
762 async fn connector(client: GitHubClient) {
763 let mut sock = GitHubQueue::connect(client, "foo").await.unwrap();
764 while let Some(msg) = sock.read_msg().await.unwrap() {
765 assert_eq!(msg, b"ping");
766 sock.write_msg(&b"pong"[..]).await.unwrap();
767 }
768 }
769
770 #[tokio::test]
771 async fn real_github_integration_test() {
772 let Some(client) = crate::client::tests::client_factory() else {
773 println!("skipping due to missing GitHub credentials");
774 return;
775 };
776 println!("test found GitHub credentials");
777
778 match &std::env::var("TEST_ACTOR").unwrap()[..] {
779 "1" => acceptor(client).await,
780 "2" => connector(client).await,
781 "3" => connector(client).await,
782 _ => panic!("unknown test actor"),
783 }
784 }
785}